1use super::errors::SessionError;
7use super::events::{SessionCommand, SessionEvent};
8use super::state::{SessionPhase, SessionState};
9use super::traits::{SessionReader, SessionWriter};
10use crate::protocol::{Content, FunctionResponse};
11use async_trait::async_trait;
12use std::sync::Arc;
13use tokio::sync::{broadcast, mpsc, watch};
14use tokio::task::JoinHandle;
15
16#[derive(Clone)]
21pub struct SessionHandle {
22 pub command_tx: mpsc::Sender<SessionCommand>,
24 event_tx: broadcast::Sender<SessionEvent>,
26 pub state: Arc<SessionState>,
28 phase_rx: watch::Receiver<SessionPhase>,
30 task: Arc<tokio::sync::Mutex<Option<JoinHandle<()>>>>,
36 audio_pacer: Option<Arc<tokio::sync::Mutex<crate::transport::TokenBucket>>>,
42}
43
44impl SessionHandle {
45 pub fn new(
47 command_tx: mpsc::Sender<SessionCommand>,
48 event_tx: broadcast::Sender<SessionEvent>,
49 state: Arc<SessionState>,
50 phase_rx: watch::Receiver<SessionPhase>,
51 ) -> Self {
52 Self {
53 command_tx,
54 event_tx,
55 state,
56 phase_rx,
57 task: Arc::new(tokio::sync::Mutex::new(None)),
58 audio_pacer: None,
59 }
60 }
61
62 pub fn with_audio_pacing(mut self, config: crate::transport::BackpressureConfig) -> Self {
67 self.audio_pacer = Some(Arc::new(tokio::sync::Mutex::new(
68 crate::transport::TokenBucket::new(config),
69 )));
70 self
71 }
72
73 pub fn set_task(&self, handle: JoinHandle<()>) {
77 if let Ok(mut guard) = self.task.try_lock() {
79 *guard = Some(handle);
80 }
81 }
82
83 pub async fn join(&self) -> Result<(), tokio::task::JoinError> {
91 let task = self.task.lock().await.take();
92 if let Some(handle) = task {
93 handle.await
94 } else {
95 Ok(())
96 }
97 }
98
99 pub fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
101 self.event_tx.subscribe()
102 }
103
104 pub fn event_sender(&self) -> &broadcast::Sender<SessionEvent> {
106 &self.event_tx
107 }
108
109 pub fn phase(&self) -> SessionPhase {
111 self.state.phase()
112 }
113
114 pub fn session_id(&self) -> &str {
116 &self.state.session_id
117 }
118
119 pub async fn wait_for_phase(&self, target: SessionPhase) {
121 let mut rx = self.phase_rx.clone();
122 while *rx.borrow_and_update() != target {
123 if rx.changed().await.is_err() {
124 break;
125 }
126 }
127 }
128
129 pub async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
135 if let Some(pacer) = &self.audio_pacer {
136 pacer.lock().await.consume(data.len()).await;
137 }
138 self.send_command(SessionCommand::SendAudio(data)).await
139 }
140
141 pub async fn send_text(&self, text: impl Into<String>) -> Result<(), SessionError> {
143 self.send_command(SessionCommand::SendText(text.into()))
144 .await
145 }
146
147 pub async fn send_tool_response(
149 &self,
150 responses: Vec<FunctionResponse>,
151 ) -> Result<(), SessionError> {
152 self.send_command(SessionCommand::SendToolResponse(responses))
153 .await
154 }
155
156 pub async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
158 self.send_command(SessionCommand::SendVideo(jpeg_data))
159 .await
160 }
161
162 pub async fn update_instruction(
164 &self,
165 instruction: impl Into<String>,
166 ) -> Result<(), SessionError> {
167 self.send_command(SessionCommand::UpdateInstruction(instruction.into()))
168 .await
169 }
170
171 pub async fn signal_activity_start(&self) -> Result<(), SessionError> {
173 self.send_command(SessionCommand::ActivityStart).await
174 }
175
176 pub async fn signal_activity_end(&self) -> Result<(), SessionError> {
178 self.send_command(SessionCommand::ActivityEnd).await
179 }
180
181 pub async fn send_client_content(
184 &self,
185 turns: Vec<Content>,
186 turn_complete: bool,
187 ) -> Result<(), SessionError> {
188 self.send_command(SessionCommand::SendClientContent {
189 turns,
190 turn_complete,
191 })
192 .await
193 }
194
195 pub async fn disconnect(&self) -> Result<(), SessionError> {
197 self.send_command(SessionCommand::Disconnect).await
198 }
199
200 async fn send_command(&self, cmd: SessionCommand) -> Result<(), SessionError> {
202 self.command_tx
203 .send(cmd)
204 .await
205 .map_err(|_| SessionError::ChannelClosed)
206 }
207}
208
209impl std::fmt::Debug for SessionHandle {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("SessionHandle")
212 .field("session_id", &self.state.session_id)
213 .field("phase", &self.state.phase())
214 .finish()
215 }
216}
217
218#[async_trait]
223impl SessionWriter for SessionHandle {
224 async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
225 SessionHandle::send_audio(self, data).await
226 }
227
228 async fn send_text(&self, text: String) -> Result<(), SessionError> {
229 self.send_command(SessionCommand::SendText(text)).await
230 }
231
232 async fn send_tool_response(
233 &self,
234 responses: Vec<FunctionResponse>,
235 ) -> Result<(), SessionError> {
236 self.send_command(SessionCommand::SendToolResponse(responses))
237 .await
238 }
239
240 async fn send_client_content(
241 &self,
242 turns: Vec<Content>,
243 turn_complete: bool,
244 ) -> Result<(), SessionError> {
245 self.send_command(SessionCommand::SendClientContent {
246 turns,
247 turn_complete,
248 })
249 .await
250 }
251
252 async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
253 self.send_command(SessionCommand::SendVideo(jpeg_data))
254 .await
255 }
256
257 async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
258 self.send_command(SessionCommand::UpdateInstruction(instruction))
259 .await
260 }
261
262 async fn signal_activity_start(&self) -> Result<(), SessionError> {
263 self.send_command(SessionCommand::ActivityStart).await
264 }
265
266 async fn signal_activity_end(&self) -> Result<(), SessionError> {
267 self.send_command(SessionCommand::ActivityEnd).await
268 }
269
270 async fn disconnect(&self) -> Result<(), SessionError> {
271 self.send_command(SessionCommand::Disconnect).await
272 }
273}
274
275impl SessionReader for SessionHandle {
276 fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
277 self.event_tx.subscribe()
278 }
279
280 fn phase(&self) -> SessionPhase {
281 self.state.phase()
282 }
283
284 fn session_id(&self) -> &str {
285 &self.state.session_id
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[tokio::test(start_paused = true)]
294 async fn audio_pacing_throttles_producer_to_sustained_rate() {
295 let (command_tx, mut command_rx) = mpsc::channel(64);
296 let (event_tx, _) = broadcast::channel(16);
297 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Active);
298 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
299
300 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx).with_audio_pacing(
302 crate::transport::BackpressureConfig {
303 bucket_capacity: 1000,
304 refill_rate_bps: 1000,
305 },
306 );
307
308 let start = tokio::time::Instant::now();
309 handle.send_audio(vec![0u8; 1000]).await.unwrap();
311 let after_burst = start.elapsed();
312 handle.send_audio(vec![0u8; 1000]).await.unwrap();
313 let after_paced = start.elapsed();
314
315 assert!(after_burst < std::time::Duration::from_millis(50));
316 assert!(
317 after_paced >= std::time::Duration::from_millis(900),
318 "second send should be paced ~1s, was {after_paced:?}"
319 );
320 assert!(command_rx.recv().await.is_some());
322 assert!(command_rx.recv().await.is_some());
323 }
324
325 #[tokio::test]
326 async fn session_handle_join_returns_ok_after_task_completes() {
327 let (command_tx, _command_rx) = mpsc::channel(8);
328 let (event_tx, _) = broadcast::channel(16);
329 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
330 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
331
332 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
333
334 let task = tokio::spawn(async {});
336 handle.set_task(task);
337
338 let result = handle.join().await;
340 assert!(
341 result.is_ok(),
342 "join() should return Ok after task completes"
343 );
344 }
345
346 #[tokio::test]
347 async fn session_handle_join_without_task_returns_ok() {
348 let (command_tx, _command_rx) = mpsc::channel(8);
349 let (event_tx, _) = broadcast::channel(16);
350 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
351 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
352
353 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
354
355 let result = handle.join().await;
357 assert!(result.is_ok(), "join() without task should return Ok");
358 }
359
360 #[tokio::test]
361 async fn session_handle_join_idempotent() {
362 let (command_tx, _command_rx) = mpsc::channel(8);
363 let (event_tx, _) = broadcast::channel(16);
364 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
365 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
366
367 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
368
369 let task = tokio::spawn(async {});
370 handle.set_task(task);
371
372 assert!(handle.join().await.is_ok());
374 assert!(handle.join().await.is_ok());
376 }
377
378 #[tokio::test]
379 async fn session_handle_join_works_on_clone() {
380 let (command_tx, _command_rx) = mpsc::channel(8);
381 let (event_tx, _) = broadcast::channel(16);
382 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
383 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
384
385 let handle = SessionHandle::new(command_tx, event_tx, state, phase_rx);
386 let handle_clone = handle.clone();
387
388 let task = tokio::spawn(async {});
389 handle.set_task(task);
390
391 let result = handle_clone.join().await;
393 assert!(result.is_ok(), "join() on clone should work");
394
395 assert!(handle.join().await.is_ok());
397 }
398
399 #[tokio::test]
402 async fn phase_changed_event_emitted_on_transition() {
403 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Disconnected);
404 let (event_tx, mut event_rx) = broadcast::channel(16);
405 let state = SessionState::with_events(phase_tx, event_tx);
406
407 state.transition_to(SessionPhase::Connecting).unwrap();
408
409 match event_rx.try_recv() {
410 Ok(SessionEvent::PhaseChanged(SessionPhase::Connecting)) => {}
411 other => panic!("expected PhaseChanged(Connecting), got {:?}", other),
412 }
413 }
414
415 #[test]
416 fn phase_changed_not_emitted_without_event_tx() {
417 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Disconnected);
418 let state = SessionState::new(phase_tx);
419 state.transition_to(SessionPhase::Connecting).unwrap();
421 }
422}