1use std::sync::Arc;
31
32use async_trait::async_trait;
33use parking_lot::Mutex;
34
35use gemini_genai_rs::prelude::{Content, FunctionResponse};
36use gemini_genai_rs::session::{SessionError, SessionWriter};
37
38pub struct PendingContext {
49 buffer: Mutex<Vec<Content>>,
50 prompt: Mutex<bool>,
52}
53
54impl PendingContext {
55 pub fn new() -> Self {
57 Self {
58 buffer: Mutex::new(Vec::new()),
59 prompt: Mutex::new(false),
60 }
61 }
62
63 pub fn push(&self, content: Content) {
65 self.buffer.lock().push(content);
66 }
67
68 pub fn extend(&self, contents: Vec<Content>) {
70 if !contents.is_empty() {
71 self.buffer.lock().extend(contents);
72 }
73 }
74
75 pub fn set_prompt(&self) {
77 *self.prompt.lock() = true;
78 }
79
80 pub fn drain(&self) -> (Vec<Content>, bool) {
84 let contents = self.drain_context();
85 let prompt = self.take_prompt();
86 (contents, prompt)
87 }
88
89 pub fn drain_context(&self) -> Vec<Content> {
91 let contents = {
92 let mut buf = self.buffer.lock();
93 std::mem::take(&mut *buf)
94 };
95 contents
96 }
97
98 pub fn take_prompt(&self) -> bool {
100 let mut p = self.prompt.lock();
101 std::mem::replace(&mut *p, false)
102 }
103
104 pub fn clear_prompt(&self) {
106 *self.prompt.lock() = false;
107 }
108
109 pub fn has_prompt(&self) -> bool {
111 *self.prompt.lock()
112 }
113
114 pub fn is_empty(&self) -> bool {
116 self.buffer.lock().is_empty() && !*self.prompt.lock()
117 }
118}
119
120impl Default for PendingContext {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126pub struct DeferredWriter {
151 inner: Arc<dyn SessionWriter>,
152 pending: Arc<PendingContext>,
153}
154
155impl DeferredWriter {
156 pub fn new(inner: Arc<dyn SessionWriter>, pending: Arc<PendingContext>) -> Self {
158 Self { inner, pending }
159 }
160
161 async fn flush_context(&self) -> Result<(), SessionError> {
166 let contents = self.pending.drain_context();
167 if !contents.is_empty() {
168 self.inner.send_client_content(contents, false).await?;
169 }
170 Ok(())
171 }
172
173 pub fn pending(&self) -> &Arc<PendingContext> {
175 &self.pending
176 }
177}
178
179#[async_trait]
180impl SessionWriter for DeferredWriter {
181 async fn send_audio(&self, data: Vec<u8>) -> Result<(), SessionError> {
182 self.flush_context().await?;
183 self.inner.send_audio(data).await
184 }
185
186 async fn send_text(&self, text: String) -> Result<(), SessionError> {
187 self.flush_context().await?;
188 self.inner.send_text(text).await
189 }
190
191 async fn send_tool_response(
192 &self,
193 responses: Vec<FunctionResponse>,
194 ) -> Result<(), SessionError> {
195 self.inner.send_tool_response(responses).await
197 }
198
199 async fn send_client_content(
200 &self,
201 turns: Vec<Content>,
202 turn_complete: bool,
203 ) -> Result<(), SessionError> {
204 self.inner.send_client_content(turns, turn_complete).await
207 }
208
209 async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), SessionError> {
210 self.flush_context().await?;
211 self.inner.send_video(jpeg_data).await
212 }
213
214 async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
215 self.inner.update_instruction(instruction).await
217 }
218
219 async fn signal_activity_start(&self) -> Result<(), SessionError> {
220 self.inner.signal_activity_start().await
221 }
222
223 async fn signal_activity_end(&self) -> Result<(), SessionError> {
224 self.inner.signal_activity_end().await
225 }
226
227 async fn disconnect(&self) -> Result<(), SessionError> {
228 let _ = self.flush_context().await;
230 self.inner.disconnect().await
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use std::sync::atomic::{AtomicUsize, Ordering};
238
239 struct CountingWriter {
241 audio_count: AtomicUsize,
242 text_count: AtomicUsize,
243 client_content_count: AtomicUsize,
244 video_count: AtomicUsize,
245 }
246
247 impl CountingWriter {
248 fn new() -> Self {
249 Self {
250 audio_count: AtomicUsize::new(0),
251 text_count: AtomicUsize::new(0),
252 client_content_count: AtomicUsize::new(0),
253 video_count: AtomicUsize::new(0),
254 }
255 }
256 }
257
258 #[async_trait]
259 impl SessionWriter for CountingWriter {
260 async fn send_audio(&self, _: Vec<u8>) -> Result<(), SessionError> {
261 self.audio_count.fetch_add(1, Ordering::SeqCst);
262 Ok(())
263 }
264 async fn send_text(&self, _: String) -> Result<(), SessionError> {
265 self.text_count.fetch_add(1, Ordering::SeqCst);
266 Ok(())
267 }
268 async fn send_tool_response(&self, _: Vec<FunctionResponse>) -> Result<(), SessionError> {
269 Ok(())
270 }
271 async fn send_client_content(&self, _: Vec<Content>, _: bool) -> Result<(), SessionError> {
272 self.client_content_count.fetch_add(1, Ordering::SeqCst);
273 Ok(())
274 }
275 async fn send_video(&self, _: Vec<u8>) -> Result<(), SessionError> {
276 self.video_count.fetch_add(1, Ordering::SeqCst);
277 Ok(())
278 }
279 async fn update_instruction(&self, _: String) -> Result<(), SessionError> {
280 Ok(())
281 }
282 async fn signal_activity_start(&self) -> Result<(), SessionError> {
283 Ok(())
284 }
285 async fn signal_activity_end(&self) -> Result<(), SessionError> {
286 Ok(())
287 }
288 async fn disconnect(&self) -> Result<(), SessionError> {
289 Ok(())
290 }
291 }
292
293 #[test]
294 fn pending_context_push_and_drain() {
295 let pc = PendingContext::new();
296 assert!(pc.is_empty());
297
298 pc.push(Content::model("context 1"));
299 pc.push(Content::model("context 2"));
300 assert!(!pc.is_empty());
301
302 let (contents, prompt) = pc.drain();
303 assert_eq!(contents.len(), 2);
304 assert!(!prompt);
305 assert!(pc.is_empty());
306 }
307
308 #[test]
309 fn pending_context_extend() {
310 let pc = PendingContext::new();
311 pc.extend(vec![
312 Content::model("a"),
313 Content::model("b"),
314 Content::model("c"),
315 ]);
316 let (contents, _) = pc.drain();
317 assert_eq!(contents.len(), 3);
318 }
319
320 #[test]
321 fn pending_context_prompt_flag() {
322 let pc = PendingContext::new();
323 pc.push(Content::model("ctx"));
324 pc.set_prompt();
325 assert!(!pc.is_empty());
326
327 let (contents, prompt) = pc.drain();
328 assert_eq!(contents.len(), 1);
329 assert!(prompt);
330 assert!(pc.is_empty());
331 }
332
333 #[test]
334 fn pending_context_drain_clears() {
335 let pc = PendingContext::new();
336 pc.push(Content::model("a"));
337 pc.set_prompt();
338 let _ = pc.drain();
339
340 let (contents, prompt) = pc.drain();
342 assert!(contents.is_empty());
343 assert!(!prompt);
344 }
345
346 #[tokio::test]
347 async fn deferred_writer_flushes_on_send_audio() {
348 let inner = Arc::new(CountingWriter::new());
349 let pending = Arc::new(PendingContext::new());
350 let writer = DeferredWriter::new(inner.clone(), pending.clone());
351
352 pending.push(Content::model("steering context"));
353 pending.push(Content::model("phase instruction"));
354
355 writer.send_audio(vec![0u8; 100]).await.unwrap();
356
357 assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
359 assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
360 assert!(pending.is_empty());
361 }
362
363 #[tokio::test]
364 async fn deferred_writer_flushes_on_send_text() {
365 let inner = Arc::new(CountingWriter::new());
366 let pending = Arc::new(PendingContext::new());
367 let writer = DeferredWriter::new(inner.clone(), pending.clone());
368
369 pending.push(Content::model("context"));
370
371 writer.send_text("hello".into()).await.unwrap();
372
373 assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
374 assert_eq!(inner.text_count.load(Ordering::SeqCst), 1);
375 }
376
377 #[tokio::test]
378 async fn deferred_writer_flushes_on_send_video() {
379 let inner = Arc::new(CountingWriter::new());
380 let pending = Arc::new(PendingContext::new());
381 let writer = DeferredWriter::new(inner.clone(), pending.clone());
382
383 pending.push(Content::model("context"));
384
385 writer.send_video(vec![0xFFu8; 50]).await.unwrap();
386
387 assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
388 assert_eq!(inner.video_count.load(Ordering::SeqCst), 1);
389 }
390
391 #[tokio::test]
392 async fn deferred_writer_no_flush_when_empty() {
393 let inner = Arc::new(CountingWriter::new());
394 let pending = Arc::new(PendingContext::new());
395 let writer = DeferredWriter::new(inner.clone(), pending.clone());
396
397 writer.send_audio(vec![0u8; 100]).await.unwrap();
399
400 assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 0);
401 assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
402 }
403
404 #[tokio::test]
405 async fn deferred_writer_keeps_prompt_pending_on_user_audio() {
406 let inner = Arc::new(CountingWriter::new());
407 let pending = Arc::new(PendingContext::new());
408 let writer = DeferredWriter::new(inner.clone(), pending.clone());
409
410 pending.push(Content::model("repair nudge"));
411 pending.set_prompt();
412
413 writer.send_audio(vec![0u8; 100]).await.unwrap();
414
415 assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
418 assert_eq!(inner.audio_count.load(Ordering::SeqCst), 1);
419 assert!(pending.is_empty() == false);
420 assert!(pending.take_prompt());
421 }
422
423 #[tokio::test]
424 async fn deferred_writer_does_not_flush_on_tool_response() {
425 let inner = Arc::new(CountingWriter::new());
426 let pending = Arc::new(PendingContext::new());
427 let writer = DeferredWriter::new(inner.clone(), pending.clone());
428
429 pending.push(Content::model("context"));
430
431 writer.send_tool_response(vec![]).await.unwrap();
432
433 assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 0);
435 assert!(!pending.is_empty());
436 }
437
438 #[tokio::test]
439 async fn deferred_writer_client_content_passes_through() {
440 let inner = Arc::new(CountingWriter::new());
441 let pending = Arc::new(PendingContext::new());
442 let writer = DeferredWriter::new(inner.clone(), pending.clone());
443
444 pending.push(Content::model("queued context"));
445
446 writer
448 .send_client_content(vec![Content::user("explicit")], true)
449 .await
450 .unwrap();
451
452 assert_eq!(inner.client_content_count.load(Ordering::SeqCst), 1);
453 assert!(!pending.is_empty());
455 }
456}