gemini_adk_rs/tool/
dispatcher.rs

1//! Tool dispatcher — routes function calls to the right tool implementation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio_util::sync::CancellationToken;
8
9use gemini_genai_rs::prelude::{FunctionCall, FunctionDeclaration, FunctionResponse, Tool};
10
11use crate::error::ToolError;
12
13use super::{ActiveStreamingTool, ToolClass, ToolFunction, ToolKind, DEFAULT_TOOL_TIMEOUT};
14
15/// Routes function calls to the right tool implementation.
16pub struct ToolDispatcher {
17    tools: HashMap<String, ToolKind>,
18    active: Arc<tokio::sync::Mutex<HashMap<String, ActiveStreamingTool>>>,
19    default_timeout: Duration,
20    /// Cached tool declarations — computed once on first access.
21    cached_declarations: std::sync::OnceLock<Vec<Tool>>,
22    /// Optional provider consulted before running confirmation-gated tools.
23    confirmation_provider: Option<Arc<dyn crate::confirmation::ConfirmationProvider>>,
24}
25
26impl ToolDispatcher {
27    /// Create a new empty tool dispatcher with the default 30-second timeout.
28    ///
29    /// # Examples
30    ///
31    /// ```rust,ignore
32    /// use gemini_adk_rs::tool::{ToolDispatcher, SimpleTool};
33    /// use serde_json::json;
34    ///
35    /// let mut dispatcher = ToolDispatcher::new();
36    /// dispatcher.register(SimpleTool::new(
37    ///     "echo", "Echo input", None,
38    ///     |args| async move { Ok(args) },
39    /// ));
40    /// ```
41    pub fn new() -> Self {
42        Self {
43            tools: HashMap::new(),
44            active: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
45            default_timeout: DEFAULT_TOOL_TIMEOUT,
46            cached_declarations: std::sync::OnceLock::new(),
47            confirmation_provider: None,
48        }
49    }
50
51    /// Set the default timeout for tool calls.
52    pub fn with_timeout(mut self, timeout: Duration) -> Self {
53        self.default_timeout = timeout;
54        self
55    }
56
57    /// Attach a confirmation provider (builder form).
58    ///
59    /// Once set, any tool reporting
60    /// [`requires_confirmation`](crate::tool::ToolFunction::requires_confirmation)
61    /// — e.g. one built with `T::confirm(..)` — is checked against the provider
62    /// before it executes; a denied decision returns a `ToolError` instead of
63    /// running the tool. With no provider configured, confirmation-gated tools
64    /// run normally (enforcement is opt-in).
65    pub fn with_confirmation_provider(
66        mut self,
67        provider: Arc<dyn crate::confirmation::ConfirmationProvider>,
68    ) -> Self {
69        self.confirmation_provider = Some(provider);
70        self
71    }
72
73    /// Attach a confirmation provider in place. See
74    /// [`with_confirmation_provider`](Self::with_confirmation_provider).
75    pub fn set_confirmation_provider(
76        &mut self,
77        provider: Arc<dyn crate::confirmation::ConfirmationProvider>,
78    ) {
79        self.confirmation_provider = Some(provider);
80    }
81
82    /// Whether a confirmation provider is configured.
83    pub fn has_confirmation_provider(&self) -> bool {
84        self.confirmation_provider.is_some()
85    }
86
87    /// Consult the confirmation provider for a gated tool. Returns `Ok(())`
88    /// when the tool is not gated, no provider is set, or the call is approved;
89    /// returns a `ToolError` when the provider denies it.
90    async fn ensure_confirmed(
91        &self,
92        func: &Arc<dyn ToolFunction>,
93        args: &serde_json::Value,
94    ) -> Result<(), ToolError> {
95        if !func.requires_confirmation() {
96            return Ok(());
97        }
98        let Some(provider) = &self.confirmation_provider else {
99            return Ok(());
100        };
101        let request = crate::confirmation::ConfirmationRequest {
102            tool_name: func.name().to_string(),
103            args: args.clone(),
104            message: func.confirmation_message().map(str::to_string),
105        };
106        let decision = provider.confirm(request).await;
107        if decision.confirmed {
108            Ok(())
109        } else {
110            Err(ToolError::Cancelled)
111        }
112    }
113
114    /// Returns the configured default timeout.
115    pub fn default_timeout(&self) -> Duration {
116        self.default_timeout
117    }
118
119    /// Register a tool that implements [`ToolFunction`].
120    pub fn register(&mut self, tool: impl ToolFunction) {
121        let tool = Arc::new(tool);
122        self.tools
123            .insert(tool.name().to_string(), ToolKind::Function(tool));
124    }
125
126    /// Register a regular function tool (pre-wrapped in Arc).
127    pub fn register_function(&mut self, tool: Arc<dyn ToolFunction>) {
128        self.tools
129            .insert(tool.name().to_string(), ToolKind::Function(tool));
130    }
131
132    /// Register a streaming tool.
133    pub fn register_streaming(&mut self, tool: Arc<dyn super::StreamingTool>) {
134        self.tools
135            .insert(tool.name().to_string(), ToolKind::Streaming(tool));
136    }
137
138    /// Register an input-streaming tool.
139    pub fn register_input_streaming(&mut self, tool: Arc<dyn super::InputStreamingTool>) {
140        self.tools
141            .insert(tool.name().to_string(), ToolKind::InputStream(tool));
142    }
143
144    /// Get a tool by name (for introspection/streaming tool spawning).
145    pub fn get_tool(&self, name: &str) -> Option<&ToolKind> {
146        self.tools.get(name)
147    }
148
149    /// Classify a tool by name.
150    pub fn classify(&self, name: &str) -> Option<ToolClass> {
151        self.tools.get(name).map(|t| match t {
152            ToolKind::Function(_) => ToolClass::Regular,
153            ToolKind::Streaming(_) => ToolClass::Streaming,
154            ToolKind::InputStream(_) => ToolClass::InputStream,
155        })
156    }
157
158    /// Call a regular function tool by name, using the default timeout.
159    pub async fn call_function(
160        &self,
161        name: &str,
162        args: serde_json::Value,
163    ) -> Result<serde_json::Value, ToolError> {
164        self.call_function_with_timeout(name, args, self.default_timeout)
165            .await
166    }
167
168    /// Call a regular function tool by name with an explicit timeout.
169    ///
170    /// If the tool does not complete within the given duration, its future is
171    /// dropped (cancelling it) and `ToolError::Timeout` is returned.
172    pub async fn call_function_with_timeout(
173        &self,
174        name: &str,
175        args: serde_json::Value,
176        timeout: Duration,
177    ) -> Result<serde_json::Value, ToolError> {
178        let func = match self.tools.get(name) {
179            Some(ToolKind::Function(f)) => f.clone(),
180            Some(_) => {
181                return Err(ToolError::Other(format!(
182                    "{name} is not a regular function tool"
183                )))
184            }
185            None => return Err(ToolError::NotFound(name.to_string())),
186        };
187
188        self.ensure_confirmed(&func, &args).await?;
189
190        match tokio::time::timeout(timeout, func.call(args)).await {
191            Ok(result) => result,
192            Err(_elapsed) => Err(ToolError::Timeout(timeout)),
193        }
194    }
195
196    /// Call a regular function tool by name, racing against a cancellation token.
197    ///
198    /// If the token is cancelled before the tool completes, its future is
199    /// dropped and `ToolError::Cancelled` is returned.
200    pub async fn call_function_with_cancel(
201        &self,
202        name: &str,
203        args: serde_json::Value,
204        cancel: CancellationToken,
205    ) -> Result<serde_json::Value, ToolError> {
206        let func = match self.tools.get(name) {
207            Some(ToolKind::Function(f)) => f.clone(),
208            Some(_) => {
209                return Err(ToolError::Other(format!(
210                    "{name} is not a regular function tool"
211                )))
212            }
213            None => return Err(ToolError::NotFound(name.to_string())),
214        };
215
216        self.ensure_confirmed(&func, &args).await?;
217
218        tokio::select! {
219            result = func.call(args) => result,
220            _ = cancel.cancelled() => Err(ToolError::Cancelled),
221        }
222    }
223
224    /// Build a FunctionResponse from a FunctionCall result.
225    pub fn build_response(
226        call: &FunctionCall,
227        result: Result<serde_json::Value, ToolError>,
228    ) -> FunctionResponse {
229        match result {
230            Ok(value) => FunctionResponse {
231                name: call.name.clone(),
232                response: value,
233                id: call.id.clone(),
234                scheduling: None,
235            },
236            Err(e) => FunctionResponse {
237                name: call.name.clone(),
238                response: serde_json::json!({"error": e.to_string()}),
239                id: call.id.clone(),
240                scheduling: None,
241            },
242        }
243    }
244
245    /// Cancel a streaming tool by name.
246    pub async fn cancel_streaming(&self, name: &str) {
247        let mut active = self.active.lock().await;
248        if let Some(tool) = active.remove(name) {
249            tool.cancel.cancel();
250            tool.task.abort();
251        }
252    }
253
254    /// Store an active streaming tool (for cancellation tracking).
255    pub(crate) async fn store_active(&self, id: String, tool: ActiveStreamingTool) {
256        self.active.lock().await.insert(id, tool);
257    }
258
259    /// Cancel streaming tools by IDs.
260    pub async fn cancel_by_ids(&self, ids: &[String]) {
261        let mut active = self.active.lock().await;
262        for id in ids {
263            if let Some(tool) = active.remove(id.as_str()) {
264                tool.cancel.cancel();
265                tool.task.abort();
266            }
267        }
268    }
269
270    /// Generate Tool declarations for the setup message.
271    ///
272    /// Results are cached after first computation. The cache is invalidated
273    /// when tools are registered via `register*()` methods.
274    pub fn to_tool_declarations(&self) -> Vec<Tool> {
275        self.cached_declarations
276            .get_or_init(|| {
277                let declarations: Vec<FunctionDeclaration> = self
278                    .tools
279                    .values()
280                    .map(|t| {
281                        let (name, desc, params) = match t {
282                            ToolKind::Function(f) => (f.name(), f.description(), f.parameters()),
283                            ToolKind::Streaming(s) => (s.name(), s.description(), s.parameters()),
284                            ToolKind::InputStream(i) => (i.name(), i.description(), i.parameters()),
285                        };
286                        FunctionDeclaration {
287                            name: name.to_string(),
288                            description: desc.to_string(),
289                            parameters: params,
290                            behavior: None,
291                        }
292                    })
293                    .collect();
294
295                if declarations.is_empty() {
296                    vec![]
297                } else {
298                    vec![Tool::functions(declarations)]
299                }
300            })
301            .clone()
302    }
303
304    /// Number of registered tools.
305    pub fn len(&self) -> usize {
306        self.tools.len()
307    }
308
309    /// Whether no tools are registered.
310    pub fn is_empty(&self) -> bool {
311        self.tools.is_empty()
312    }
313}
314
315impl Default for ToolDispatcher {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321impl gemini_genai_rs::prelude::ToolProvider for ToolDispatcher {
322    fn declarations(&self) -> Vec<gemini_genai_rs::prelude::Tool> {
323        self.to_tool_declarations()
324    }
325}
326
327#[cfg(test)]
328mod confirmation_tests {
329    use super::*;
330    use crate::confirmation::StaticConfirmation;
331    use crate::tool::{policy::ToolPolicy, PolicyTool, SimpleTool};
332    use serde_json::json;
333    use std::sync::atomic::{AtomicUsize, Ordering};
334
335    /// A counting tool wrapped in a confirm policy.
336    fn confirm_tool(runs: Arc<AtomicUsize>) -> Arc<dyn ToolFunction> {
337        let inner: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
338            "danger",
339            "does something sensitive",
340            None,
341            move |_| {
342                let runs = runs.clone();
343                async move {
344                    runs.fetch_add(1, Ordering::SeqCst);
345                    Ok(json!({ "ok": true }))
346                }
347            },
348        ));
349        Arc::new(PolicyTool::new(
350            inner,
351            ToolPolicy::new().with_confirm(Some("delete production data?".into())),
352        ))
353    }
354
355    #[tokio::test]
356    async fn denied_confirmation_blocks_execution() {
357        let runs = Arc::new(AtomicUsize::new(0));
358        let mut d = ToolDispatcher::new();
359        d.register_function(confirm_tool(runs.clone()));
360        d.set_confirmation_provider(StaticConfirmation::deny_all("blocked by policy"));
361
362        let result = d.call_function("danger", json!({})).await;
363        assert!(matches!(result, Err(ToolError::Cancelled)));
364        assert_eq!(
365            runs.load(Ordering::SeqCst),
366            0,
367            "tool must not run when denied"
368        );
369    }
370
371    #[tokio::test]
372    async fn approved_confirmation_runs() {
373        let runs = Arc::new(AtomicUsize::new(0));
374        let mut d = ToolDispatcher::new();
375        d.register_function(confirm_tool(runs.clone()));
376        d.set_confirmation_provider(StaticConfirmation::allow_all());
377
378        let out = d.call_function("danger", json!({})).await.unwrap();
379        assert_eq!(out["ok"], true);
380        assert_eq!(runs.load(Ordering::SeqCst), 1);
381    }
382
383    #[tokio::test]
384    async fn no_provider_runs_optin() {
385        // Enforcement is opt-in: a confirm-gated tool runs when no provider is set.
386        let runs = Arc::new(AtomicUsize::new(0));
387        let mut d = ToolDispatcher::new();
388        d.register_function(confirm_tool(runs.clone()));
389
390        let out = d.call_function("danger", json!({})).await.unwrap();
391        assert_eq!(out["ok"], true);
392        assert_eq!(runs.load(Ordering::SeqCst), 1);
393    }
394
395    #[tokio::test]
396    async fn provider_sees_request_and_ignores_non_gated_tools() {
397        // A non-confirm tool is never sent to the (deny-all) provider.
398        let mut d = ToolDispatcher::new();
399        d.register(SimpleTool::new(
400            "plain",
401            "no confirmation",
402            None,
403            |_| async move { Ok(json!({ "ran": true })) },
404        ));
405        d.set_confirmation_provider(StaticConfirmation::deny_all("should not be consulted"));
406
407        let out = d.call_function("plain", json!({})).await.unwrap();
408        assert_eq!(out["ran"], true);
409    }
410
411    #[tokio::test]
412    async fn nested_policy_wrapper_does_not_bypass_confirmation() {
413        // T::cached(T::confirm(tool)): an outer cache PolicyTool (confirm=false)
414        // wraps an inner confirm PolicyTool. The gate must still fire.
415        let runs = Arc::new(AtomicUsize::new(0));
416        let inner_confirm = confirm_tool(runs.clone()); // Arc<PolicyTool{confirm}>
417        let outer_cached: Arc<dyn ToolFunction> = Arc::new(PolicyTool::new(
418            inner_confirm,
419            ToolPolicy::new().with_cache(),
420        ));
421        assert!(
422            outer_cached.requires_confirmation(),
423            "must propagate through nesting"
424        );
425
426        let mut d = ToolDispatcher::new();
427        d.register_function(outer_cached);
428        d.set_confirmation_provider(StaticConfirmation::deny_all("blocked"));
429
430        let result = d.call_function("danger", json!({})).await;
431        assert!(matches!(result, Err(ToolError::Cancelled)));
432        assert_eq!(
433            runs.load(Ordering::SeqCst),
434            0,
435            "nested confirm must not run when denied"
436        );
437    }
438
439    #[tokio::test]
440    async fn closure_provider_can_gate_by_name() {
441        let runs = Arc::new(AtomicUsize::new(0));
442        let mut d = ToolDispatcher::new();
443        d.register_function(confirm_tool(runs.clone()));
444        d.set_confirmation_provider(Arc::new(
445            |req: crate::confirmation::ConfirmationRequest| async move {
446                if req.tool_name == "danger" {
447                    crate::confirmation::ToolConfirmation::denied("name-gated")
448                } else {
449                    crate::confirmation::ToolConfirmation::confirmed()
450                }
451            },
452        ));
453
454        assert!(d.call_function("danger", json!({})).await.is_err());
455        assert_eq!(runs.load(Ordering::SeqCst), 0);
456    }
457}