1use std::sync::Arc;
4
5use gemini_genai_rs::prelude::SessionPhase;
6use gemini_genai_rs::session::{SessionError, SessionWriter};
7use tokio::sync::broadcast;
8
9use super::context_writer::PendingContext;
10use super::events::LiveEvent;
11use super::reactor::{EffectMode, LiveEffect, Reaction};
12
13#[derive(Clone)]
15pub struct LiveEffectExecutor {
16 writer: Arc<dyn SessionWriter>,
17 pending_context: Option<Arc<PendingContext>>,
18 event_tx: broadcast::Sender<LiveEvent>,
19}
20
21impl LiveEffectExecutor {
22 pub fn new(
24 writer: Arc<dyn SessionWriter>,
25 pending_context: Option<Arc<PendingContext>>,
26 event_tx: broadcast::Sender<LiveEvent>,
27 ) -> Self {
28 Self {
29 writer,
30 pending_context,
31 event_tx,
32 }
33 }
34
35 pub async fn execute_reactions(&self, reactions: Vec<Reaction>) -> Result<(), SessionError> {
37 for reaction in reactions {
38 match reaction.policy.mode {
39 EffectMode::Blocking => {
40 let executor = self.clone();
41 let fut = executor.execute(reaction.effect);
42 if let Some(timeout) = reaction.policy.timeout {
43 tokio::time::timeout(timeout, fut).await.map_err(|_| {
44 SessionError::Timeout {
45 phase: SessionPhase::Active,
46 elapsed: timeout,
47 }
48 })??;
49 } else {
50 fut.await?;
51 }
52 }
53 EffectMode::Concurrent => {
54 let executor = self.clone();
55 let timeout = reaction.policy.timeout;
56 let source = reaction.source;
57 let effect = reaction.effect;
58 tokio::spawn(async move {
59 let result = match timeout {
60 Some(timeout) => {
61 tokio::time::timeout(timeout, executor.execute(effect))
62 .await
63 .unwrap_or(Err(SessionError::Timeout {
64 phase: SessionPhase::Active,
65 elapsed: timeout,
66 }))
67 }
68 None => executor.execute(effect).await,
69 };
70 if let Err(err) = result {
73 let _ = executor.event_tx.send(LiveEvent::Error(format!(
74 "reaction '{source}' failed: {err}"
75 )));
76 }
77 });
78 }
79 }
80 }
81 Ok(())
82 }
83
84 pub async fn execute(&self, effect: LiveEffect) -> Result<(), SessionError> {
86 match effect {
87 LiveEffect::Noop => Ok(()),
88 LiveEffect::SendContext(contents) => {
89 if !contents.is_empty() {
90 self.writer.send_client_content(contents, false).await?;
91 }
92 Ok(())
93 }
94 LiveEffect::PromptModel => self.flush_deferred_prompt().await,
95 LiveEffect::CancelDeferredPrompt => {
96 if let Some(pending) = &self.pending_context {
97 pending.clear_prompt();
98 }
99 Ok(())
100 }
101 LiveEffect::SignalUserActivityStart => self.writer.signal_activity_start().await,
102 LiveEffect::SignalUserActivityEnd => self.writer.signal_activity_end().await,
103 LiveEffect::UpdateInstruction(instruction) => {
104 self.writer.update_instruction(instruction).await
105 }
106 LiveEffect::Emit(event) => {
107 let _ = self.event_tx.send(event);
108 Ok(())
109 }
110 }
111 }
112
113 pub async fn flush_deferred_prompt(&self) -> Result<(), SessionError> {
119 let Some(pending) = &self.pending_context else {
120 return Ok(());
121 };
122
123 let contents = pending.drain_context();
124 if !contents.is_empty() {
125 self.writer.send_client_content(contents, false).await?;
126 }
127 if pending.take_prompt() {
128 self.writer.send_client_content(vec![], true).await?;
129 }
130 Ok(())
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use async_trait::async_trait;
138 use gemini_genai_rs::prelude::{Content, FunctionResponse};
139 use parking_lot::Mutex;
140
141 #[derive(Debug, Clone, PartialEq, Eq)]
142 enum Write {
143 ClientContent { turns: usize, turn_complete: bool },
144 Instruction(String),
145 ActivityStart,
146 ActivityEnd,
147 }
148
149 #[derive(Default)]
150 struct MockWriter {
151 writes: Mutex<Vec<Write>>,
152 }
153
154 #[async_trait]
155 impl SessionWriter for MockWriter {
156 async fn send_audio(&self, _data: Vec<u8>) -> Result<(), SessionError> {
157 Ok(())
158 }
159
160 async fn send_text(&self, _text: String) -> Result<(), SessionError> {
161 Ok(())
162 }
163
164 async fn send_tool_response(
165 &self,
166 _responses: Vec<FunctionResponse>,
167 ) -> Result<(), SessionError> {
168 Ok(())
169 }
170
171 async fn send_client_content(
172 &self,
173 turns: Vec<Content>,
174 turn_complete: bool,
175 ) -> Result<(), SessionError> {
176 self.writes.lock().push(Write::ClientContent {
177 turns: turns.len(),
178 turn_complete,
179 });
180 Ok(())
181 }
182
183 async fn send_video(&self, _jpeg_data: Vec<u8>) -> Result<(), SessionError> {
184 Ok(())
185 }
186
187 async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
188 self.writes.lock().push(Write::Instruction(instruction));
189 Ok(())
190 }
191
192 async fn signal_activity_start(&self) -> Result<(), SessionError> {
193 self.writes.lock().push(Write::ActivityStart);
194 Ok(())
195 }
196
197 async fn signal_activity_end(&self) -> Result<(), SessionError> {
198 self.writes.lock().push(Write::ActivityEnd);
199 Ok(())
200 }
201
202 async fn disconnect(&self) -> Result<(), SessionError> {
203 Ok(())
204 }
205 }
206
207 #[tokio::test]
208 async fn prompt_model_flushes_context_then_armed_prompt() {
209 let writer = Arc::new(MockWriter::default());
210 let pending = Arc::new(PendingContext::new());
211 pending.push(Content::model("phase context"));
212 pending.set_prompt();
213 let (event_tx, _) = broadcast::channel(8);
214 let executor = LiveEffectExecutor::new(writer.clone(), Some(pending.clone()), event_tx);
215
216 executor.execute(LiveEffect::PromptModel).await.unwrap();
217
218 assert_eq!(
219 writer.writes.lock().as_slice(),
220 &[
221 Write::ClientContent {
222 turns: 1,
223 turn_complete: false
224 },
225 Write::ClientContent {
226 turns: 0,
227 turn_complete: true
228 }
229 ]
230 );
231 assert!(pending.is_empty());
232 }
233
234 #[tokio::test]
235 async fn prompt_model_without_armed_prompt_only_flushes_context() {
236 let writer = Arc::new(MockWriter::default());
237 let pending = Arc::new(PendingContext::new());
238 pending.push(Content::model("phase context"));
239 let (event_tx, _) = broadcast::channel(8);
240 let executor = LiveEffectExecutor::new(writer.clone(), Some(pending), event_tx);
241
242 executor.execute(LiveEffect::PromptModel).await.unwrap();
243
244 assert_eq!(
245 writer.writes.lock().as_slice(),
246 &[Write::ClientContent {
247 turns: 1,
248 turn_complete: false
249 }]
250 );
251 }
252
253 #[tokio::test]
254 async fn update_instruction_uses_writer() {
255 let writer = Arc::new(MockWriter::default());
256 let (event_tx, _) = broadcast::channel(8);
257 let executor = LiveEffectExecutor::new(writer.clone(), None, event_tx);
258
259 executor
260 .execute(LiveEffect::UpdateInstruction("new instruction".into()))
261 .await
262 .unwrap();
263
264 assert_eq!(
265 writer.writes.lock().as_slice(),
266 &[Write::Instruction("new instruction".into())]
267 );
268 }
269
270 #[tokio::test]
271 async fn cancel_deferred_prompt_keeps_context() {
272 let writer = Arc::new(MockWriter::default());
273 let pending = Arc::new(PendingContext::new());
274 pending.push(Content::model("still useful with user audio"));
275 pending.set_prompt();
276 let (event_tx, _) = broadcast::channel(8);
277 let executor = LiveEffectExecutor::new(writer, Some(pending.clone()), event_tx);
278
279 executor
280 .execute(LiveEffect::CancelDeferredPrompt)
281 .await
282 .unwrap();
283
284 assert!(!pending.has_prompt());
285 assert_eq!(pending.drain_context().len(), 1);
286 }
287
288 #[tokio::test]
289 async fn user_activity_effects_signal_writer() {
290 let writer = Arc::new(MockWriter::default());
291 let (event_tx, _) = broadcast::channel(8);
292 let executor = LiveEffectExecutor::new(writer.clone(), None, event_tx);
293
294 executor
295 .execute_reactions(vec![
296 Reaction::blocking("test", LiveEffect::SignalUserActivityStart),
297 Reaction::blocking("test", LiveEffect::SignalUserActivityEnd),
298 ])
299 .await
300 .unwrap();
301
302 assert_eq!(
303 writer.writes.lock().as_slice(),
304 &[Write::ActivityStart, Write::ActivityEnd]
305 );
306 }
307
308 #[tokio::test]
309 async fn concurrent_effect_failure_is_surfaced_as_event() {
310 struct FailWriter;
311 #[async_trait]
312 impl SessionWriter for FailWriter {
313 async fn send_audio(&self, _: Vec<u8>) -> Result<(), SessionError> {
314 Ok(())
315 }
316 async fn send_text(&self, _: String) -> Result<(), SessionError> {
317 Ok(())
318 }
319 async fn send_tool_response(
320 &self,
321 _: Vec<FunctionResponse>,
322 ) -> Result<(), SessionError> {
323 Ok(())
324 }
325 async fn send_client_content(
326 &self,
327 _: Vec<Content>,
328 _: bool,
329 ) -> Result<(), SessionError> {
330 Err(SessionError::NotConnected)
331 }
332 async fn send_video(&self, _: Vec<u8>) -> Result<(), SessionError> {
333 Ok(())
334 }
335 async fn update_instruction(&self, _: String) -> Result<(), SessionError> {
336 Ok(())
337 }
338 async fn signal_activity_start(&self) -> Result<(), SessionError> {
339 Ok(())
340 }
341 async fn signal_activity_end(&self) -> Result<(), SessionError> {
342 Ok(())
343 }
344 async fn disconnect(&self) -> Result<(), SessionError> {
345 Ok(())
346 }
347 }
348
349 let (event_tx, mut rx) = broadcast::channel(8);
350 let executor = LiveEffectExecutor::new(Arc::new(FailWriter), None, event_tx);
351
352 executor
354 .execute_reactions(vec![Reaction::concurrent(
355 "test",
356 LiveEffect::SendContext(vec![Content::model("x")]),
357 )])
358 .await
359 .unwrap();
360
361 let event = tokio::time::timeout(std::time::Duration::from_secs(1), rx.recv())
362 .await
363 .expect("a reaction-failure event within the timeout")
364 .expect("event received");
365 assert!(
366 matches!(&event, LiveEvent::Error(msg) if msg.contains("reaction 'test' failed")),
367 "expected a reaction-failure error event, got {event:?}"
368 );
369 }
370}