gemini_genai_rs/buffer/
jitter.rs

1//! Adaptive jitter buffer for smooth playback of network audio.
2//!
3//! Network audio arrives in variable-size bursts. The jitter buffer
4//! smooths playback by accumulating a configurable minimum depth
5//! before starting playback, and adjusting depth dynamically based
6//! on measured inter-arrival jitter (EWMA, similar to TCP RTT estimation).
7
8use std::collections::VecDeque;
9use std::time::Instant;
10
11/// Configuration for the jitter buffer.
12#[derive(Debug, Clone)]
13pub struct JitterConfig {
14    /// Sample rate in Hz (e.g., 24000 for Gemini output).
15    pub sample_rate: u32,
16    /// Minimum buffer depth in samples before playback starts.
17    pub min_depth_samples: usize,
18    /// Maximum buffer depth in samples (overflow drops oldest).
19    pub max_depth_samples: usize,
20    /// EWMA smoothing factor for jitter estimation (0.0–1.0).
21    /// Lower = smoother, higher = more responsive.
22    pub jitter_alpha: f64,
23    /// Multiplier for jitter estimate to compute adaptive min depth.
24    pub target_jitter_multiple: f64,
25}
26
27impl Default for JitterConfig {
28    fn default() -> Self {
29        Self {
30            sample_rate: 24000,
31            min_depth_samples: 24000 / 5, // 200ms at 24kHz
32            max_depth_samples: 24000 * 2, // 2 seconds
33            jitter_alpha: 0.125,          // RFC 6298 default
34            target_jitter_multiple: 2.0,
35        }
36    }
37}
38
39impl JitterConfig {
40    /// Create a config for a given sample rate with sensible defaults.
41    pub fn for_sample_rate(sample_rate: u32) -> Self {
42        Self {
43            sample_rate,
44            min_depth_samples: sample_rate as usize / 5,
45            max_depth_samples: sample_rate as usize * 2,
46            ..Default::default()
47        }
48    }
49}
50
51/// Current state of the jitter buffer.
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum BufferState {
54    /// Accumulating initial depth before playback.
55    Filling,
56    /// Normal playback — pulling samples for output.
57    Playing,
58    /// Underrun — generating silence while re-filling.
59    Underrun,
60}
61
62/// Adaptive jitter buffer for audio playback.
63pub struct AudioJitterBuffer {
64    config: JitterConfig,
65    queue: VecDeque<i16>,
66    state: BufferState,
67    /// Smoothed jitter estimate in microseconds.
68    jitter_estimate_us: f64,
69    /// Timestamp of last push.
70    last_arrival: Option<Instant>,
71    /// Total underrun events.
72    underrun_count: u64,
73}
74
75impl AudioJitterBuffer {
76    /// Create a new jitter buffer with the given configuration.
77    pub fn new(config: JitterConfig) -> Self {
78        let initial_capacity = config.max_depth_samples;
79        Self {
80            config,
81            queue: VecDeque::with_capacity(initial_capacity),
82            state: BufferState::Filling,
83            jitter_estimate_us: 0.0,
84            last_arrival: None,
85            underrun_count: 0,
86        }
87    }
88
89    /// Current buffer state.
90    pub fn state(&self) -> BufferState {
91        self.state
92    }
93
94    /// Number of samples currently buffered.
95    pub fn depth(&self) -> usize {
96        self.queue.len()
97    }
98
99    /// Depth in milliseconds.
100    pub fn depth_ms(&self) -> f64 {
101        self.queue.len() as f64 / self.config.sample_rate as f64 * 1000.0
102    }
103
104    /// Total underrun events since creation.
105    pub fn underrun_count(&self) -> u64 {
106        self.underrun_count
107    }
108
109    /// Current smoothed jitter estimate in microseconds.
110    pub fn jitter_estimate_us(&self) -> f64 {
111        self.jitter_estimate_us
112    }
113
114    /// Compute the adaptive minimum depth based on measured jitter.
115    fn adaptive_min_depth(&self) -> usize {
116        let jitter_samples = (self.jitter_estimate_us / 1_000_000.0
117            * self.config.sample_rate as f64
118            * self.config.target_jitter_multiple) as usize;
119        jitter_samples.max(self.config.min_depth_samples)
120    }
121
122    /// Push audio samples into the buffer (called when network data arrives).
123    pub fn push(&mut self, samples: &[i16]) {
124        // Update jitter estimate
125        let now = Instant::now();
126        if let Some(last) = self.last_arrival {
127            let interval_us = now.duration_since(last).as_micros() as f64;
128            // EWMA jitter update (RFC 6298 style)
129            let deviation = (interval_us - self.jitter_estimate_us).abs();
130            self.jitter_estimate_us = self.jitter_estimate_us * (1.0 - self.config.jitter_alpha)
131                + deviation * self.config.jitter_alpha;
132        }
133        self.last_arrival = Some(now);
134
135        // Enforce max depth — drop oldest if overflow
136        let total_after = self.queue.len() + samples.len();
137        if total_after > self.config.max_depth_samples {
138            let to_drop = total_after - self.config.max_depth_samples;
139            self.queue.drain(..to_drop.min(self.queue.len()));
140        }
141
142        self.queue.extend(samples.iter());
143
144        // State transitions
145        if (self.state == BufferState::Filling || self.state == BufferState::Underrun)
146            && self.queue.len() >= self.adaptive_min_depth()
147        {
148            self.state = BufferState::Playing;
149        }
150    }
151
152    /// Pull audio samples for playback.
153    ///
154    /// Fills `out` with audio data. If the buffer underruns, fills remaining
155    /// slots with silence (zero) for click-free output.
156    ///
157    /// Returns the number of real (non-silence) samples written.
158    pub fn pull(&mut self, out: &mut [i16]) -> usize {
159        match self.state {
160            BufferState::Filling => {
161                // Not ready yet — fill with silence
162                out.fill(0);
163                0
164            }
165            BufferState::Playing | BufferState::Underrun => {
166                let available = self.queue.len().min(out.len());
167                for (i, sample) in self.queue.drain(..available).enumerate() {
168                    out[i] = sample;
169                }
170
171                // Fill remainder with silence if underrun
172                if available < out.len() {
173                    out[available..].fill(0);
174                    if self.state == BufferState::Playing {
175                        self.state = BufferState::Underrun;
176                        self.underrun_count += 1;
177                    }
178                } else if self.state == BufferState::Underrun
179                    && self.queue.len() >= self.adaptive_min_depth()
180                {
181                    self.state = BufferState::Playing;
182                }
183
184                available
185            }
186        }
187    }
188
189    /// Flush the buffer immediately (used for barge-in).
190    ///
191    /// Drops all buffered audio and resets to the Filling state.
192    /// This produces instant silence when the user starts speaking.
193    pub fn flush(&mut self) {
194        self.queue.clear();
195        self.state = BufferState::Filling;
196        self.last_arrival = None;
197    }
198
199    /// Reset the buffer completely, including jitter estimates.
200    pub fn reset(&mut self) {
201        self.flush();
202        self.jitter_estimate_us = 0.0;
203        self.underrun_count = 0;
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    fn make_buffer() -> AudioJitterBuffer {
212        AudioJitterBuffer::new(JitterConfig {
213            sample_rate: 16000,
214            min_depth_samples: 1600, // 100ms
215            max_depth_samples: 16000,
216            jitter_alpha: 0.125,
217            target_jitter_multiple: 2.0,
218        })
219    }
220
221    #[test]
222    fn starts_in_filling_state() {
223        let buf = make_buffer();
224        assert_eq!(buf.state(), BufferState::Filling);
225        assert_eq!(buf.depth(), 0);
226    }
227
228    #[test]
229    fn filling_produces_silence() {
230        let mut buf = make_buffer();
231        buf.push(&vec![42i16; 800]); // < min_depth
232
233        let mut out = [0i16; 160];
234        let real = buf.pull(&mut out);
235        assert_eq!(real, 0);
236        assert!(out.iter().all(|&s| s == 0));
237    }
238
239    #[test]
240    fn transitions_to_playing() {
241        let mut buf = make_buffer();
242        buf.push(&vec![100i16; 1600]); // = min_depth
243
244        assert_eq!(buf.state(), BufferState::Playing);
245
246        let mut out = [0i16; 160];
247        let real = buf.pull(&mut out);
248        assert_eq!(real, 160);
249        assert!(out.iter().all(|&s| s == 100));
250    }
251
252    #[test]
253    fn underrun_fills_silence() {
254        let mut buf = make_buffer();
255        buf.push(&vec![99i16; 1600]);
256        assert_eq!(buf.state(), BufferState::Playing);
257
258        // Drain most of the buffer
259        let mut out = [0i16; 1600];
260        buf.pull(&mut out);
261
262        // Now try to pull more — underrun
263        let mut out2 = [0i16; 160];
264        let real = buf.pull(&mut out2);
265        assert_eq!(real, 0);
266        assert_eq!(buf.state(), BufferState::Underrun);
267        assert_eq!(buf.underrun_count(), 1);
268    }
269
270    #[test]
271    fn flush_clears_and_resets() {
272        let mut buf = make_buffer();
273        buf.push(&vec![42i16; 3200]);
274        assert_eq!(buf.state(), BufferState::Playing);
275
276        buf.flush();
277        assert_eq!(buf.state(), BufferState::Filling);
278        assert_eq!(buf.depth(), 0);
279    }
280
281    #[test]
282    fn overflow_drops_oldest() {
283        let mut buf = AudioJitterBuffer::new(JitterConfig {
284            sample_rate: 16000,
285            min_depth_samples: 100,
286            max_depth_samples: 500,
287            ..Default::default()
288        });
289
290        buf.push(&vec![1i16; 400]);
291        buf.push(&vec![2i16; 200]); // total 600 > max 500 → drop 100 oldest
292
293        assert!(buf.depth() <= 500);
294
295        // The oldest samples (1s) were dropped, we should get some 1s then 2s
296        let mut out = [0i16; 500];
297        buf.pull(&mut out);
298        // Last 200 should be 2s
299        assert!(out[300..].iter().all(|&s| s == 2));
300    }
301
302    #[test]
303    fn depth_ms_calculation() {
304        let mut buf = make_buffer();
305        buf.push(&vec![0i16; 1600]); // 100ms at 16kHz
306        assert!((buf.depth_ms() - 100.0).abs() < 0.01);
307    }
308}