gemini_adk_rs/live/
context_writer.rs

1//! Deferred context delivery — flush pending context alongside user content.
2//!
3//! When the control plane produces model-role context turns (tool advisory,
4//! repair nudge, steering modifiers, phase instructions, on_enter_context),
5//! they can be queued in a [`PendingContext`] buffer instead of sent immediately.
6//!
7//! [`DeferredWriter`] wraps any [`SessionWriter`] and transparently drains the
8//! pending queue before forwarding user-initiated sends (`send_audio`,
9//! `send_text`, `send_video`).  This ensures context arrives in the same burst
10//! as user content rather than as isolated WebSocket frames that can confuse
11//! the model or clash with concurrent user input.
12//!
13//! # Architecture
14//!
15//! ```text
16//!   Control lane (lifecycle)         User code (LiveHandle)
17//!          |                                |
18//!   push context to                  send_audio / send_text
19//!   PendingContext                          |
20//!          |                         DeferredWriter
21//!          v                          1. drain PendingContext
22//!   +---------------+                2. send_client_content(drained, false)
23//!   | PendingContext | <-- drain ---  3. forward original send
24//!   +---------------+
25//! ```
26//!
27//! The queue uses `parking_lot::Mutex` for fast, uncontested locking — the
28//! control lane pushes once per turn, and user sends drain before each frame.
29
30use std::sync::Arc;
31
32use async_trait::async_trait;
33use parking_lot::Mutex;
34
35use gemini_genai_rs::prelude::{Content, FunctionResponse};
36use gemini_genai_rs::session::{SessionError, SessionWriter};
37
38/// Thread-safe buffer for pending context turns awaiting delivery.
39///
40/// Context is queued by the control plane (lifecycle steps 7d/7e/7f/12/13)
41/// and drained by [`DeferredWriter`] before the next user interaction.
42///
43/// # Thread safety
44///
45/// Uses `parking_lot::Mutex` — fast uncontested locking, no poisoning.
46/// The control lane pushes once per turn; user sends drain once per frame.
47/// Contention is near-zero.
48pub struct PendingContext {
49    buffer: Mutex<Vec<Content>>,
50    /// Whether a prompt (turnComplete:true) should be sent after flushing.
51    prompt: Mutex<bool>,
52}
53
54impl PendingContext {
55    /// Create an empty pending context buffer.
56    pub fn new() -> Self {
57        Self {
58            buffer: Mutex::new(Vec::new()),
59            prompt: Mutex::new(false),
60        }
61    }
62
63    /// Push a single context turn into the buffer.
64    pub fn push(&self, content: Content) {
65        self.buffer.lock().push(content);
66    }
67
68    /// Push multiple context turns into the buffer.
69    pub fn extend(&self, contents: Vec<Content>) {
70        if !contents.is_empty() {
71            self.buffer.lock().extend(contents);
72        }
73    }
74
75    /// Mark that a prompt (turnComplete:true) should follow the next flush.
76    pub fn set_prompt(&self) {
77        *self.prompt.lock() = true;
78    }
79
80    /// Drain all pending context, returning the contents and whether to prompt.
81    ///
82    /// After this call, the buffer is empty and the prompt flag is cleared.
83    pub fn drain(&self) -> (Vec<Content>, bool) {
84        let contents = self.drain_context();
85        let prompt = self.take_prompt();
86        (contents, prompt)
87    }
88
89    /// Drain only context turns, leaving any pending prompt armed.
90    pub fn drain_context(&self) -> Vec<Content> {
91        let contents = {
92            let mut buf = self.buffer.lock();
93            std::mem::take(&mut *buf)
94        };
95        contents
96    }
97
98    /// Take and clear the pending prompt flag without touching queued context.
99    pub fn take_prompt(&self) -> bool {
100        let mut p = self.prompt.lock();
101        std::mem::replace(&mut *p, false)
102    }
103
104    /// Clear any armed prompt without touching queued context.
105    pub fn clear_prompt(&self) {
106        *self.prompt.lock() = false;
107    }
108
109    /// Return whether a prompt is currently armed.
110    pub fn has_prompt(&self) -> bool {
111        *self.prompt.lock()
112    }
113
114    /// Check if the buffer is empty (no pending context or prompt).
115    pub fn is_empty(&self) -> bool {
116        self.buffer.lock().is_empty() && !*self.prompt.lock()
117    }
118}
119
120impl Default for PendingContext {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126/// A [`SessionWriter`] wrapper that flushes pending context before user content.
127///
128/// Wraps an inner writer and drains a shared [`PendingContext`] buffer before
129/// forwarding `send_audio`, `send_text`, or `send_video` calls.  This ensures
130/// model-role context turns arrive in the same burst as user content.
131///
132/// # When context is flushed
133///
134/// - **`send_audio`**: Context is flushed as `send_client_content(drained, false)`
135///   immediately before the audio frame.  Audio goes via `realtimeInput` (different
136///   wire message), so they are two frames — but sent back-to-back with no gap.
137///
138/// - **`send_text`**: Context is flushed, then user text is sent.  Both go via
139///   `clientContent`, but as separate messages since the user text needs
140///   `turn_complete: true` to trigger a model response.
141///
142/// - **`send_video`**: Same as audio — flush then forward.
143///
144/// # When context is NOT flushed
145///
146/// `send_tool_response`, `update_instruction`, `send_client_content`,
147/// `signal_activity_start/end`, and `disconnect` do NOT trigger a flush.
148/// These are either internal SDK operations or explicit user control — flushing
149/// context before them would be surprising.
150pub struct DeferredWriter {
151    inner: Arc<dyn SessionWriter>,
152    pending: Arc<PendingContext>,
153}
154
155impl DeferredWriter {
156    /// Create a new deferred writer wrapping the given writer.
157    pub fn new(inner: Arc<dyn SessionWriter>, pending: Arc<PendingContext>) -> Self {
158        Self { inner, pending }
159    }
160
161    /// Flush any pending context to the wire without triggering a model prompt.
162    ///
163    /// User sends (audio/text/video) use this context-only flush so queued
164    /// phase prompts cannot make the model speak while the user is speaking.
165    async fn flush_context(&self) -> Result<(), SessionError> {
166        let contents = self.pending.drain_context();
167        if !contents.is_empty() {
168            self.inner.send_client_content(contents, false).await?;
169        }
170        Ok(())
171    }
172
173    /// Get a reference to the shared pending context buffer.
174    pub fn pending(&self) -> &Arc<PendingContext> {
175        &self.pending
176    }
177}
178
179#[async_trait]
180impl SessionWriter for DeferredWriter {
181    async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
182        self.flush_context().await?;
183        self.inner.send_audio(data).await
184    }
185
186    async fn send_text(&self, text: String) -> Result<(), SessionError> {
187        self.flush_context().await?;
188        self.inner.send_text(text).await
189    }
190
191    async fn send_tool_response(
192        &self,
193        responses: Vec<FunctionResponse>,
194    ) -> Result<(), SessionError> {
195        // Tool responses are SDK-internal — don't flush context here.
196        self.inner.send_tool_response(responses).await
197    }
198
199    async fn send_client_content(
200        &self,
201        turns: Vec<Content>,
202        turn_complete: bool,
203    ) -> Result<(), SessionError> {
204        // Explicit client content calls pass through unchanged.
205        // The caller knows what they're doing.
206        self.inner.send_client_content(turns, turn_complete).await
207    }
208
209    async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
210        self.flush_context().await?;
211        self.inner.send_video(jpeg_data).await
212    }
213
214    async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
215        // Instruction updates are SDK-internal — don't flush context here.
216        self.inner.update_instruction(instruction).await
217    }
218
219    async fn signal_activity_start(&self) -> Result<(), SessionError> {
220        self.inner.signal_activity_start().await
221    }
222
223    async fn signal_activity_end(&self) -> Result<(), SessionError> {
224        self.inner.signal_activity_end().await
225    }
226
227    async fn disconnect(&self) -> Result<(), SessionError> {
228        // Flush any remaining context before disconnecting so it's not lost.
229        let _ = self.flush_context().await;
230        self.inner.disconnect().await
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use std::sync::atomic::{AtomicUsize, Ordering};
238
239    /// Minimal writer that counts calls by type.
240    struct CountingWriter {
241        audio_count: AtomicUsize,
242        text_count: AtomicUsize,
243        client_content_count: AtomicUsize,
244        video_count: AtomicUsize,
245    }
246
247    impl CountingWriter {
248        fn new() -> Self {
249            Self {
250                audio_count: AtomicUsize::new(0),
251                text_count: AtomicUsize::new(0),
252                client_content_count: AtomicUsize::new(0),
253                video_count: AtomicUsize::new(0),
254            }
255        }
256    }
257
258    #[async_trait]
259    impl SessionWriter for CountingWriter {
260        async fn send_audio(&self, _: Vec<u8>) -> Result<(), SessionError> {
261            self.audio_count.fetch_add(1, Ordering::SeqCst);
262            Ok(())
263        }
264        async fn send_text(&self, _: String) -> Result<(), SessionError> {
265            self.text_count.fetch_add(1, Ordering::SeqCst);
266            Ok(())
267        }
268        async fn send_tool_response(&self, _: Vec<FunctionResponse>) -> Result<(), SessionError> {
269            Ok(())
270        }
271        async fn send_client_content(&self, _: Vec<Content>, _: bool) -> Result<(), SessionError> {
272            self.client_content_count.fetch_add(1, Ordering::SeqCst);
273            Ok(())
274        }
275        async fn send_video(&self, _: Vec<u8>) -> Result<(), SessionError> {
276            self.video_count.fetch_add(1, Ordering::SeqCst);
277            Ok(())
278        }
279        async fn update_instruction(&self, _: String) -> Result<(), SessionError> {
280            Ok(())
281        }
282        async fn signal_activity_start(&self) -> Result<(), SessionError> {
283            Ok(())
284        }
285        async fn signal_activity_end(&self) -> Result<(), SessionError> {
286            Ok(())
287        }
288        async fn disconnect(&self) -> Result<(), SessionError> {
289            Ok(())
290        }
291    }
292
293    #[test]
294    fn pending_context_push_and_drain() {
295        let pc = PendingContext::new();
296        assert!(pc.is_empty());
297
298        pc.push(Content::model("context 1"));
299        pc.push(Content::model("context 2"));
300        assert!(!pc.is_empty());
301
302        let (contents, prompt) = pc.drain();
303        assert_eq!(contents.len(), 2);
304        assert!(!prompt);
305        assert!(pc.is_empty());
306    }
307
308    #[test]
309    fn pending_context_extend() {
310        let pc = PendingContext::new();
311        pc.extend(vec![
312            Content::model("a"),
313            Content::model("b"),
314            Content::model("c"),
315        ]);
316        let (contents, _) = pc.drain();
317        assert_eq!(contents.len(), 3);
318    }
319
320    #[test]
321    fn pending_context_prompt_flag() {
322        let pc = PendingContext::new();
323        pc.push(Content::model("ctx"));
324        pc.set_prompt();
325        assert!(!pc.is_empty());
326
327        let (contents, prompt) = pc.drain();
328        assert_eq!(contents.len(), 1);
329        assert!(prompt);
330        assert!(pc.is_empty());
331    }
332
333    #[test]
334    fn pending_context_drain_clears() {
335        let pc = PendingContext::new();
336        pc.push(Content::model("a"));
337        pc.set_prompt();
338        let _ = pc.drain();
339
340        // Second drain should be empty
341        let (contents, prompt) = pc.drain();
342        assert!(contents.is_empty());
343        assert!(!prompt);
344    }
345
346    #[tokio::test]
347    async fn deferred_writer_flushes_on_send_audio() {
348        let inner = Arc::new(CountingWriter::new());
349        let pending = Arc::new(PendingContext::new());
350        let writer = DeferredWriter::new(inner.clone(), pending.clone());
351
352        pending.push(Content::model("steering context"));
353        pending.push(Content::model("phase instruction"));
354
355        writer.send_audio(vec![0u8; 100]).await.unwrap();
356
357        // Should have flushed: 1 client_content + 1 audio
358        assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
359        assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
360        assert!(pending.is_empty());
361    }
362
363    #[tokio::test]
364    async fn deferred_writer_flushes_on_send_text() {
365        let inner = Arc::new(CountingWriter::new());
366        let pending = Arc::new(PendingContext::new());
367        let writer = DeferredWriter::new(inner.clone(), pending.clone());
368
369        pending.push(Content::model("context"));
370
371        writer.send_text("hello".into()).await.unwrap();
372
373        assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
374        assert_eq!(inner.text_count.load(Ordering::SeqCst), 1);
375    }
376
377    #[tokio::test]
378    async fn deferred_writer_flushes_on_send_video() {
379        let inner = Arc::new(CountingWriter::new());
380        let pending = Arc::new(PendingContext::new());
381        let writer = DeferredWriter::new(inner.clone(), pending.clone());
382
383        pending.push(Content::model("context"));
384
385        writer.send_video(vec![0xFFu8; 50]).await.unwrap();
386
387        assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
388        assert_eq!(inner.video_count.load(Ordering::SeqCst), 1);
389    }
390
391    #[tokio::test]
392    async fn deferred_writer_no_flush_when_empty() {
393        let inner = Arc::new(CountingWriter::new());
394        let pending = Arc::new(PendingContext::new());
395        let writer = DeferredWriter::new(inner.clone(), pending.clone());
396
397        // No pending context — should just send audio, no client_content
398        writer.send_audio(vec![0u8; 100]).await.unwrap();
399
400        assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 0);
401        assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
402    }
403
404    #[tokio::test]
405    async fn deferred_writer_keeps_prompt_pending_on_user_audio() {
406        let inner = Arc::new(CountingWriter::new());
407        let pending = Arc::new(PendingContext::new());
408        let writer = DeferredWriter::new(inner.clone(), pending.clone());
409
410        pending.push(Content::model("repair nudge"));
411        pending.set_prompt();
412
413        writer.send_audio(vec![0u8; 100]).await.unwrap();
414
415        // User audio only flushes context. Prompt remains armed until an
416        // explicit idle/playback-drained flush.
417        assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
418        assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
419        assert!(pending.is_empty() == false);
420        assert!(pending.take_prompt());
421    }
422
423    #[tokio::test]
424    async fn deferred_writer_does_not_flush_on_tool_response() {
425        let inner = Arc::new(CountingWriter::new());
426        let pending = Arc::new(PendingContext::new());
427        let writer = DeferredWriter::new(inner.clone(), pending.clone());
428
429        pending.push(Content::model("context"));
430
431        writer.send_tool_response(vec![]).await.unwrap();
432
433        // Tool response should NOT flush — context still pending
434        assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 0);
435        assert!(!pending.is_empty());
436    }
437
438    #[tokio::test]
439    async fn deferred_writer_client_content_passes_through() {
440        let inner = Arc::new(CountingWriter::new());
441        let pending = Arc::new(PendingContext::new());
442        let writer = DeferredWriter::new(inner.clone(), pending.clone());
443
444        pending.push(Content::model("queued context"));
445
446        // Explicit client_content should pass through without flushing
447        writer
448            .send_client_content(vec![Content::user("explicit")], true)
449            .await
450            .unwrap();
451
452        assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
453        // Queued context still pending
454        assert!(!pending.is_empty());
455    }
456}