1use super::errors::SessionError;
7use super::events::{SessionCommand, SessionEvent};
8use super::state::{SessionPhase, SessionState};
9use super::traits::{SessionReader, SessionWriter};
10use crate::protocol::{Content, FunctionResponse};
11use async_trait::async_trait;
12use std::sync::Arc;
13use tokio::sync::{broadcast, mpsc, watch};
14use tokio::task::JoinHandle;
15
16#[derive(Clone)]
21pub struct SessionHandle {
22 pub command_tx: mpsc::Sender<SessionCommand>,
24 event_tx: broadcast::Sender<SessionEvent>,
26 pub state: Arc<SessionState>,
28 phase_rx: watch::Receiver<SessionPhase>,
30 task: Arc<tokio::sync::Mutex<Option<JoinHandle<()>>>>,
36}
37
38impl SessionHandle {
39 pub fn new(
41 command_tx: mpsc::Sender<SessionCommand>,
42 event_tx: broadcast::Sender<SessionEvent>,
43 state: Arc<SessionState>,
44 phase_rx: watch::Receiver<SessionPhase>,
45 ) -> Self {
46 Self {
47 command_tx,
48 event_tx,
49 state,
50 phase_rx,
51 task: Arc::new(tokio::sync::Mutex::new(None)),
52 }
53 }
54
55 pub fn set_task(&self, handle: JoinHandle<()>) {
59 if let Ok(mut guard) = self.task.try_lock() {
61 *guard = Some(handle);
62 }
63 }
64
65 pub async fn join(&self) -> Result<(), tokio::task::JoinError> {
73 let task = self.task.lock().await.take();
74 if let Some(handle) = task {
75 handle.await
76 } else {
77 Ok(())
78 }
79 }
80
81 pub fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
83 self.event_tx.subscribe()
84 }
85
86 pub fn event_sender(&self) -> &broadcast::Sender<SessionEvent> {
88 &self.event_tx
89 }
90
91 pub fn phase(&self) -> SessionPhase {
93 self.state.phase()
94 }
95
96 pub fn session_id(&self) -> &str {
98 &self.state.session_id
99 }
100
101 pub async fn wait_for_phase(&self, target: SessionPhase) {
103 let mut rx = self.phase_rx.clone();
104 while *rx.borrow_and_update() != target {
105 if rx.changed().await.is_err() {
106 break;
107 }
108 }
109 }
110
111 pub async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
113 self.send_command(SessionCommand::SendAudio(data)).await
114 }
115
116 pub async fn send_text(&self, text: impl Into<String>) -> Result<(), SessionError> {
118 self.send_command(SessionCommand::SendText(text.into()))
119 .await
120 }
121
122 pub async fn send_tool_response(
124 &self,
125 responses: Vec<FunctionResponse>,
126 ) -> Result<(), SessionError> {
127 self.send_command(SessionCommand::SendToolResponse(responses))
128 .await
129 }
130
131 pub async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
133 self.send_command(SessionCommand::SendVideo(jpeg_data))
134 .await
135 }
136
137 pub async fn update_instruction(
139 &self,
140 instruction: impl Into<String>,
141 ) -> Result<(), SessionError> {
142 self.send_command(SessionCommand::UpdateInstruction(instruction.into()))
143 .await
144 }
145
146 pub async fn signal_activity_start(&self) -> Result<(), SessionError> {
148 self.send_command(SessionCommand::ActivityStart).await
149 }
150
151 pub async fn signal_activity_end(&self) -> Result<(), SessionError> {
153 self.send_command(SessionCommand::ActivityEnd).await
154 }
155
156 pub async fn send_client_content(
159 &self,
160 turns: Vec<Content>,
161 turn_complete: bool,
162 ) -> Result<(), SessionError> {
163 self.send_command(SessionCommand::SendClientContent {
164 turns,
165 turn_complete,
166 })
167 .await
168 }
169
170 pub async fn disconnect(&self) -> Result<(), SessionError> {
172 self.send_command(SessionCommand::Disconnect).await
173 }
174
175 async fn send_command(&self, cmd: SessionCommand) -> Result<(), SessionError> {
177 self.command_tx
178 .send(cmd)
179 .await
180 .map_err(|_| SessionError::ChannelClosed)
181 }
182}
183
184impl std::fmt::Debug for SessionHandle {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 f.debug_struct("SessionHandle")
187 .field("session_id", &self.state.session_id)
188 .field("phase", &self.state.phase())
189 .finish()
190 }
191}
192
193#[async_trait]
198impl SessionWriter for SessionHandle {
199 async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
200 self.send_command(SessionCommand::SendAudio(data)).await
201 }
202
203 async fn send_text(&self, text: String) -> Result<(), SessionError> {
204 self.send_command(SessionCommand::SendText(text)).await
205 }
206
207 async fn send_tool_response(
208 &self,
209 responses: Vec<FunctionResponse>,
210 ) -> Result<(), SessionError> {
211 self.send_command(SessionCommand::SendToolResponse(responses))
212 .await
213 }
214
215 async fn send_client_content(
216 &self,
217 turns: Vec<Content>,
218 turn_complete: bool,
219 ) -> Result<(), SessionError> {
220 self.send_command(SessionCommand::SendClientContent {
221 turns,
222 turn_complete,
223 })
224 .await
225 }
226
227 async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
228 self.send_command(SessionCommand::SendVideo(jpeg_data))
229 .await
230 }
231
232 async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
233 self.send_command(SessionCommand::UpdateInstruction(instruction))
234 .await
235 }
236
237 async fn signal_activity_start(&self) -> Result<(), SessionError> {
238 self.send_command(SessionCommand::ActivityStart).await
239 }
240
241 async fn signal_activity_end(&self) -> Result<(), SessionError> {
242 self.send_command(SessionCommand::ActivityEnd).await
243 }
244
245 async fn disconnect(&self) -> Result<(), SessionError> {
246 self.send_command(SessionCommand::Disconnect).await
247 }
248}
249
250impl SessionReader for SessionHandle {
251 fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
252 self.event_tx.subscribe()
253 }
254
255 fn phase(&self) -> SessionPhase {
256 self.state.phase()
257 }
258
259 fn session_id(&self) -> &str {
260 &self.state.session_id
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[tokio::test]
269 async fn session_handle_join_returns_ok_after_task_completes() {
270 let (command_tx, _command_rx) = mpsc::channel(8);
271 let (event_tx, _) = broadcast::channel(16);
272 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
273 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
274
275 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
276
277 let task = tokio::spawn(async {});
279 handle.set_task(task);
280
281 let result = handle.join().await;
283 assert!(
284 result.is_ok(),
285 "join() should return Ok after task completes"
286 );
287 }
288
289 #[tokio::test]
290 async fn session_handle_join_without_task_returns_ok() {
291 let (command_tx, _command_rx) = mpsc::channel(8);
292 let (event_tx, _) = broadcast::channel(16);
293 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
294 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
295
296 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
297
298 let result = handle.join().await;
300 assert!(result.is_ok(), "join() without task should return Ok");
301 }
302
303 #[tokio::test]
304 async fn session_handle_join_idempotent() {
305 let (command_tx, _command_rx) = mpsc::channel(8);
306 let (event_tx, _) = broadcast::channel(16);
307 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
308 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
309
310 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
311
312 let task = tokio::spawn(async {});
313 handle.set_task(task);
314
315 assert!(handle.join().await.is_ok());
317 assert!(handle.join().await.is_ok());
319 }
320
321 #[tokio::test]
322 async fn session_handle_join_works_on_clone() {
323 let (command_tx, _command_rx) = mpsc::channel(8);
324 let (event_tx, _) = broadcast::channel(16);
325 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
326 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
327
328 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
329 let handle_clone = handle.clone();
330
331 let task = tokio::spawn(async {});
332 handle.set_task(task);
333
334 let result = handle_clone.join().await;
336 assert!(result.is_ok(), "join() on clone should work");
337
338 assert!(handle.join().await.is_ok());
340 }
341
342 #[tokio::test]
345 async fn phase_changed_event_emitted_on_transition() {
346 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Disconnected);
347 let (event_tx, mut event_rx) = broadcast::channel(16);
348 let state = SessionState::with_events(phase_tx, event_tx);
349
350 state.transition_to(SessionPhase::Connecting).unwrap();
351
352 match event_rx.try_recv() {
353 Ok(SessionEvent::PhaseChanged(SessionPhase::Connecting)) => {}
354 other => panic!("expected PhaseChanged(Connecting), got {:?}", other),
355 }
356 }
357
358 #[test]
359 fn phase_changed_not_emitted_without_event_tx() {
360 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Disconnected);
361 let state = SessionState::new(phase_tx);
362 state.transition_to(SessionPhase::Connecting).unwrap();
364 }
365}