gemini_adk_fluent_rs/live/
callbacks.rs

1//! Event callback registration methods for `Live`.
2
3use std::future::Future;
4use std::sync::Arc;
5use std::time::Duration;
6
7use bytes::Bytes;
8
9use gemini_adk_rs::live::CallbackMode;
10use gemini_adk_rs::State;
11use gemini_genai_rs::prelude::*;
12
13use super::Live;
14
15impl Live {
16    // -- Outbound Interceptors --
17
18    /// Intercept tool responses before they are sent back to Gemini.
19    ///
20    /// Use this to rewrite, augment, or filter tool results based on
21    /// conversation state. The callback receives the tool responses and the
22    /// shared `State`, and returns (potentially modified) responses.
23    ///
24    /// # Example
25    /// ```ignore
26    /// .before_tool_response(|responses, state| async move {
27    ///     let order: OrderState = state.get("OrderState").unwrap_or_default();
28    ///     responses.into_iter().map(|mut r| {
29    ///         r.response["current_order"] = serde_json::to_value(&order).unwrap();
30    ///         r
31    ///     }).collect()
32    /// })
33    /// ```
34    pub fn before_tool_response<F, Fut>(mut self, f: F) -> Self
35    where
36        F: Fn(Vec<FunctionResponse>, gemini_adk_rs::State) -> Fut + Send + Sync + 'static,
37        Fut: Future<Output = Vec<FunctionResponse>> + Send + 'static,
38    {
39        self.callbacks.before_tool_response = Some(Arc::new(move |responses, state| {
40            Box::pin(f(responses, state))
41        }));
42        self
43    }
44
45    /// Hook called at turn boundaries — after extractors run, before `on_turn_complete`.
46    ///
47    /// Receives the shared `State` and a `SessionWriter` for injecting content
48    /// into the conversation. Use for context stuffing, K/V data injection,
49    /// condensed state summaries, or any outbound content interleaving.
50    ///
51    /// # Example
52    /// ```ignore
53    /// .on_turn_boundary(|state, writer| async move {
54    ///     let summary = state.get::<String>("summary").unwrap_or_default();
55    ///     writer.send_client_content(
56    ///         vec![Content::user().text(format!("[Context: {summary}]"))],
57    ///         false,
58    ///     ).await.ok();
59    /// })
60    /// ```
61    pub fn on_turn_boundary<F, Fut>(mut self, f: F) -> Self
62    where
63        F: Fn(gemini_adk_rs::State, Arc<dyn gemini_genai_rs::session::SessionWriter>) -> Fut
64            + Send
65            + Sync
66            + 'static,
67        Fut: Future<Output = ()> + Send + 'static,
68    {
69        self.callbacks.on_turn_boundary =
70            Some(Arc::new(move |state, writer| Box::pin(f(state, writer))));
71        self
72    }
73
74    // -- Fast Lane Callbacks (sync, < 1ms) --
75
76    /// Called for each audio chunk from the model (PCM16 24kHz).
77    pub fn on_audio(mut self, f: impl Fn(&Bytes) + Send + Sync + 'static) -> Self {
78        self.callbacks.on_audio = Some(Box::new(f));
79        self
80    }
81
82    /// Called for each incremental text delta.
83    pub fn on_text(mut self, f: impl Fn(&str) + Send + Sync + 'static) -> Self {
84        self.callbacks.on_text = Some(Box::new(f));
85        self
86    }
87
88    /// Called when model completes a text response.
89    pub fn on_text_complete(mut self, f: impl Fn(&str) + Send + Sync + 'static) -> Self {
90        self.callbacks.on_text_complete = Some(Box::new(f));
91        self
92    }
93
94    /// Called for input (user speech) transcription.
95    pub fn on_input_transcript(mut self, f: impl Fn(&str, bool) + Send + Sync + 'static) -> Self {
96        self.callbacks.on_input_transcript = Some(Box::new(f));
97        self
98    }
99
100    /// Called for output (model speech) transcription.
101    pub fn on_output_transcript(mut self, f: impl Fn(&str, bool) + Send + Sync + 'static) -> Self {
102        self.callbacks.on_output_transcript = Some(Box::new(f));
103        self
104    }
105
106    /// Called when the model emits a thought/reasoning summary.
107    ///
108    /// Requires `.include_thoughts()` on the session config. Fast lane callback
109    /// (sync, must complete in < 1ms).
110    pub fn on_thought(mut self, f: impl Fn(&str) + Send + Sync + 'static) -> Self {
111        self.callbacks.on_thought = Some(Box::new(f));
112        self
113    }
114
115    /// Called when server VAD detects voice activity start.
116    pub fn on_vad_start(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
117        self.callbacks.on_vad_start = Some(Box::new(f));
118        self
119    }
120
121    /// Called when server VAD detects voice activity end.
122    pub fn on_vad_end(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
123        self.callbacks.on_vad_end = Some(Box::new(f));
124        self
125    }
126
127    /// Called when server sends token usage metadata.
128    ///
129    /// Receives a reference to the full [`UsageMetadata`] including prompt,
130    /// response, cached, tool-use, and thoughts token counts plus per-modality
131    /// breakdowns. Fires on the telemetry lane (not the fast lane).
132    pub fn on_usage(mut self, f: impl Fn(&UsageMetadata) + Send + Sync + 'static) -> Self {
133        self.callbacks.on_usage = Some(Box::new(f));
134        self
135    }
136
137    /// Called on session phase transitions.
138    ///
139    /// Receives the new [`SessionPhase`]. Fast lane callback (sync, must
140    /// complete in < 1ms). Use for lightweight UI state updates or metrics.
141    pub fn on_phase(mut self, f: impl Fn(SessionPhase) + Send + Sync + 'static) -> Self {
142        self.callbacks.on_phase = Some(Box::new(f));
143        self
144    }
145
146    // -- Control Lane Callbacks (async, can block) --
147
148    /// Called when model is interrupted by barge-in.
149    pub fn on_interrupted<F, Fut>(mut self, f: F) -> Self
150    where
151        F: Fn() -> Fut + Send + Sync + 'static,
152        Fut: Future<Output = ()> + Send + 'static,
153    {
154        self.callbacks.on_interrupted = Some(Arc::new(move || Box::pin(f())));
155        self
156    }
157
158    /// Called when model requests tool execution.
159    /// Return `None` to auto-dispatch, `Some(responses)` to override.
160    /// Receives State for natural state promotion from tool results.
161    pub fn on_tool_call<F, Fut>(mut self, f: F) -> Self
162    where
163        F: Fn(Vec<FunctionCall>, State) -> Fut + Send + Sync + 'static,
164        Fut: Future<Output = Option<Vec<FunctionResponse>>> + Send + 'static,
165    {
166        self.callbacks.on_tool_call = Some(Arc::new(move |calls, state| Box::pin(f(calls, state))));
167        self
168    }
169
170    /// Called when the server cancels pending tool calls.
171    ///
172    /// Receives the list of cancelled tool call IDs. Use to clean up any
173    /// in-flight async work associated with those calls.
174    pub fn on_tool_cancelled<F, Fut>(mut self, f: F) -> Self
175    where
176        F: Fn(Vec<String>) -> Fut + Send + Sync + 'static,
177        Fut: Future<Output = ()> + Send + 'static,
178    {
179        self.callbacks.on_tool_cancelled = Some(Arc::new(move |ids| Box::pin(f(ids))));
180        self
181    }
182
183    /// Called when model turn completes.
184    pub fn on_turn_complete<F, Fut>(mut self, f: F) -> Self
185    where
186        F: Fn() -> Fut + Send + Sync + 'static,
187        Fut: Future<Output = ()> + Send + 'static,
188    {
189        self.callbacks.on_turn_complete = Some(Arc::new(move || Box::pin(f())));
190        self
191    }
192
193    /// Called when the model finishes generating its full intended response.
194    ///
195    /// Fires on the wire `GenerationComplete` event, before any interruption
196    /// truncation. Use this to capture the model's complete output even when
197    /// the user barges in. Paired with `.extract_on_generation()` for structured
198    /// extraction of the pre-truncation response.
199    pub fn on_generation_complete<F, Fut>(mut self, f: F) -> Self
200    where
201        F: Fn() -> Fut + Send + Sync + 'static,
202        Fut: Future<Output = ()> + Send + 'static,
203    {
204        self.callbacks.on_generation_complete = Some(Arc::new(move || Box::pin(f())));
205        self
206    }
207
208    /// Called when server sends GoAway.
209    pub fn on_go_away<F, Fut>(mut self, f: F) -> Self
210    where
211        F: Fn(Duration) -> Fut + Send + Sync + 'static,
212        Fut: Future<Output = ()> + Send + 'static,
213    {
214        self.callbacks.on_go_away = Some(Arc::new(move |d| Box::pin(f(d))));
215        self
216    }
217
218    /// Called when session connects (setup complete).
219    ///
220    /// Receives a `SessionWriter` for sending messages on connect.
221    pub fn on_connected<F, Fut>(mut self, f: F) -> Self
222    where
223        F: Fn(Arc<dyn gemini_genai_rs::session::SessionWriter>) -> Fut + Send + Sync + 'static,
224        Fut: Future<Output = ()> + Send + 'static,
225    {
226        self.callbacks.on_connected = Some(Arc::new(move |w| Box::pin(f(w))));
227        self
228    }
229
230    /// Called when session disconnects.
231    pub fn on_disconnected<F, Fut>(mut self, f: F) -> Self
232    where
233        F: Fn(Option<String>) -> Fut + Send + Sync + 'static,
234        Fut: Future<Output = ()> + Send + 'static,
235    {
236        self.callbacks.on_disconnected = Some(Arc::new(move |r| Box::pin(f(r))));
237        self
238    }
239
240    /// Called after the session resumes following a GoAway disconnect.
241    ///
242    /// Use to re-subscribe to external streams, reset UI state, or log
243    /// resume events. Paired with `.session_resume(true)` on the builder.
244    pub fn on_resumed<F, Fut>(mut self, f: F) -> Self
245    where
246        F: Fn() -> Fut + Send + Sync + 'static,
247        Fut: Future<Output = ()> + Send + 'static,
248    {
249        self.callbacks.on_resumed = Some(Arc::new(move || Box::pin(f())));
250        self
251    }
252
253    /// Called on non-fatal errors.
254    pub fn on_error<F, Fut>(mut self, f: F) -> Self
255    where
256        F: Fn(String) -> Fut + Send + Sync + 'static,
257        Fut: Future<Output = ()> + Send + 'static,
258    {
259        self.callbacks.on_error = Some(Arc::new(move |e| Box::pin(f(e))));
260        self
261    }
262
263    // -- Concurrent callback variants --
264    // These set CallbackMode::Concurrent so the callback is spawned as a
265    // detached tokio task instead of being awaited inline.
266
267    /// Called when model turn completes (spawned concurrently).
268    pub fn on_turn_complete_concurrent<F, Fut>(mut self, f: F) -> Self
269    where
270        F: Fn() -> Fut + Send + Sync + 'static,
271        Fut: Future<Output = ()> + Send + 'static,
272    {
273        self.callbacks.on_turn_complete = Some(Arc::new(move || Box::pin(f())));
274        self.callbacks.on_turn_complete_mode = CallbackMode::Concurrent;
275        self
276    }
277
278    /// Called when the model finishes generating its full intended response (spawned concurrently).
279    pub fn on_generation_complete_concurrent<F, Fut>(mut self, f: F) -> Self
280    where
281        F: Fn() -> Fut + Send + Sync + 'static,
282        Fut: Future<Output = ()> + Send + 'static,
283    {
284        self.callbacks.on_generation_complete = Some(Arc::new(move || Box::pin(f())));
285        self.callbacks.on_generation_complete_mode = CallbackMode::Concurrent;
286        self
287    }
288
289    /// Called when session connects (spawned concurrently).
290    pub fn on_connected_concurrent<F, Fut>(mut self, f: F) -> Self
291    where
292        F: Fn(Arc<dyn gemini_genai_rs::session::SessionWriter>) -> Fut + Send + Sync + 'static,
293        Fut: Future<Output = ()> + Send + 'static,
294    {
295        self.callbacks.on_connected = Some(Arc::new(move |w| Box::pin(f(w))));
296        self.callbacks.on_connected_mode = CallbackMode::Concurrent;
297        self
298    }
299
300    /// Called when session disconnects (spawned concurrently).
301    pub fn on_disconnected_concurrent<F, Fut>(mut self, f: F) -> Self
302    where
303        F: Fn(Option<String>) -> Fut + Send + Sync + 'static,
304        Fut: Future<Output = ()> + Send + 'static,
305    {
306        self.callbacks.on_disconnected = Some(Arc::new(move |r| Box::pin(f(r))));
307        self.callbacks.on_disconnected_mode = CallbackMode::Concurrent;
308        self
309    }
310
311    /// Called after session resumes from GoAway (spawned concurrently).
312    pub fn on_resumed_concurrent<F, Fut>(mut self, f: F) -> Self
313    where
314        F: Fn() -> Fut + Send + Sync + 'static,
315        Fut: Future<Output = ()> + Send + 'static,
316    {
317        self.callbacks.on_resumed = Some(Arc::new(move || Box::pin(f())));
318        self.callbacks.on_resumed_mode = CallbackMode::Concurrent;
319        self
320    }
321
322    /// Called on non-fatal errors (spawned concurrently).
323    pub fn on_error_concurrent<F, Fut>(mut self, f: F) -> Self
324    where
325        F: Fn(String) -> Fut + Send + Sync + 'static,
326        Fut: Future<Output = ()> + Send + 'static,
327    {
328        self.callbacks.on_error = Some(Arc::new(move |e| Box::pin(f(e))));
329        self.callbacks.on_error_mode = CallbackMode::Concurrent;
330        self
331    }
332
333    /// Called when server sends GoAway (spawned concurrently).
334    pub fn on_go_away_concurrent<F, Fut>(mut self, f: F) -> Self
335    where
336        F: Fn(Duration) -> Fut + Send + Sync + 'static,
337        Fut: Future<Output = ()> + Send + 'static,
338    {
339        self.callbacks.on_go_away = Some(Arc::new(move |d| Box::pin(f(d))));
340        self.callbacks.on_go_away_mode = CallbackMode::Concurrent;
341        self
342    }
343
344    /// Called when the server cancels pending tool calls (spawned concurrently).
345    pub fn on_tool_cancelled_concurrent<F, Fut>(mut self, f: F) -> Self
346    where
347        F: Fn(Vec<String>) -> Fut + Send + Sync + 'static,
348        Fut: Future<Output = ()> + Send + 'static,
349    {
350        self.callbacks.on_tool_cancelled = Some(Arc::new(move |ids| Box::pin(f(ids))));
351        self.callbacks.on_tool_cancelled_mode = CallbackMode::Concurrent;
352        self
353    }
354
355    /// Called when a TurnExtractor produces a result (spawned concurrently).
356    pub fn on_extracted_concurrent<F, Fut>(mut self, f: F) -> Self
357    where
358        F: Fn(String, serde_json::Value) -> Fut + Send + Sync + 'static,
359        Fut: Future<Output = ()> + Send + 'static,
360    {
361        self.callbacks.on_extracted = Some(Arc::new(move |name, value| Box::pin(f(name, value))));
362        self.callbacks.on_extracted_mode = CallbackMode::Concurrent;
363        self
364    }
365
366    /// Called when a TurnExtractor fails (spawned concurrently).
367    pub fn on_extraction_error_concurrent<F, Fut>(mut self, f: F) -> Self
368    where
369        F: Fn(String, String) -> Fut + Send + Sync + 'static,
370        Fut: Future<Output = ()> + Send + 'static,
371    {
372        self.callbacks.on_extraction_error =
373            Some(Arc::new(move |name, error| Box::pin(f(name, error))));
374        self.callbacks.on_extraction_error_mode = CallbackMode::Concurrent;
375        self
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382
383    /// Verify that all four new callback setters are accepted by the builder
384    /// and that the chain returns `Self` (i.e., the type-system accepts them).
385    #[test]
386    fn builder_accepts_new_callbacks() {
387        let _live = Live::builder()
388            // on_phase: sync fast-lane
389            .on_phase(|_phase| {})
390            // on_tool_cancelled: async control-lane
391            .on_tool_cancelled(|_ids| async {})
392            // on_generation_complete: async control-lane, no args
393            .on_generation_complete(|| async {})
394            // on_resumed: async control-lane, no args
395            .on_resumed(|| async {});
396        // Compiles = test passes
397    }
398
399    /// Verify that the concurrent variants of the new setters also compile.
400    #[test]
401    fn builder_accepts_new_callbacks_concurrent() {
402        let _live = Live::builder()
403            .on_tool_cancelled_concurrent(|_ids| async {})
404            .on_generation_complete_concurrent(|| async {})
405            .on_resumed_concurrent(|| async {});
406        // Compiles = test passes
407    }
408}