1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::time::Instant;
14
15use gemini_genai_rs::prelude::{SessionEvent, SessionPhase};
16use parking_lot::Mutex;
17
18use crate::state::State;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum SessionType {
27 AudioOnly,
29 AudioVideo,
31}
32
33pub struct SessionSignals {
47 state: State,
48 start: Instant,
50 connected_at_ns: AtomicU64,
52 is_connected: AtomicBool,
54 last_activity_ns: AtomicU64,
56 has_video: AtomicBool,
58 go_away_at: Mutex<Option<Instant>>,
60 latest_resume_handle: Mutex<Option<String>>,
62}
63
64impl SessionSignals {
65 pub fn new(state: State) -> Self {
67 Self {
68 state,
69 start: Instant::now(),
70 connected_at_ns: AtomicU64::new(0),
71 is_connected: AtomicBool::new(false),
72 last_activity_ns: AtomicU64::new(0),
73 has_video: AtomicBool::new(false),
74 go_away_at: Mutex::new(None),
75 latest_resume_handle: Mutex::new(None),
76 }
77 }
78
79 pub fn on_event(&self, event: &SessionEvent) {
86 match event {
87 SessionEvent::Connected => {
88 let now_ns = self.elapsed_ns();
89 self.connected_at_ns.store(now_ns, Ordering::Relaxed);
90 self.is_connected.store(true, Ordering::Relaxed);
91 self.last_activity_ns.store(now_ns, Ordering::Relaxed);
92 self.state.session().set("connected_at_ms", 0u64);
93 self.state.session().set("interrupt_count", 0u64);
94 self.state.session().set("error_count", 0u64);
95 self.state.session().set("is_user_speaking", false);
96 self.state.session().set("is_model_speaking", false);
97 self.state.session().set("go_away_received", false);
98 self.state.session().set("resumable", false);
99 self.state.session().set("session_type", "audio_only");
100 }
101
102 SessionEvent::VoiceActivityStart => {
103 self.state.session().set("is_user_speaking", true);
104 self.touch_activity();
105 }
106
107 SessionEvent::VoiceActivityEnd => {
108 self.state.session().set("is_user_speaking", false);
109 self.touch_activity();
110 }
111
112 SessionEvent::Interrupted => {
113 let count: u64 = self.state.session().get("interrupt_count").unwrap_or(0);
114 self.state.session().set("interrupt_count", count + 1);
115 self.touch_activity();
116 }
117
118 SessionEvent::Error(msg) => {
119 let count: u64 = self.state.session().get("error_count").unwrap_or(0);
120 self.state.session().set("error_count", count + 1);
121 self.state.session().set("last_error", msg.clone());
122 }
123
124 SessionEvent::PhaseChanged(phase) => {
125 self.state
126 .session()
127 .set("is_model_speaking", *phase == SessionPhase::ModelSpeaking);
128 self.state.session().set("phase", phase.to_string());
129 self.touch_activity();
130 }
131
132 SessionEvent::GoAway(time_left) => {
133 self.state.session().set("go_away_received", true);
134 if let Some(ref tl) = time_left {
135 self.state.session().set("go_away_time_left", tl.clone());
136 if let Ok(secs) = tl.trim_end_matches('s').parse::<u64>() {
137 let deadline = Instant::now() + std::time::Duration::from_secs(secs);
138 *self.go_away_at.lock() = Some(deadline);
139 self.state
140 .session()
141 .set("go_away_time_left_ms", secs * 1000);
142 }
143 }
144 }
145
146 SessionEvent::SessionResumeUpdate(info) => {
147 *self.latest_resume_handle.lock() = Some(info.handle.clone());
148 self.state.session().set("resumable", info.resumable);
149 if let Some(ref idx) = info.last_consumed_index {
150 self.state
151 .session()
152 .set("last_consumed_client_index", idx.clone());
153 }
154 }
155
156 SessionEvent::Usage(usage) => {
157 if let Some(total) = usage.total_token_count {
158 self.state.session().set("total_token_count", total);
159 }
160 if let Some(prompt) = usage.prompt_token_count {
161 self.state.session().set("prompt_token_count", prompt);
162 }
163 if let Some(response) = usage.response_token_count {
164 self.state.session().set("response_token_count", response);
165 }
166 if let Some(cached) = usage.cached_content_token_count {
167 self.state
168 .session()
169 .set("cached_content_token_count", cached);
170 }
171 if let Some(thoughts) = usage.thoughts_token_count {
172 self.state.session().set("thoughts_token_count", thoughts);
173 }
174 }
175
176 SessionEvent::GenerationComplete => {
177 }
179
180 SessionEvent::InputTranscription(text) => {
181 self.state
182 .session()
183 .set("last_input_transcription", text.clone());
184 self.touch_activity();
185 }
186
187 SessionEvent::OutputTranscription(text) => {
188 self.state
189 .session()
190 .set("last_output_transcription", text.clone());
191 self.touch_activity();
192 }
193
194 SessionEvent::AudioData(_)
195 | SessionEvent::TextDelta(_)
196 | SessionEvent::TextComplete(_) => {
197 self.touch_activity();
200 }
201
202 SessionEvent::TurnComplete => {
203 self.touch_activity();
204 }
205
206 SessionEvent::Disconnected(_reason) => {
207 self.is_connected.store(false, Ordering::Relaxed);
208 self.state.session().set("disconnected", true);
209 }
210
211 _ => {}
212 }
213 }
214
215 pub fn flush_timing(&self) {
221 let last_activity = self.last_activity_ns.load(Ordering::Relaxed);
222 if last_activity > 0 {
223 let now_ns = self.elapsed_ns();
224 let silence_ms = now_ns.saturating_sub(last_activity) / 1_000_000;
225 self.state.session().set("silence_ms", silence_ms);
226 }
227
228 if self.is_connected.load(Ordering::Relaxed) {
229 let connected_ns = self.connected_at_ns.load(Ordering::Relaxed);
230 let now_ns = self.elapsed_ns();
231 let elapsed_ms = now_ns.saturating_sub(connected_ns) / 1_000_000;
232 self.state.session().set("elapsed_ms", elapsed_ms);
233
234 let limit_ms: u64 = match self.session_type() {
235 SessionType::AudioOnly => 15 * 60 * 1000,
236 SessionType::AudioVideo => 2 * 60 * 1000,
237 };
238 let remaining = limit_ms.saturating_sub(elapsed_ms);
239 self.state.session().set("remaining_budget_ms", remaining);
240 }
241 }
242
243 #[inline]
244 fn touch_activity(&self) {
245 self.last_activity_ns
246 .store(self.elapsed_ns(), Ordering::Relaxed);
247 }
248
249 #[inline]
250 fn elapsed_ns(&self) -> u64 {
251 self.start.elapsed().as_nanos() as u64
252 }
253
254 pub fn session_type(&self) -> SessionType {
256 if self.has_video.load(Ordering::Relaxed) {
257 SessionType::AudioVideo
258 } else {
259 SessionType::AudioOnly
260 }
261 }
262
263 pub fn latest_resume_handle(&self) -> Option<String> {
265 self.latest_resume_handle.lock().clone()
266 }
267
268 pub fn mark_video_sent(&self) {
270 if !self.has_video.swap(true, Ordering::Relaxed) {
271 self.state.session().set("session_type", "audio_video");
272 }
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use bytes::Bytes;
280 use gemini_genai_rs::prelude::SessionEvent;
281
282 fn signals() -> SessionSignals {
283 SessionSignals::new(State::new())
284 }
285
286 #[test]
287 fn connected_initializes_state() {
288 let s = signals();
289 s.on_event(&SessionEvent::Connected);
290
291 assert_eq!(s.state.session().get::<u64>("connected_at_ms"), Some(0));
292 assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(0));
293 assert_eq!(s.state.session().get::<u64>("error_count"), Some(0));
294 assert_eq!(
295 s.state.session().get::<bool>("is_user_speaking"),
296 Some(false)
297 );
298 assert_eq!(
299 s.state.session().get::<bool>("is_model_speaking"),
300 Some(false)
301 );
302 assert_eq!(
303 s.state.session().get::<bool>("go_away_received"),
304 Some(false)
305 );
306 assert_eq!(s.state.session().get::<bool>("resumable"), Some(false));
307 assert_eq!(
308 s.state.session().get::<String>("session_type"),
309 Some("audio_only".to_string())
310 );
311 assert!(s.is_connected.load(Ordering::Relaxed));
312 }
313
314 #[test]
315 fn voice_activity_toggles_user_speaking() {
316 let s = signals();
317 s.on_event(&SessionEvent::Connected);
318 s.on_event(&SessionEvent::VoiceActivityStart);
319 assert_eq!(
320 s.state.session().get::<bool>("is_user_speaking"),
321 Some(true)
322 );
323 s.on_event(&SessionEvent::VoiceActivityEnd);
324 assert_eq!(
325 s.state.session().get::<bool>("is_user_speaking"),
326 Some(false)
327 );
328 }
329
330 #[test]
331 fn interrupted_increments_count() {
332 let s = signals();
333 s.on_event(&SessionEvent::Connected);
334 s.on_event(&SessionEvent::Interrupted);
335 assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(1));
336 s.on_event(&SessionEvent::Interrupted);
337 assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(2));
338 s.on_event(&SessionEvent::Interrupted);
339 assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(3));
340 }
341
342 #[test]
343 fn error_increments_count() {
344 let s = signals();
345 s.on_event(&SessionEvent::Connected);
346 s.on_event(&SessionEvent::Error("oops".into()));
347 assert_eq!(s.state.session().get::<u64>("error_count"), Some(1));
348 assert_eq!(
349 s.state.session().get::<String>("last_error"),
350 Some("oops".into())
351 );
352 s.on_event(&SessionEvent::Error("oops2".into()));
353 assert_eq!(s.state.session().get::<u64>("error_count"), Some(2));
354 assert_eq!(
355 s.state.session().get::<String>("last_error"),
356 Some("oops2".into())
357 );
358 }
359
360 #[test]
361 fn phase_changed_sets_model_speaking() {
362 let s = signals();
363 s.on_event(&SessionEvent::Connected);
364 s.on_event(&SessionEvent::PhaseChanged(SessionPhase::ModelSpeaking));
365 assert_eq!(
366 s.state.session().get::<bool>("is_model_speaking"),
367 Some(true)
368 );
369 assert_eq!(
370 s.state.session().get::<String>("phase"),
371 Some("ModelSpeaking".into())
372 );
373 s.on_event(&SessionEvent::PhaseChanged(SessionPhase::Active));
374 assert_eq!(
375 s.state.session().get::<bool>("is_model_speaking"),
376 Some(false)
377 );
378 assert_eq!(
379 s.state.session().get::<String>("phase"),
380 Some("Active".into())
381 );
382 }
383
384 #[test]
385 fn go_away_sets_state() {
386 let s = signals();
387 s.on_event(&SessionEvent::Connected);
388 s.on_event(&SessionEvent::GoAway(Some("60s".into())));
389 assert_eq!(
390 s.state.session().get::<bool>("go_away_received"),
391 Some(true)
392 );
393 assert_eq!(
394 s.state.session().get::<String>("go_away_time_left"),
395 Some("60s".into())
396 );
397 assert_eq!(
398 s.state.session().get::<u64>("go_away_time_left_ms"),
399 Some(60_000)
400 );
401 assert!(s.go_away_at.lock().is_some());
402 }
403
404 #[test]
405 fn go_away_without_time_left() {
406 let s = signals();
407 s.on_event(&SessionEvent::Connected);
408 s.on_event(&SessionEvent::GoAway(None));
409 assert_eq!(
410 s.state.session().get::<bool>("go_away_received"),
411 Some(true)
412 );
413 assert_eq!(s.state.session().get::<String>("go_away_time_left"), None);
414 assert!(s.go_away_at.lock().is_none());
415 }
416
417 #[test]
418 fn session_resume_handle_stored() {
419 let s = signals();
420 s.on_event(&SessionEvent::Connected);
421 s.on_event(&SessionEvent::SessionResumeUpdate(
422 gemini_genai_rs::session::ResumeInfo {
423 handle: "handle-abc".into(),
424 resumable: true,
425 last_consumed_index: None,
426 },
427 ));
428 assert_eq!(s.state.session().get::<bool>("resumable"), Some(true));
429 assert_eq!(s.latest_resume_handle(), Some("handle-abc".to_string()));
430 }
431
432 #[test]
433 fn transcription_stores_last() {
434 let s = signals();
435 s.on_event(&SessionEvent::Connected);
436 s.on_event(&SessionEvent::InputTranscription("hello".into()));
437 assert_eq!(
438 s.state.session().get::<String>("last_input_transcription"),
439 Some("hello".into())
440 );
441 s.on_event(&SessionEvent::OutputTranscription("hi there".into()));
442 assert_eq!(
443 s.state.session().get::<String>("last_output_transcription"),
444 Some("hi there".into())
445 );
446 s.on_event(&SessionEvent::InputTranscription("bye".into()));
447 assert_eq!(
448 s.state.session().get::<String>("last_input_transcription"),
449 Some("bye".into())
450 );
451 }
452
453 #[test]
454 fn session_type_defaults_to_audio_only() {
455 let s = signals();
456 assert_eq!(s.session_type(), SessionType::AudioOnly);
457 }
458
459 #[test]
460 fn mark_video_sent_changes_session_type() {
461 let s = signals();
462 s.on_event(&SessionEvent::Connected);
463 assert_eq!(s.session_type(), SessionType::AudioOnly);
464 s.mark_video_sent();
465 assert_eq!(s.session_type(), SessionType::AudioVideo);
466 assert_eq!(
467 s.state.session().get::<String>("session_type"),
468 Some("audio_video".into())
469 );
470 }
471
472 #[test]
473 fn mark_video_sent_idempotent() {
474 let s = signals();
475 s.on_event(&SessionEvent::Connected);
476 s.mark_video_sent();
477 s.mark_video_sent();
478 assert_eq!(s.session_type(), SessionType::AudioVideo);
479 }
480
481 #[test]
482 fn flush_timing_after_connected() {
483 let s = signals();
484 s.on_event(&SessionEvent::Connected);
485 s.flush_timing();
486 let elapsed: u64 = s.state.session().get("elapsed_ms").unwrap_or(0);
487 assert!(elapsed < 100, "elapsed should be near zero, got {elapsed}");
488 let remaining: u64 = s.state.session().get("remaining_budget_ms").unwrap();
489 let limit = 15 * 60 * 1000u64;
490 assert!(
491 remaining > limit - 1000,
492 "remaining should be near limit, got {remaining}"
493 );
494 }
495
496 #[test]
497 fn flush_timing_respects_video_budget() {
498 let s = signals();
499 s.on_event(&SessionEvent::Connected);
500 s.flush_timing();
501 let remaining_audio: u64 = s.state.session().get("remaining_budget_ms").unwrap();
502 assert!(remaining_audio > 14 * 60 * 1000);
503 s.mark_video_sent();
504 s.flush_timing();
505 let remaining_video: u64 = s.state.session().get("remaining_budget_ms").unwrap();
506 assert!(
507 remaining_video <= 2 * 60 * 1000,
508 "video remaining should be <= 120_000, got {remaining_video}"
509 );
510 }
511
512 #[test]
513 fn latest_resume_handle_initially_none() {
514 let s = signals();
515 assert_eq!(s.latest_resume_handle(), None);
516 }
517
518 #[test]
519 fn latest_resume_handle_updates() {
520 let s = signals();
521 s.on_event(&SessionEvent::SessionResumeUpdate(
522 gemini_genai_rs::session::ResumeInfo {
523 handle: "h1".into(),
524 resumable: true,
525 last_consumed_index: None,
526 },
527 ));
528 assert_eq!(s.latest_resume_handle(), Some("h1".to_string()));
529 s.on_event(&SessionEvent::SessionResumeUpdate(
530 gemini_genai_rs::session::ResumeInfo {
531 handle: "h2".into(),
532 resumable: true,
533 last_consumed_index: Some("5".into()),
534 },
535 ));
536 assert_eq!(s.latest_resume_handle(), Some("h2".to_string()));
537 }
538
539 #[test]
540 fn silence_ms_tracked() {
541 let s = signals();
542 s.on_event(&SessionEvent::Connected);
543 s.flush_timing();
544 let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
545 assert!(silence < 100, "silence should be near zero, got {silence}");
546 }
547
548 #[test]
549 fn audio_data_updates_activity() {
550 let s = signals();
551 s.on_event(&SessionEvent::Connected);
552 s.on_event(&SessionEvent::AudioData(Bytes::from_static(b"pcm")));
553 s.flush_timing();
554 let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
555 assert!(silence < 100);
556 }
557
558 #[test]
559 fn turn_complete_updates_activity() {
560 let s = signals();
561 s.on_event(&SessionEvent::Connected);
562 s.on_event(&SessionEvent::TurnComplete);
563 s.flush_timing();
564 let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
565 assert!(silence < 100);
566 }
567
568 #[test]
569 fn text_complete_updates_activity() {
570 let s = signals();
571 s.on_event(&SessionEvent::Connected);
572 s.on_event(&SessionEvent::TextComplete("done".into()));
573 s.flush_timing();
574 let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
575 assert!(silence < 100);
576 }
577
578 #[test]
579 fn disconnected_clears_connected_and_sets_flag() {
580 let s = signals();
581 s.on_event(&SessionEvent::Connected);
582 assert!(s.is_connected.load(Ordering::Relaxed));
583 s.on_event(&SessionEvent::Disconnected(Some("server closed".into())));
584 assert!(!s.is_connected.load(Ordering::Relaxed));
585 assert_eq!(s.state.session().get::<bool>("disconnected"), Some(true));
586 }
587
588 #[test]
589 fn disconnected_without_reason() {
590 let s = signals();
591 s.on_event(&SessionEvent::Connected);
592 s.on_event(&SessionEvent::Disconnected(None));
593 assert!(!s.is_connected.load(Ordering::Relaxed));
594 assert_eq!(s.state.session().get::<bool>("disconnected"), Some(true));
595 }
596}