1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::time::Instant;
14
15use gemini_genai_rs::prelude::{SessionEvent, SessionPhase};
16use parking_lot::Mutex;
17
18use crate::state::State;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum SessionType {
27 AudioOnly,
29 AudioVideo,
31}
32
33pub struct SessionSignals {
47 state: State,
48 start: Instant,
50 connected_at_ns: AtomicU64,
52 is_connected: AtomicBool,
54 last_activity_ns: AtomicU64,
56 has_video: AtomicBool,
58 go_away_at: Mutex<Option<Instant>>,
60 latest_resume_handle: Mutex<Option<String>>,
62}
63
64impl SessionSignals {
65 pub fn new(state: State) -> Self {
67 Self {
68 state,
69 start: Instant::now(),
70 connected_at_ns: AtomicU64::new(0),
71 is_connected: AtomicBool::new(false),
72 last_activity_ns: AtomicU64::new(0),
73 has_video: AtomicBool::new(false),
74 go_away_at: Mutex::new(None),
75 latest_resume_handle: Mutex::new(None),
76 }
77 }
78
79 pub fn on_event(&self, event: &SessionEvent) {
86 match event {
87 SessionEvent::Connected => {
88 let now_ns = self.elapsed_ns();
89 self.connected_at_ns.store(now_ns, Ordering::Relaxed);
90 self.is_connected.store(true, Ordering::Relaxed);
91 self.last_activity_ns.store(now_ns, Ordering::Relaxed);
92 let _ = self.state.session().set("connected_at_ms", 0u64);
93 let _ = self.state.session().set("interrupt_count", 0u64);
94 let _ = self.state.session().set("error_count", 0u64);
95 let _ = self.state.session().set("is_user_speaking", false);
96 let _ = self.state.session().set("is_model_speaking", false);
97 let _ = self.state.session().set("go_away_received", false);
98 let _ = self.state.session().set("resumable", false);
99 let _ = self.state.session().set("session_type", "audio_only");
100 }
101
102 SessionEvent::VoiceActivityStart => {
103 let _ = self.state.session().set("is_user_speaking", true);
104 self.touch_activity();
105 }
106
107 SessionEvent::VoiceActivityEnd => {
108 let _ = self.state.session().set("is_user_speaking", false);
109 self.touch_activity();
110 }
111
112 SessionEvent::Interrupted => {
113 let count: u64 = self.state.session().get("interrupt_count").unwrap_or(0);
114 let _ = self.state.session().set("interrupt_count", count + 1);
115 self.touch_activity();
116 }
117
118 SessionEvent::Error(msg) => {
119 let count: u64 = self.state.session().get("error_count").unwrap_or(0);
120 let _ = self.state.session().set("error_count", count + 1);
121 let _ = self.state.session().set("last_error", msg.clone());
122 }
123
124 SessionEvent::PhaseChanged(phase) => {
125 let _ = self
126 .state
127 .session()
128 .set("is_model_speaking", *phase == SessionPhase::ModelSpeaking);
129 let _ = self.state.session().set("phase", phase.to_string());
130 self.touch_activity();
131 }
132
133 SessionEvent::GoAway(time_left) => {
134 let _ = self.state.session().set("go_away_received", true);
135 if let Some(ref tl) = time_left {
136 let _ = self.state.session().set("go_away_time_left", tl.clone());
137 if let Ok(secs) = tl.trim_end_matches('s').parse::<u64>() {
138 let deadline = Instant::now() + std::time::Duration::from_secs(secs);
139 *self.go_away_at.lock() = Some(deadline);
140 let _ = self
141 .state
142 .session()
143 .set("go_away_time_left_ms", secs * 1000);
144 }
145 }
146 }
147
148 SessionEvent::SessionResumeUpdate(info) => {
149 *self.latest_resume_handle.lock() = Some(info.handle.clone());
150 let _ = self.state.session().set("resumable", info.resumable);
151 if let Some(ref idx) = info.last_consumed_index {
152 let _ = self
153 .state
154 .session()
155 .set("last_consumed_client_index", idx.clone());
156 }
157 }
158
159 SessionEvent::Usage(usage) => {
160 if let Some(total) = usage.total_token_count {
161 let _ = self.state.session().set("total_token_count", total);
162 }
163 if let Some(prompt) = usage.prompt_token_count {
164 let _ = self.state.session().set("prompt_token_count", prompt);
165 }
166 if let Some(response) = usage.response_token_count {
167 let _ = self.state.session().set("response_token_count", response);
168 }
169 if let Some(cached) = usage.cached_content_token_count {
170 let _ = self
171 .state
172 .session()
173 .set("cached_content_token_count", cached);
174 }
175 if let Some(thoughts) = usage.thoughts_token_count {
176 let _ = self.state.session().set("thoughts_token_count", thoughts);
177 }
178 }
179
180 SessionEvent::GenerationComplete => {
181 }
183
184 SessionEvent::InputTranscription(text) => {
185 let _ = self
186 .state
187 .session()
188 .set("last_input_transcription", text.clone());
189 self.touch_activity();
190 }
191
192 SessionEvent::OutputTranscription(text) => {
193 let _ = self
194 .state
195 .session()
196 .set("last_output_transcription", text.clone());
197 self.touch_activity();
198 }
199
200 SessionEvent::AudioData(_)
201 | SessionEvent::TextDelta(_)
202 | SessionEvent::TextComplete(_) => {
203 self.touch_activity();
206 }
207
208 SessionEvent::TurnComplete => {
209 self.touch_activity();
210 }
211
212 SessionEvent::Disconnected(_reason) => {
213 self.is_connected.store(false, Ordering::Relaxed);
214 let _ = self.state.session().set("disconnected", true);
215 }
216
217 _ => {}
218 }
219 }
220
221 pub fn flush_timing(&self) {
227 let last_activity = self.last_activity_ns.load(Ordering::Relaxed);
228 if last_activity > 0 {
229 let now_ns = self.elapsed_ns();
230 let silence_ms = now_ns.saturating_sub(last_activity) / 1_000_000;
231 let _ = self.state.session().set("silence_ms", silence_ms);
232 }
233
234 if self.is_connected.load(Ordering::Relaxed) {
235 let connected_ns = self.connected_at_ns.load(Ordering::Relaxed);
236 let now_ns = self.elapsed_ns();
237 let elapsed_ms = now_ns.saturating_sub(connected_ns) / 1_000_000;
238 let _ = self.state.session().set("elapsed_ms", elapsed_ms);
239
240 let limit_ms: u64 = match self.session_type() {
241 SessionType::AudioOnly => 15 * 60 * 1000,
242 SessionType::AudioVideo => 2 * 60 * 1000,
243 };
244 let remaining = limit_ms.saturating_sub(elapsed_ms);
245 let _ = self.state.session().set("remaining_budget_ms", remaining);
246 }
247 }
248
249 #[inline]
250 fn touch_activity(&self) {
251 self.last_activity_ns
252 .store(self.elapsed_ns(), Ordering::Relaxed);
253 }
254
255 #[inline]
256 fn elapsed_ns(&self) -> u64 {
257 self.start.elapsed().as_nanos() as u64
258 }
259
260 pub fn session_type(&self) -> SessionType {
262 if self.has_video.load(Ordering::Relaxed) {
263 SessionType::AudioVideo
264 } else {
265 SessionType::AudioOnly
266 }
267 }
268
269 pub fn latest_resume_handle(&self) -> Option<String> {
271 self.latest_resume_handle.lock().clone()
272 }
273
274 pub fn mark_video_sent(&self) {
276 if !self.has_video.swap(true, Ordering::Relaxed) {
277 let _ = self.state.session().set("session_type", "audio_video");
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use bytes::Bytes;
286 use gemini_genai_rs::prelude::SessionEvent;
287
288 fn signals() -> SessionSignals {
289 SessionSignals::new(State::new())
290 }
291
292 #[test]
293 fn connected_initializes_state() {
294 let s = signals();
295 s.on_event(&SessionEvent::Connected);
296
297 assert_eq!(s.state.session().get::<u64>("connected_at_ms"), Some(0));
298 assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(0));
299 assert_eq!(s.state.session().get::<u64>("error_count"), Some(0));
300 assert_eq!(
301 s.state.session().get::<bool>("is_user_speaking"),
302 Some(false)
303 );
304 assert_eq!(
305 s.state.session().get::<bool>("is_model_speaking"),
306 Some(false)
307 );
308 assert_eq!(
309 s.state.session().get::<bool>("go_away_received"),
310 Some(false)
311 );
312 assert_eq!(s.state.session().get::<bool>("resumable"), Some(false));
313 assert_eq!(
314 s.state.session().get::<String>("session_type"),
315 Some("audio_only".to_string())
316 );
317 assert!(s.is_connected.load(Ordering::Relaxed));
318 }
319
320 #[test]
321 fn voice_activity_toggles_user_speaking() {
322 let s = signals();
323 s.on_event(&SessionEvent::Connected);
324 s.on_event(&SessionEvent::VoiceActivityStart);
325 assert_eq!(
326 s.state.session().get::<bool>("is_user_speaking"),
327 Some(true)
328 );
329 s.on_event(&SessionEvent::VoiceActivityEnd);
330 assert_eq!(
331 s.state.session().get::<bool>("is_user_speaking"),
332 Some(false)
333 );
334 }
335
336 #[test]
337 fn interrupted_increments_count() {
338 let s = signals();
339 s.on_event(&SessionEvent::Connected);
340 s.on_event(&SessionEvent::Interrupted);
341 assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(1));
342 s.on_event(&SessionEvent::Interrupted);
343 assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(2));
344 s.on_event(&SessionEvent::Interrupted);
345 assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(3));
346 }
347
348 #[test]
349 fn error_increments_count() {
350 let s = signals();
351 s.on_event(&SessionEvent::Connected);
352 s.on_event(&SessionEvent::Error("oops".into()));
353 assert_eq!(s.state.session().get::<u64>("error_count"), Some(1));
354 assert_eq!(
355 s.state.session().get::<String>("last_error"),
356 Some("oops".into())
357 );
358 s.on_event(&SessionEvent::Error("oops2".into()));
359 assert_eq!(s.state.session().get::<u64>("error_count"), Some(2));
360 assert_eq!(
361 s.state.session().get::<String>("last_error"),
362 Some("oops2".into())
363 );
364 }
365
366 #[test]
367 fn phase_changed_sets_model_speaking() {
368 let s = signals();
369 s.on_event(&SessionEvent::Connected);
370 s.on_event(&SessionEvent::PhaseChanged(SessionPhase::ModelSpeaking));
371 assert_eq!(
372 s.state.session().get::<bool>("is_model_speaking"),
373 Some(true)
374 );
375 assert_eq!(
376 s.state.session().get::<String>("phase"),
377 Some("ModelSpeaking".into())
378 );
379 s.on_event(&SessionEvent::PhaseChanged(SessionPhase::Active));
380 assert_eq!(
381 s.state.session().get::<bool>("is_model_speaking"),
382 Some(false)
383 );
384 assert_eq!(
385 s.state.session().get::<String>("phase"),
386 Some("Active".into())
387 );
388 }
389
390 #[test]
391 fn go_away_sets_state() {
392 let s = signals();
393 s.on_event(&SessionEvent::Connected);
394 s.on_event(&SessionEvent::GoAway(Some("60s".into())));
395 assert_eq!(
396 s.state.session().get::<bool>("go_away_received"),
397 Some(true)
398 );
399 assert_eq!(
400 s.state.session().get::<String>("go_away_time_left"),
401 Some("60s".into())
402 );
403 assert_eq!(
404 s.state.session().get::<u64>("go_away_time_left_ms"),
405 Some(60_000)
406 );
407 assert!(s.go_away_at.lock().is_some());
408 }
409
410 #[test]
411 fn go_away_without_time_left() {
412 let s = signals();
413 s.on_event(&SessionEvent::Connected);
414 s.on_event(&SessionEvent::GoAway(None));
415 assert_eq!(
416 s.state.session().get::<bool>("go_away_received"),
417 Some(true)
418 );
419 assert_eq!(s.state.session().get::<String>("go_away_time_left"), None);
420 assert!(s.go_away_at.lock().is_none());
421 }
422
423 #[test]
424 fn session_resume_handle_stored() {
425 let s = signals();
426 s.on_event(&SessionEvent::Connected);
427 s.on_event(&SessionEvent::SessionResumeUpdate(
428 gemini_genai_rs::session::ResumeInfo {
429 handle: "handle-abc".into(),
430 resumable: true,
431 last_consumed_index: None,
432 },
433 ));
434 assert_eq!(s.state.session().get::<bool>("resumable"), Some(true));
435 assert_eq!(s.latest_resume_handle(), Some("handle-abc".to_string()));
436 }
437
438 #[test]
439 fn transcription_stores_last() {
440 let s = signals();
441 s.on_event(&SessionEvent::Connected);
442 s.on_event(&SessionEvent::InputTranscription("hello".into()));
443 assert_eq!(
444 s.state.session().get::<String>("last_input_transcription"),
445 Some("hello".into())
446 );
447 s.on_event(&SessionEvent::OutputTranscription("hi there".into()));
448 assert_eq!(
449 s.state.session().get::<String>("last_output_transcription"),
450 Some("hi there".into())
451 );
452 s.on_event(&SessionEvent::InputTranscription("bye".into()));
453 assert_eq!(
454 s.state.session().get::<String>("last_input_transcription"),
455 Some("bye".into())
456 );
457 }
458
459 #[test]
460 fn session_type_defaults_to_audio_only() {
461 let s = signals();
462 assert_eq!(s.session_type(), SessionType::AudioOnly);
463 }
464
465 #[test]
466 fn mark_video_sent_changes_session_type() {
467 let s = signals();
468 s.on_event(&SessionEvent::Connected);
469 assert_eq!(s.session_type(), SessionType::AudioOnly);
470 s.mark_video_sent();
471 assert_eq!(s.session_type(), SessionType::AudioVideo);
472 assert_eq!(
473 s.state.session().get::<String>("session_type"),
474 Some("audio_video".into())
475 );
476 }
477
478 #[test]
479 fn mark_video_sent_idempotent() {
480 let s = signals();
481 s.on_event(&SessionEvent::Connected);
482 s.mark_video_sent();
483 s.mark_video_sent();
484 assert_eq!(s.session_type(), SessionType::AudioVideo);
485 }
486
487 #[test]
488 fn flush_timing_after_connected() {
489 let s = signals();
490 s.on_event(&SessionEvent::Connected);
491 s.flush_timing();
492 let elapsed: u64 = s.state.session().get("elapsed_ms").unwrap_or(0);
493 assert!(elapsed < 100, "elapsed should be near zero, got {elapsed}");
494 let remaining: u64 = s.state.session().get("remaining_budget_ms").unwrap();
495 let limit = 15 * 60 * 1000u64;
496 assert!(
497 remaining > limit - 1000,
498 "remaining should be near limit, got {remaining}"
499 );
500 }
501
502 #[test]
503 fn flush_timing_respects_video_budget() {
504 let s = signals();
505 s.on_event(&SessionEvent::Connected);
506 s.flush_timing();
507 let remaining_audio: u64 = s.state.session().get("remaining_budget_ms").unwrap();
508 assert!(remaining_audio > 14 * 60 * 1000);
509 s.mark_video_sent();
510 s.flush_timing();
511 let remaining_video: u64 = s.state.session().get("remaining_budget_ms").unwrap();
512 assert!(
513 remaining_video <= 2 * 60 * 1000,
514 "video remaining should be <= 120_000, got {remaining_video}"
515 );
516 }
517
518 #[test]
519 fn latest_resume_handle_initially_none() {
520 let s = signals();
521 assert_eq!(s.latest_resume_handle(), None);
522 }
523
524 #[test]
525 fn latest_resume_handle_updates() {
526 let s = signals();
527 s.on_event(&SessionEvent::SessionResumeUpdate(
528 gemini_genai_rs::session::ResumeInfo {
529 handle: "h1".into(),
530 resumable: true,
531 last_consumed_index: None,
532 },
533 ));
534 assert_eq!(s.latest_resume_handle(), Some("h1".to_string()));
535 s.on_event(&SessionEvent::SessionResumeUpdate(
536 gemini_genai_rs::session::ResumeInfo {
537 handle: "h2".into(),
538 resumable: true,
539 last_consumed_index: Some("5".into()),
540 },
541 ));
542 assert_eq!(s.latest_resume_handle(), Some("h2".to_string()));
543 }
544
545 #[test]
546 fn silence_ms_tracked() {
547 let s = signals();
548 s.on_event(&SessionEvent::Connected);
549 s.flush_timing();
550 let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
551 assert!(silence < 100, "silence should be near zero, got {silence}");
552 }
553
554 #[test]
555 fn audio_data_updates_activity() {
556 let s = signals();
557 s.on_event(&SessionEvent::Connected);
558 s.on_event(&SessionEvent::AudioData(Bytes::from_static(b"pcm")));
559 s.flush_timing();
560 let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
561 assert!(silence < 100);
562 }
563
564 #[test]
565 fn turn_complete_updates_activity() {
566 let s = signals();
567 s.on_event(&SessionEvent::Connected);
568 s.on_event(&SessionEvent::TurnComplete);
569 s.flush_timing();
570 let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
571 assert!(silence < 100);
572 }
573
574 #[test]
575 fn text_complete_updates_activity() {
576 let s = signals();
577 s.on_event(&SessionEvent::Connected);
578 s.on_event(&SessionEvent::TextComplete("done".into()));
579 s.flush_timing();
580 let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
581 assert!(silence < 100);
582 }
583
584 #[test]
585 fn disconnected_clears_connected_and_sets_flag() {
586 let s = signals();
587 s.on_event(&SessionEvent::Connected);
588 assert!(s.is_connected.load(Ordering::Relaxed));
589 s.on_event(&SessionEvent::Disconnected(Some("server closed".into())));
590 assert!(!s.is_connected.load(Ordering::Relaxed));
591 assert_eq!(s.state.session().get::<bool>("disconnected"), Some(true));
592 }
593
594 #[test]
595 fn disconnected_without_reason() {
596 let s = signals();
597 s.on_event(&SessionEvent::Connected);
598 s.on_event(&SessionEvent::Disconnected(None));
599 assert!(!s.is_connected.load(Ordering::Relaxed));
600 assert_eq!(s.state.session().get::<bool>("disconnected"), Some(true));
601 }
602}