gemini_genai_rs/vad/
mod.rs

1//! Client-side Voice Activity Detection (VAD).
2//!
3//! WaveKat-backed VAD with the previous dual-threshold energy detector retained
4//! as a fallback for unsupported sample rates or frame sizes. Complements
5//! Gemini's server-side VAD:
6//!
7//! - **Bandwidth savings**: Don't send silence over the network
8//! - **Latency reduction**: Signal `activityStart` before server detects it
9//! - **Barge-in pre-emption**: Flush jitter buffer locally before server confirms
10
11/// VAD configuration parameters.
12#[derive(Debug, Clone)]
13pub struct VadConfig {
14    /// Sample rate in Hz.
15    pub sample_rate: u32,
16    /// Frame duration in milliseconds (typically 10–30ms).
17    pub frame_duration_ms: u32,
18    /// Energy threshold (dBFS) above noise floor to trigger speech start.
19    pub start_threshold_db: f64,
20    /// Energy threshold (dBFS) above noise floor to end speech.
21    pub stop_threshold_db: f64,
22    /// Minimum speech duration in frames before confirming speech.
23    pub min_speech_frames: u32,
24    /// Hangover duration in frames — keeps "speaking" state after energy drops.
25    pub hangover_frames: u32,
26    /// ZCR range for speech confirmation (low, high).
27    pub speech_zcr_range: (f64, f64),
28    /// Initial noise floor estimate (dBFS).
29    pub initial_noise_floor_db: f64,
30    /// Number of pre-speech frames to buffer.
31    pub pre_speech_frames: usize,
32}
33
34impl Default for VadConfig {
35    fn default() -> Self {
36        Self {
37            sample_rate: 16000,
38            frame_duration_ms: 30,
39            start_threshold_db: 15.0,
40            stop_threshold_db: 10.0,
41            min_speech_frames: 3,
42            hangover_frames: 10,
43            speech_zcr_range: (0.02, 0.5),
44            initial_noise_floor_db: -60.0,
45            pre_speech_frames: 3,
46        }
47    }
48}
49
50impl VadConfig {
51    /// Number of samples per frame.
52    pub fn frame_size(&self) -> usize {
53        (self.sample_rate * self.frame_duration_ms / 1000) as usize
54    }
55}
56
57/// VAD state machine states.
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum VadState {
60    /// No speech detected.
61    Silence,
62    /// Energy exceeded threshold but min duration not yet met.
63    PendingSpeech,
64    /// Speech confirmed.
65    Speech,
66    /// Energy dropped but still in hangover period.
67    Hangover,
68}
69
70/// Events emitted by the VAD.
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum VadEvent {
73    /// Speech onset detected.
74    SpeechStart,
75    /// Speech ended (after hangover).
76    SpeechEnd,
77}
78
79/// Voice Activity Detector with adaptive noise floor.
80pub struct VoiceActivityDetector {
81    config: VadConfig,
82    #[cfg(feature = "vad-wavekat")]
83    wavekat: Option<WaveKatWebRtcBackend>,
84    use_wavekat: bool,
85    state: VadState,
86    /// Adaptive noise floor estimate (dBFS).
87    noise_floor_db: f64,
88    /// Frames spent in current state.
89    state_frames: u32,
90    /// Number of frames used for noise adaptation.
91    noise_adapt_frames: u64,
92    /// Last backend speech probability/decision, normalized to 0.0..1.0.
93    last_probability: Option<f32>,
94    /// Circular buffer of pre-speech frames.
95    pre_speech_buf: Vec<Vec<i16>>,
96    pre_speech_idx: usize,
97}
98
99impl VoiceActivityDetector {
100    /// Create a new VAD with the given configuration.
101    pub fn new(config: VadConfig) -> Self {
102        Self::new_with_backend(config, true)
103    }
104
105    fn new_with_backend(config: VadConfig, prefer_wavekat: bool) -> Self {
106        let frame_size = config.frame_size();
107        let pre_speech_buf: Vec<Vec<i16>> = (0..config.pre_speech_frames)
108            .map(|_| vec![0i16; frame_size])
109            .collect();
110        #[cfg(feature = "vad-wavekat")]
111        let wavekat = prefer_wavekat
112            .then(|| WaveKatWebRtcBackend::new(&config))
113            .flatten();
114        #[cfg(feature = "vad-wavekat")]
115        let use_wavekat = wavekat.is_some();
116        #[cfg(not(feature = "vad-wavekat"))]
117        let use_wavekat = {
118            let _ = prefer_wavekat;
119            false
120        };
121        Self {
122            noise_floor_db: config.initial_noise_floor_db,
123            state: VadState::Silence,
124            state_frames: 0,
125            noise_adapt_frames: 0,
126            last_probability: None,
127            pre_speech_buf,
128            pre_speech_idx: 0,
129            #[cfg(feature = "vad-wavekat")]
130            wavekat,
131            use_wavekat,
132            config,
133        }
134    }
135
136    #[cfg(test)]
137    fn new_energy(config: VadConfig) -> Self {
138        Self::new_with_backend(config, false)
139    }
140
141    /// Current VAD state.
142    pub fn state(&self) -> VadState {
143        self.state
144    }
145
146    /// Whether speech is currently detected (Speech or Hangover state).
147    pub fn is_speaking(&self) -> bool {
148        matches!(self.state, VadState::Speech | VadState::Hangover)
149    }
150
151    /// Current noise floor estimate (dBFS).
152    pub fn noise_floor_db(&self) -> f64 {
153        self.noise_floor_db
154    }
155
156    /// Whether this detector is currently using the WaveKat backend.
157    pub fn is_wavekat_backed(&self) -> bool {
158        self.use_wavekat
159    }
160
161    /// Name of the active backend.
162    pub fn backend_name(&self) -> &'static str {
163        if self.use_wavekat {
164            "wavekat-webrtc"
165        } else {
166            "energy-zcr"
167        }
168    }
169
170    /// Last normalized speech probability or binary backend decision.
171    pub fn last_probability(&self) -> Option<f32> {
172        self.last_probability
173    }
174
175    /// Get pre-speech frames (the frames captured just before speech onset).
176    pub fn drain_pre_speech(&mut self) -> Vec<Vec<i16>> {
177        let frame_size = self.config.frame_size();
178        let mut fresh: Vec<Vec<i16>> = (0..self.config.pre_speech_frames)
179            .map(|_| vec![0i16; frame_size])
180            .collect();
181        std::mem::swap(&mut self.pre_speech_buf, &mut fresh);
182        self.pre_speech_idx = 0;
183        fresh
184    }
185
186    /// Process a single audio frame and return any state-change event.
187    pub fn process_frame(&mut self, samples: &[i16]) -> Option<VadEvent> {
188        let energy_db = compute_energy_db(samples);
189        let zcr = compute_zcr(samples);
190        let energy_above_noise = energy_db - self.noise_floor_db;
191
192        let wavekat_decision = self.wavekat_decision(samples);
193        let energy_speech_like = energy_above_noise > self.config.start_threshold_db
194            && zcr >= self.config.speech_zcr_range.0
195            && zcr <= self.config.speech_zcr_range.1;
196        let energy_above_stop = energy_above_noise > self.config.stop_threshold_db;
197        let is_speech_like = wavekat_decision.unwrap_or(energy_speech_like);
198        let is_above_stop = wavekat_decision.unwrap_or(energy_above_stop);
199        if wavekat_decision.is_none() {
200            self.last_probability = Some(if energy_speech_like { 1.0 } else { 0.0 });
201        }
202
203        match self.state {
204            VadState::Silence => {
205                // Update noise floor during confirmed silence
206                self.update_noise_floor(energy_db);
207
208                // Store pre-speech frame (copy into pre-allocated slot, zero-alloc)
209                if self.config.pre_speech_frames > 0 && !self.pre_speech_buf.is_empty() {
210                    let idx = self.pre_speech_idx % self.config.pre_speech_frames;
211                    let buf = &mut self.pre_speech_buf[idx];
212                    buf.resize(samples.len(), 0);
213                    buf.copy_from_slice(samples);
214                    self.pre_speech_idx += 1;
215                }
216
217                if is_speech_like {
218                    self.state = VadState::PendingSpeech;
219                    self.state_frames = 1;
220                }
221                None
222            }
223
224            VadState::PendingSpeech => {
225                if is_speech_like {
226                    self.state_frames += 1;
227                    if self.state_frames >= self.config.min_speech_frames {
228                        self.state = VadState::Speech;
229                        self.state_frames = 0;
230                        Some(VadEvent::SpeechStart)
231                    } else {
232                        None
233                    }
234                } else {
235                    // False alarm — go back to silence
236                    self.state = VadState::Silence;
237                    self.state_frames = 0;
238                    None
239                }
240            }
241
242            VadState::Speech => {
243                if !is_above_stop {
244                    self.state = VadState::Hangover;
245                    self.state_frames = 1;
246                }
247                None
248            }
249
250            VadState::Hangover => {
251                if is_above_stop {
252                    // Speech resumed — back to Speech
253                    self.state = VadState::Speech;
254                    self.state_frames = 0;
255                    None
256                } else {
257                    self.state_frames += 1;
258                    if self.state_frames >= self.config.hangover_frames {
259                        self.state = VadState::Silence;
260                        self.state_frames = 0;
261                        for buf in &mut self.pre_speech_buf {
262                            buf.iter_mut().for_each(|s| *s = 0);
263                        }
264                        self.pre_speech_idx = 0;
265                        Some(VadEvent::SpeechEnd)
266                    } else {
267                        None
268                    }
269                }
270            }
271        }
272    }
273
274    /// Update the adaptive noise floor using EWMA.
275    fn update_noise_floor(&mut self, energy_db: f64) {
276        self.noise_adapt_frames += 1;
277        // Alpha decreases over time: fast initial adaptation, slow drift
278        let alpha = 0.01_f64.min(1.0 / self.noise_adapt_frames as f64);
279        self.noise_floor_db = self.noise_floor_db * (1.0 - alpha) + energy_db * alpha;
280    }
281
282    fn wavekat_decision(&mut self, samples: &[i16]) -> Option<bool> {
283        #[cfg(feature = "vad-wavekat")]
284        {
285            let probability = self
286                .wavekat
287                .as_mut()
288                .and_then(|backend| backend.process(samples, self.config.sample_rate));
289            self.last_probability = probability;
290            probability.map(|probability| probability >= 0.5)
291        }
292        #[cfg(not(feature = "vad-wavekat"))]
293        {
294            let _ = samples;
295            None
296        }
297    }
298
299    /// Reset the VAD to its initial state.
300    pub fn reset(&mut self) {
301        self.state = VadState::Silence;
302        self.state_frames = 0;
303        self.noise_adapt_frames = 0;
304        self.last_probability = None;
305        self.noise_floor_db = self.config.initial_noise_floor_db;
306        for buf in &mut self.pre_speech_buf {
307            buf.iter_mut().for_each(|s| *s = 0);
308        }
309        self.pre_speech_idx = 0;
310    }
311}
312
313#[cfg(feature = "vad-wavekat")]
314struct WaveKatWebRtcBackend {
315    detector: wavekat_vad::backends::webrtc::WebRtcVad,
316}
317
318#[cfg(feature = "vad-wavekat")]
319impl WaveKatWebRtcBackend {
320    fn new(config: &VadConfig) -> Option<Self> {
321        if !matches!(config.sample_rate, 8000 | 16000 | 32000 | 48000) {
322            return None;
323        }
324        if !matches!(config.frame_duration_ms, 10 | 20 | 30) {
325            return None;
326        }
327
328        let detector = wavekat_vad::backends::webrtc::WebRtcVad::with_frame_duration(
329            config.sample_rate,
330            wavekat_vad::backends::webrtc::WebRtcVadMode::Aggressive,
331            config.frame_duration_ms,
332        )
333        .ok()?;
334        Some(Self { detector })
335    }
336
337    fn process(&mut self, samples: &[i16], sample_rate: u32) -> Option<f32> {
338        use wavekat_vad::VoiceActivityDetector as _;
339
340        self.detector.process(samples, sample_rate).ok()
341    }
342}
343
344/// Compute RMS energy in dBFS for a frame of PCM16 samples.
345fn compute_energy_db(samples: &[i16]) -> f64 {
346    if samples.is_empty() {
347        return -96.0;
348    }
349
350    let sum_sq: f64 = samples.iter().map(|&s| (s as f64) * (s as f64)).sum();
351    let rms = (sum_sq / samples.len() as f64).sqrt();
352    let db = 20.0 * (rms / 32767.0).log10();
353    db.max(-96.0) // Floor at -96 dBFS
354}
355
356/// Compute zero-crossing rate for a frame of PCM16 samples.
357fn compute_zcr(samples: &[i16]) -> f64 {
358    if samples.len() < 2 {
359        return 0.0;
360    }
361
362    let crossings = samples
363        .windows(2)
364        .filter(|w| (w[0] >= 0) != (w[1] >= 0))
365        .count();
366
367    crossings as f64 / (samples.len() - 1) as f64
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    fn make_vad() -> VoiceActivityDetector {
375        VoiceActivityDetector::new_energy(VadConfig {
376            sample_rate: 16000,
377            frame_duration_ms: 20,
378            start_threshold_db: 15.0,
379            stop_threshold_db: 10.0,
380            min_speech_frames: 2,
381            hangover_frames: 3,
382            speech_zcr_range: (0.01, 0.9),
383            initial_noise_floor_db: -60.0,
384            pre_speech_frames: 2,
385        })
386    }
387
388    fn silence_frame(len: usize) -> Vec<i16> {
389        vec![0i16; len]
390    }
391
392    fn speech_frame(len: usize, amplitude: i16) -> Vec<i16> {
393        // Generate a simple alternating signal that has both energy and ZCR
394        (0..len)
395            .map(|i| if i % 4 < 2 { amplitude } else { -amplitude })
396            .collect()
397    }
398
399    #[test]
400    fn starts_silent() {
401        let vad = make_vad();
402        assert_eq!(vad.state(), VadState::Silence);
403        assert!(!vad.is_speaking());
404    }
405
406    #[cfg(feature = "vad-wavekat")]
407    #[test]
408    fn default_detector_uses_wavekat_for_supported_frames() {
409        let vad = VoiceActivityDetector::new(VadConfig {
410            sample_rate: 16000,
411            frame_duration_ms: 20,
412            ..VadConfig::default()
413        });
414        assert!(vad.is_wavekat_backed());
415    }
416
417    #[cfg(feature = "vad-wavekat")]
418    #[test]
419    fn unsupported_frames_fall_back_to_energy_detector() {
420        let vad = VoiceActivityDetector::new(VadConfig {
421            sample_rate: 16000,
422            frame_duration_ms: 32,
423            ..VadConfig::default()
424        });
425        assert!(!vad.is_wavekat_backed());
426    }
427
428    #[test]
429    fn silence_stays_silent() {
430        let mut vad = make_vad();
431        let frame = silence_frame(320);
432        for _ in 0..10 {
433            let event = vad.process_frame(&frame);
434            assert!(event.is_none());
435        }
436        assert_eq!(vad.state(), VadState::Silence);
437    }
438
439    #[test]
440    fn speech_detected_after_min_frames() {
441        let mut vad = make_vad();
442        let frame = speech_frame(320, 10000);
443
444        // Frame 1: PendingSpeech
445        let e1 = vad.process_frame(&frame);
446        assert!(e1.is_none());
447        assert_eq!(vad.state(), VadState::PendingSpeech);
448
449        // Frame 2: min_speech_frames = 2 → SpeechStart
450        let e2 = vad.process_frame(&frame);
451        assert_eq!(e2, Some(VadEvent::SpeechStart));
452        assert_eq!(vad.state(), VadState::Speech);
453        assert!(vad.is_speaking());
454    }
455
456    #[test]
457    fn speech_end_after_hangover() {
458        let mut vad = make_vad();
459        let speech = speech_frame(320, 10000);
460        let silence = silence_frame(320);
461
462        // Trigger speech
463        vad.process_frame(&speech);
464        vad.process_frame(&speech);
465        assert_eq!(vad.state(), VadState::Speech);
466
467        // Drop energy → hangover
468        vad.process_frame(&silence);
469        assert_eq!(vad.state(), VadState::Hangover);
470
471        // Hangover frames 2 and 3
472        vad.process_frame(&silence);
473        let e = vad.process_frame(&silence);
474        assert_eq!(e, Some(VadEvent::SpeechEnd));
475        assert_eq!(vad.state(), VadState::Silence);
476    }
477
478    #[test]
479    fn speech_resumes_during_hangover() {
480        let mut vad = make_vad();
481        let speech = speech_frame(320, 10000);
482        let silence = silence_frame(320);
483
484        // Trigger speech
485        vad.process_frame(&speech);
486        vad.process_frame(&speech);
487        assert_eq!(vad.state(), VadState::Speech);
488
489        // Brief silence → hangover
490        vad.process_frame(&silence);
491        assert_eq!(vad.state(), VadState::Hangover);
492
493        // Speech resumes
494        let e = vad.process_frame(&speech);
495        assert!(e.is_none()); // No event, just resumes
496        assert_eq!(vad.state(), VadState::Speech);
497    }
498
499    #[test]
500    fn false_alarm_returns_to_silence() {
501        let mut vad = make_vad();
502        let speech = speech_frame(320, 10000);
503        let silence = silence_frame(320);
504
505        // 1 speech frame → PendingSpeech
506        vad.process_frame(&speech);
507        assert_eq!(vad.state(), VadState::PendingSpeech);
508
509        // Then silence → back to Silence (false alarm)
510        vad.process_frame(&silence);
511        assert_eq!(vad.state(), VadState::Silence);
512    }
513
514    #[test]
515    fn energy_db_calculation() {
516        // Full-scale sine approximation
517        let full_scale: Vec<i16> = (0..320).map(|_| i16::MAX).collect();
518        let db = compute_energy_db(&full_scale);
519        assert!(db > -1.0); // Should be near 0 dBFS
520
521        let silence = vec![0i16; 320];
522        let db_silence = compute_energy_db(&silence);
523        assert_eq!(db_silence, -96.0);
524    }
525
526    #[test]
527    fn zcr_calculation() {
528        // Alternating signal → high ZCR
529        let alternating: Vec<i16> = (0..100)
530            .map(|i| if i % 2 == 0 { 1000 } else { -1000 })
531            .collect();
532        let zcr = compute_zcr(&alternating);
533        assert!(zcr > 0.9);
534
535        // Constant signal → zero ZCR
536        let constant = vec![1000i16; 100];
537        let zcr_const = compute_zcr(&constant);
538        assert_eq!(zcr_const, 0.0);
539    }
540
541    #[test]
542    fn noise_floor_adapts() {
543        let mut vad = make_vad();
544        // Feed low-energy frames — noise floor should move toward them
545        let low_noise: Vec<i16> = vec![10; 320]; // Very quiet
546        for _ in 0..100 {
547            vad.process_frame(&low_noise);
548        }
549        // Noise floor should have adapted upward from -60 dBFS
550        assert!(vad.noise_floor_db() > -96.0);
551    }
552
553    #[test]
554    fn reset_clears_state() {
555        let mut vad = make_vad();
556        let speech = speech_frame(320, 10000);
557        vad.process_frame(&speech);
558        vad.process_frame(&speech);
559        assert_eq!(vad.state(), VadState::Speech);
560
561        vad.reset();
562        assert_eq!(vad.state(), VadState::Silence);
563        assert_eq!(vad.noise_floor_db(), -60.0);
564    }
565}