gemini_adk_rs/live/
effect_executor.rs

1//! Execute typed Live reactor effects against a session writer.
2
3use std::sync::Arc;
4
5use gemini_genai_rs::prelude::SessionPhase;
6use gemini_genai_rs::session::{SessionError, SessionWriter};
7use tokio::sync::broadcast;
8
9use super::context_writer::PendingContext;
10use super::events::LiveEvent;
11use super::reactor::{EffectMode, LiveEffect, Reaction};
12
13/// Executes [`LiveEffect`] values emitted by the Live reactor.
14#[derive(Clone)]
15pub struct LiveEffectExecutor {
16    writer: Arc<dyn SessionWriter>,
17    pending_context: Option<Arc<PendingContext>>,
18    event_tx: broadcast::Sender<LiveEvent>,
19}
20
21impl LiveEffectExecutor {
22    /// Create an executor backed by a session writer.
23    pub fn new(
24        writer: Arc<dyn SessionWriter>,
25        pending_context: Option<Arc<PendingContext>>,
26        event_tx: broadcast::Sender<LiveEvent>,
27    ) -> Self {
28        Self {
29            writer,
30            pending_context,
31            event_tx,
32        }
33    }
34
35    /// Execute a list of policy-wrapped reactions.
36    pub async fn execute_reactions(&self, reactions: Vec<Reaction>) -> Result<(), SessionError> {
37        for reaction in reactions {
38            match reaction.policy.mode {
39                EffectMode::Blocking => {
40                    let executor = self.clone();
41                    let fut = executor.execute(reaction.effect);
42                    if let Some(timeout) = reaction.policy.timeout {
43                        tokio::time::timeout(timeout, fut).await.map_err(|_| {
44                            SessionError::Timeout {
45                                phase: SessionPhase::Active,
46                                elapsed: timeout,
47                            }
48                        })??;
49                    } else {
50                        fut.await?;
51                    }
52                }
53                EffectMode::Concurrent => {
54                    let executor = self.clone();
55                    let timeout = reaction.policy.timeout;
56                    let source = reaction.source;
57                    let effect = reaction.effect;
58                    tokio::spawn(async move {
59                        let result = match timeout {
60                            Some(timeout) => {
61                                tokio::time::timeout(timeout, executor.execute(effect))
62                                    .await
63                                    .unwrap_or(Err(SessionError::Timeout {
64                                        phase: SessionPhase::Active,
65                                        elapsed: timeout,
66                                    }))
67                            }
68                            None => executor.execute(effect).await,
69                        };
70                        // Supervise: surface concurrent failures rather than
71                        // silently dropping them.
72                        if let Err(err) = result {
73                            let _ = executor.event_tx.send(LiveEvent::Error(format!(
74                                "reaction '{source}' failed: {err}"
75                            )));
76                        }
77                    });
78                }
79            }
80        }
81        Ok(())
82    }
83
84    /// Execute one typed effect.
85    pub async fn execute(&self, effect: LiveEffect) -> Result<(), SessionError> {
86        match effect {
87            LiveEffect::Noop => Ok(()),
88            LiveEffect::SendContext(contents) => {
89                if !contents.is_empty() {
90                    self.writer.send_client_content(contents, false).await?;
91                }
92                Ok(())
93            }
94            LiveEffect::PromptModel => self.flush_deferred_prompt().await,
95            LiveEffect::CancelDeferredPrompt => {
96                if let Some(pending) = &self.pending_context {
97                    pending.clear_prompt();
98                }
99                Ok(())
100            }
101            LiveEffect::SignalUserActivityStart => self.writer.signal_activity_start().await,
102            LiveEffect::SignalUserActivityEnd => self.writer.signal_activity_end().await,
103            LiveEffect::UpdateInstruction(instruction) => {
104                self.writer.update_instruction(instruction).await
105            }
106            LiveEffect::Emit(event) => {
107                let _ = self.event_tx.send(event);
108                Ok(())
109            }
110        }
111    }
112
113    /// Flush deferred context and an armed prompt.
114    ///
115    /// This is intentionally gated by [`PendingContext::take_prompt`], so a
116    /// playback-drained event cannot trigger a new empty model turn unless the
117    /// control plane explicitly armed one.
118    pub async fn flush_deferred_prompt(&self) -> Result<(), SessionError> {
119        let Some(pending) = &self.pending_context else {
120            return Ok(());
121        };
122
123        let contents = pending.drain_context();
124        if !contents.is_empty() {
125            self.writer.send_client_content(contents, false).await?;
126        }
127        if pending.take_prompt() {
128            self.writer.send_client_content(vec![], true).await?;
129        }
130        Ok(())
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use async_trait::async_trait;
138    use gemini_genai_rs::prelude::{Content, FunctionResponse};
139    use parking_lot::Mutex;
140
141    #[derive(Debug, Clone, PartialEq, Eq)]
142    enum Write {
143        ClientContent { turns: usize, turn_complete: bool },
144        Instruction(String),
145        ActivityStart,
146        ActivityEnd,
147    }
148
149    #[derive(Default)]
150    struct MockWriter {
151        writes: Mutex<Vec<Write>>,
152    }
153
154    #[async_trait]
155    impl SessionWriter for MockWriter {
156        async fn send_audio(&self, _data: Vec<u8>) -> Result<(), SessionError> {
157            Ok(())
158        }
159
160        async fn send_text(&self, _text: String) -> Result<(), SessionError> {
161            Ok(())
162        }
163
164        async fn send_tool_response(
165            &self,
166            _responses: Vec<FunctionResponse>,
167        ) -> Result<(), SessionError> {
168            Ok(())
169        }
170
171        async fn send_client_content(
172            &self,
173            turns: Vec<Content>,
174            turn_complete: bool,
175        ) -> Result<(), SessionError> {
176            self.writes.lock().push(Write::ClientContent {
177                turns: turns.len(),
178                turn_complete,
179            });
180            Ok(())
181        }
182
183        async fn send_video(&self, _jpeg_data: Vec<u8>) -> Result<(), SessionError> {
184            Ok(())
185        }
186
187        async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
188            self.writes.lock().push(Write::Instruction(instruction));
189            Ok(())
190        }
191
192        async fn signal_activity_start(&self) -> Result<(), SessionError> {
193            self.writes.lock().push(Write::ActivityStart);
194            Ok(())
195        }
196
197        async fn signal_activity_end(&self) -> Result<(), SessionError> {
198            self.writes.lock().push(Write::ActivityEnd);
199            Ok(())
200        }
201
202        async fn disconnect(&self) -> Result<(), SessionError> {
203            Ok(())
204        }
205    }
206
207    #[tokio::test]
208    async fn prompt_model_flushes_context_then_armed_prompt() {
209        let writer = Arc::new(MockWriter::default());
210        let pending = Arc::new(PendingContext::new());
211        pending.push(Content::model("phase context"));
212        pending.set_prompt();
213        let (event_tx, _) = broadcast::channel(8);
214        let executor = LiveEffectExecutor::new(writer.clone(), Some(pending.clone()), event_tx);
215
216        executor.execute(LiveEffect::PromptModel).await.unwrap();
217
218        assert_eq!(
219            writer.writes.lock().as_slice(),
220            &[
221                Write::ClientContent {
222                    turns: 1,
223                    turn_complete: false
224                },
225                Write::ClientContent {
226                    turns: 0,
227                    turn_complete: true
228                }
229            ]
230        );
231        assert!(pending.is_empty());
232    }
233
234    #[tokio::test]
235    async fn prompt_model_without_armed_prompt_only_flushes_context() {
236        let writer = Arc::new(MockWriter::default());
237        let pending = Arc::new(PendingContext::new());
238        pending.push(Content::model("phase context"));
239        let (event_tx, _) = broadcast::channel(8);
240        let executor = LiveEffectExecutor::new(writer.clone(), Some(pending), event_tx);
241
242        executor.execute(LiveEffect::PromptModel).await.unwrap();
243
244        assert_eq!(
245            writer.writes.lock().as_slice(),
246            &[Write::ClientContent {
247                turns: 1,
248                turn_complete: false
249            }]
250        );
251    }
252
253    #[tokio::test]
254    async fn update_instruction_uses_writer() {
255        let writer = Arc::new(MockWriter::default());
256        let (event_tx, _) = broadcast::channel(8);
257        let executor = LiveEffectExecutor::new(writer.clone(), None, event_tx);
258
259        executor
260            .execute(LiveEffect::UpdateInstruction("new instruction".into()))
261            .await
262            .unwrap();
263
264        assert_eq!(
265            writer.writes.lock().as_slice(),
266            &[Write::Instruction("new instruction".into())]
267        );
268    }
269
270    #[tokio::test]
271    async fn cancel_deferred_prompt_keeps_context() {
272        let writer = Arc::new(MockWriter::default());
273        let pending = Arc::new(PendingContext::new());
274        pending.push(Content::model("still useful with user audio"));
275        pending.set_prompt();
276        let (event_tx, _) = broadcast::channel(8);
277        let executor = LiveEffectExecutor::new(writer, Some(pending.clone()), event_tx);
278
279        executor
280            .execute(LiveEffect::CancelDeferredPrompt)
281            .await
282            .unwrap();
283
284        assert!(!pending.has_prompt());
285        assert_eq!(pending.drain_context().len(), 1);
286    }
287
288    #[tokio::test]
289    async fn user_activity_effects_signal_writer() {
290        let writer = Arc::new(MockWriter::default());
291        let (event_tx, _) = broadcast::channel(8);
292        let executor = LiveEffectExecutor::new(writer.clone(), None, event_tx);
293
294        executor
295            .execute_reactions(vec![
296                Reaction::blocking("test", LiveEffect::SignalUserActivityStart),
297                Reaction::blocking("test", LiveEffect::SignalUserActivityEnd),
298            ])
299            .await
300            .unwrap();
301
302        assert_eq!(
303            writer.writes.lock().as_slice(),
304            &[Write::ActivityStart, Write::ActivityEnd]
305        );
306    }
307
308    #[tokio::test]
309    async fn concurrent_effect_failure_is_surfaced_as_event() {
310        struct FailWriter;
311        #[async_trait]
312        impl SessionWriter for FailWriter {
313            async fn send_audio(&self, _: Vec<u8>) -> Result<(), SessionError> {
314                Ok(())
315            }
316            async fn send_text(&self, _: String) -> Result<(), SessionError> {
317                Ok(())
318            }
319            async fn send_tool_response(
320                &self,
321                _: Vec<FunctionResponse>,
322            ) -> Result<(), SessionError> {
323                Ok(())
324            }
325            async fn send_client_content(
326                &self,
327                _: Vec<Content>,
328                _: bool,
329            ) -> Result<(), SessionError> {
330                Err(SessionError::NotConnected)
331            }
332            async fn send_video(&self, _: Vec<u8>) -> Result<(), SessionError> {
333                Ok(())
334            }
335            async fn update_instruction(&self, _: String) -> Result<(), SessionError> {
336                Ok(())
337            }
338            async fn signal_activity_start(&self) -> Result<(), SessionError> {
339                Ok(())
340            }
341            async fn signal_activity_end(&self) -> Result<(), SessionError> {
342                Ok(())
343            }
344            async fn disconnect(&self) -> Result<(), SessionError> {
345                Ok(())
346            }
347        }
348
349        let (event_tx, mut rx) = broadcast::channel(8);
350        let executor = LiveEffectExecutor::new(Arc::new(FailWriter), None, event_tx);
351
352        // A concurrent effect that fails must surface as a LiveEvent, not vanish.
353        executor
354            .execute_reactions(vec![Reaction::concurrent(
355                "test",
356                LiveEffect::SendContext(vec![Content::model("x")]),
357            )])
358            .await
359            .unwrap();
360
361        let event = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
362            .await
363            .expect("a reaction-failure event within the timeout")
364            .expect("event received");
365        assert!(
366            matches!(&event, LiveEvent::Error(msg) if msg.contains("reaction 'test' failed")),
367            "expected a reaction-failure error event, got {event:?}"
368        );
369    }
370}