gemini_adk_rs/live/
session_signals.rs

1//! Auto-tracked session-level state signals.
2//!
3//! [`SessionSignals`] is called by the telemetry lane on every
4//! [`SessionEvent`] and transparently updates keys under the `session:`
5//! prefix in the shared [`State`].
6//!
7//! Hot-path timestamps use [`AtomicU64`] (nanos since start) instead of
8//! `Mutex<Instant>`, eliminating per-event mutex contention. Derived
9//! timing signals (`silence_ms`, `elapsed_ms`, `remaining_budget_ms`)
10//! are flushed periodically via `flush_timing()` rather than on every event.
11
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::time::Instant;
14
15use gemini_genai_rs::prelude::{SessionEvent, SessionPhase};
16use parking_lot::Mutex;
17
18use crate::state::State;
19
20// ---------------------------------------------------------------------------
21// Public types
22// ---------------------------------------------------------------------------
23
24/// Session type determines the server-side duration limit.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum SessionType {
27    /// Audio-only session (~15 min limit).
28    AudioOnly,
29    /// Audio + video session (~2 min limit).
30    AudioVideo,
31}
32
33// ---------------------------------------------------------------------------
34// SessionSignals
35// ---------------------------------------------------------------------------
36
37/// Tracks session-level signals automatically from events.
38///
39/// Every call to [`on_event`](SessionSignals::on_event) updates the
40/// corresponding keys under `session:` in the shared [`State`], making
41/// them available to instruction templates, watchers, and computed vars.
42///
43/// **Performance**: Timestamps use `AtomicU64` (nanos since session start)
44/// instead of `Mutex<Instant>`. Derived timing signals are flushed
45/// periodically via `flush_timing()` (100ms interval) rather than per-event.
46pub struct SessionSignals {
47    state: State,
48    /// Session start time — used as epoch for all atomic timestamps.
49    start: Instant,
50    /// Nanos since start when connected.
51    connected_at_ns: AtomicU64,
52    /// Whether currently connected.
53    is_connected: AtomicBool,
54    /// Nanos since start of last activity.
55    last_activity_ns: AtomicU64,
56    /// Whether the session includes video input.
57    has_video: AtomicBool,
58    /// Server-sent GoAway timestamp, if received.
59    go_away_at: Mutex<Option<Instant>>,
60    /// Latest resumption handle from server (persisted for reconnection).
61    latest_resume_handle: Mutex<Option<String>>,
62}
63
64impl SessionSignals {
65    /// Create a new `SessionSignals` backed by the given [`State`].
66    pub fn new(state: State) -> Self {
67        Self {
68            state,
69            start: Instant::now(),
70            connected_at_ns: AtomicU64::new(0),
71            is_connected: AtomicBool::new(false),
72            last_activity_ns: AtomicU64::new(0),
73            has_video: AtomicBool::new(false),
74            go_away_at: Mutex::new(None),
75            latest_resume_handle: Mutex::new(None),
76        }
77    }
78
79    /// Process an event — updates state keys and atomic timestamps.
80    ///
81    /// This is the per-event handler. It updates boolean flags, counters,
82    /// and atomic timestamps. **Derived timing** (silence_ms, elapsed_ms,
83    /// remaining_budget_ms) is NOT computed here — call `flush_timing()`
84    /// periodically instead.
85    pub fn on_event(&self, event: &SessionEvent) {
86        match event {
87            SessionEvent::Connected => {
88                let now_ns = self.elapsed_ns();
89                self.connected_at_ns.store(now_ns, Ordering::Relaxed);
90                self.is_connected.store(true, Ordering::Relaxed);
91                self.last_activity_ns.store(now_ns, Ordering::Relaxed);
92                let _ = self.state.session().set("connected_at_ms", 0u64);
93                let _ = self.state.session().set("interrupt_count", 0u64);
94                let _ = self.state.session().set("error_count", 0u64);
95                let _ = self.state.session().set("is_user_speaking", false);
96                let _ = self.state.session().set("is_model_speaking", false);
97                let _ = self.state.session().set("go_away_received", false);
98                let _ = self.state.session().set("resumable", false);
99                let _ = self.state.session().set("session_type", "audio_only");
100            }
101
102            SessionEvent::VoiceActivityStart => {
103                let _ = self.state.session().set("is_user_speaking", true);
104                self.touch_activity();
105            }
106
107            SessionEvent::VoiceActivityEnd => {
108                let _ = self.state.session().set("is_user_speaking", false);
109                self.touch_activity();
110            }
111
112            SessionEvent::Interrupted => {
113                let count: u64 = self.state.session().get("interrupt_count").unwrap_or(0);
114                let _ = self.state.session().set("interrupt_count", count + 1);
115                self.touch_activity();
116            }
117
118            SessionEvent::Error(msg) => {
119                let count: u64 = self.state.session().get("error_count").unwrap_or(0);
120                let _ = self.state.session().set("error_count", count + 1);
121                let _ = self.state.session().set("last_error", msg.clone());
122            }
123
124            SessionEvent::PhaseChanged(phase) => {
125                let _ = self
126                    .state
127                    .session()
128                    .set("is_model_speaking", *phase == SessionPhase::ModelSpeaking);
129                let _ = self.state.session().set("phase", phase.to_string());
130                self.touch_activity();
131            }
132
133            SessionEvent::GoAway(time_left) => {
134                let _ = self.state.session().set("go_away_received", true);
135                if let Some(ref tl) = time_left {
136                    let _ = self.state.session().set("go_away_time_left", tl.clone());
137                    if let Ok(secs) = tl.trim_end_matches('s').parse::<u64>() {
138                        let deadline = Instant::now() + std::time::Duration::from_secs(secs);
139                        *self.go_away_at.lock() = Some(deadline);
140                        let _ = self
141                            .state
142                            .session()
143                            .set("go_away_time_left_ms", secs * 1000);
144                    }
145                }
146            }
147
148            SessionEvent::SessionResumeUpdate(info) => {
149                *self.latest_resume_handle.lock() = Some(info.handle.clone());
150                let _ = self.state.session().set("resumable", info.resumable);
151                if let Some(ref idx) = info.last_consumed_index {
152                    let _ = self
153                        .state
154                        .session()
155                        .set("last_consumed_client_index", idx.clone());
156                }
157            }
158
159            SessionEvent::Usage(usage) => {
160                if let Some(total) = usage.total_token_count {
161                    let _ = self.state.session().set("total_token_count", total);
162                }
163                if let Some(prompt) = usage.prompt_token_count {
164                    let _ = self.state.session().set("prompt_token_count", prompt);
165                }
166                if let Some(response) = usage.response_token_count {
167                    let _ = self.state.session().set("response_token_count", response);
168                }
169                if let Some(cached) = usage.cached_content_token_count {
170                    let _ = self
171                        .state
172                        .session()
173                        .set("cached_content_token_count", cached);
174                }
175                if let Some(thoughts) = usage.thoughts_token_count {
176                    let _ = self.state.session().set("thoughts_token_count", thoughts);
177                }
178            }
179
180            SessionEvent::GenerationComplete => {
181                // No-op for signals — generation complete is handled by control lane
182            }
183
184            SessionEvent::InputTranscription(text) => {
185                let _ = self
186                    .state
187                    .session()
188                    .set("last_input_transcription", text.clone());
189                self.touch_activity();
190            }
191
192            SessionEvent::OutputTranscription(text) => {
193                let _ = self
194                    .state
195                    .session()
196                    .set("last_output_transcription", text.clone());
197                self.touch_activity();
198            }
199
200            SessionEvent::AudioData(_)
201            | SessionEvent::TextDelta(_)
202            | SessionEvent::TextComplete(_) => {
203                // High-frequency events: only touch the atomic timestamp.
204                // No DashMap writes, no mutex locks.
205                self.touch_activity();
206            }
207
208            SessionEvent::TurnComplete => {
209                self.touch_activity();
210            }
211
212            SessionEvent::Disconnected(_reason) => {
213                self.is_connected.store(false, Ordering::Relaxed);
214                let _ = self.state.session().set("disconnected", true);
215            }
216
217            _ => {}
218        }
219    }
220
221    /// Flush derived timing signals to state.
222    ///
223    /// Call this periodically (e.g., every 100ms) from the telemetry lane.
224    /// Computes `silence_ms`, `elapsed_ms`, and `remaining_budget_ms` from
225    /// atomic timestamps without any mutex locks.
226    pub fn flush_timing(&self) {
227        let last_activity = self.last_activity_ns.load(Ordering::Relaxed);
228        if last_activity > 0 {
229            let now_ns = self.elapsed_ns();
230            let silence_ms = now_ns.saturating_sub(last_activity) / 1_000_000;
231            let _ = self.state.session().set("silence_ms", silence_ms);
232        }
233
234        if self.is_connected.load(Ordering::Relaxed) {
235            let connected_ns = self.connected_at_ns.load(Ordering::Relaxed);
236            let now_ns = self.elapsed_ns();
237            let elapsed_ms = now_ns.saturating_sub(connected_ns) / 1_000_000;
238            let _ = self.state.session().set("elapsed_ms", elapsed_ms);
239
240            let limit_ms: u64 = match self.session_type() {
241                SessionType::AudioOnly => 15 * 60 * 1000,
242                SessionType::AudioVideo => 2 * 60 * 1000,
243            };
244            let remaining = limit_ms.saturating_sub(elapsed_ms);
245            let _ = self.state.session().set("remaining_budget_ms", remaining);
246        }
247    }
248
249    #[inline]
250    fn touch_activity(&self) {
251        self.last_activity_ns
252            .store(self.elapsed_ns(), Ordering::Relaxed);
253    }
254
255    #[inline]
256    fn elapsed_ns(&self) -> u64 {
257        self.start.elapsed().as_nanos() as u64
258    }
259
260    /// Returns the current session type based on whether video has been sent.
261    pub fn session_type(&self) -> SessionType {
262        if self.has_video.load(Ordering::Relaxed) {
263            SessionType::AudioVideo
264        } else {
265            SessionType::AudioOnly
266        }
267    }
268
269    /// Returns the latest resumption handle for reconnection.
270    pub fn latest_resume_handle(&self) -> Option<String> {
271        self.latest_resume_handle.lock().clone()
272    }
273
274    /// Mark that video has been sent (changes session type to `AudioVideo`).
275    pub fn mark_video_sent(&self) {
276        if !self.has_video.swap(true, Ordering::Relaxed) {
277            let _ = self.state.session().set("session_type", "audio_video");
278        }
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use bytes::Bytes;
286    use gemini_genai_rs::prelude::SessionEvent;
287
288    fn signals() -> SessionSignals {
289        SessionSignals::new(State::new())
290    }
291
292    #[test]
293    fn connected_initializes_state() {
294        let s = signals();
295        s.on_event(&SessionEvent::Connected);
296
297        assert_eq!(s.state.session().get::<u64>("connected_at_ms"), Some(0));
298        assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(0));
299        assert_eq!(s.state.session().get::<u64>("error_count"), Some(0));
300        assert_eq!(
301            s.state.session().get::<bool>("is_user_speaking"),
302            Some(false)
303        );
304        assert_eq!(
305            s.state.session().get::<bool>("is_model_speaking"),
306            Some(false)
307        );
308        assert_eq!(
309            s.state.session().get::<bool>("go_away_received"),
310            Some(false)
311        );
312        assert_eq!(s.state.session().get::<bool>("resumable"), Some(false));
313        assert_eq!(
314            s.state.session().get::<String>("session_type"),
315            Some("audio_only".to_string())
316        );
317        assert!(s.is_connected.load(Ordering::Relaxed));
318    }
319
320    #[test]
321    fn voice_activity_toggles_user_speaking() {
322        let s = signals();
323        s.on_event(&SessionEvent::Connected);
324        s.on_event(&SessionEvent::VoiceActivityStart);
325        assert_eq!(
326            s.state.session().get::<bool>("is_user_speaking"),
327            Some(true)
328        );
329        s.on_event(&SessionEvent::VoiceActivityEnd);
330        assert_eq!(
331            s.state.session().get::<bool>("is_user_speaking"),
332            Some(false)
333        );
334    }
335
336    #[test]
337    fn interrupted_increments_count() {
338        let s = signals();
339        s.on_event(&SessionEvent::Connected);
340        s.on_event(&SessionEvent::Interrupted);
341        assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(1));
342        s.on_event(&SessionEvent::Interrupted);
343        assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(2));
344        s.on_event(&SessionEvent::Interrupted);
345        assert_eq!(s.state.session().get::<u64>("interrupt_count"), Some(3));
346    }
347
348    #[test]
349    fn error_increments_count() {
350        let s = signals();
351        s.on_event(&SessionEvent::Connected);
352        s.on_event(&SessionEvent::Error("oops".into()));
353        assert_eq!(s.state.session().get::<u64>("error_count"), Some(1));
354        assert_eq!(
355            s.state.session().get::<String>("last_error"),
356            Some("oops".into())
357        );
358        s.on_event(&SessionEvent::Error("oops2".into()));
359        assert_eq!(s.state.session().get::<u64>("error_count"), Some(2));
360        assert_eq!(
361            s.state.session().get::<String>("last_error"),
362            Some("oops2".into())
363        );
364    }
365
366    #[test]
367    fn phase_changed_sets_model_speaking() {
368        let s = signals();
369        s.on_event(&SessionEvent::Connected);
370        s.on_event(&SessionEvent::PhaseChanged(SessionPhase::ModelSpeaking));
371        assert_eq!(
372            s.state.session().get::<bool>("is_model_speaking"),
373            Some(true)
374        );
375        assert_eq!(
376            s.state.session().get::<String>("phase"),
377            Some("ModelSpeaking".into())
378        );
379        s.on_event(&SessionEvent::PhaseChanged(SessionPhase::Active));
380        assert_eq!(
381            s.state.session().get::<bool>("is_model_speaking"),
382            Some(false)
383        );
384        assert_eq!(
385            s.state.session().get::<String>("phase"),
386            Some("Active".into())
387        );
388    }
389
390    #[test]
391    fn go_away_sets_state() {
392        let s = signals();
393        s.on_event(&SessionEvent::Connected);
394        s.on_event(&SessionEvent::GoAway(Some("60s".into())));
395        assert_eq!(
396            s.state.session().get::<bool>("go_away_received"),
397            Some(true)
398        );
399        assert_eq!(
400            s.state.session().get::<String>("go_away_time_left"),
401            Some("60s".into())
402        );
403        assert_eq!(
404            s.state.session().get::<u64>("go_away_time_left_ms"),
405            Some(60_000)
406        );
407        assert!(s.go_away_at.lock().is_some());
408    }
409
410    #[test]
411    fn go_away_without_time_left() {
412        let s = signals();
413        s.on_event(&SessionEvent::Connected);
414        s.on_event(&SessionEvent::GoAway(None));
415        assert_eq!(
416            s.state.session().get::<bool>("go_away_received"),
417            Some(true)
418        );
419        assert_eq!(s.state.session().get::<String>("go_away_time_left"), None);
420        assert!(s.go_away_at.lock().is_none());
421    }
422
423    #[test]
424    fn session_resume_handle_stored() {
425        let s = signals();
426        s.on_event(&SessionEvent::Connected);
427        s.on_event(&SessionEvent::SessionResumeUpdate(
428            gemini_genai_rs::session::ResumeInfo {
429                handle: "handle-abc".into(),
430                resumable: true,
431                last_consumed_index: None,
432            },
433        ));
434        assert_eq!(s.state.session().get::<bool>("resumable"), Some(true));
435        assert_eq!(s.latest_resume_handle(), Some("handle-abc".to_string()));
436    }
437
438    #[test]
439    fn transcription_stores_last() {
440        let s = signals();
441        s.on_event(&SessionEvent::Connected);
442        s.on_event(&SessionEvent::InputTranscription("hello".into()));
443        assert_eq!(
444            s.state.session().get::<String>("last_input_transcription"),
445            Some("hello".into())
446        );
447        s.on_event(&SessionEvent::OutputTranscription("hi there".into()));
448        assert_eq!(
449            s.state.session().get::<String>("last_output_transcription"),
450            Some("hi there".into())
451        );
452        s.on_event(&SessionEvent::InputTranscription("bye".into()));
453        assert_eq!(
454            s.state.session().get::<String>("last_input_transcription"),
455            Some("bye".into())
456        );
457    }
458
459    #[test]
460    fn session_type_defaults_to_audio_only() {
461        let s = signals();
462        assert_eq!(s.session_type(), SessionType::AudioOnly);
463    }
464
465    #[test]
466    fn mark_video_sent_changes_session_type() {
467        let s = signals();
468        s.on_event(&SessionEvent::Connected);
469        assert_eq!(s.session_type(), SessionType::AudioOnly);
470        s.mark_video_sent();
471        assert_eq!(s.session_type(), SessionType::AudioVideo);
472        assert_eq!(
473            s.state.session().get::<String>("session_type"),
474            Some("audio_video".into())
475        );
476    }
477
478    #[test]
479    fn mark_video_sent_idempotent() {
480        let s = signals();
481        s.on_event(&SessionEvent::Connected);
482        s.mark_video_sent();
483        s.mark_video_sent();
484        assert_eq!(s.session_type(), SessionType::AudioVideo);
485    }
486
487    #[test]
488    fn flush_timing_after_connected() {
489        let s = signals();
490        s.on_event(&SessionEvent::Connected);
491        s.flush_timing();
492        let elapsed: u64 = s.state.session().get("elapsed_ms").unwrap_or(0);
493        assert!(elapsed < 100, "elapsed should be near zero, got {elapsed}");
494        let remaining: u64 = s.state.session().get("remaining_budget_ms").unwrap();
495        let limit = 15 * 60 * 1000u64;
496        assert!(
497            remaining > limit - 1000,
498            "remaining should be near limit, got {remaining}"
499        );
500    }
501
502    #[test]
503    fn flush_timing_respects_video_budget() {
504        let s = signals();
505        s.on_event(&SessionEvent::Connected);
506        s.flush_timing();
507        let remaining_audio: u64 = s.state.session().get("remaining_budget_ms").unwrap();
508        assert!(remaining_audio > 14 * 60 * 1000);
509        s.mark_video_sent();
510        s.flush_timing();
511        let remaining_video: u64 = s.state.session().get("remaining_budget_ms").unwrap();
512        assert!(
513            remaining_video <= 2 * 60 * 1000,
514            "video remaining should be <= 120_000, got {remaining_video}"
515        );
516    }
517
518    #[test]
519    fn latest_resume_handle_initially_none() {
520        let s = signals();
521        assert_eq!(s.latest_resume_handle(), None);
522    }
523
524    #[test]
525    fn latest_resume_handle_updates() {
526        let s = signals();
527        s.on_event(&SessionEvent::SessionResumeUpdate(
528            gemini_genai_rs::session::ResumeInfo {
529                handle: "h1".into(),
530                resumable: true,
531                last_consumed_index: None,
532            },
533        ));
534        assert_eq!(s.latest_resume_handle(), Some("h1".to_string()));
535        s.on_event(&SessionEvent::SessionResumeUpdate(
536            gemini_genai_rs::session::ResumeInfo {
537                handle: "h2".into(),
538                resumable: true,
539                last_consumed_index: Some("5".into()),
540            },
541        ));
542        assert_eq!(s.latest_resume_handle(), Some("h2".to_string()));
543    }
544
545    #[test]
546    fn silence_ms_tracked() {
547        let s = signals();
548        s.on_event(&SessionEvent::Connected);
549        s.flush_timing();
550        let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
551        assert!(silence < 100, "silence should be near zero, got {silence}");
552    }
553
554    #[test]
555    fn audio_data_updates_activity() {
556        let s = signals();
557        s.on_event(&SessionEvent::Connected);
558        s.on_event(&SessionEvent::AudioData(Bytes::from_static(b"pcm")));
559        s.flush_timing();
560        let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
561        assert!(silence < 100);
562    }
563
564    #[test]
565    fn turn_complete_updates_activity() {
566        let s = signals();
567        s.on_event(&SessionEvent::Connected);
568        s.on_event(&SessionEvent::TurnComplete);
569        s.flush_timing();
570        let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
571        assert!(silence < 100);
572    }
573
574    #[test]
575    fn text_complete_updates_activity() {
576        let s = signals();
577        s.on_event(&SessionEvent::Connected);
578        s.on_event(&SessionEvent::TextComplete("done".into()));
579        s.flush_timing();
580        let silence: u64 = s.state.session().get("silence_ms").unwrap_or(u64::MAX);
581        assert!(silence < 100);
582    }
583
584    #[test]
585    fn disconnected_clears_connected_and_sets_flag() {
586        let s = signals();
587        s.on_event(&SessionEvent::Connected);
588        assert!(s.is_connected.load(Ordering::Relaxed));
589        s.on_event(&SessionEvent::Disconnected(Some("server closed".into())));
590        assert!(!s.is_connected.load(Ordering::Relaxed));
591        assert_eq!(s.state.session().get::<bool>("disconnected"), Some(true));
592    }
593
594    #[test]
595    fn disconnected_without_reason() {
596        let s = signals();
597        s.on_event(&SessionEvent::Connected);
598        s.on_event(&SessionEvent::Disconnected(None));
599        assert!(!s.is_connected.load(Ordering::Relaxed));
600        assert_eq!(s.state.session().get::<bool>("disconnected"), Some(true));
601    }
602}