gemini_adk_rs/tool/
dispatcher.rs

1//! Tool dispatcher — routes function calls to the right tool implementation.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio_util::sync::CancellationToken;
8
9use gemini_genai_rs::prelude::{FunctionCall, FunctionDeclaration, FunctionResponse, Tool};
10
11use crate::error::ToolError;
12
13use super::{ActiveStreamingTool, ToolClass, ToolFunction, ToolKind, DEFAULT_TOOL_TIMEOUT};
14
15/// Routes function calls to the right tool implementation.
16pub struct ToolDispatcher {
17    tools: HashMap<String, ToolKind>,
18    active: Arc<tokio::sync::Mutex<HashMap<String, ActiveStreamingTool>>>,
19    default_timeout: Duration,
20    /// Cached tool declarations — computed once on first access.
21    cached_declarations: std::sync::OnceLock<Vec<Tool>>,
22}
23
24impl ToolDispatcher {
25    /// Create a new empty tool dispatcher with the default 30-second timeout.
26    ///
27    /// # Examples
28    ///
29    /// ```rust,ignore
30    /// use gemini_adk_rs::tool::{ToolDispatcher, SimpleTool};
31    /// use serde_json::json;
32    ///
33    /// let mut dispatcher = ToolDispatcher::new();
34    /// dispatcher.register(SimpleTool::new(
35    ///     "echo", "Echo input", None,
36    ///     |args| async move { Ok(args) },
37    /// ));
38    /// ```
39    pub fn new() -> Self {
40        Self {
41            tools: HashMap::new(),
42            active: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
43            default_timeout: DEFAULT_TOOL_TIMEOUT,
44            cached_declarations: std::sync::OnceLock::new(),
45        }
46    }
47
48    /// Set the default timeout for tool calls.
49    pub fn with_timeout(mut self, timeout: Duration) -> Self {
50        self.default_timeout = timeout;
51        self
52    }
53
54    /// Returns the configured default timeout.
55    pub fn default_timeout(&self) -> Duration {
56        self.default_timeout
57    }
58
59    /// Register a tool that implements [`ToolFunction`].
60    pub fn register(&mut self, tool: impl ToolFunction) {
61        let tool = Arc::new(tool);
62        self.tools
63            .insert(tool.name().to_string(), ToolKind::Function(tool));
64    }
65
66    /// Register a regular function tool (pre-wrapped in Arc).
67    pub fn register_function(&mut self, tool: Arc<dyn ToolFunction>) {
68        self.tools
69            .insert(tool.name().to_string(), ToolKind::Function(tool));
70    }
71
72    /// Register a streaming tool.
73    pub fn register_streaming(&mut self, tool: Arc<dyn super::StreamingTool>) {
74        self.tools
75            .insert(tool.name().to_string(), ToolKind::Streaming(tool));
76    }
77
78    /// Register an input-streaming tool.
79    pub fn register_input_streaming(&mut self, tool: Arc<dyn super::InputStreamingTool>) {
80        self.tools
81            .insert(tool.name().to_string(), ToolKind::InputStream(tool));
82    }
83
84    /// Get a tool by name (for introspection/streaming tool spawning).
85    pub fn get_tool(&self, name: &str) -> Option<&ToolKind> {
86        self.tools.get(name)
87    }
88
89    /// Classify a tool by name.
90    pub fn classify(&self, name: &str) -> Option<ToolClass> {
91        self.tools.get(name).map(|t| match t {
92            ToolKind::Function(_) => ToolClass::Regular,
93            ToolKind::Streaming(_) => ToolClass::Streaming,
94            ToolKind::InputStream(_) => ToolClass::InputStream,
95        })
96    }
97
98    /// Call a regular function tool by name, using the default timeout.
99    pub async fn call_function(
100        &self,
101        name: &str,
102        args: serde_json::Value,
103    ) -> Result<serde_json::Value, ToolError> {
104        self.call_function_with_timeout(name, args, self.default_timeout)
105            .await
106    }
107
108    /// Call a regular function tool by name with an explicit timeout.
109    ///
110    /// If the tool does not complete within the given duration, its future is
111    /// dropped (cancelling it) and `ToolError::Timeout` is returned.
112    pub async fn call_function_with_timeout(
113        &self,
114        name: &str,
115        args: serde_json::Value,
116        timeout: Duration,
117    ) -> Result<serde_json::Value, ToolError> {
118        let func = match self.tools.get(name) {
119            Some(ToolKind::Function(f)) => f.clone(),
120            Some(_) => {
121                return Err(ToolError::Other(format!(
122                    "{name} is not a regular function tool"
123                )))
124            }
125            None => return Err(ToolError::NotFound(name.to_string())),
126        };
127
128        match tokio::time::timeout(timeout, func.call(args)).await {
129            Ok(result) => result,
130            Err(_elapsed) => Err(ToolError::Timeout(timeout)),
131        }
132    }
133
134    /// Call a regular function tool by name, racing against a cancellation token.
135    ///
136    /// If the token is cancelled before the tool completes, its future is
137    /// dropped and `ToolError::Cancelled` is returned.
138    pub async fn call_function_with_cancel(
139        &self,
140        name: &str,
141        args: serde_json::Value,
142        cancel: CancellationToken,
143    ) -> Result<serde_json::Value, ToolError> {
144        let func = match self.tools.get(name) {
145            Some(ToolKind::Function(f)) => f.clone(),
146            Some(_) => {
147                return Err(ToolError::Other(format!(
148                    "{name} is not a regular function tool"
149                )))
150            }
151            None => return Err(ToolError::NotFound(name.to_string())),
152        };
153
154        tokio::select! {
155            result = func.call(args) => result,
156            _ = cancel.cancelled() => Err(ToolError::Cancelled),
157        }
158    }
159
160    /// Build a FunctionResponse from a FunctionCall result.
161    pub fn build_response(
162        call: &FunctionCall,
163        result: Result<serde_json::Value, ToolError>,
164    ) -> FunctionResponse {
165        match result {
166            Ok(value) => FunctionResponse {
167                name: call.name.clone(),
168                response: value,
169                id: call.id.clone(),
170                scheduling: None,
171            },
172            Err(e) => FunctionResponse {
173                name: call.name.clone(),
174                response: serde_json::json!({"error": e.to_string()}),
175                id: call.id.clone(),
176                scheduling: None,
177            },
178        }
179    }
180
181    /// Cancel a streaming tool by name.
182    pub async fn cancel_streaming(&self, name: &str) {
183        let mut active = self.active.lock().await;
184        if let Some(tool) = active.remove(name) {
185            tool.cancel.cancel();
186            tool.task.abort();
187        }
188    }
189
190    /// Store an active streaming tool (for cancellation tracking).
191    pub(crate) async fn store_active(&self, id: String, tool: ActiveStreamingTool) {
192        self.active.lock().await.insert(id, tool);
193    }
194
195    /// Cancel streaming tools by IDs.
196    pub async fn cancel_by_ids(&self, ids: &[String]) {
197        let mut active = self.active.lock().await;
198        for id in ids {
199            if let Some(tool) = active.remove(id.as_str()) {
200                tool.cancel.cancel();
201                tool.task.abort();
202            }
203        }
204    }
205
206    /// Generate Tool declarations for the setup message.
207    ///
208    /// Results are cached after first computation. The cache is invalidated
209    /// when tools are registered via `register*()` methods.
210    pub fn to_tool_declarations(&self) -> Vec<Tool> {
211        self.cached_declarations
212            .get_or_init(|| {
213                let declarations: Vec<FunctionDeclaration> = self
214                    .tools
215                    .values()
216                    .map(|t| {
217                        let (name, desc, params) = match t {
218                            ToolKind::Function(f) => (f.name(), f.description(), f.parameters()),
219                            ToolKind::Streaming(s) => (s.name(), s.description(), s.parameters()),
220                            ToolKind::InputStream(i) => (i.name(), i.description(), i.parameters()),
221                        };
222                        FunctionDeclaration {
223                            name: name.to_string(),
224                            description: desc.to_string(),
225                            parameters: params,
226                            behavior: None,
227                        }
228                    })
229                    .collect();
230
231                if declarations.is_empty() {
232                    vec![]
233                } else {
234                    vec![Tool::functions(declarations)]
235                }
236            })
237            .clone()
238    }
239
240    /// Number of registered tools.
241    pub fn len(&self) -> usize {
242        self.tools.len()
243    }
244
245    /// Whether no tools are registered.
246    pub fn is_empty(&self) -> bool {
247        self.tools.is_empty()
248    }
249}
250
251impl Default for ToolDispatcher {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257impl gemini_genai_rs::prelude::ToolProvider for ToolDispatcher {
258    fn declarations(&self) -> Vec<gemini_genai_rs::prelude::Tool> {
259        self.to_tool_declarations()
260    }
261}