gemini_genai_rs/transport/
flow.rs

1//! Flow control — token bucket rate limiter for send pacing.
2//!
3//! Prevents overwhelming the WebSocket with audio faster than
4//! the network or server can absorb.
5
6use std::time::{Duration, Instant};
7
8/// Configuration for the token bucket rate limiter.
9#[derive(Debug, Clone)]
10pub struct FlowConfig {
11    /// Maximum tokens (bytes) in the bucket.
12    pub bucket_capacity: usize,
13    /// Refill rate in bytes per second.
14    pub refill_rate_bps: usize,
15}
16
17impl Default for FlowConfig {
18    fn default() -> Self {
19        // Default: 256 kbps (16kHz × 16-bit PCM)
20        Self {
21            bucket_capacity: 64_000, // ~250ms burst allowance
22            refill_rate_bps: 32_000, // 16kHz × 2 bytes per sample
23        }
24    }
25}
26
27/// Token bucket rate limiter for send pacing.
28///
29/// Allows bursts up to `bucket_capacity` bytes, then throttles
30/// to `refill_rate_bps` sustained rate.
31pub struct TokenBucket {
32    config: FlowConfig,
33    /// Current token count.
34    tokens: f64,
35    /// Last refill timestamp.
36    last_refill: Instant,
37}
38
39impl TokenBucket {
40    /// Create a new token bucket with the given configuration.
41    pub fn new(config: FlowConfig) -> Self {
42        let tokens = config.bucket_capacity as f64;
43        Self {
44            config,
45            tokens,
46            last_refill: Instant::now(),
47        }
48    }
49
50    /// Try to consume `n` tokens. Returns `true` if allowed, `false` if rate-limited.
51    pub fn try_consume(&mut self, n: usize) -> bool {
52        self.refill();
53        if self.tokens >= n as f64 {
54            self.tokens -= n as f64;
55            true
56        } else {
57            false
58        }
59    }
60
61    /// How long to wait before `n` tokens are available.
62    pub fn wait_duration(&mut self, n: usize) -> Duration {
63        self.refill();
64        if self.tokens >= n as f64 {
65            Duration::ZERO
66        } else {
67            let deficit = n as f64 - self.tokens;
68            let secs = deficit / self.config.refill_rate_bps as f64;
69            Duration::from_secs_f64(secs)
70        }
71    }
72
73    /// Consume tokens, waiting if necessary.
74    pub async fn consume(&mut self, n: usize) {
75        let wait = self.wait_duration(n);
76        if !wait.is_zero() {
77            tokio::time::sleep(wait).await;
78        }
79        self.refill();
80        self.tokens -= n as f64;
81        // Tokens can go negative after a long wait; that's fine,
82        // subsequent calls will wait proportionally.
83    }
84
85    /// Current available tokens.
86    pub fn available(&mut self) -> usize {
87        self.refill();
88        self.tokens.max(0.0) as usize
89    }
90
91    /// Refill tokens based on elapsed time.
92    fn refill(&mut self) {
93        let now = Instant::now();
94        let elapsed = now.duration_since(self.last_refill);
95        self.last_refill = now;
96
97        let added = elapsed.as_secs_f64() * self.config.refill_rate_bps as f64;
98        self.tokens = (self.tokens + added).min(self.config.bucket_capacity as f64);
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn initial_burst() {
108        let mut bucket = TokenBucket::new(FlowConfig {
109            bucket_capacity: 1000,
110            refill_rate_bps: 100,
111        });
112
113        // Should allow up to capacity
114        assert!(bucket.try_consume(1000));
115        // Should be empty now
116        assert!(!bucket.try_consume(1));
117    }
118
119    #[test]
120    fn refill_over_time() {
121        let mut bucket = TokenBucket::new(FlowConfig {
122            bucket_capacity: 1000,
123            refill_rate_bps: 1000,
124        });
125
126        bucket.try_consume(1000); // drain it
127        assert!(!bucket.try_consume(1));
128
129        // Manually advance time by setting last_refill in the past
130        bucket.last_refill = Instant::now() - Duration::from_secs(1);
131        assert!(bucket.try_consume(500)); // ~1000 tokens refilled
132    }
133
134    #[test]
135    fn wait_duration_calculation() {
136        let mut bucket = TokenBucket::new(FlowConfig {
137            bucket_capacity: 1000,
138            refill_rate_bps: 100,
139        });
140
141        bucket.try_consume(1000); // drain
142        let wait = bucket.wait_duration(100);
143        // Need 100 tokens at 100/sec = 1 second
144        assert!(wait >= Duration::from_millis(900));
145        assert!(wait <= Duration::from_millis(1100));
146    }
147}