gemini_genai_rs/flow/
turn_detection.rs1use std::time::{Duration, Instant};
7
8#[derive(Debug, Clone)]
10pub struct TurnDetectionConfig {
11 pub end_of_speech_delay_ms: u64,
13 pub enabled: bool,
16}
17
18impl Default for TurnDetectionConfig {
19 fn default() -> Self {
20 Self {
21 end_of_speech_delay_ms: 300,
22 enabled: true,
23 }
24 }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum TurnDetectionEvent {
30 SpeechStarted,
32 TurnEnded,
34}
35
36pub struct TurnDetector {
38 config: TurnDetectionConfig,
39 is_speaking: bool,
41 speech_ended_at: Option<Instant>,
43 turn_ended_emitted: bool,
45}
46
47impl TurnDetector {
48 pub fn new(config: TurnDetectionConfig) -> Self {
50 Self {
51 config,
52 is_speaking: false,
53 speech_ended_at: None,
54 turn_ended_emitted: false,
55 }
56 }
57
58 pub fn update(&mut self, vad_is_speaking: bool) -> Option<TurnDetectionEvent> {
62 if !self.config.enabled {
63 return None;
64 }
65
66 if vad_is_speaking && !self.is_speaking {
67 self.is_speaking = true;
69 self.speech_ended_at = None;
70 self.turn_ended_emitted = false;
71 return Some(TurnDetectionEvent::SpeechStarted);
72 }
73
74 if !vad_is_speaking && self.is_speaking {
75 self.is_speaking = false;
77 self.speech_ended_at = Some(Instant::now());
78 }
79
80 if let Some(ended_at) = self.speech_ended_at {
82 if !self.turn_ended_emitted
83 && ended_at.elapsed() >= Duration::from_millis(self.config.end_of_speech_delay_ms)
84 {
85 self.turn_ended_emitted = true;
86 self.speech_ended_at = None;
87 return Some(TurnDetectionEvent::TurnEnded);
88 }
89 }
90
91 None
92 }
93
94 pub fn is_speaking(&self) -> bool {
96 self.is_speaking
97 }
98
99 pub fn is_pending_turn_end(&self) -> bool {
101 self.speech_ended_at.is_some() && !self.turn_ended_emitted
102 }
103
104 pub fn reset(&mut self) {
106 self.is_speaking = false;
107 self.speech_ended_at = None;
108 self.turn_ended_emitted = false;
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::thread;
116
117 #[test]
118 fn speech_start_detected() {
119 let mut detector = TurnDetector::new(TurnDetectionConfig::default());
120
121 let event = detector.update(true);
122 assert_eq!(event, Some(TurnDetectionEvent::SpeechStarted));
123 assert!(detector.is_speaking());
124 }
125
126 #[test]
127 fn turn_end_after_delay() {
128 let mut detector = TurnDetector::new(TurnDetectionConfig {
129 end_of_speech_delay_ms: 50,
130 enabled: true,
131 });
132
133 detector.update(true);
135
136 detector.update(false);
138 assert!(detector.is_pending_turn_end());
139
140 let event = detector.update(false);
142 assert!(event.is_none() || matches!(event, Some(TurnDetectionEvent::TurnEnded)));
143
144 thread::sleep(Duration::from_millis(60));
146 let event = detector.update(false);
147 assert_eq!(event, Some(TurnDetectionEvent::TurnEnded));
148 }
149
150 #[test]
151 fn speech_resume_cancels_turn_end() {
152 let mut detector = TurnDetector::new(TurnDetectionConfig {
153 end_of_speech_delay_ms: 200,
154 enabled: true,
155 });
156
157 detector.update(true);
159 detector.update(false);
160 assert!(detector.is_pending_turn_end());
161
162 let event = detector.update(true);
164 assert_eq!(event, Some(TurnDetectionEvent::SpeechStarted));
165 assert!(!detector.is_pending_turn_end());
166 }
167
168 #[test]
169 fn disabled_detector_emits_nothing() {
170 let mut detector = TurnDetector::new(TurnDetectionConfig {
171 enabled: false,
172 ..Default::default()
173 });
174
175 assert!(detector.update(true).is_none());
176 assert!(detector.update(false).is_none());
177 }
178}