gemini_genai_rs/session/
handle.rs

1//! [`SessionHandle`] — the public API surface for a Gemini Live session.
2//!
3//! Cheaply cloneable (wraps `Arc`). Provides methods to send commands,
4//! subscribe to events, and observe session state.
5
6use super::errors::SessionError;
7use super::events::{SessionCommand, SessionEvent};
8use super::state::{SessionPhase, SessionState};
9use super::traits::{SessionReader, SessionWriter};
10use crate::protocol::{Content, FunctionResponse};
11use async_trait::async_trait;
12use std::sync::Arc;
13use tokio::sync::{broadcast, mpsc, watch};
14use tokio::task::JoinHandle;
15
16/// The public API surface for a Gemini Live session.
17///
18/// Cheaply cloneable (wraps `Arc`). Provides methods to send commands,
19/// subscribe to events, and observe session state.
20#[derive(Clone)]
21pub struct SessionHandle {
22    /// Channel for sending commands to the transport layer.
23    pub command_tx: mpsc::Sender<SessionCommand>,
24    /// Broadcast channel for session events.
25    event_tx: broadcast::Sender<SessionEvent>,
26    /// Shared session state.
27    pub state: Arc<SessionState>,
28    /// Phase watch receiver for async observation.
29    phase_rx: watch::Receiver<SessionPhase>,
30    /// Handle to the spawned connection loop task.
31    ///
32    /// Wrapped in `Arc<Mutex<Option<...>>>` so that `SessionHandle` remains
33    /// `Clone` (since `JoinHandle` is not `Clone`). The first call to
34    /// [`join()`](Self::join) takes the handle; subsequent calls return `Ok(())`.
35    task: Arc<tokio::sync::Mutex<Option<JoinHandle<()>>>>,
36    /// Optional producer-side audio pacing (see [`SessionConfig::audio_pacing`](crate::protocol::SessionConfig::audio_pacing)).
37    ///
38    /// Shared across handle clones so all producers draw from one bucket. The
39    /// tokio mutex is held across the pacing wait deliberately: concurrent
40    /// audio producers serialize, which is the desired backpressure semantic.
41    audio_pacer: Option<Arc<tokio::sync::Mutex<crate::transport::TokenBucket>>>,
42}
43
44impl SessionHandle {
45    /// Create a new session handle from its components.
46    pub fn new(
47        command_tx: mpsc::Sender<SessionCommand>,
48        event_tx: broadcast::Sender<SessionEvent>,
49        state: Arc<SessionState>,
50        phase_rx: watch::Receiver<SessionPhase>,
51    ) -> Self {
52        Self {
53            command_tx,
54            event_tx,
55            state,
56            phase_rx,
57            task: Arc::new(tokio::sync::Mutex::new(None)),
58            audio_pacer: None,
59        }
60    }
61
62    /// Enable producer-side audio send pacing (token bucket).
63    ///
64    /// Installed by the connection layer when
65    /// [`SessionConfig::audio_pacing`](crate::protocol::SessionConfig::audio_pacing) is set.
66    pub fn with_audio_pacing(mut self, config: crate::transport::BackpressureConfig) -> Self {
67        self.audio_pacer = Some(Arc::new(tokio::sync::Mutex::new(
68            crate::transport::TokenBucket::new(config),
69        )));
70        self
71    }
72
73    /// Store the connection loop task handle.
74    ///
75    /// Called by the transport layer after spawning the connection loop.
76    pub fn set_task(&self, handle: JoinHandle<()>) {
77        // Use try_lock to avoid blocking — this is only called once at startup.
78        if let Ok(mut guard) = self.task.try_lock() {
79            *guard = Some(handle);
80        }
81    }
82
83    /// Wait for the session connection loop to complete.
84    ///
85    /// Returns `Ok(())` when the session disconnects normally.
86    /// Returns `Err` if the connection task panicked.
87    ///
88    /// Only the first call across all clones actually awaits the task;
89    /// subsequent calls return `Ok(())` immediately.
90    pub async fn join(&self) -> Result<(), tokio::task::JoinError> {
91        let task = self.task.lock().await.take();
92        if let Some(handle) = task {
93            handle.await
94        } else {
95            Ok(())
96        }
97    }
98
99    /// Subscribe to session events.
100    pub fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
101        self.event_tx.subscribe()
102    }
103
104    /// Get the event sender (for internal use by transport).
105    pub fn event_sender(&self) -> &broadcast::Sender<SessionEvent> {
106        &self.event_tx
107    }
108
109    /// Current session phase.
110    pub fn phase(&self) -> SessionPhase {
111        self.state.phase()
112    }
113
114    /// Session ID.
115    pub fn session_id(&self) -> &str {
116        &self.state.session_id
117    }
118
119    /// Wait for the session to reach a specific phase.
120    pub async fn wait_for_phase(&self, target: SessionPhase) {
121        let mut rx = self.phase_rx.clone();
122        while *rx.borrow_and_update() != target {
123            if rx.changed().await.is_err() {
124                break;
125            }
126        }
127    }
128
129    /// Send audio data (raw PCM16 bytes).
130    ///
131    /// When [`SessionConfig::audio_pacing`](crate::protocol::SessionConfig::audio_pacing) is configured, this paces the
132    /// caller: pushing audio faster than the configured sustained rate waits
133    /// here instead of overflowing the send queue.
134    pub async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
135        if let Some(pacer) = &self.audio_pacer {
136            pacer.lock().await.consume(data.len()).await;
137        }
138        self.send_command(SessionCommand::SendAudio(data)).await
139    }
140
141    /// Send a text message.
142    pub async fn send_text(&self, text: impl Into<String>) -> Result<(), SessionError> {
143        self.send_command(SessionCommand::SendText(text.into()))
144            .await
145    }
146
147    /// Send tool responses.
148    pub async fn send_tool_response(
149        &self,
150        responses: Vec<FunctionResponse>,
151    ) -> Result<(), SessionError> {
152        self.send_command(SessionCommand::SendToolResponse(responses))
153            .await
154    }
155
156    /// Send a video/image frame (raw JPEG bytes).
157    pub async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
158        self.send_command(SessionCommand::SendVideo(jpeg_data))
159            .await
160    }
161
162    /// Update the system instruction mid-session.
163    pub async fn update_instruction(
164        &self,
165        instruction: impl Into<String>,
166    ) -> Result<(), SessionError> {
167        self.send_command(SessionCommand::UpdateInstruction(instruction.into()))
168            .await
169    }
170
171    /// Signal activity start (user started speaking).
172    pub async fn signal_activity_start(&self) -> Result<(), SessionError> {
173        self.send_command(SessionCommand::ActivityStart).await
174    }
175
176    /// Signal activity end (user stopped speaking).
177    pub async fn signal_activity_end(&self) -> Result<(), SessionError> {
178        self.send_command(SessionCommand::ActivityEnd).await
179    }
180
181    /// Send client content (turns + turn_complete flag).
182    /// Used for injecting conversation history, context, or multi-turn text.
183    pub async fn send_client_content(
184        &self,
185        turns: Vec<Content>,
186        turn_complete: bool,
187    ) -> Result<(), SessionError> {
188        self.send_command(SessionCommand::SendClientContent {
189            turns,
190            turn_complete,
191        })
192        .await
193    }
194
195    /// Gracefully disconnect the session.
196    pub async fn disconnect(&self) -> Result<(), SessionError> {
197        self.send_command(SessionCommand::Disconnect).await
198    }
199
200    /// Send a command to the transport.
201    async fn send_command(&self, cmd: SessionCommand) -> Result<(), SessionError> {
202        self.command_tx
203            .send(cmd)
204            .await
205            .map_err(|_| SessionError::ChannelClosed)
206    }
207}
208
209impl std::fmt::Debug for SessionHandle {
210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211        f.debug_struct("SessionHandle")
212            .field("session_id", &self.state.session_id)
213            .field("phase", &self.state.phase())
214            .finish()
215    }
216}
217
218// ---------------------------------------------------------------------------
219// Trait implementations for SessionHandle
220// ---------------------------------------------------------------------------
221
222#[async_trait]
223impl SessionWriter for SessionHandle {
224    async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
225        SessionHandle::send_audio(self, data).await
226    }
227
228    async fn send_text(&self, text: String) -> Result<(), SessionError> {
229        self.send_command(SessionCommand::SendText(text)).await
230    }
231
232    async fn send_tool_response(
233        &self,
234        responses: Vec<FunctionResponse>,
235    ) -> Result<(), SessionError> {
236        self.send_command(SessionCommand::SendToolResponse(responses))
237            .await
238    }
239
240    async fn send_client_content(
241        &self,
242        turns: Vec<Content>,
243        turn_complete: bool,
244    ) -> Result<(), SessionError> {
245        self.send_command(SessionCommand::SendClientContent {
246            turns,
247            turn_complete,
248        })
249        .await
250    }
251
252    async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
253        self.send_command(SessionCommand::SendVideo(jpeg_data))
254            .await
255    }
256
257    async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
258        self.send_command(SessionCommand::UpdateInstruction(instruction))
259            .await
260    }
261
262    async fn signal_activity_start(&self) -> Result<(), SessionError> {
263        self.send_command(SessionCommand::ActivityStart).await
264    }
265
266    async fn signal_activity_end(&self) -> Result<(), SessionError> {
267        self.send_command(SessionCommand::ActivityEnd).await
268    }
269
270    async fn disconnect(&self) -> Result<(), SessionError> {
271        self.send_command(SessionCommand::Disconnect).await
272    }
273}
274
275impl SessionReader for SessionHandle {
276    fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
277        self.event_tx.subscribe()
278    }
279
280    fn phase(&self) -> SessionPhase {
281        self.state.phase()
282    }
283
284    fn session_id(&self) -> &str {
285        &self.state.session_id
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[tokio::test(start_paused = true)]
294    async fn audio_pacing_throttles_producer_to_sustained_rate() {
295        let (command_tx, mut command_rx) = mpsc::channel(64);
296        let (event_tx, _) = broadcast::channel(16);
297        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Active);
298        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
299
300        // 1000-byte burst allowance, 1000 B/s sustained.
301        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx).with_audio_pacing(
302            crate::transport::BackpressureConfig {
303                bucket_capacity: 1000,
304                refill_rate_bps: 1000,
305            },
306        );
307
308        let start = tokio::time::Instant::now();
309        // First 1000 bytes ride the burst; the next 1000 must wait ~1s.
310        handle.send_audio(vec![0u8; 1000]).await.unwrap();
311        let after_burst = start.elapsed();
312        handle.send_audio(vec![0u8; 1000]).await.unwrap();
313        let after_paced = start.elapsed();
314
315        assert!(after_burst < std::time::Duration::from_millis(50));
316        assert!(
317            after_paced >= std::time::Duration::from_millis(900),
318            "second send should be paced ~1s, was {after_paced:?}"
319        );
320        // Both frames were enqueued.
321        assert!(command_rx.recv().await.is_some());
322        assert!(command_rx.recv().await.is_some());
323    }
324
325    #[tokio::test]
326    async fn session_handle_join_returns_ok_after_task_completes() {
327        let (command_tx, _command_rx) = mpsc::channel(8);
328        let (event_tx, _) = broadcast::channel(16);
329        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
330        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
331
332        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
333
334        // Spawn a trivial task that completes immediately
335        let task = tokio::spawn(async {});
336        handle.set_task(task);
337
338        // join() should return Ok(())
339        let result = handle.join().await;
340        assert!(
341            result.is_ok(),
342            "join() should return Ok after task completes"
343        );
344    }
345
346    #[tokio::test]
347    async fn session_handle_join_without_task_returns_ok() {
348        let (command_tx, _command_rx) = mpsc::channel(8);
349        let (event_tx, _) = broadcast::channel(16);
350        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
351        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
352
353        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
354
355        // join() without set_task should return Ok immediately
356        let result = handle.join().await;
357        assert!(result.is_ok(), "join() without task should return Ok");
358    }
359
360    #[tokio::test]
361    async fn session_handle_join_idempotent() {
362        let (command_tx, _command_rx) = mpsc::channel(8);
363        let (event_tx, _) = broadcast::channel(16);
364        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
365        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
366
367        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
368
369        let task = tokio::spawn(async {});
370        handle.set_task(task);
371
372        // First join takes the handle
373        assert!(handle.join().await.is_ok());
374        // Second join returns Ok immediately (handle already taken)
375        assert!(handle.join().await.is_ok());
376    }
377
378    #[tokio::test]
379    async fn session_handle_join_works_on_clone() {
380        let (command_tx, _command_rx) = mpsc::channel(8);
381        let (event_tx, _) = broadcast::channel(16);
382        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
383        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
384
385        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
386        let handle_clone = handle.clone();
387
388        let task = tokio::spawn(async {});
389        handle.set_task(task);
390
391        // join() on clone should work (shares the Arc)
392        let result = handle_clone.join().await;
393        assert!(result.is_ok(), "join() on clone should work");
394
395        // Original handle's join should now return Ok (handle already taken)
396        assert!(handle.join().await.is_ok());
397    }
398
399    // PhaseChanged event emission tests
400
401    #[tokio::test]
402    async fn phase_changed_event_emitted_on_transition() {
403        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Disconnected);
404        let (event_tx, mut event_rx) = broadcast::channel(16);
405        let state = SessionState::with_events(phase_tx, event_tx);
406
407        state.transition_to(SessionPhase::Connecting).unwrap();
408
409        match event_rx.try_recv() {
410            Ok(SessionEvent::PhaseChanged(SessionPhase::Connecting)) => {}
411            other => panic!("expected PhaseChanged(Connecting), got {:?}", other),
412        }
413    }
414
415    #[test]
416    fn phase_changed_not_emitted_without_event_tx() {
417        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Disconnected);
418        let state = SessionState::new(phase_tx);
419        // Should not panic even though no event_tx
420        state.transition_to(SessionPhase::Connecting).unwrap();
421    }
422}