gemini_genai_rs/session/
handle.rs

1//! [`SessionHandle`] — the public API surface for a Gemini Live session.
2//!
3//! Cheaply cloneable (wraps `Arc`). Provides methods to send commands,
4//! subscribe to events, and observe session state.
5
6use super::errors::SessionError;
7use super::events::{SessionCommand, SessionEvent};
8use super::state::{SessionPhase, SessionState};
9use super::traits::{SessionReader, SessionWriter};
10use crate::protocol::{Content, FunctionResponse};
11use async_trait::async_trait;
12use std::sync::Arc;
13use tokio::sync::{broadcast, mpsc, watch};
14use tokio::task::JoinHandle;
15
16/// The public API surface for a Gemini Live session.
17///
18/// Cheaply cloneable (wraps `Arc`). Provides methods to send commands,
19/// subscribe to events, and observe session state.
20#[derive(Clone)]
21pub struct SessionHandle {
22    /// Channel for sending commands to the transport layer.
23    pub command_tx: mpsc::Sender<SessionCommand>,
24    /// Broadcast channel for session events.
25    event_tx: broadcast::Sender<SessionEvent>,
26    /// Shared session state.
27    pub state: Arc<SessionState>,
28    /// Phase watch receiver for async observation.
29    phase_rx: watch::Receiver<SessionPhase>,
30    /// Handle to the spawned connection loop task.
31    ///
32    /// Wrapped in `Arc<Mutex<Option<...>>>` so that `SessionHandle` remains
33    /// `Clone` (since `JoinHandle` is not `Clone`). The first call to
34    /// [`join()`](Self::join) takes the handle; subsequent calls return `Ok(())`.
35    task: Arc<tokio::sync::Mutex<Option<JoinHandle<()>>>>,
36}
37
38impl SessionHandle {
39    /// Create a new session handle from its components.
40    pub fn new(
41        command_tx: mpsc::Sender<SessionCommand>,
42        event_tx: broadcast::Sender<SessionEvent>,
43        state: Arc<SessionState>,
44        phase_rx: watch::Receiver<SessionPhase>,
45    ) -> Self {
46        Self {
47            command_tx,
48            event_tx,
49            state,
50            phase_rx,
51            task: Arc::new(tokio::sync::Mutex::new(None)),
52        }
53    }
54
55    /// Store the connection loop task handle.
56    ///
57    /// Called by the transport layer after spawning the connection loop.
58    pub fn set_task(&self, handle: JoinHandle<()>) {
59        // Use try_lock to avoid blocking — this is only called once at startup.
60        if let Ok(mut guard) = self.task.try_lock() {
61            *guard = Some(handle);
62        }
63    }
64
65    /// Wait for the session connection loop to complete.
66    ///
67    /// Returns `Ok(())` when the session disconnects normally.
68    /// Returns `Err` if the connection task panicked.
69    ///
70    /// Only the first call across all clones actually awaits the task;
71    /// subsequent calls return `Ok(())` immediately.
72    pub async fn join(&self) -> Result<(), tokio::task::JoinError> {
73        let task = self.task.lock().await.take();
74        if let Some(handle) = task {
75            handle.await
76        } else {
77            Ok(())
78        }
79    }
80
81    /// Subscribe to session events.
82    pub fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
83        self.event_tx.subscribe()
84    }
85
86    /// Get the event sender (for internal use by transport).
87    pub fn event_sender(&self) -> &broadcast::Sender<SessionEvent> {
88        &self.event_tx
89    }
90
91    /// Current session phase.
92    pub fn phase(&self) -> SessionPhase {
93        self.state.phase()
94    }
95
96    /// Session ID.
97    pub fn session_id(&self) -> &str {
98        &self.state.session_id
99    }
100
101    /// Wait for the session to reach a specific phase.
102    pub async fn wait_for_phase(&self, target: SessionPhase) {
103        let mut rx = self.phase_rx.clone();
104        while *rx.borrow_and_update() != target {
105            if rx.changed().await.is_err() {
106                break;
107            }
108        }
109    }
110
111    /// Send audio data (raw PCM16 bytes).
112    pub async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
113        self.send_command(SessionCommand::SendAudio(data)).await
114    }
115
116    /// Send a text message.
117    pub async fn send_text(&self, text: impl Into<String>) -> Result<(), SessionError> {
118        self.send_command(SessionCommand::SendText(text.into()))
119            .await
120    }
121
122    /// Send tool responses.
123    pub async fn send_tool_response(
124        &self,
125        responses: Vec<FunctionResponse>,
126    ) -> Result<(), SessionError> {
127        self.send_command(SessionCommand::SendToolResponse(responses))
128            .await
129    }
130
131    /// Send a video/image frame (raw JPEG bytes).
132    pub async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
133        self.send_command(SessionCommand::SendVideo(jpeg_data))
134            .await
135    }
136
137    /// Update the system instruction mid-session.
138    pub async fn update_instruction(
139        &self,
140        instruction: impl Into<String>,
141    ) -> Result<(), SessionError> {
142        self.send_command(SessionCommand::UpdateInstruction(instruction.into()))
143            .await
144    }
145
146    /// Signal activity start (user started speaking).
147    pub async fn signal_activity_start(&self) -> Result<(), SessionError> {
148        self.send_command(SessionCommand::ActivityStart).await
149    }
150
151    /// Signal activity end (user stopped speaking).
152    pub async fn signal_activity_end(&self) -> Result<(), SessionError> {
153        self.send_command(SessionCommand::ActivityEnd).await
154    }
155
156    /// Send client content (turns + turn_complete flag).
157    /// Used for injecting conversation history, context, or multi-turn text.
158    pub async fn send_client_content(
159        &self,
160        turns: Vec<Content>,
161        turn_complete: bool,
162    ) -> Result<(), SessionError> {
163        self.send_command(SessionCommand::SendClientContent {
164            turns,
165            turn_complete,
166        })
167        .await
168    }
169
170    /// Gracefully disconnect the session.
171    pub async fn disconnect(&self) -> Result<(), SessionError> {
172        self.send_command(SessionCommand::Disconnect).await
173    }
174
175    /// Send a command to the transport.
176    async fn send_command(&self, cmd: SessionCommand) -> Result<(), SessionError> {
177        self.command_tx
178            .send(cmd)
179            .await
180            .map_err(|_| SessionError::ChannelClosed)
181    }
182}
183
184impl std::fmt::Debug for SessionHandle {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        f.debug_struct("SessionHandle")
187            .field("session_id", &self.state.session_id)
188            .field("phase", &self.state.phase())
189            .finish()
190    }
191}
192
193// ---------------------------------------------------------------------------
194// Trait implementations for SessionHandle
195// ---------------------------------------------------------------------------
196
197#[async_trait]
198impl SessionWriter for SessionHandle {
199    async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
200        self.send_command(SessionCommand::SendAudio(data)).await
201    }
202
203    async fn send_text(&self, text: String) -> Result<(), SessionError> {
204        self.send_command(SessionCommand::SendText(text)).await
205    }
206
207    async fn send_tool_response(
208        &self,
209        responses: Vec<FunctionResponse>,
210    ) -> Result<(), SessionError> {
211        self.send_command(SessionCommand::SendToolResponse(responses))
212            .await
213    }
214
215    async fn send_client_content(
216        &self,
217        turns: Vec<Content>,
218        turn_complete: bool,
219    ) -> Result<(), SessionError> {
220        self.send_command(SessionCommand::SendClientContent {
221            turns,
222            turn_complete,
223        })
224        .await
225    }
226
227    async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
228        self.send_command(SessionCommand::SendVideo(jpeg_data))
229            .await
230    }
231
232    async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
233        self.send_command(SessionCommand::UpdateInstruction(instruction))
234            .await
235    }
236
237    async fn signal_activity_start(&self) -> Result<(), SessionError> {
238        self.send_command(SessionCommand::ActivityStart).await
239    }
240
241    async fn signal_activity_end(&self) -> Result<(), SessionError> {
242        self.send_command(SessionCommand::ActivityEnd).await
243    }
244
245    async fn disconnect(&self) -> Result<(), SessionError> {
246        self.send_command(SessionCommand::Disconnect).await
247    }
248}
249
250impl SessionReader for SessionHandle {
251    fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
252        self.event_tx.subscribe()
253    }
254
255    fn phase(&self) -> SessionPhase {
256        self.state.phase()
257    }
258
259    fn session_id(&self) -> &str {
260        &self.state.session_id
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[tokio::test]
269    async fn session_handle_join_returns_ok_after_task_completes() {
270        let (command_tx, _command_rx) = mpsc::channel(8);
271        let (event_tx, _) = broadcast::channel(16);
272        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
273        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
274
275        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
276
277        // Spawn a trivial task that completes immediately
278        let task = tokio::spawn(async {});
279        handle.set_task(task);
280
281        // join() should return Ok(())
282        let result = handle.join().await;
283        assert!(
284            result.is_ok(),
285            "join() should return Ok after task completes"
286        );
287    }
288
289    #[tokio::test]
290    async fn session_handle_join_without_task_returns_ok() {
291        let (command_tx, _command_rx) = mpsc::channel(8);
292        let (event_tx, _) = broadcast::channel(16);
293        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
294        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
295
296        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
297
298        // join() without set_task should return Ok immediately
299        let result = handle.join().await;
300        assert!(result.is_ok(), "join() without task should return Ok");
301    }
302
303    #[tokio::test]
304    async fn session_handle_join_idempotent() {
305        let (command_tx, _command_rx) = mpsc::channel(8);
306        let (event_tx, _) = broadcast::channel(16);
307        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
308        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
309
310        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
311
312        let task = tokio::spawn(async {});
313        handle.set_task(task);
314
315        // First join takes the handle
316        assert!(handle.join().await.is_ok());
317        // Second join returns Ok immediately (handle already taken)
318        assert!(handle.join().await.is_ok());
319    }
320
321    #[tokio::test]
322    async fn session_handle_join_works_on_clone() {
323        let (command_tx, _command_rx) = mpsc::channel(8);
324        let (event_tx, _) = broadcast::channel(16);
325        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
326        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
327
328        let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
329        let handle_clone = handle.clone();
330
331        let task = tokio::spawn(async {});
332        handle.set_task(task);
333
334        // join() on clone should work (shares the Arc)
335        let result = handle_clone.join().await;
336        assert!(result.is_ok(), "join() on clone should work");
337
338        // Original handle's join should now return Ok (handle already taken)
339        assert!(handle.join().await.is_ok());
340    }
341
342    // PhaseChanged event emission tests
343
344    #[tokio::test]
345    async fn phase_changed_event_emitted_on_transition() {
346        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Disconnected);
347        let (event_tx, mut event_rx) = broadcast::channel(16);
348        let state = SessionState::with_events(phase_tx, event_tx);
349
350        state.transition_to(SessionPhase::Connecting).unwrap();
351
352        match event_rx.try_recv() {
353            Ok(SessionEvent::PhaseChanged(SessionPhase::Connecting)) => {}
354            other => panic!("expected PhaseChanged(Connecting), got {:?}", other),
355        }
356    }
357
358    #[test]
359    fn phase_changed_not_emitted_without_event_tx() {
360        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Disconnected);
361        let state = SessionState::new(phase_tx);
362        // Should not panic even though no event_tx
363        state.transition_to(SessionPhase::Connecting).unwrap();
364    }
365}