1use std::collections::VecDeque;
13use std::sync::atomic::{AtomicU32, Ordering};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use gemini_genai_rs::session::{SessionEvent, SessionWriter};
18
19use super::BoxFuture;
20use crate::state::State;
21
22pub trait PatternDetector: Send + Sync {
30 fn check(&self, state: &State, event: Option<&SessionEvent>, now: Instant) -> bool;
36
37 fn reset(&self);
39
40 fn needs_timer(&self) -> bool {
45 false
46 }
47}
48
49pub struct TemporalPattern {
53 pub name: String,
55 pub detector: Box<dyn PatternDetector>,
57 pub action: Arc<dyn Fn(State, Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync>,
60 pub cooldown: Option<Duration>,
62 last_triggered: parking_lot::Mutex<Option<Instant>>,
64}
65
66impl TemporalPattern {
67 pub fn new(
69 name: impl Into<String>,
70 detector: Box<dyn PatternDetector>,
71 action: Arc<dyn Fn(State, Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync>,
72 cooldown: Option<Duration>,
73 ) -> Self {
74 Self {
75 name: name.into(),
76 detector,
77 action,
78 cooldown,
79 last_triggered: parking_lot::Mutex::new(None),
80 }
81 }
82
83 fn try_fire(
85 &self,
86 state: &State,
87 event: Option<&SessionEvent>,
88 writer: &Arc<dyn SessionWriter>,
89 now: Instant,
90 ) -> Option<BoxFuture<()>> {
91 if !self.detector.check(state, event, now) {
92 return None;
93 }
94
95 let mut last = self.last_triggered.lock();
97 if let (Some(cooldown), Some(prev)) = (self.cooldown, *last) {
98 if now.duration_since(prev) < cooldown {
99 return None;
100 }
101 }
102
103 *last = Some(now);
104
105 let s = state.clone();
106 let w = writer.clone();
107 Some((self.action)(s, w))
108 }
109}
110
111pub struct TemporalRegistry {
115 patterns: Vec<TemporalPattern>,
116}
117
118impl Default for TemporalRegistry {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl TemporalRegistry {
125 pub fn new() -> Self {
127 Self {
128 patterns: Vec::new(),
129 }
130 }
131
132 pub fn add(&mut self, pattern: TemporalPattern) {
134 self.patterns.push(pattern);
135 }
136
137 pub fn check_all(
142 &self,
143 state: &State,
144 event: Option<&SessionEvent>,
145 writer: &Arc<dyn SessionWriter>,
146 ) -> Vec<BoxFuture<()>> {
147 let now = Instant::now();
148 self.patterns
149 .iter()
150 .filter_map(|p| p.try_fire(state, event, writer, now))
151 .collect()
152 }
153
154 pub fn needs_timer(&self) -> bool {
157 self.patterns.iter().any(|p| p.detector.needs_timer())
158 }
159}
160
161pub struct SustainedDetector {
174 condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
175 duration: Duration,
176 became_true_at: parking_lot::Mutex<Option<Instant>>,
177}
178
179impl SustainedDetector {
180 pub fn new(condition: Arc<dyn Fn(&State) -> bool + Send + Sync>, duration: Duration) -> Self {
185 Self {
186 condition,
187 duration,
188 became_true_at: parking_lot::Mutex::new(None),
189 }
190 }
191}
192
193impl PatternDetector for SustainedDetector {
194 fn check(&self, state: &State, _event: Option<&SessionEvent>, now: Instant) -> bool {
195 if (self.condition)(state) {
196 let mut guard = self.became_true_at.lock();
197 match *guard {
198 None => {
199 *guard = Some(now);
200 false
201 }
202 Some(t) => now.duration_since(t) >= self.duration,
203 }
204 } else {
205 *self.became_true_at.lock() = None;
206 false
207 }
208 }
209
210 fn reset(&self) {
211 *self.became_true_at.lock() = None;
212 }
213
214 fn needs_timer(&self) -> bool {
215 true
216 }
217}
218
219pub struct RateDetector {
229 filter: Arc<dyn Fn(&SessionEvent) -> bool + Send + Sync>,
230 count: u32,
231 window: Duration,
232 timestamps: parking_lot::Mutex<VecDeque<Instant>>,
233}
234
235impl RateDetector {
236 pub fn new(
242 filter: Arc<dyn Fn(&SessionEvent) -> bool + Send + Sync>,
243 count: u32,
244 window: Duration,
245 ) -> Self {
246 Self {
247 filter,
248 count,
249 window,
250 timestamps: parking_lot::Mutex::new(VecDeque::new()),
251 }
252 }
253}
254
255impl PatternDetector for RateDetector {
256 fn check(&self, _state: &State, event: Option<&SessionEvent>, now: Instant) -> bool {
257 let mut ts = self.timestamps.lock();
258
259 if let Some(evt) = event {
261 if (self.filter)(evt) {
262 ts.push_back(now);
263 }
264 }
265
266 while let Some(&front) = ts.front() {
268 if now.duration_since(front) > self.window {
269 ts.pop_front();
270 } else {
271 break;
272 }
273 }
274
275 ts.len() as u32 >= self.count
276 }
277
278 fn reset(&self) {
279 self.timestamps.lock().clear();
280 }
281
282 }
284
285pub struct TurnCountDetector {
293 condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
294 required: u32,
295 consecutive: AtomicU32,
296}
297
298impl TurnCountDetector {
299 pub fn new(condition: Arc<dyn Fn(&State) -> bool + Send + Sync>, required: u32) -> Self {
304 Self {
305 condition,
306 required,
307 consecutive: AtomicU32::new(0),
308 }
309 }
310}
311
312impl PatternDetector for TurnCountDetector {
313 fn check(&self, state: &State, _event: Option<&SessionEvent>, _now: Instant) -> bool {
314 if (self.condition)(state) {
315 let prev = self.consecutive.fetch_add(1, Ordering::SeqCst);
316 prev + 1 >= self.required
317 } else {
318 self.consecutive.store(0, Ordering::SeqCst);
319 false
320 }
321 }
322
323 fn reset(&self) {
324 self.consecutive.store(0, Ordering::SeqCst);
325 }
326}
327
328pub struct ConsecutiveFailureDetector {
336 tool_name: String,
337 threshold: u32,
338 consecutive: AtomicU32,
339}
340
341impl ConsecutiveFailureDetector {
342 pub fn new(tool_name: impl Into<String>, threshold: u32) -> Self {
347 Self {
348 tool_name: tool_name.into(),
349 threshold,
350 consecutive: AtomicU32::new(0),
351 }
352 }
353}
354
355impl PatternDetector for ConsecutiveFailureDetector {
356 fn check(&self, state: &State, _event: Option<&SessionEvent>, _now: Instant) -> bool {
357 let key = format!("bg:{}_failed", self.tool_name);
358 let failed: bool = state.get(&key).unwrap_or(false);
359
360 if failed {
361 let prev = self.consecutive.fetch_add(1, Ordering::SeqCst);
362 prev + 1 >= self.threshold
363 } else {
364 self.consecutive.store(0, Ordering::SeqCst);
365 false
366 }
367 }
368
369 fn reset(&self) {
370 self.consecutive.store(0, Ordering::SeqCst);
371 }
372}
373
374#[cfg(test)]
377mod tests {
378 use super::*;
379 use std::sync::atomic::{AtomicU32, Ordering};
380
381 struct MockWriter;
383
384 #[async_trait::async_trait]
385 impl SessionWriter for MockWriter {
386 async fn send_audio(
387 &self,
388 _: Vec<u8>,
389 ) -> Result<(), gemini_genai_rs::session::SessionError> {
390 Ok(())
391 }
392 async fn send_text(&self, _: String) -> Result<(), gemini_genai_rs::session::SessionError> {
393 Ok(())
394 }
395 async fn send_tool_response(
396 &self,
397 _: Vec<gemini_genai_rs::protocol::FunctionResponse>,
398 ) -> Result<(), gemini_genai_rs::session::SessionError> {
399 Ok(())
400 }
401 async fn send_client_content(
402 &self,
403 _: Vec<gemini_genai_rs::protocol::Content>,
404 _: bool,
405 ) -> Result<(), gemini_genai_rs::session::SessionError> {
406 Ok(())
407 }
408 async fn send_video(
409 &self,
410 _: Vec<u8>,
411 ) -> Result<(), gemini_genai_rs::session::SessionError> {
412 Ok(())
413 }
414 async fn update_instruction(
415 &self,
416 _: String,
417 ) -> Result<(), gemini_genai_rs::session::SessionError> {
418 Ok(())
419 }
420 async fn signal_activity_start(
421 &self,
422 ) -> Result<(), gemini_genai_rs::session::SessionError> {
423 Ok(())
424 }
425 async fn signal_activity_end(&self) -> Result<(), gemini_genai_rs::session::SessionError> {
426 Ok(())
427 }
428 async fn disconnect(&self) -> Result<(), gemini_genai_rs::session::SessionError> {
429 Ok(())
430 }
431 }
432
433 fn mock_writer() -> Arc<dyn SessionWriter> {
434 Arc::new(MockWriter)
435 }
436
437 fn counting_action(
439 counter: Arc<AtomicU32>,
440 ) -> Arc<dyn Fn(State, Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync> {
441 Arc::new(move |_state, _writer| {
442 let c = counter.clone();
443 Box::pin(async move {
444 c.fetch_add(1, Ordering::SeqCst);
445 })
446 })
447 }
448
449 #[test]
452 fn sustained_fires_after_duration() {
453 let state = State::new();
454 state.set("hot", true);
455
456 let detector = SustainedDetector::new(
457 Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
458 Duration::from_secs(5),
459 );
460
461 let t0 = Instant::now();
462
463 assert!(!detector.check(&state, None, t0));
465
466 assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
468
469 assert!(detector.check(&state, None, t0 + Duration::from_secs(5)));
471
472 assert!(detector.check(&state, None, t0 + Duration::from_secs(6)));
474 }
475
476 #[test]
479 fn sustained_resets_on_false() {
480 let state = State::new();
481 state.set("hot", true);
482
483 let detector = SustainedDetector::new(
484 Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
485 Duration::from_secs(5),
486 );
487
488 let t0 = Instant::now();
489
490 assert!(!detector.check(&state, None, t0));
492
493 state.set("hot", false);
495 assert!(!detector.check(&state, None, t0 + Duration::from_secs(2)));
496
497 state.set("hot", true);
499 assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
500
501 assert!(!detector.check(&state, None, t0 + Duration::from_secs(7)));
503
504 assert!(detector.check(&state, None, t0 + Duration::from_secs(8)));
506 }
507
508 #[test]
511 fn sustained_reset_clears_state() {
512 let state = State::new();
513 state.set("hot", true);
514
515 let detector = SustainedDetector::new(
516 Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
517 Duration::from_secs(5),
518 );
519
520 let t0 = Instant::now();
521
522 assert!(!detector.check(&state, None, t0));
524
525 detector.reset();
527
528 assert!(!detector.check(&state, None, t0 + Duration::from_secs(4)));
530 assert!(detector.check(&state, None, t0 + Duration::from_secs(9)));
531 }
532
533 #[test]
536 fn rate_fires_when_count_reached() {
537 let state = State::new();
538 let detector = RateDetector::new(
539 Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
540 3,
541 Duration::from_secs(10),
542 );
543
544 let t0 = Instant::now();
545 let event = SessionEvent::TurnComplete;
546
547 assert!(!detector.check(&state, Some(&event), t0));
548 assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
549 assert!(detector.check(&state, Some(&event), t0 + Duration::from_secs(2)));
551 }
552
553 #[test]
556 fn rate_does_not_fire_when_events_outside_window() {
557 let state = State::new();
558 let detector = RateDetector::new(
559 Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
560 3,
561 Duration::from_secs(5),
562 );
563
564 let t0 = Instant::now();
565 let event = SessionEvent::TurnComplete;
566
567 assert!(!detector.check(&state, Some(&event), t0));
569 assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
570
571 assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(10)));
573 }
574
575 #[test]
578 fn rate_filter_rejects_events() {
579 let state = State::new();
580 let detector = RateDetector::new(
581 Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
582 2,
583 Duration::from_secs(10),
584 );
585
586 let t0 = Instant::now();
587
588 let text_event = SessionEvent::TextDelta("hello".to_string());
590 assert!(!detector.check(&state, Some(&text_event), t0));
591 assert!(!detector.check(&state, Some(&text_event), t0 + Duration::from_secs(1)));
592 assert!(!detector.check(&state, Some(&text_event), t0 + Duration::from_secs(2)));
593
594 assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
596 }
597
598 #[test]
601 fn turn_count_fires_after_n_consecutive() {
602 let state = State::new();
603 state.set("confused", true);
604
605 let detector = TurnCountDetector::new(
606 Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
607 3,
608 );
609
610 let t0 = Instant::now();
611
612 assert!(!detector.check(&state, None, t0));
613 assert!(!detector.check(&state, None, t0));
614 assert!(detector.check(&state, None, t0));
616 }
617
618 #[test]
621 fn turn_count_resets_on_false() {
622 let state = State::new();
623 state.set("confused", true);
624
625 let detector = TurnCountDetector::new(
626 Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
627 3,
628 );
629
630 let t0 = Instant::now();
631
632 assert!(!detector.check(&state, None, t0));
633 assert!(!detector.check(&state, None, t0));
634
635 state.set("confused", false);
637 assert!(!detector.check(&state, None, t0));
638
639 state.set("confused", true);
641 assert!(!detector.check(&state, None, t0));
642 assert!(!detector.check(&state, None, t0));
643 assert!(detector.check(&state, None, t0));
644 }
645
646 #[test]
649 fn consecutive_failure_fires_after_threshold() {
650 let state = State::new();
651 state.set("bg:search_failed", true);
652
653 let detector = ConsecutiveFailureDetector::new("search", 3);
654
655 let t0 = Instant::now();
656
657 assert!(!detector.check(&state, None, t0));
658 assert!(!detector.check(&state, None, t0));
659 assert!(detector.check(&state, None, t0));
661 }
662
663 #[test]
666 fn consecutive_failure_resets_on_success() {
667 let state = State::new();
668 state.set("bg:search_failed", true);
669
670 let detector = ConsecutiveFailureDetector::new("search", 3);
671
672 let t0 = Instant::now();
673
674 assert!(!detector.check(&state, None, t0));
675 assert!(!detector.check(&state, None, t0));
676
677 state.set("bg:search_failed", false);
679 assert!(!detector.check(&state, None, t0));
680
681 state.set("bg:search_failed", true);
683 assert!(!detector.check(&state, None, t0));
684 assert!(!detector.check(&state, None, t0));
685 assert!(detector.check(&state, None, t0));
686 }
687
688 #[tokio::test]
691 async fn pattern_cooldown_prevents_rapid_refiring() {
692 let counter = Arc::new(AtomicU32::new(0));
693 let state = State::new();
694 state.set("active", true);
695 let writer = mock_writer();
696
697 let pattern = TemporalPattern::new(
698 "test-cooldown",
699 Box::new(SustainedDetector::new(
700 Arc::new(|s: &State| s.get::<bool>("active").unwrap_or(false)),
701 Duration::from_secs(0), )),
703 counting_action(counter.clone()),
704 Some(Duration::from_secs(10)), );
706
707 let t0 = Instant::now();
708
709 assert!(pattern.try_fire(&state, None, &writer, t0).is_none());
712
713 let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_millis(1));
715 assert!(fut.is_some());
716 fut.unwrap().await;
717 assert_eq!(counter.load(Ordering::SeqCst), 1);
718
719 assert!(pattern
721 .try_fire(&state, None, &writer, t0 + Duration::from_millis(2))
722 .is_none());
723
724 let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_secs(11));
726 assert!(fut.is_some());
727 fut.unwrap().await;
728 assert_eq!(counter.load(Ordering::SeqCst), 2);
729 }
730
731 #[tokio::test]
734 async fn registry_check_all_returns_actions() {
735 let counter = Arc::new(AtomicU32::new(0));
736 let state = State::new();
737 state.set("confused", true);
738 let writer = mock_writer();
739
740 let mut registry = TemporalRegistry::new();
741
742 registry.add(TemporalPattern::new(
744 "confusion",
745 Box::new(TurnCountDetector::new(
746 Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
747 1,
748 )),
749 counting_action(counter.clone()),
750 None,
751 ));
752
753 let actions = registry.check_all(&state, None, &writer);
754 assert_eq!(actions.len(), 1);
755
756 for fut in actions {
757 fut.await;
758 }
759 assert_eq!(counter.load(Ordering::SeqCst), 1);
760 }
761
762 #[test]
765 fn needs_timer_true_with_sustained_detector() {
766 let counter = Arc::new(AtomicU32::new(0));
767 let mut registry = TemporalRegistry::new();
768
769 registry.add(TemporalPattern::new(
770 "sustained",
771 Box::new(SustainedDetector::new(
772 Arc::new(|_: &State| true),
773 Duration::from_secs(5),
774 )),
775 counting_action(counter),
776 None,
777 ));
778
779 assert!(registry.needs_timer());
780 }
781
782 #[test]
785 fn needs_timer_false_without_sustained_detector() {
786 let counter = Arc::new(AtomicU32::new(0));
787 let mut registry = TemporalRegistry::new();
788
789 registry.add(TemporalPattern::new(
790 "turn-count",
791 Box::new(TurnCountDetector::new(Arc::new(|_: &State| true), 3)),
792 counting_action(counter.clone()),
793 None,
794 ));
795
796 registry.add(TemporalPattern::new(
797 "rate",
798 Box::new(RateDetector::new(
799 Arc::new(|_: &SessionEvent| true),
800 5,
801 Duration::from_secs(10),
802 )),
803 counting_action(counter),
804 None,
805 ));
806
807 assert!(!registry.needs_timer());
808 }
809
810 #[test]
813 fn default_creates_empty_registry() {
814 let registry = TemporalRegistry::default();
815 assert!(!registry.needs_timer());
816 }
817
818 #[test]
821 fn rate_reset_clears_timestamps() {
822 let state = State::new();
823 let detector = RateDetector::new(
824 Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
825 2,
826 Duration::from_secs(10),
827 );
828
829 let t0 = Instant::now();
830 let event = SessionEvent::TurnComplete;
831
832 assert!(!detector.check(&state, Some(&event), t0));
833 detector.reset();
834 assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
836 assert!(detector.check(&state, Some(&event), t0 + Duration::from_secs(2)));
838 }
839
840 #[test]
843 fn turn_count_reset_clears_counter() {
844 let state = State::new();
845 state.set("confused", true);
846
847 let detector = TurnCountDetector::new(
848 Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
849 3,
850 );
851
852 let t0 = Instant::now();
853
854 assert!(!detector.check(&state, None, t0));
855 assert!(!detector.check(&state, None, t0));
856 detector.reset();
859
860 assert!(!detector.check(&state, None, t0));
862 assert!(!detector.check(&state, None, t0));
863 assert!(detector.check(&state, None, t0));
864 }
865
866 #[test]
869 fn consecutive_failure_reset_clears_counter() {
870 let state = State::new();
871 state.set("bg:search_failed", true);
872
873 let detector = ConsecutiveFailureDetector::new("search", 3);
874 let t0 = Instant::now();
875
876 assert!(!detector.check(&state, None, t0));
877 assert!(!detector.check(&state, None, t0));
878
879 detector.reset();
880
881 assert!(!detector.check(&state, None, t0));
882 assert!(!detector.check(&state, None, t0));
883 assert!(detector.check(&state, None, t0));
884 }
885
886 #[test]
889 fn sustained_detector_needs_timer() {
890 let detector = SustainedDetector::new(Arc::new(|_: &State| true), Duration::from_secs(5));
891 assert!(detector.needs_timer());
892 }
893
894 #[test]
897 fn rate_detector_does_not_need_timer() {
898 let detector = RateDetector::new(
899 Arc::new(|_: &SessionEvent| true),
900 5,
901 Duration::from_secs(10),
902 );
903 assert!(!detector.needs_timer());
904 }
905
906 #[test]
909 fn turn_count_detector_does_not_need_timer() {
910 let detector = TurnCountDetector::new(Arc::new(|_: &State| true), 3);
911 assert!(!detector.needs_timer());
912 }
913
914 #[tokio::test]
917 async fn pattern_without_cooldown_fires_every_time() {
918 let counter = Arc::new(AtomicU32::new(0));
919 let state = State::new();
920 state.set("active", true);
921 let writer = mock_writer();
922
923 let pattern = TemporalPattern::new(
924 "no-cooldown",
925 Box::new(TurnCountDetector::new(
926 Arc::new(|s: &State| s.get::<bool>("active").unwrap_or(false)),
927 1,
928 )),
929 counting_action(counter.clone()),
930 None, );
932
933 let t0 = Instant::now();
934
935 for i in 0..5u32 {
936 let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_millis(i as u64));
937 assert!(fut.is_some(), "should fire on iteration {i}");
938 fut.unwrap().await;
939 }
940
941 assert_eq!(counter.load(Ordering::SeqCst), 5);
942 }
943}