gemini_adk_rs/live/
session_signals.rs

1//! Auto-tracked session-level state signals.
2//!
3//! [`SessionSignals`] is called by the telemetry lane on every
4//! [`SessionEvent`] and transparently updates keys under the `session:`
5//! prefix in the shared [`State`].
6//!
7//! Hot-path timestamps use [`AtomicU64`] (nanos since start) instead of
8//! `Mutex<Instant>`, eliminating per-event mutex contention. Derived
9//! timing signals (`silence_ms`, `elapsed_ms`, `remaining_budget_ms`)
10//! are flushed periodically via `flush_timing()` rather than on every event.
11
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::time::Instant;
14
15use gemini_genai_rs::prelude::{SessionEvent, SessionPhase};
16use parking_lot::Mutex;
17
18use crate::state::State;
19
20// ---------------------------------------------------------------------------
21// Public types
22// ---------------------------------------------------------------------------
23
24/// Session type determines the server-side duration limit.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum SessionType {
27    /// Audio-only session (~15 min limit).
28    AudioOnly,
29    /// Audio + video session (~2 min limit).
30    AudioVideo,
31}
32
33// ---------------------------------------------------------------------------
34// SessionSignals
35// ---------------------------------------------------------------------------
36
37/// Tracks session-level signals automatically from events.
38///
39/// Every call to [`on_event`](SessionSignals::on_event) updates the
40/// corresponding keys under `session:` in the shared [`State`], making
41/// them available to instruction templates, watchers, and computed vars.
42///
43/// **Performance**: Timestamps use `AtomicU64` (nanos since session start)
44/// instead of `Mutex<Instant>`. Derived timing signals are flushed
45/// periodically via `flush_timing()` (100ms interval) rather than per-event.
46pub struct SessionSignals {
47    state: State,
48    /// Session start time — used as epoch for all atomic timestamps.
49    start: Instant,
50    /// Nanos since start when connected.
51    connected_at_ns: AtomicU64,
52    /// Whether currently connected.
53    is_connected: AtomicBool,
54    /// Nanos since start of last activity.
55    last_activity_ns: AtomicU64,
56    /// Whether the session includes video input.
57    has_video: AtomicBool,
58    /// Server-sent GoAway timestamp, if received.
59    go_away_at: Mutex<Option<Instant>>,
60    /// Latest resumption handle from server (persisted for reconnection).
61    latest_resume_handle: Mutex<Option<String>>,
62}
63
64impl SessionSignals {
65    /// Create a new `SessionSignals` backed by the given [`State`].
66    pub fn new(state: State) -> Self {
67        Self {
68            state,
69            start: Instant::now(),
70            connected_at_ns: AtomicU64::new(0),
71            is_connected: AtomicBool::new(false),
72            last_activity_ns: AtomicU64::new(0),
73            has_video: AtomicBool::new(false),
74            go_away_at: Mutex::new(None),
75            latest_resume_handle: Mutex::new(None),
76        }
77    }
78
79    /// Process an event — updates state keys and atomic timestamps.
80    ///
81    /// This is the per-event handler. It updates boolean flags, counters,
82    /// and atomic timestamps. **Derived timing** (silence_ms, elapsed_ms,
83    /// remaining_budget_ms) is NOT computed here — call `flush_timing()`
84    /// periodically instead.
85    pub fn on_event(&self, event: &SessionEvent) {
86        match event {
87            SessionEvent::Connected => {
88                let now_ns = self.elapsed_ns();
89                self.connected_at_ns.store(now_ns, Ordering::Relaxed);
90                self.is_connected.store(true, Ordering::Relaxed);
91                self.last_activity_ns.store(now_ns, Ordering::Relaxed);
92                self.state.session().set("connected_at_ms", 0u64);
93                self.state.session().set("interrupt_count", 0u64);
94                self.state.session().set("error_count", 0u64);
95                self.state.session().set("is_user_speaking", false);
96                self.state.session().set("is_model_speaking", false);
97                self.state.session().set("go_away_received", false);
98                self.state.session().set("resumable", false);
99                self.state.session().set("session_type", "audio_only");
100            }
101
102            SessionEvent::VoiceActivityStart => {
103                self.state.session().set("is_user_speaking", true);
104                self.touch_activity();
105            }
106
107            SessionEvent::VoiceActivityEnd => {
108                self.state.session().set("is_user_speaking", false);
109                self.touch_activity();
110            }
111
112            SessionEvent::Interrupted => {
113                let count: u64 = self.state.session().get("interrupt_count").unwrap_or(0);
114                self.state.session().set("interrupt_count", count + 1);
115                self.touch_activity();
116            }
117
118            SessionEvent::Error(msg) => {
119                let count: u64 = self.state.session().get("error_count").unwrap_or(0);
120                self.state.session().set("error_count", count + 1);
121                self.state.session().set("last_error", msg.clone());
122            }
123
124            SessionEvent::PhaseChanged(phase) => {
125                self.state
126                    .session()
127                    .set("is_model_speaking", *phase == SessionPhase::ModelSpeaking);
128                self.state.session().set("phase", phase.to_string());
129                self.touch_activity();
130            }
131
132            SessionEvent::GoAway(time_left) => {
133                self.state.session().set("go_away_received", true);
134                if let Some(ref tl) = time_left {
135                    self.state.session().set("go_away_time_left", tl.clone());
136                    if let Ok(secs) = tl.trim_end_matches('s').parse::<u64>() {
137                        let deadline = Instant::now() + std::time::Duration::from_secs(secs);
138                        *self.go_away_at.lock() = Some(deadline);
139                        self.state
140                            .session()
141                            .set("go_away_time_left_ms", secs * 1000);
142                    }
143                }
144            }
145
146            SessionEvent::SessionResumeUpdate(info) => {
147                *self.latest_resume_handle.lock() = Some(info.handle.clone());
148                self.state.session().set("resumable", info.resumable);
149                if let Some(ref idx) = info.last_consumed_index {
150                    self.state
151                        .session()
152                        .set("last_consumed_client_index", idx.clone());
153                }
154            }
155
156            SessionEvent::Usage(usage) => {
157                if let Some(total) = usage.total_token_count {
158                    self.state.session().set("total_token_count", total);
159                }
160                if let Some(prompt) = usage.prompt_token_count {
161                    self.state.session().set("prompt_token_count", prompt);
162                }
163                if let Some(response) = usage.response_token_count {
164                    self.state.session().set("response_token_count", response);
165                }
166                if let Some(cached) = usage.cached_content_token_count {
167                    self.state
168                        .session()
169                        .set("cached_content_token_count", cached);
170                }
171                if let Some(thoughts) = usage.thoughts_token_count {
172                    self.state.session().set("thoughts_token_count", thoughts);
173                }
174            }
175
176            SessionEvent::GenerationComplete => {
177                // No-op for signals — generation complete is handled by control lane
178            }
179
180            SessionEvent::InputTranscription(text) => {
181                self.state
182                    .session()
183                    .set("last_input_transcription", text.clone());
184                self.touch_activity();
185            }
186
187            SessionEvent::OutputTranscription(text) => {
188                self.state
189                    .session()
190                    .set("last_output_transcription", text.clone());
191                self.touch_activity();
192            }
193
194            SessionEvent::AudioData(_)
195            | SessionEvent::TextDelta(_)
196            | SessionEvent::TextComplete(_) => {
197                // High-frequency events: only touch the atomic timestamp.
198                // No DashMap writes, no mutex locks.
199                self.touch_activity();
200            }
201
202            SessionEvent::TurnComplete => {
203                self.touch_activity();
204            }
205
206            SessionEvent::Disconnected(_reason) => {
207                self.is_connected.store(false, Ordering::Relaxed);
208                self.state.session().set("disconnected", true);
209            }
210
211            _ => {}
212        }
213    }
214
215    /// Flush derived timing signals to state.
216    ///
217    /// Call this periodically (e.g., every 100ms) from the telemetry lane.
218    /// Computes `silence_ms`, `elapsed_ms`, and `remaining_budget_ms` from
219    /// atomic timestamps without any mutex locks.
220    pub fn flush_timing(&self) {
221        let last_activity = self.last_activity_ns.load(Ordering::Relaxed);
222        if last_activity > 0 {
223            let now_ns = self.elapsed_ns();
224            let silence_ms = now_ns.saturating_sub(last_activity) / 1_000_000;
225            self.state.session().set("silence_ms", silence_ms);
226        }
227
228        if self.is_connected.load(Ordering::Relaxed) {
229            let connected_ns = self.connected_at_ns.load(Ordering::Relaxed);
230            let now_ns = self.elapsed_ns();
231            let elapsed_ms = now_ns.saturating_sub(connected_ns) / 1_000_000;
232            self.state.session().set("elapsed_ms", elapsed_ms);
233
234            let limit_ms: u64 = match self.session_type() {
235                SessionType::AudioOnly => 15 * 60 * 1000,
236                SessionType::AudioVideo => 2 * 60 * 1000,
237            };
238            let remaining = limit_ms.saturating_sub(elapsed_ms);
239            self.state.session().set("remaining_budget_ms", remaining);
240        }
241    }
242
243    #[inline]
244    fn touch_activity(&self) {
245        self.last_activity_ns
246            .store(self.elapsed_ns(), Ordering::Relaxed);
247    }
248
249    #[inline]
250    fn elapsed_ns(&self) -> u64 {
251        self.start.elapsed().as_nanos() as u64
252    }
253
254    /// Returns the current session type based on whether video has been sent.
255    pub fn session_type(&self) -> SessionType {
256        if self.has_video.load(Ordering::Relaxed) {
257            SessionType::AudioVideo
258        } else {
259            SessionType::AudioOnly
260        }
261    }
262
263    /// Returns the latest resumption handle for reconnection.
264    pub fn latest_resume_handle(&self) -> Option<String> {
265        self.latest_resume_handle.lock().clone()
266    }
267
268    /// Mark that video has been sent (changes session type to `AudioVideo`).
269    pub fn mark_video_sent(&self) {
270        if !self.has_video.swap(true, Ordering::Relaxed) {
271            self.state.session().set("session_type", "audio_video");
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use bytes::Bytes;
280    use gemini_genai_rs::prelude::SessionEvent;
281
282    fn signals() -> SessionSignals {
283        SessionSignals::new(State::new())
284    }
285
286    #[test]
287    fn connected_initializes_state() {
288        let s = signals();
289        s.on_event(&SessionEvent::Connected);
290
291        assert_eq!(s.state.session().get::<u64>("connected_at_ms"), Some(0));
292        assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(0));
293        assert_eq!(s.state.session().get::<u64>("error_count"), Some(0));
294        assert_eq!(
295            s.state.session().get::<bool>("is_user_speaking"),
296            Some(false)
297        );
298        assert_eq!(
299            s.state.session().get::<bool>("is_model_speaking"),
300            Some(false)
301        );
302        assert_eq!(
303            s.state.session().get::<bool>("go_away_received"),
304            Some(false)
305        );
306        assert_eq!(s.state.session().get::<bool>("resumable"), Some(false));
307        assert_eq!(
308            s.state.session().get::<String>("session_type"),
309            Some("audio_only".to_string())
310        );
311        assert!(s.is_connected.load(Ordering::Relaxed));
312    }
313
314    #[test]
315    fn voice_activity_toggles_user_speaking() {
316        let s = signals();
317        s.on_event(&SessionEvent::Connected);
318        s.on_event(&SessionEvent::VoiceActivityStart);
319        assert_eq!(
320            s.state.session().get::<bool>("is_user_speaking"),
321            Some(true)
322        );
323        s.on_event(&SessionEvent::VoiceActivityEnd);
324        assert_eq!(
325            s.state.session().get::<bool>("is_user_speaking"),
326            Some(false)
327        );
328    }
329
330    #[test]
331    fn interrupted_increments_count() {
332        let s = signals();
333        s.on_event(&SessionEvent::Connected);
334        s.on_event(&SessionEvent::Interrupted);
335        assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(1));
336        s.on_event(&SessionEvent::Interrupted);
337        assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(2));
338        s.on_event(&SessionEvent::Interrupted);
339        assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(3));
340    }
341
342    #[test]
343    fn error_increments_count() {
344        let s = signals();
345        s.on_event(&SessionEvent::Connected);
346        s.on_event(&SessionEvent::Error("oops".into()));
347        assert_eq!(s.state.session().get::<u64>("error_count"), Some(1));
348        assert_eq!(
349            s.state.session().get::<String>("last_error"),
350            Some("oops".into())
351        );
352        s.on_event(&SessionEvent::Error("oops2".into()));
353        assert_eq!(s.state.session().get::<u64>("error_count"), Some(2));
354        assert_eq!(
355            s.state.session().get::<String>("last_error"),
356            Some("oops2".into())
357        );
358    }
359
360    #[test]
361    fn phase_changed_sets_model_speaking() {
362        let s = signals();
363        s.on_event(&SessionEvent::Connected);
364        s.on_event(&SessionEvent::PhaseChanged(SessionPhase::ModelSpeaking));
365        assert_eq!(
366            s.state.session().get::<bool>("is_model_speaking"),
367            Some(true)
368        );
369        assert_eq!(
370            s.state.session().get::<String>("phase"),
371            Some("ModelSpeaking".into())
372        );
373        s.on_event(&SessionEvent::PhaseChanged(SessionPhase::Active));
374        assert_eq!(
375            s.state.session().get::<bool>("is_model_speaking"),
376            Some(false)
377        );
378        assert_eq!(
379            s.state.session().get::<String>("phase"),
380            Some("Active".into())
381        );
382    }
383
384    #[test]
385    fn go_away_sets_state() {
386        let s = signals();
387        s.on_event(&SessionEvent::Connected);
388        s.on_event(&SessionEvent::GoAway(Some("60s".into())));
389        assert_eq!(
390            s.state.session().get::<bool>("go_away_received"),
391            Some(true)
392        );
393        assert_eq!(
394            s.state.session().get::<String>("go_away_time_left"),
395            Some("60s".into())
396        );
397        assert_eq!(
398            s.state.session().get::<u64>("go_away_time_left_ms"),
399            Some(60_000)
400        );
401        assert!(s.go_away_at.lock().is_some());
402    }
403
404    #[test]
405    fn go_away_without_time_left() {
406        let s = signals();
407        s.on_event(&SessionEvent::Connected);
408        s.on_event(&SessionEvent::GoAway(None));
409        assert_eq!(
410            s.state.session().get::<bool>("go_away_received"),
411            Some(true)
412        );
413        assert_eq!(s.state.session().get::<String>("go_away_time_left"), None);
414        assert!(s.go_away_at.lock().is_none());
415    }
416
417    #[test]
418    fn session_resume_handle_stored() {
419        let s = signals();
420        s.on_event(&SessionEvent::Connected);
421        s.on_event(&SessionEvent::SessionResumeUpdate(
422            gemini_genai_rs::session::ResumeInfo {
423                handle: "handle-abc".into(),
424                resumable: true,
425                last_consumed_index: None,
426            },
427        ));
428        assert_eq!(s.state.session().get::<bool>("resumable"), Some(true));
429        assert_eq!(s.latest_resume_handle(), Some("handle-abc".to_string()));
430    }
431
432    #[test]
433    fn transcription_stores_last() {
434        let s = signals();
435        s.on_event(&SessionEvent::Connected);
436        s.on_event(&SessionEvent::InputTranscription("hello".into()));
437        assert_eq!(
438            s.state.session().get::<String>("last_input_transcription"),
439            Some("hello".into())
440        );
441        s.on_event(&SessionEvent::OutputTranscription("hi there".into()));
442        assert_eq!(
443            s.state.session().get::<String>("last_output_transcription"),
444            Some("hi there".into())
445        );
446        s.on_event(&SessionEvent::InputTranscription("bye".into()));
447        assert_eq!(
448            s.state.session().get::<String>("last_input_transcription"),
449            Some("bye".into())
450        );
451    }
452
453    #[test]
454    fn session_type_defaults_to_audio_only() {
455        let s = signals();
456        assert_eq!(s.session_type(), SessionType::AudioOnly);
457    }
458
459    #[test]
460    fn mark_video_sent_changes_session_type() {
461        let s = signals();
462        s.on_event(&SessionEvent::Connected);
463        assert_eq!(s.session_type(), SessionType::AudioOnly);
464        s.mark_video_sent();
465        assert_eq!(s.session_type(), SessionType::AudioVideo);
466        assert_eq!(
467            s.state.session().get::<String>("session_type"),
468            Some("audio_video".into())
469        );
470    }
471
472    #[test]
473    fn mark_video_sent_idempotent() {
474        let s = signals();
475        s.on_event(&SessionEvent::Connected);
476        s.mark_video_sent();
477        s.mark_video_sent();
478        assert_eq!(s.session_type(), SessionType::AudioVideo);
479    }
480
481    #[test]
482    fn flush_timing_after_connected() {
483        let s = signals();
484        s.on_event(&SessionEvent::Connected);
485        s.flush_timing();
486        let elapsed: u64 = s.state.session().get("elapsed_ms").unwrap_or(0);
487        assert!(elapsed < 100, "elapsed should be near zero, got {elapsed}");
488        let remaining: u64 = s.state.session().get("remaining_budget_ms").unwrap();
489        let limit = 15 * 60 * 1000u64;
490        assert!(
491            remaining > limit - 1000,
492            "remaining should be near limit, got {remaining}"
493        );
494    }
495
496    #[test]
497    fn flush_timing_respects_video_budget() {
498        let s = signals();
499        s.on_event(&SessionEvent::Connected);
500        s.flush_timing();
501        let remaining_audio: u64 = s.state.session().get("remaining_budget_ms").unwrap();
502        assert!(remaining_audio > 14 * 60 * 1000);
503        s.mark_video_sent();
504        s.flush_timing();
505        let remaining_video: u64 = s.state.session().get("remaining_budget_ms").unwrap();
506        assert!(
507            remaining_video <= 2 * 60 * 1000,
508            "video remaining should be <= 120_000, got {remaining_video}"
509        );
510    }
511
512    #[test]
513    fn latest_resume_handle_initially_none() {
514        let s = signals();
515        assert_eq!(s.latest_resume_handle(), None);
516    }
517
518    #[test]
519    fn latest_resume_handle_updates() {
520        let s = signals();
521        s.on_event(&SessionEvent::SessionResumeUpdate(
522            gemini_genai_rs::session::ResumeInfo {
523                handle: "h1".into(),
524                resumable: true,
525                last_consumed_index: None,
526            },
527        ));
528        assert_eq!(s.latest_resume_handle(), Some("h1".to_string()));
529        s.on_event(&SessionEvent::SessionResumeUpdate(
530            gemini_genai_rs::session::ResumeInfo {
531                handle: "h2".into(),
532                resumable: true,
533                last_consumed_index: Some("5".into()),
534            },
535        ));
536        assert_eq!(s.latest_resume_handle(), Some("h2".to_string()));
537    }
538
539    #[test]
540    fn silence_ms_tracked() {
541        let s = signals();
542        s.on_event(&SessionEvent::Connected);
543        s.flush_timing();
544        let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
545        assert!(silence < 100, "silence should be near zero, got {silence}");
546    }
547
548    #[test]
549    fn audio_data_updates_activity() {
550        let s = signals();
551        s.on_event(&SessionEvent::Connected);
552        s.on_event(&SessionEvent::AudioData(Bytes::from_static(b"pcm")));
553        s.flush_timing();
554        let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
555        assert!(silence < 100);
556    }
557
558    #[test]
559    fn turn_complete_updates_activity() {
560        let s = signals();
561        s.on_event(&SessionEvent::Connected);
562        s.on_event(&SessionEvent::TurnComplete);
563        s.flush_timing();
564        let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
565        assert!(silence < 100);
566    }
567
568    #[test]
569    fn text_complete_updates_activity() {
570        let s = signals();
571        s.on_event(&SessionEvent::Connected);
572        s.on_event(&SessionEvent::TextComplete("done".into()));
573        s.flush_timing();
574        let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
575        assert!(silence < 100);
576    }
577
578    #[test]
579    fn disconnected_clears_connected_and_sets_flag() {
580        let s = signals();
581        s.on_event(&SessionEvent::Connected);
582        assert!(s.is_connected.load(Ordering::Relaxed));
583        s.on_event(&SessionEvent::Disconnected(Some("server closed".into())));
584        assert!(!s.is_connected.load(Ordering::Relaxed));
585        assert_eq!(s.state.session().get::<bool>("disconnected"), Some(true));
586    }
587
588    #[test]
589    fn disconnected_without_reason() {
590        let s = signals();
591        s.on_event(&SessionEvent::Connected);
592        s.on_event(&SessionEvent::Disconnected(None));
593        assert!(!s.is_connected.load(Ordering::Relaxed));
594        assert_eq!(s.state.session().get::<bool>("disconnected"), Some(true));
595    }
596}