1use std::collections::VecDeque;
13use std::sync::atomic::{AtomicU32, Ordering};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use gemini_genai_rs::session::{SessionEvent, SessionWriter};
18
19use super::BoxFuture;
20use crate::state::State;
21
22pub trait PatternDetector: Send + Sync {
30 fn check(&self, state: &State, event: Option<&SessionEvent>, now: Instant) -> bool;
36
37 fn reset(&self);
39
40 fn needs_timer(&self) -> bool {
45 false
46 }
47}
48
49pub struct TemporalPattern {
53 pub name: String,
55 pub detector: Box<dyn PatternDetector>,
57 pub action: super::phase::PhaseHook,
60 pub cooldown: Option<Duration>,
62 last_triggered: parking_lot::Mutex<Option<Instant>>,
64}
65
66impl TemporalPattern {
67 pub fn new(
69 name: impl Into<String>,
70 detector: Box<dyn PatternDetector>,
71 action: super::phase::PhaseHook,
72 cooldown: Option<Duration>,
73 ) -> Self {
74 Self {
75 name: name.into(),
76 detector,
77 action,
78 cooldown,
79 last_triggered: parking_lot::Mutex::new(None),
80 }
81 }
82
83 fn try_fire(
85 &self,
86 state: &State,
87 event: Option<&SessionEvent>,
88 writer: &Arc<dyn SessionWriter>,
89 now: Instant,
90 ) -> Option<BoxFuture<()>> {
91 if !self.detector.check(state, event, now) {
92 return None;
93 }
94
95 let mut last = self.last_triggered.lock();
97 if let (Some(cooldown), Some(prev)) = (self.cooldown, *last) {
98 if now.duration_since(prev) < cooldown {
99 return None;
100 }
101 }
102
103 *last = Some(now);
104
105 let s = state.clone();
106 let w = writer.clone();
107 Some((self.action)(s, w))
108 }
109}
110
111pub struct TemporalRegistry {
115 patterns: Vec<TemporalPattern>,
116}
117
118impl Default for TemporalRegistry {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl TemporalRegistry {
125 pub fn new() -> Self {
127 Self {
128 patterns: Vec::new(),
129 }
130 }
131
132 pub fn add(&mut self, pattern: TemporalPattern) {
134 self.patterns.push(pattern);
135 }
136
137 pub fn check_all(
142 &self,
143 state: &State,
144 event: Option<&SessionEvent>,
145 writer: &Arc<dyn SessionWriter>,
146 ) -> Vec<BoxFuture<()>> {
147 let now = Instant::now();
148 self.patterns
149 .iter()
150 .filter_map(|p| p.try_fire(state, event, writer, now))
151 .collect()
152 }
153
154 pub fn needs_timer(&self) -> bool {
157 self.patterns.iter().any(|p| p.detector.needs_timer())
158 }
159}
160
161pub struct SustainedDetector {
174 condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
175 duration: Duration,
176 became_true_at: parking_lot::Mutex<Option<Instant>>,
177}
178
179impl SustainedDetector {
180 pub fn new(condition: Arc<dyn Fn(&State) -> bool + Send + Sync>, duration: Duration) -> Self {
185 Self {
186 condition,
187 duration,
188 became_true_at: parking_lot::Mutex::new(None),
189 }
190 }
191}
192
193impl PatternDetector for SustainedDetector {
194 fn check(&self, state: &State, _event: Option<&SessionEvent>, now: Instant) -> bool {
195 if (self.condition)(state) {
196 let mut guard = self.became_true_at.lock();
197 match *guard {
198 None => {
199 *guard = Some(now);
200 false
201 }
202 Some(t) => now.duration_since(t) >= self.duration,
203 }
204 } else {
205 *self.became_true_at.lock() = None;
206 false
207 }
208 }
209
210 fn reset(&self) {
211 *self.became_true_at.lock() = None;
212 }
213
214 fn needs_timer(&self) -> bool {
215 true
216 }
217}
218
219pub struct RateDetector {
229 filter: Arc<dyn Fn(&SessionEvent) -> bool + Send + Sync>,
230 count: u32,
231 window: Duration,
232 timestamps: parking_lot::Mutex<VecDeque<Instant>>,
233}
234
235impl RateDetector {
236 pub fn new(
242 filter: Arc<dyn Fn(&SessionEvent) -> bool + Send + Sync>,
243 count: u32,
244 window: Duration,
245 ) -> Self {
246 Self {
247 filter,
248 count,
249 window,
250 timestamps: parking_lot::Mutex::new(VecDeque::new()),
251 }
252 }
253}
254
255impl PatternDetector for RateDetector {
256 fn check(&self, _state: &State, event: Option<&SessionEvent>, now: Instant) -> bool {
257 let mut ts = self.timestamps.lock();
258
259 if let Some(evt) = event {
261 if (self.filter)(evt) {
262 ts.push_back(now);
263 }
264 }
265
266 while let Some(&front) = ts.front() {
268 if now.duration_since(front) > self.window {
269 ts.pop_front();
270 } else {
271 break;
272 }
273 }
274
275 ts.len() as u32 >= self.count
276 }
277
278 fn reset(&self) {
279 self.timestamps.lock().clear();
280 }
281
282 }
284
285pub struct TurnCountDetector {
293 condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
294 required: u32,
295 consecutive: AtomicU32,
296}
297
298impl TurnCountDetector {
299 pub fn new(condition: Arc<dyn Fn(&State) -> bool + Send + Sync>, required: u32) -> Self {
304 Self {
305 condition,
306 required,
307 consecutive: AtomicU32::new(0),
308 }
309 }
310}
311
312impl PatternDetector for TurnCountDetector {
313 fn check(&self, state: &State, _event: Option<&SessionEvent>, _now: Instant) -> bool {
314 if (self.condition)(state) {
315 let prev = self.consecutive.fetch_add(1, Ordering::SeqCst);
316 prev + 1 >= self.required
317 } else {
318 self.consecutive.store(0, Ordering::SeqCst);
319 false
320 }
321 }
322
323 fn reset(&self) {
324 self.consecutive.store(0, Ordering::SeqCst);
325 }
326}
327
328pub struct ConsecutiveFailureDetector {
336 tool_name: String,
337 threshold: u32,
338 consecutive: AtomicU32,
339}
340
341impl ConsecutiveFailureDetector {
342 pub fn new(tool_name: impl Into<String>, threshold: u32) -> Self {
347 Self {
348 tool_name: tool_name.into(),
349 threshold,
350 consecutive: AtomicU32::new(0),
351 }
352 }
353}
354
355impl PatternDetector for ConsecutiveFailureDetector {
356 fn check(&self, state: &State, _event: Option<&SessionEvent>, _now: Instant) -> bool {
357 let key = format!("bg:{}_failed", self.tool_name);
358 let failed: bool = state.get(&key).unwrap_or(false);
359
360 if failed {
361 let prev = self.consecutive.fetch_add(1, Ordering::SeqCst);
362 prev + 1 >= self.threshold
363 } else {
364 self.consecutive.store(0, Ordering::SeqCst);
365 false
366 }
367 }
368
369 fn reset(&self) {
370 self.consecutive.store(0, Ordering::SeqCst);
371 }
372}
373
374#[cfg(test)]
377mod tests {
378 use super::*;
379 use std::sync::atomic::{AtomicU32, Ordering};
380
381 struct MockWriter;
383
384 #[async_trait::async_trait]
385 impl SessionWriter for MockWriter {
386 async fn send_audio(
387 &self,
388 _: Vec<u8>,
389 ) -> Result<(), gemini_genai_rs::session::SessionError> {
390 Ok(())
391 }
392 async fn send_text(&self, _: String) -> Result<(), gemini_genai_rs::session::SessionError> {
393 Ok(())
394 }
395 async fn send_tool_response(
396 &self,
397 _: Vec<gemini_genai_rs::protocol::FunctionResponse>,
398 ) -> Result<(), gemini_genai_rs::session::SessionError> {
399 Ok(())
400 }
401 async fn send_client_content(
402 &self,
403 _: Vec<gemini_genai_rs::protocol::Content>,
404 _: bool,
405 ) -> Result<(), gemini_genai_rs::session::SessionError> {
406 Ok(())
407 }
408 async fn send_video(
409 &self,
410 _: Vec<u8>,
411 ) -> Result<(), gemini_genai_rs::session::SessionError> {
412 Ok(())
413 }
414 async fn update_instruction(
415 &self,
416 _: String,
417 ) -> Result<(), gemini_genai_rs::session::SessionError> {
418 Ok(())
419 }
420 async fn signal_activity_start(
421 &self,
422 ) -> Result<(), gemini_genai_rs::session::SessionError> {
423 Ok(())
424 }
425 async fn signal_activity_end(&self) -> Result<(), gemini_genai_rs::session::SessionError> {
426 Ok(())
427 }
428 async fn disconnect(&self) -> Result<(), gemini_genai_rs::session::SessionError> {
429 Ok(())
430 }
431 }
432
433 fn mock_writer() -> Arc<dyn SessionWriter> {
434 Arc::new(MockWriter)
435 }
436
437 fn counting_action(counter: Arc<AtomicU32>) -> crate::live::phase::PhaseHook {
439 Arc::new(move |_state, _writer| {
440 let c = counter.clone();
441 Box::pin(async move {
442 c.fetch_add(1, Ordering::SeqCst);
443 })
444 })
445 }
446
447 #[test]
450 fn sustained_fires_after_duration() {
451 let state = State::new();
452 let _ = state.set("hot", true);
453
454 let detector = SustainedDetector::new(
455 Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
456 Duration::from_secs(5),
457 );
458
459 let t0 = Instant::now();
460
461 assert!(!detector.check(&state, None, t0));
463
464 assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
466
467 assert!(detector.check(&state, None, t0 + Duration::from_secs(5)));
469
470 assert!(detector.check(&state, None, t0 + Duration::from_secs(6)));
472 }
473
474 #[test]
477 fn sustained_resets_on_false() {
478 let state = State::new();
479 let _ = state.set("hot", true);
480
481 let detector = SustainedDetector::new(
482 Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
483 Duration::from_secs(5),
484 );
485
486 let t0 = Instant::now();
487
488 assert!(!detector.check(&state, None, t0));
490
491 let _ = state.set("hot", false);
493 assert!(!detector.check(&state, None, t0 + Duration::from_secs(2)));
494
495 let _ = state.set("hot", true);
497 assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
498
499 assert!(!detector.check(&state, None, t0 + Duration::from_secs(7)));
501
502 assert!(detector.check(&state, None, t0 + Duration::from_secs(8)));
504 }
505
506 #[test]
509 fn sustained_reset_clears_state() {
510 let state = State::new();
511 let _ = state.set("hot", true);
512
513 let detector = SustainedDetector::new(
514 Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
515 Duration::from_secs(5),
516 );
517
518 let t0 = Instant::now();
519
520 assert!(!detector.check(&state, None, t0));
522
523 detector.reset();
525
526 assert!(!detector.check(&state, None, t0 + Duration::from_secs(4)));
528 assert!(detector.check(&state, None, t0 + Duration::from_secs(9)));
529 }
530
531 #[test]
534 fn rate_fires_when_count_reached() {
535 let state = State::new();
536 let detector = RateDetector::new(
537 Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
538 3,
539 Duration::from_secs(10),
540 );
541
542 let t0 = Instant::now();
543 let event = SessionEvent::TurnComplete;
544
545 assert!(!detector.check(&state, Some(&event), t0));
546 assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
547 assert!(detector.check(&state, Some(&event), t0 + Duration::from_secs(2)));
549 }
550
551 #[test]
554 fn rate_does_not_fire_when_events_outside_window() {
555 let state = State::new();
556 let detector = RateDetector::new(
557 Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
558 3,
559 Duration::from_secs(5),
560 );
561
562 let t0 = Instant::now();
563 let event = SessionEvent::TurnComplete;
564
565 assert!(!detector.check(&state, Some(&event), t0));
567 assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
568
569 assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(10)));
571 }
572
573 #[test]
576 fn rate_filter_rejects_events() {
577 let state = State::new();
578 let detector = RateDetector::new(
579 Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
580 2,
581 Duration::from_secs(10),
582 );
583
584 let t0 = Instant::now();
585
586 let text_event = SessionEvent::TextDelta("hello".to_string());
588 assert!(!detector.check(&state, Some(&text_event), t0));
589 assert!(!detector.check(&state, Some(&text_event), t0 + Duration::from_secs(1)));
590 assert!(!detector.check(&state, Some(&text_event), t0 + Duration::from_secs(2)));
591
592 assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
594 }
595
596 #[test]
599 fn turn_count_fires_after_n_consecutive() {
600 let state = State::new();
601 let _ = state.set("confused", true);
602
603 let detector = TurnCountDetector::new(
604 Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
605 3,
606 );
607
608 let t0 = Instant::now();
609
610 assert!(!detector.check(&state, None, t0));
611 assert!(!detector.check(&state, None, t0));
612 assert!(detector.check(&state, None, t0));
614 }
615
616 #[test]
619 fn turn_count_resets_on_false() {
620 let state = State::new();
621 let _ = state.set("confused", true);
622
623 let detector = TurnCountDetector::new(
624 Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
625 3,
626 );
627
628 let t0 = Instant::now();
629
630 assert!(!detector.check(&state, None, t0));
631 assert!(!detector.check(&state, None, t0));
632
633 let _ = state.set("confused", false);
635 assert!(!detector.check(&state, None, t0));
636
637 let _ = state.set("confused", true);
639 assert!(!detector.check(&state, None, t0));
640 assert!(!detector.check(&state, None, t0));
641 assert!(detector.check(&state, None, t0));
642 }
643
644 #[test]
647 fn consecutive_failure_fires_after_threshold() {
648 let state = State::new();
649 let _ = state.set("bg:search_failed", true);
650
651 let detector = ConsecutiveFailureDetector::new("search", 3);
652
653 let t0 = Instant::now();
654
655 assert!(!detector.check(&state, None, t0));
656 assert!(!detector.check(&state, None, t0));
657 assert!(detector.check(&state, None, t0));
659 }
660
661 #[test]
664 fn consecutive_failure_resets_on_success() {
665 let state = State::new();
666 let _ = state.set("bg:search_failed", true);
667
668 let detector = ConsecutiveFailureDetector::new("search", 3);
669
670 let t0 = Instant::now();
671
672 assert!(!detector.check(&state, None, t0));
673 assert!(!detector.check(&state, None, t0));
674
675 let _ = state.set("bg:search_failed", false);
677 assert!(!detector.check(&state, None, t0));
678
679 let _ = state.set("bg:search_failed", true);
681 assert!(!detector.check(&state, None, t0));
682 assert!(!detector.check(&state, None, t0));
683 assert!(detector.check(&state, None, t0));
684 }
685
686 #[tokio::test]
689 async fn pattern_cooldown_prevents_rapid_refiring() {
690 let counter = Arc::new(AtomicU32::new(0));
691 let state = State::new();
692 let _ = state.set("active", true);
693 let writer = mock_writer();
694
695 let pattern = TemporalPattern::new(
696 "test-cooldown",
697 Box::new(SustainedDetector::new(
698 Arc::new(|s: &State| s.get::<bool>("active").unwrap_or(false)),
699 Duration::from_secs(0), )),
701 counting_action(counter.clone()),
702 Some(Duration::from_secs(10)), );
704
705 let t0 = Instant::now();
706
707 assert!(pattern.try_fire(&state, None, &writer, t0).is_none());
710
711 let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_millis(1));
713 assert!(fut.is_some());
714 fut.unwrap().await;
715 assert_eq!(counter.load(Ordering::SeqCst), 1);
716
717 assert!(pattern
719 .try_fire(&state, None, &writer, t0 + Duration::from_millis(2))
720 .is_none());
721
722 let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_secs(11));
724 assert!(fut.is_some());
725 fut.unwrap().await;
726 assert_eq!(counter.load(Ordering::SeqCst), 2);
727 }
728
729 #[tokio::test]
732 async fn registry_check_all_returns_actions() {
733 let counter = Arc::new(AtomicU32::new(0));
734 let state = State::new();
735 let _ = state.set("confused", true);
736 let writer = mock_writer();
737
738 let mut registry = TemporalRegistry::new();
739
740 registry.add(TemporalPattern::new(
742 "confusion",
743 Box::new(TurnCountDetector::new(
744 Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
745 1,
746 )),
747 counting_action(counter.clone()),
748 None,
749 ));
750
751 let actions = registry.check_all(&state, None, &writer);
752 assert_eq!(actions.len(), 1);
753
754 for fut in actions {
755 fut.await;
756 }
757 assert_eq!(counter.load(Ordering::SeqCst), 1);
758 }
759
760 #[test]
763 fn needs_timer_true_with_sustained_detector() {
764 let counter = Arc::new(AtomicU32::new(0));
765 let mut registry = TemporalRegistry::new();
766
767 registry.add(TemporalPattern::new(
768 "sustained",
769 Box::new(SustainedDetector::new(
770 Arc::new(|_: &State| true),
771 Duration::from_secs(5),
772 )),
773 counting_action(counter),
774 None,
775 ));
776
777 assert!(registry.needs_timer());
778 }
779
780 #[test]
783 fn needs_timer_false_without_sustained_detector() {
784 let counter = Arc::new(AtomicU32::new(0));
785 let mut registry = TemporalRegistry::new();
786
787 registry.add(TemporalPattern::new(
788 "turn-count",
789 Box::new(TurnCountDetector::new(Arc::new(|_: &State| true), 3)),
790 counting_action(counter.clone()),
791 None,
792 ));
793
794 registry.add(TemporalPattern::new(
795 "rate",
796 Box::new(RateDetector::new(
797 Arc::new(|_: &SessionEvent| true),
798 5,
799 Duration::from_secs(10),
800 )),
801 counting_action(counter),
802 None,
803 ));
804
805 assert!(!registry.needs_timer());
806 }
807
808 #[test]
811 fn default_creates_empty_registry() {
812 let registry = TemporalRegistry::default();
813 assert!(!registry.needs_timer());
814 }
815
816 #[test]
819 fn rate_reset_clears_timestamps() {
820 let state = State::new();
821 let detector = RateDetector::new(
822 Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
823 2,
824 Duration::from_secs(10),
825 );
826
827 let t0 = Instant::now();
828 let event = SessionEvent::TurnComplete;
829
830 assert!(!detector.check(&state, Some(&event), t0));
831 detector.reset();
832 assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
834 assert!(detector.check(&state, Some(&event), t0 + Duration::from_secs(2)));
836 }
837
838 #[test]
841 fn turn_count_reset_clears_counter() {
842 let state = State::new();
843 let _ = state.set("confused", true);
844
845 let detector = TurnCountDetector::new(
846 Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
847 3,
848 );
849
850 let t0 = Instant::now();
851
852 assert!(!detector.check(&state, None, t0));
853 assert!(!detector.check(&state, None, t0));
854 detector.reset();
857
858 assert!(!detector.check(&state, None, t0));
860 assert!(!detector.check(&state, None, t0));
861 assert!(detector.check(&state, None, t0));
862 }
863
864 #[test]
867 fn consecutive_failure_reset_clears_counter() {
868 let state = State::new();
869 let _ = state.set("bg:search_failed", true);
870
871 let detector = ConsecutiveFailureDetector::new("search", 3);
872 let t0 = Instant::now();
873
874 assert!(!detector.check(&state, None, t0));
875 assert!(!detector.check(&state, None, t0));
876
877 detector.reset();
878
879 assert!(!detector.check(&state, None, t0));
880 assert!(!detector.check(&state, None, t0));
881 assert!(detector.check(&state, None, t0));
882 }
883
884 #[test]
887 fn sustained_detector_needs_timer() {
888 let detector = SustainedDetector::new(Arc::new(|_: &State| true), Duration::from_secs(5));
889 assert!(detector.needs_timer());
890 }
891
892 #[test]
895 fn rate_detector_does_not_need_timer() {
896 let detector = RateDetector::new(
897 Arc::new(|_: &SessionEvent| true),
898 5,
899 Duration::from_secs(10),
900 );
901 assert!(!detector.needs_timer());
902 }
903
904 #[test]
907 fn turn_count_detector_does_not_need_timer() {
908 let detector = TurnCountDetector::new(Arc::new(|_: &State| true), 3);
909 assert!(!detector.needs_timer());
910 }
911
912 #[tokio::test]
915 async fn pattern_without_cooldown_fires_every_time() {
916 let counter = Arc::new(AtomicU32::new(0));
917 let state = State::new();
918 let _ = state.set("active", true);
919 let writer = mock_writer();
920
921 let pattern = TemporalPattern::new(
922 "no-cooldown",
923 Box::new(TurnCountDetector::new(
924 Arc::new(|s: &State| s.get::<bool>("active").unwrap_or(false)),
925 1,
926 )),
927 counting_action(counter.clone()),
928 None, );
930
931 let t0 = Instant::now();
932
933 for i in 0..5u32 {
934 let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_millis(i as u64));
935 assert!(fut.is_some(), "should fire on iteration {i}");
936 fut.unwrap().await;
937 }
938
939 assert_eq!(counter.load(Ordering::SeqCst), 5);
940 }
941}