1#[derive(Debug, Clone)]
13pub struct VadConfig {
14 pub sample_rate: u32,
16 pub frame_duration_ms: u32,
18 pub start_threshold_db: f64,
20 pub stop_threshold_db: f64,
22 pub min_speech_frames: u32,
24 pub hangover_frames: u32,
26 pub speech_zcr_range: (f64, f64),
28 pub initial_noise_floor_db: f64,
30 pub pre_speech_frames: usize,
32}
33
34impl Default for VadConfig {
35 fn default() -> Self {
36 Self {
37 sample_rate: 16000,
38 frame_duration_ms: 30,
39 start_threshold_db: 15.0,
40 stop_threshold_db: 10.0,
41 min_speech_frames: 3,
42 hangover_frames: 10,
43 speech_zcr_range: (0.02, 0.5),
44 initial_noise_floor_db: -60.0,
45 pre_speech_frames: 3,
46 }
47 }
48}
49
50impl VadConfig {
51 pub fn frame_size(&self) -> usize {
53 (self.sample_rate * self.frame_duration_ms / 1000) as usize
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum VadState {
60 Silence,
62 PendingSpeech,
64 Speech,
66 Hangover,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum VadEvent {
73 SpeechStart,
75 SpeechEnd,
77}
78
79pub struct VoiceActivityDetector {
81 config: VadConfig,
82 #[cfg(feature = "vad-wavekat")]
83 wavekat: Option<WaveKatWebRtcBackend>,
84 use_wavekat: bool,
85 state: VadState,
86 noise_floor_db: f64,
88 state_frames: u32,
90 noise_adapt_frames: u64,
92 last_probability: Option<f32>,
94 pre_speech_buf: Vec<Vec<i16>>,
96 pre_speech_idx: usize,
97}
98
99impl VoiceActivityDetector {
100 pub fn new(config: VadConfig) -> Self {
102 Self::new_with_backend(config, true)
103 }
104
105 fn new_with_backend(config: VadConfig, prefer_wavekat: bool) -> Self {
106 let frame_size = config.frame_size();
107 let pre_speech_buf: Vec<Vec<i16>> = (0..config.pre_speech_frames)
108 .map(|_| vec![0i16; frame_size])
109 .collect();
110 #[cfg(feature = "vad-wavekat")]
111 let wavekat = prefer_wavekat
112 .then(|| WaveKatWebRtcBackend::new(&config))
113 .flatten();
114 #[cfg(feature = "vad-wavekat")]
115 let use_wavekat = wavekat.is_some();
116 #[cfg(not(feature = "vad-wavekat"))]
117 let use_wavekat = {
118 let _ = prefer_wavekat;
119 false
120 };
121 Self {
122 noise_floor_db: config.initial_noise_floor_db,
123 state: VadState::Silence,
124 state_frames: 0,
125 noise_adapt_frames: 0,
126 last_probability: None,
127 pre_speech_buf,
128 pre_speech_idx: 0,
129 #[cfg(feature = "vad-wavekat")]
130 wavekat,
131 use_wavekat,
132 config,
133 }
134 }
135
136 #[cfg(test)]
137 fn new_energy(config: VadConfig) -> Self {
138 Self::new_with_backend(config, false)
139 }
140
141 pub fn state(&self) -> VadState {
143 self.state
144 }
145
146 pub fn is_speaking(&self) -> bool {
148 matches!(self.state, VadState::Speech | VadState::Hangover)
149 }
150
151 pub fn noise_floor_db(&self) -> f64 {
153 self.noise_floor_db
154 }
155
156 pub fn is_wavekat_backed(&self) -> bool {
158 self.use_wavekat
159 }
160
161 pub fn backend_name(&self) -> &'static str {
163 if self.use_wavekat {
164 "wavekat-webrtc"
165 } else {
166 "energy-zcr"
167 }
168 }
169
170 pub fn last_probability(&self) -> Option<f32> {
172 self.last_probability
173 }
174
175 pub fn drain_pre_speech(&mut self) -> Vec<Vec<i16>> {
177 let frame_size = self.config.frame_size();
178 let mut fresh: Vec<Vec<i16>> = (0..self.config.pre_speech_frames)
179 .map(|_| vec![0i16; frame_size])
180 .collect();
181 std::mem::swap(&mut self.pre_speech_buf, &mut fresh);
182 self.pre_speech_idx = 0;
183 fresh
184 }
185
186 pub fn process_frame(&mut self, samples: &[i16]) -> Option<VadEvent> {
188 let energy_db = compute_energy_db(samples);
189 let zcr = compute_zcr(samples);
190 let energy_above_noise = energy_db - self.noise_floor_db;
191
192 let wavekat_decision = self.wavekat_decision(samples);
193 let energy_speech_like = energy_above_noise > self.config.start_threshold_db
194 && zcr >= self.config.speech_zcr_range.0
195 && zcr <= self.config.speech_zcr_range.1;
196 let energy_above_stop = energy_above_noise > self.config.stop_threshold_db;
197 let is_speech_like = wavekat_decision.unwrap_or(energy_speech_like);
198 let is_above_stop = wavekat_decision.unwrap_or(energy_above_stop);
199 if wavekat_decision.is_none() {
200 self.last_probability = Some(if energy_speech_like { 1.0 } else { 0.0 });
201 }
202
203 match self.state {
204 VadState::Silence => {
205 self.update_noise_floor(energy_db);
207
208 if self.config.pre_speech_frames > 0 && !self.pre_speech_buf.is_empty() {
210 let idx = self.pre_speech_idx % self.config.pre_speech_frames;
211 let buf = &mut self.pre_speech_buf[idx];
212 buf.resize(samples.len(), 0);
213 buf.copy_from_slice(samples);
214 self.pre_speech_idx += 1;
215 }
216
217 if is_speech_like {
218 self.state = VadState::PendingSpeech;
219 self.state_frames = 1;
220 }
221 None
222 }
223
224 VadState::PendingSpeech => {
225 if is_speech_like {
226 self.state_frames += 1;
227 if self.state_frames >= self.config.min_speech_frames {
228 self.state = VadState::Speech;
229 self.state_frames = 0;
230 Some(VadEvent::SpeechStart)
231 } else {
232 None
233 }
234 } else {
235 self.state = VadState::Silence;
237 self.state_frames = 0;
238 None
239 }
240 }
241
242 VadState::Speech => {
243 if !is_above_stop {
244 self.state = VadState::Hangover;
245 self.state_frames = 1;
246 }
247 None
248 }
249
250 VadState::Hangover => {
251 if is_above_stop {
252 self.state = VadState::Speech;
254 self.state_frames = 0;
255 None
256 } else {
257 self.state_frames += 1;
258 if self.state_frames >= self.config.hangover_frames {
259 self.state = VadState::Silence;
260 self.state_frames = 0;
261 for buf in &mut self.pre_speech_buf {
262 buf.iter_mut().for_each(|s| *s = 0);
263 }
264 self.pre_speech_idx = 0;
265 Some(VadEvent::SpeechEnd)
266 } else {
267 None
268 }
269 }
270 }
271 }
272 }
273
274 fn update_noise_floor(&mut self, energy_db: f64) {
276 self.noise_adapt_frames += 1;
277 let alpha = 0.01_f64.min(1.0 / self.noise_adapt_frames as f64);
279 self.noise_floor_db = self.noise_floor_db * (1.0 - alpha) + energy_db * alpha;
280 }
281
282 fn wavekat_decision(&mut self, samples: &[i16]) -> Option<bool> {
283 #[cfg(feature = "vad-wavekat")]
284 {
285 let probability = self
286 .wavekat
287 .as_mut()
288 .and_then(|backend| backend.process(samples, self.config.sample_rate));
289 self.last_probability = probability;
290 probability.map(|probability| probability >= 0.5)
291 }
292 #[cfg(not(feature = "vad-wavekat"))]
293 {
294 let _ = samples;
295 None
296 }
297 }
298
299 pub fn reset(&mut self) {
301 self.state = VadState::Silence;
302 self.state_frames = 0;
303 self.noise_adapt_frames = 0;
304 self.last_probability = None;
305 self.noise_floor_db = self.config.initial_noise_floor_db;
306 for buf in &mut self.pre_speech_buf {
307 buf.iter_mut().for_each(|s| *s = 0);
308 }
309 self.pre_speech_idx = 0;
310 }
311}
312
313#[cfg(feature = "vad-wavekat")]
314struct WaveKatWebRtcBackend {
315 detector: wavekat_vad::backends::webrtc::WebRtcVad,
316}
317
318#[cfg(feature = "vad-wavekat")]
319impl WaveKatWebRtcBackend {
320 fn new(config: &VadConfig) -> Option<Self> {
321 if !matches!(config.sample_rate, 8000 | 16000 | 32000 | 48000) {
322 return None;
323 }
324 if !matches!(config.frame_duration_ms, 10 | 20 | 30) {
325 return None;
326 }
327
328 let detector = wavekat_vad::backends::webrtc::WebRtcVad::with_frame_duration(
329 config.sample_rate,
330 wavekat_vad::backends::webrtc::WebRtcVadMode::Aggressive,
331 config.frame_duration_ms,
332 )
333 .ok()?;
334 Some(Self { detector })
335 }
336
337 fn process(&mut self, samples: &[i16], sample_rate: u32) -> Option<f32> {
338 use wavekat_vad::VoiceActivityDetector as _;
339
340 self.detector.process(samples, sample_rate).ok()
341 }
342}
343
344fn compute_energy_db(samples: &[i16]) -> f64 {
346 if samples.is_empty() {
347 return -96.0;
348 }
349
350 let sum_sq: f64 = samples.iter().map(|&s| (s as f64) * (s as f64)).sum();
351 let rms = (sum_sq / samples.len() as f64).sqrt();
352 let db = 20.0 * (rms / 32767.0).log10();
353 db.max(-96.0) }
355
356fn compute_zcr(samples: &[i16]) -> f64 {
358 if samples.len() < 2 {
359 return 0.0;
360 }
361
362 let crossings = samples
363 .windows(2)
364 .filter(|w| (w[0] >= 0) != (w[1] >= 0))
365 .count();
366
367 crossings as f64 / (samples.len() - 1) as f64
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 fn make_vad() -> VoiceActivityDetector {
375 VoiceActivityDetector::new_energy(VadConfig {
376 sample_rate: 16000,
377 frame_duration_ms: 20,
378 start_threshold_db: 15.0,
379 stop_threshold_db: 10.0,
380 min_speech_frames: 2,
381 hangover_frames: 3,
382 speech_zcr_range: (0.01, 0.9),
383 initial_noise_floor_db: -60.0,
384 pre_speech_frames: 2,
385 })
386 }
387
388 fn silence_frame(len: usize) -> Vec<i16> {
389 vec![0i16; len]
390 }
391
392 fn speech_frame(len: usize, amplitude: i16) -> Vec<i16> {
393 (0..len)
395 .map(|i| if i % 4 < 2 { amplitude } else { -amplitude })
396 .collect()
397 }
398
399 #[test]
400 fn starts_silent() {
401 let vad = make_vad();
402 assert_eq!(vad.state(), VadState::Silence);
403 assert!(!vad.is_speaking());
404 }
405
406 #[cfg(feature = "vad-wavekat")]
407 #[test]
408 fn default_detector_uses_wavekat_for_supported_frames() {
409 let vad = VoiceActivityDetector::new(VadConfig {
410 sample_rate: 16000,
411 frame_duration_ms: 20,
412 ..VadConfig::default()
413 });
414 assert!(vad.is_wavekat_backed());
415 }
416
417 #[cfg(feature = "vad-wavekat")]
418 #[test]
419 fn unsupported_frames_fall_back_to_energy_detector() {
420 let vad = VoiceActivityDetector::new(VadConfig {
421 sample_rate: 16000,
422 frame_duration_ms: 32,
423 ..VadConfig::default()
424 });
425 assert!(!vad.is_wavekat_backed());
426 }
427
428 #[test]
429 fn silence_stays_silent() {
430 let mut vad = make_vad();
431 let frame = silence_frame(320);
432 for _ in 0..10 {
433 let event = vad.process_frame(&frame);
434 assert!(event.is_none());
435 }
436 assert_eq!(vad.state(), VadState::Silence);
437 }
438
439 #[test]
440 fn speech_detected_after_min_frames() {
441 let mut vad = make_vad();
442 let frame = speech_frame(320, 10000);
443
444 let e1 = vad.process_frame(&frame);
446 assert!(e1.is_none());
447 assert_eq!(vad.state(), VadState::PendingSpeech);
448
449 let e2 = vad.process_frame(&frame);
451 assert_eq!(e2, Some(VadEvent::SpeechStart));
452 assert_eq!(vad.state(), VadState::Speech);
453 assert!(vad.is_speaking());
454 }
455
456 #[test]
457 fn speech_end_after_hangover() {
458 let mut vad = make_vad();
459 let speech = speech_frame(320, 10000);
460 let silence = silence_frame(320);
461
462 vad.process_frame(&speech);
464 vad.process_frame(&speech);
465 assert_eq!(vad.state(), VadState::Speech);
466
467 vad.process_frame(&silence);
469 assert_eq!(vad.state(), VadState::Hangover);
470
471 vad.process_frame(&silence);
473 let e = vad.process_frame(&silence);
474 assert_eq!(e, Some(VadEvent::SpeechEnd));
475 assert_eq!(vad.state(), VadState::Silence);
476 }
477
478 #[test]
479 fn speech_resumes_during_hangover() {
480 let mut vad = make_vad();
481 let speech = speech_frame(320, 10000);
482 let silence = silence_frame(320);
483
484 vad.process_frame(&speech);
486 vad.process_frame(&speech);
487 assert_eq!(vad.state(), VadState::Speech);
488
489 vad.process_frame(&silence);
491 assert_eq!(vad.state(), VadState::Hangover);
492
493 let e = vad.process_frame(&speech);
495 assert!(e.is_none()); assert_eq!(vad.state(), VadState::Speech);
497 }
498
499 #[test]
500 fn false_alarm_returns_to_silence() {
501 let mut vad = make_vad();
502 let speech = speech_frame(320, 10000);
503 let silence = silence_frame(320);
504
505 vad.process_frame(&speech);
507 assert_eq!(vad.state(), VadState::PendingSpeech);
508
509 vad.process_frame(&silence);
511 assert_eq!(vad.state(), VadState::Silence);
512 }
513
514 #[test]
515 fn energy_db_calculation() {
516 let full_scale: Vec<i16> = (0..320).map(|_| i16::MAX).collect();
518 let db = compute_energy_db(&full_scale);
519 assert!(db > -1.0); let silence = vec![0i16; 320];
522 let db_silence = compute_energy_db(&silence);
523 assert_eq!(db_silence, -96.0);
524 }
525
526 #[test]
527 fn zcr_calculation() {
528 let alternating: Vec<i16> = (0..100)
530 .map(|i| if i % 2 == 0 { 1000 } else { -1000 })
531 .collect();
532 let zcr = compute_zcr(&alternating);
533 assert!(zcr > 0.9);
534
535 let constant = vec![1000i16; 100];
537 let zcr_const = compute_zcr(&constant);
538 assert_eq!(zcr_const, 0.0);
539 }
540
541 #[test]
542 fn noise_floor_adapts() {
543 let mut vad = make_vad();
544 let low_noise: Vec<i16> = vec![10; 320]; for _ in 0..100 {
547 vad.process_frame(&low_noise);
548 }
549 assert!(vad.noise_floor_db() > -96.0);
551 }
552
553 #[test]
554 fn reset_clears_state() {
555 let mut vad = make_vad();
556 let speech = speech_frame(320, 10000);
557 vad.process_frame(&speech);
558 vad.process_frame(&speech);
559 assert_eq!(vad.state(), VadState::Speech);
560
561 vad.reset();
562 assert_eq!(vad.state(), VadState::Silence);
563 assert_eq!(vad.noise_floor_db(), -60.0);
564 }
565}