gemini_adk_rs/live/
background_tool.rs

1//! Non-blocking tool execution infrastructure.
2//!
3//! Provides [`BackgroundToolTracker`] for managing in-flight background tool
4//! executions, [`ResultFormatter`] for customizing tool response formatting,
5//! and [`ToolExecutionMode`] for declaring whether a tool runs synchronously
6//! or in the background.
7
8use std::sync::Arc;
9
10use dashmap::DashMap;
11use serde_json::Value;
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14
15use gemini_genai_rs::prelude::FunctionCall;
16
17use crate::error::ToolError;
18
19// ---------------------------------------------------------------------------
20// ResultFormatter trait
21// ---------------------------------------------------------------------------
22
23/// Formats tool responses for background execution lifecycle.
24///
25/// Implementations control the shape of JSON values sent back to the model
26/// at each stage of a background tool's lifecycle: when the tool starts
27/// running, when it completes (or fails), and when it is cancelled.
28pub trait ResultFormatter: Send + Sync + 'static {
29    /// Format the immediate acknowledgment sent when a background tool starts.
30    fn format_running(&self, call: &FunctionCall) -> Value;
31
32    /// Format the final result after tool completes or fails.
33    fn format_result(&self, call: &FunctionCall, result: Result<Value, ToolError>) -> Value;
34
35    /// Format a cancellation response.
36    fn format_cancelled(&self, call_id: &str) -> Value;
37}
38
39// ---------------------------------------------------------------------------
40// DefaultResultFormatter
41// ---------------------------------------------------------------------------
42
43/// Default formatter that wraps results in a status object.
44///
45/// Produces JSON like:
46/// ```json
47/// { "status": "running", "tool": "search" }
48/// { "status": "completed", "tool": "search", "result": { ... } }
49/// { "status": "error", "tool": "search", "error": "..." }
50/// { "status": "cancelled", "call_id": "abc123" }
51/// ```
52pub struct DefaultResultFormatter;
53
54impl ResultFormatter for DefaultResultFormatter {
55    fn format_running(&self, call: &FunctionCall) -> Value {
56        serde_json::json!({
57            "status": "running",
58            "tool": call.name,
59        })
60    }
61
62    fn format_result(&self, call: &FunctionCall, result: Result<Value, ToolError>) -> Value {
63        match result {
64            Ok(value) => serde_json::json!({
65                "status": "completed",
66                "tool": call.name,
67                "result": value,
68            }),
69            Err(e) => serde_json::json!({
70                "status": "error",
71                "tool": call.name,
72                "error": e.to_string(),
73            }),
74        }
75    }
76
77    fn format_cancelled(&self, call_id: &str) -> Value {
78        serde_json::json!({
79            "status": "cancelled",
80            "call_id": call_id,
81        })
82    }
83}
84
85// ---------------------------------------------------------------------------
86// ToolExecutionMode
87// ---------------------------------------------------------------------------
88
89/// Execution mode for a tool.
90///
91/// - [`Standard`](ToolExecutionMode::Standard): the tool runs inline and the
92///   model waits for the result before continuing.
93/// - [`Background`](ToolExecutionMode::Background): the tool is spawned as a
94///   background task. An immediate "running" acknowledgment is sent to the
95///   model, and the final result is delivered asynchronously when the task
96///   completes.
97///
98/// # With the L2 Fluent API
99///
100/// ```rust,ignore
101/// Live::builder()
102///     .tools(dispatcher)
103///     .tool_background("search_kb")           // uses DefaultResultFormatter
104///     .tool_background_with_formatter(         // custom formatter
105///         "analyze",
106///         Arc::new(MyFormatter),
107///     )
108///     .connect_vertex(project, location, token)
109///     .await?;
110/// ```
111///
112/// # With the L1 Builder
113///
114/// ```rust,ignore
115/// LiveSessionBuilder::new(config)
116///     .dispatcher(dispatcher)
117///     .tool_execution_mode("search_kb", ToolExecutionMode::Background {
118///         formatter: None,
119///     })
120///     .connect()
121///     .await?;
122/// ```
123#[derive(Clone, Default)]
124pub enum ToolExecutionMode {
125    /// The tool runs inline (blocking the model turn until complete).
126    #[default]
127    Standard,
128
129    /// The tool runs in the background.
130    ///
131    /// An optional [`ResultFormatter`] controls how acknowledgment, result,
132    /// and cancellation messages are shaped. When `None`, the
133    /// [`DefaultResultFormatter`] is used.
134    ///
135    /// The `scheduling` field controls how the model handles async results:
136    /// - `Interrupt`: halts current output, immediately reports the result
137    /// - `WhenIdle`: waits until current output finishes before handling
138    /// - `Silent`: integrates the result without notifying the user
139    Background {
140        /// Custom formatter for background tool results, or `None` for the default.
141        formatter: Option<Arc<dyn ResultFormatter>>,
142        /// How the model should handle the async result. Defaults to `WhenIdle`.
143        scheduling: Option<gemini_genai_rs::prelude::FunctionResponseScheduling>,
144    },
145}
146
147impl std::fmt::Debug for ToolExecutionMode {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        match self {
150            Self::Standard => write!(f, "Standard"),
151            Self::Background {
152                formatter,
153                scheduling,
154            } => {
155                write!(
156                    f,
157                    "Background(formatter={}, scheduling={:?})",
158                    formatter.is_some(),
159                    scheduling
160                )
161            }
162        }
163    }
164}
165
166// ---------------------------------------------------------------------------
167// BackgroundToolTracker
168// ---------------------------------------------------------------------------
169
170/// Tracks in-flight background tool executions for cancellation.
171///
172/// Uses [`DashMap`] internally so that spawned tasks can remove themselves
173/// upon completion while the control lane concurrently spawns or cancels
174/// other tasks.
175pub struct BackgroundToolTracker {
176    tasks: DashMap<String, (JoinHandle<()>, CancellationToken)>,
177}
178
179impl BackgroundToolTracker {
180    /// Create a new, empty tracker.
181    pub fn new() -> Self {
182        Self {
183            tasks: DashMap::new(),
184        }
185    }
186
187    /// Register a spawned background task.
188    ///
189    /// The `call_id` is the unique identifier for the function call (usually
190    /// from [`FunctionCall::id`]). The caller provides both a
191    /// [`JoinHandle`] (for aborting) and a [`CancellationToken`] (for
192    /// cooperative cancellation).
193    pub fn spawn(&self, call_id: String, task: JoinHandle<()>, cancel: CancellationToken) {
194        self.tasks.insert(call_id, (task, cancel));
195    }
196
197    /// Cancel specific tool calls by their IDs.
198    ///
199    /// For each matching ID the cancellation token is triggered **and** the
200    /// task handle is aborted, providing belt-and-suspenders cleanup.
201    /// Non-existent IDs are silently ignored.
202    pub fn cancel(&self, call_ids: &[String]) {
203        for id in call_ids {
204            if let Some((_, (handle, token))) = self.tasks.remove(id) {
205                token.cancel();
206                handle.abort();
207            }
208        }
209    }
210
211    /// Cancel all in-flight background tasks.
212    ///
213    /// Useful during session shutdown to ensure no orphaned tasks remain.
214    pub fn cancel_all(&self) {
215        let keys: Vec<String> = self.tasks.iter().map(|r| r.key().clone()).collect();
216        for key in keys {
217            if let Some((_, (handle, token))) = self.tasks.remove(&key) {
218                token.cancel();
219                handle.abort();
220            }
221        }
222    }
223
224    /// Get IDs of active background tool calls.
225    pub fn active_ids(&self) -> Vec<String> {
226        self.tasks.iter().map(|r| r.key().clone()).collect()
227    }
228
229    /// Remove a completed task (called when background task finishes).
230    ///
231    /// This is typically invoked by the spawned task itself to clean up the
232    /// tracker entry once execution is done.
233    pub fn remove(&self, call_id: &str) {
234        self.tasks.remove(call_id);
235    }
236
237    /// Number of active background tasks.
238    pub fn active_count(&self) -> usize {
239        self.tasks.len()
240    }
241}
242
243impl Default for BackgroundToolTracker {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249// ===========================================================================
250// Tests
251// ===========================================================================
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    // -----------------------------------------------------------------------
258    // BackgroundToolTracker tests
259    // -----------------------------------------------------------------------
260
261    #[test]
262    fn tracker_new_is_empty() {
263        let tracker = BackgroundToolTracker::new();
264        assert_eq!(tracker.active_count(), 0);
265        assert!(tracker.active_ids().is_empty());
266    }
267
268    #[tokio::test]
269    async fn spawn_shows_active_id() {
270        let tracker = BackgroundToolTracker::new();
271        let token = CancellationToken::new();
272        let t = token.clone();
273        let handle = tokio::spawn(async move {
274            t.cancelled().await;
275        });
276        tracker.spawn("call1".into(), handle, token.clone());
277
278        let ids = tracker.active_ids();
279        assert_eq!(ids, vec!["call1".to_string()]);
280
281        // Clean up
282        token.cancel();
283    }
284
285    #[tokio::test]
286    async fn spawn_increments_active_count() {
287        let tracker = BackgroundToolTracker::new();
288
289        let token1 = CancellationToken::new();
290        let t1 = token1.clone();
291        let h1 = tokio::spawn(async move {
292            t1.cancelled().await;
293        });
294        tracker.spawn("call1".into(), h1, token1.clone());
295
296        let token2 = CancellationToken::new();
297        let t2 = token2.clone();
298        let h2 = tokio::spawn(async move {
299            t2.cancelled().await;
300        });
301        tracker.spawn("call2".into(), h2, token2.clone());
302
303        assert_eq!(tracker.active_count(), 2);
304
305        // Clean up
306        token1.cancel();
307        token2.cancel();
308    }
309
310    #[tokio::test]
311    async fn cancel_removes_task_and_cancels_token() {
312        let tracker = BackgroundToolTracker::new();
313        let token = CancellationToken::new();
314        let t = token.clone();
315        let handle = tokio::spawn(async move {
316            t.cancelled().await;
317        });
318        tracker.spawn("call1".into(), handle, token.clone());
319
320        assert_eq!(tracker.active_count(), 1);
321
322        tracker.cancel(&["call1".into()]);
323
324        assert_eq!(tracker.active_count(), 0);
325        assert!(token.is_cancelled());
326    }
327
328    #[tokio::test]
329    async fn cancel_all_clears_all_tasks() {
330        let tracker = BackgroundToolTracker::new();
331
332        let token1 = CancellationToken::new();
333        let t1 = token1.clone();
334        let h1 = tokio::spawn(async move {
335            t1.cancelled().await;
336        });
337        tracker.spawn("call1".into(), h1, token1.clone());
338
339        let token2 = CancellationToken::new();
340        let t2 = token2.clone();
341        let h2 = tokio::spawn(async move {
342            t2.cancelled().await;
343        });
344        tracker.spawn("call2".into(), h2, token2.clone());
345
346        let token3 = CancellationToken::new();
347        let t3 = token3.clone();
348        let h3 = tokio::spawn(async move {
349            t3.cancelled().await;
350        });
351        tracker.spawn("call3".into(), h3, token3.clone());
352
353        assert_eq!(tracker.active_count(), 3);
354
355        tracker.cancel_all();
356
357        assert_eq!(tracker.active_count(), 0);
358        assert!(token1.is_cancelled());
359        assert!(token2.is_cancelled());
360        assert!(token3.is_cancelled());
361    }
362
363    #[tokio::test]
364    async fn remove_cleans_up_completed_task() {
365        let tracker = BackgroundToolTracker::new();
366        let token = CancellationToken::new();
367        let t = token.clone();
368        let handle = tokio::spawn(async move {
369            t.cancelled().await;
370        });
371        tracker.spawn("call1".into(), handle, token.clone());
372
373        assert_eq!(tracker.active_count(), 1);
374
375        tracker.remove("call1");
376
377        assert_eq!(tracker.active_count(), 0);
378        assert!(tracker.active_ids().is_empty());
379
380        // Clean up — token not cancelled by remove (that's intentional)
381        token.cancel();
382    }
383
384    #[test]
385    fn cancel_nonexistent_id_is_noop() {
386        let tracker = BackgroundToolTracker::new();
387        // Should not panic
388        tracker.cancel(&["nonexistent".into()]);
389        assert_eq!(tracker.active_count(), 0);
390    }
391
392    // -----------------------------------------------------------------------
393    // DefaultResultFormatter tests
394    // -----------------------------------------------------------------------
395
396    fn make_call(name: &str) -> FunctionCall {
397        FunctionCall {
398            name: name.to_string(),
399            args: serde_json::json!({"query": "test"}),
400            id: Some("fc_123".to_string()),
401        }
402    }
403
404    #[test]
405    fn format_running_output() {
406        let fmt = DefaultResultFormatter;
407        let call = make_call("search");
408        let result = fmt.format_running(&call);
409
410        assert_eq!(result["status"], "running");
411        assert_eq!(result["tool"], "search");
412    }
413
414    #[test]
415    fn format_result_ok() {
416        let fmt = DefaultResultFormatter;
417        let call = make_call("search");
418        let value = serde_json::json!({"items": [1, 2, 3]});
419        let result = fmt.format_result(&call, Ok(value.clone()));
420
421        assert_eq!(result["status"], "completed");
422        assert_eq!(result["tool"], "search");
423        assert_eq!(result["result"], value);
424    }
425
426    #[test]
427    fn format_result_err() {
428        let fmt = DefaultResultFormatter;
429        let call = make_call("search");
430        let err = ToolError::ExecutionFailed("connection timeout".into());
431        let result = fmt.format_result(&call, Err(err));
432
433        assert_eq!(result["status"], "error");
434        assert_eq!(result["tool"], "search");
435        assert!(result["error"]
436            .as_str()
437            .unwrap()
438            .contains("connection timeout"));
439    }
440
441    #[test]
442    fn format_cancelled_output() {
443        let fmt = DefaultResultFormatter;
444        let result = fmt.format_cancelled("fc_456");
445
446        assert_eq!(result["status"], "cancelled");
447        assert_eq!(result["call_id"], "fc_456");
448    }
449
450    // -----------------------------------------------------------------------
451    // ToolExecutionMode tests
452    // -----------------------------------------------------------------------
453
454    #[test]
455    fn tool_execution_mode_default_is_standard() {
456        let mode = ToolExecutionMode::default();
457        assert!(matches!(mode, ToolExecutionMode::Standard));
458    }
459
460    #[test]
461    fn tool_execution_mode_debug_standard() {
462        let mode = ToolExecutionMode::Standard;
463        assert_eq!(format!("{:?}", mode), "Standard");
464    }
465
466    #[test]
467    fn tool_execution_mode_debug_background_none() {
468        let mode = ToolExecutionMode::Background {
469            formatter: None,
470            scheduling: None,
471        };
472        assert_eq!(
473            format!("{:?}", mode),
474            "Background(formatter=false, scheduling=None)"
475        );
476    }
477
478    #[test]
479    fn tool_execution_mode_debug_background_some() {
480        let mode = ToolExecutionMode::Background {
481            formatter: Some(Arc::new(DefaultResultFormatter)),
482            scheduling: None,
483        };
484        assert_eq!(
485            format!("{:?}", mode),
486            "Background(formatter=true, scheduling=None)"
487        );
488    }
489}