1use std::sync::Arc;
4
5use gemini_genai_rs::prelude::SessionPhase;
6use gemini_genai_rs::session::{SessionError, SessionWriter};
7use tokio::sync::broadcast;
8
9use super::context_writer::PendingContext;
10use super::events::LiveEvent;
11use super::reactor::{EffectMode, LiveEffect, Reaction};
12
13#[derive(Clone)]
15pub struct LiveEffectExecutor {
16 writer: Arc<dyn SessionWriter>,
17 pending_context: Option<Arc<PendingContext>>,
18 event_tx: broadcast::Sender<LiveEvent>,
19}
20
21impl LiveEffectExecutor {
22 pub fn new(
24 writer: Arc<dyn SessionWriter>,
25 pending_context: Option<Arc<PendingContext>>,
26 event_tx: broadcast::Sender<LiveEvent>,
27 ) -> Self {
28 Self {
29 writer,
30 pending_context,
31 event_tx,
32 }
33 }
34
35 pub async fn execute_reactions(&self, reactions: Vec<Reaction>) -> Result<(), SessionError> {
37 for reaction in reactions {
38 match reaction.policy.mode {
39 EffectMode::Blocking => {
40 let executor = self.clone();
41 let fut = executor.execute(reaction.effect);
42 if let Some(timeout) = reaction.policy.timeout {
43 tokio::time::timeout(timeout, fut).await.map_err(|_| {
44 SessionError::Timeout {
45 phase: SessionPhase::Active,
46 elapsed: timeout,
47 }
48 })??;
49 } else {
50 fut.await?;
51 }
52 }
53 EffectMode::Concurrent => {
54 let executor = self.clone();
55 let timeout = reaction.policy.timeout;
56 tokio::spawn(async move {
57 let fut = executor.execute(reaction.effect);
58 if let Some(timeout) = timeout {
59 let _ = tokio::time::timeout(timeout, fut).await;
60 } else {
61 let _ = fut.await;
62 }
63 });
64 }
65 }
66 }
67 Ok(())
68 }
69
70 pub async fn execute(&self, effect: LiveEffect) -> Result<(), SessionError> {
72 match effect {
73 LiveEffect::Noop => Ok(()),
74 LiveEffect::SendContext(contents) => {
75 if !contents.is_empty() {
76 self.writer.send_client_content(contents, false).await?;
77 }
78 Ok(())
79 }
80 LiveEffect::PromptModel => self.flush_deferred_prompt().await,
81 LiveEffect::CancelDeferredPrompt => {
82 if let Some(pending) = &self.pending_context {
83 pending.clear_prompt();
84 }
85 Ok(())
86 }
87 LiveEffect::SignalUserActivityStart => self.writer.signal_activity_start().await,
88 LiveEffect::SignalUserActivityEnd => self.writer.signal_activity_end().await,
89 LiveEffect::UpdateInstruction(instruction) => {
90 self.writer.update_instruction(instruction).await
91 }
92 LiveEffect::Emit(event) => {
93 let _ = self.event_tx.send(event);
94 Ok(())
95 }
96 LiveEffect::TransitionPhase(_phase) => Ok(()),
97 }
98 }
99
100 pub async fn flush_deferred_prompt(&self) -> Result<(), SessionError> {
106 let Some(pending) = &self.pending_context else {
107 return Ok(());
108 };
109
110 let contents = pending.drain_context();
111 if !contents.is_empty() {
112 self.writer.send_client_content(contents, false).await?;
113 }
114 if pending.take_prompt() {
115 self.writer.send_client_content(vec![], true).await?;
116 }
117 Ok(())
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use async_trait::async_trait;
125 use gemini_genai_rs::prelude::{Content, FunctionResponse};
126 use parking_lot::Mutex;
127
128 #[derive(Debug, Clone, PartialEq, Eq)]
129 enum Write {
130 ClientContent { turns: usize, turn_complete: bool },
131 Instruction(String),
132 ActivityStart,
133 ActivityEnd,
134 }
135
136 #[derive(Default)]
137 struct MockWriter {
138 writes: Mutex<Vec<Write>>,
139 }
140
141 #[async_trait]
142 impl SessionWriter for MockWriter {
143 async fn send_audio(&self, _data: Vec<u8>) -> Result<(), SessionError> {
144 Ok(())
145 }
146
147 async fn send_text(&self, _text: String) -> Result<(), SessionError> {
148 Ok(())
149 }
150
151 async fn send_tool_response(
152 &self,
153 _responses: Vec<FunctionResponse>,
154 ) -> Result<(), SessionError> {
155 Ok(())
156 }
157
158 async fn send_client_content(
159 &self,
160 turns: Vec<Content>,
161 turn_complete: bool,
162 ) -> Result<(), SessionError> {
163 self.writes.lock().push(Write::ClientContent {
164 turns: turns.len(),
165 turn_complete,
166 });
167 Ok(())
168 }
169
170 async fn send_video(&self, _jpeg_data: Vec<u8>) -> Result<(), SessionError> {
171 Ok(())
172 }
173
174 async fn update_instruction(&self, instruction: String) -> Result<(), SessionError> {
175 self.writes.lock().push(Write::Instruction(instruction));
176 Ok(())
177 }
178
179 async fn signal_activity_start(&self) -> Result<(), SessionError> {
180 self.writes.lock().push(Write::ActivityStart);
181 Ok(())
182 }
183
184 async fn signal_activity_end(&self) -> Result<(), SessionError> {
185 self.writes.lock().push(Write::ActivityEnd);
186 Ok(())
187 }
188
189 async fn disconnect(&self) -> Result<(), SessionError> {
190 Ok(())
191 }
192 }
193
194 #[tokio::test]
195 async fn prompt_model_flushes_context_then_armed_prompt() {
196 let writer = Arc::new(MockWriter::default());
197 let pending = Arc::new(PendingContext::new());
198 pending.push(Content::model("phase context"));
199 pending.set_prompt();
200 let (event_tx, _) = broadcast::channel(8);
201 let executor = LiveEffectExecutor::new(writer.clone(), Some(pending.clone()), event_tx);
202
203 executor.execute(LiveEffect::PromptModel).await.unwrap();
204
205 assert_eq!(
206 writer.writes.lock().as_slice(),
207 &[
208 Write::ClientContent {
209 turns: 1,
210 turn_complete: false
211 },
212 Write::ClientContent {
213 turns: 0,
214 turn_complete: true
215 }
216 ]
217 );
218 assert!(pending.is_empty());
219 }
220
221 #[tokio::test]
222 async fn prompt_model_without_armed_prompt_only_flushes_context() {
223 let writer = Arc::new(MockWriter::default());
224 let pending = Arc::new(PendingContext::new());
225 pending.push(Content::model("phase context"));
226 let (event_tx, _) = broadcast::channel(8);
227 let executor = LiveEffectExecutor::new(writer.clone(), Some(pending), event_tx);
228
229 executor.execute(LiveEffect::PromptModel).await.unwrap();
230
231 assert_eq!(
232 writer.writes.lock().as_slice(),
233 &[Write::ClientContent {
234 turns: 1,
235 turn_complete: false
236 }]
237 );
238 }
239
240 #[tokio::test]
241 async fn update_instruction_uses_writer() {
242 let writer = Arc::new(MockWriter::default());
243 let (event_tx, _) = broadcast::channel(8);
244 let executor = LiveEffectExecutor::new(writer.clone(), None, event_tx);
245
246 executor
247 .execute(LiveEffect::UpdateInstruction("new instruction".into()))
248 .await
249 .unwrap();
250
251 assert_eq!(
252 writer.writes.lock().as_slice(),
253 &[Write::Instruction("new instruction".into())]
254 );
255 }
256
257 #[tokio::test]
258 async fn cancel_deferred_prompt_keeps_context() {
259 let writer = Arc::new(MockWriter::default());
260 let pending = Arc::new(PendingContext::new());
261 pending.push(Content::model("still useful with user audio"));
262 pending.set_prompt();
263 let (event_tx, _) = broadcast::channel(8);
264 let executor = LiveEffectExecutor::new(writer, Some(pending.clone()), event_tx);
265
266 executor
267 .execute(LiveEffect::CancelDeferredPrompt)
268 .await
269 .unwrap();
270
271 assert!(!pending.has_prompt());
272 assert_eq!(pending.drain_context().len(), 1);
273 }
274
275 #[tokio::test]
276 async fn user_activity_effects_signal_writer() {
277 let writer = Arc::new(MockWriter::default());
278 let (event_tx, _) = broadcast::channel(8);
279 let executor = LiveEffectExecutor::new(writer.clone(), None, event_tx);
280
281 executor
282 .execute_reactions(vec![
283 Reaction::blocking("test", LiveEffect::SignalUserActivityStart),
284 Reaction::blocking("test", LiveEffect::SignalUserActivityEnd),
285 ])
286 .await
287 .unwrap();
288
289 assert_eq!(
290 writer.writes.lock().as_slice(),
291 &[Write::ActivityStart, Write::ActivityEnd]
292 );
293 }
294}