1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio_util::sync::CancellationToken;
8
9use gemini_genai_rs::prelude::{FunctionCall, FunctionDeclaration, FunctionResponse, Tool};
10
11use crate::error::ToolError;
12
13use super::{ActiveStreamingTool, ToolClass, ToolFunction, ToolKind, DEFAULT_TOOL_TIMEOUT};
14
15pub struct ToolDispatcher {
17 tools: HashMap<String, ToolKind>,
18 active: Arc<tokio::sync::Mutex<HashMap<String, ActiveStreamingTool>>>,
19 default_timeout: Duration,
20 cached_declarations: std::sync::OnceLock<Vec<Tool>>,
22 confirmation_provider: Option<Arc<dyn crate::confirmation::ConfirmationProvider>>,
24}
25
26impl ToolDispatcher {
27 pub fn new() -> Self {
42 Self {
43 tools: HashMap::new(),
44 active: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
45 default_timeout: DEFAULT_TOOL_TIMEOUT,
46 cached_declarations: std::sync::OnceLock::new(),
47 confirmation_provider: None,
48 }
49 }
50
51 pub fn with_timeout(mut self, timeout: Duration) -> Self {
53 self.default_timeout = timeout;
54 self
55 }
56
57 pub fn with_confirmation_provider(
66 mut self,
67 provider: Arc<dyn crate::confirmation::ConfirmationProvider>,
68 ) -> Self {
69 self.confirmation_provider = Some(provider);
70 self
71 }
72
73 pub fn set_confirmation_provider(
76 &mut self,
77 provider: Arc<dyn crate::confirmation::ConfirmationProvider>,
78 ) {
79 self.confirmation_provider = Some(provider);
80 }
81
82 pub fn has_confirmation_provider(&self) -> bool {
84 self.confirmation_provider.is_some()
85 }
86
87 async fn ensure_confirmed(
91 &self,
92 func: &Arc<dyn ToolFunction>,
93 args: &serde_json::Value,
94 ) -> Result<(), ToolError> {
95 if !func.requires_confirmation() {
96 return Ok(());
97 }
98 let Some(provider) = &self.confirmation_provider else {
99 return Ok(());
100 };
101 let request = crate::confirmation::ConfirmationRequest {
102 tool_name: func.name().to_string(),
103 args: args.clone(),
104 message: func.confirmation_message().map(str::to_string),
105 };
106 let decision = provider.confirm(request).await;
107 if decision.confirmed {
108 Ok(())
109 } else {
110 Err(ToolError::Cancelled)
111 }
112 }
113
114 pub fn default_timeout(&self) -> Duration {
116 self.default_timeout
117 }
118
119 pub fn register(&mut self, tool: impl ToolFunction) {
121 let tool = Arc::new(tool);
122 self.tools
123 .insert(tool.name().to_string(), ToolKind::Function(tool));
124 }
125
126 pub fn register_function(&mut self, tool: Arc<dyn ToolFunction>) {
128 self.tools
129 .insert(tool.name().to_string(), ToolKind::Function(tool));
130 }
131
132 pub fn register_streaming(&mut self, tool: Arc<dyn super::StreamingTool>) {
134 self.tools
135 .insert(tool.name().to_string(), ToolKind::Streaming(tool));
136 }
137
138 pub fn register_input_streaming(&mut self, tool: Arc<dyn super::InputStreamingTool>) {
140 self.tools
141 .insert(tool.name().to_string(), ToolKind::InputStream(tool));
142 }
143
144 pub fn get_tool(&self, name: &str) -> Option<&ToolKind> {
146 self.tools.get(name)
147 }
148
149 pub fn classify(&self, name: &str) -> Option<ToolClass> {
151 self.tools.get(name).map(|t| match t {
152 ToolKind::Function(_) => ToolClass::Regular,
153 ToolKind::Streaming(_) => ToolClass::Streaming,
154 ToolKind::InputStream(_) => ToolClass::InputStream,
155 })
156 }
157
158 pub async fn call_function(
160 &self,
161 name: &str,
162 args: serde_json::Value,
163 ) -> Result<serde_json::Value, ToolError> {
164 self.call_function_with_timeout(name, args, self.default_timeout)
165 .await
166 }
167
168 pub async fn call_function_with_timeout(
173 &self,
174 name: &str,
175 args: serde_json::Value,
176 timeout: Duration,
177 ) -> Result<serde_json::Value, ToolError> {
178 let func = match self.tools.get(name) {
179 Some(ToolKind::Function(f)) => f.clone(),
180 Some(_) => {
181 return Err(ToolError::Other(format!(
182 "{name} is not a regular function tool"
183 )))
184 }
185 None => return Err(ToolError::NotFound(name.to_string())),
186 };
187
188 self.ensure_confirmed(&func, &args).await?;
189
190 match tokio::time::timeout(timeout, func.call(args)).await {
191 Ok(result) => result,
192 Err(_elapsed) => Err(ToolError::Timeout(timeout)),
193 }
194 }
195
196 pub async fn call_function_with_cancel(
201 &self,
202 name: &str,
203 args: serde_json::Value,
204 cancel: CancellationToken,
205 ) -> Result<serde_json::Value, ToolError> {
206 let func = match self.tools.get(name) {
207 Some(ToolKind::Function(f)) => f.clone(),
208 Some(_) => {
209 return Err(ToolError::Other(format!(
210 "{name} is not a regular function tool"
211 )))
212 }
213 None => return Err(ToolError::NotFound(name.to_string())),
214 };
215
216 self.ensure_confirmed(&func, &args).await?;
217
218 tokio::select! {
219 result = func.call(args) => result,
220 _ = cancel.cancelled() => Err(ToolError::Cancelled),
221 }
222 }
223
224 pub fn build_response(
226 call: &FunctionCall,
227 result: Result<serde_json::Value, ToolError>,
228 ) -> FunctionResponse {
229 match result {
230 Ok(value) => FunctionResponse {
231 name: call.name.clone(),
232 response: value,
233 id: call.id.clone(),
234 scheduling: None,
235 },
236 Err(e) => FunctionResponse {
237 name: call.name.clone(),
238 response: serde_json::json!({"error": e.to_string()}),
239 id: call.id.clone(),
240 scheduling: None,
241 },
242 }
243 }
244
245 pub async fn cancel_streaming(&self, name: &str) {
247 let mut active = self.active.lock().await;
248 if let Some(tool) = active.remove(name) {
249 tool.cancel.cancel();
250 tool.task.abort();
251 }
252 }
253
254 pub(crate) async fn store_active(&self, id: String, tool: ActiveStreamingTool) {
256 self.active.lock().await.insert(id, tool);
257 }
258
259 pub async fn cancel_by_ids(&self, ids: &[String]) {
261 let mut active = self.active.lock().await;
262 for id in ids {
263 if let Some(tool) = active.remove(id.as_str()) {
264 tool.cancel.cancel();
265 tool.task.abort();
266 }
267 }
268 }
269
270 pub fn to_tool_declarations(&self) -> Vec<Tool> {
275 self.cached_declarations
276 .get_or_init(|| {
277 let declarations: Vec<FunctionDeclaration> = self
278 .tools
279 .values()
280 .map(|t| {
281 let (name, desc, params) = match t {
282 ToolKind::Function(f) => (f.name(), f.description(), f.parameters()),
283 ToolKind::Streaming(s) => (s.name(), s.description(), s.parameters()),
284 ToolKind::InputStream(i) => (i.name(), i.description(), i.parameters()),
285 };
286 FunctionDeclaration {
287 name: name.to_string(),
288 description: desc.to_string(),
289 parameters: params,
290 behavior: None,
291 }
292 })
293 .collect();
294
295 if declarations.is_empty() {
296 vec![]
297 } else {
298 vec![Tool::functions(declarations)]
299 }
300 })
301 .clone()
302 }
303
304 pub fn len(&self) -> usize {
306 self.tools.len()
307 }
308
309 pub fn is_empty(&self) -> bool {
311 self.tools.is_empty()
312 }
313}
314
315impl Default for ToolDispatcher {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321impl gemini_genai_rs::prelude::ToolProvider for ToolDispatcher {
322 fn declarations(&self) -> Vec<gemini_genai_rs::prelude::Tool> {
323 self.to_tool_declarations()
324 }
325}
326
327#[cfg(test)]
328mod confirmation_tests {
329 use super::*;
330 use crate::confirmation::StaticConfirmation;
331 use crate::tool::{policy::ToolPolicy, PolicyTool, SimpleTool};
332 use serde_json::json;
333 use std::sync::atomic::{AtomicUsize, Ordering};
334
335 fn confirm_tool(runs: Arc<AtomicUsize>) -> Arc<dyn ToolFunction> {
337 let inner: Arc<dyn ToolFunction> = Arc::new(SimpleTool::new(
338 "danger",
339 "does something sensitive",
340 None,
341 move |_| {
342 let runs = runs.clone();
343 async move {
344 runs.fetch_add(1, Ordering::SeqCst);
345 Ok(json!({ "ok": true }))
346 }
347 },
348 ));
349 Arc::new(PolicyTool::new(
350 inner,
351 ToolPolicy::new().with_confirm(Some("delete production data?".into())),
352 ))
353 }
354
355 #[tokio::test]
356 async fn denied_confirmation_blocks_execution() {
357 let runs = Arc::new(AtomicUsize::new(0));
358 let mut d = ToolDispatcher::new();
359 d.register_function(confirm_tool(runs.clone()));
360 d.set_confirmation_provider(StaticConfirmation::deny_all("blocked by policy"));
361
362 let result = d.call_function("danger", json!({})).await;
363 assert!(matches!(result, Err(ToolError::Cancelled)));
364 assert_eq!(
365 runs.load(Ordering::SeqCst),
366 0,
367 "tool must not run when denied"
368 );
369 }
370
371 #[tokio::test]
372 async fn approved_confirmation_runs() {
373 let runs = Arc::new(AtomicUsize::new(0));
374 let mut d = ToolDispatcher::new();
375 d.register_function(confirm_tool(runs.clone()));
376 d.set_confirmation_provider(StaticConfirmation::allow_all());
377
378 let out = d.call_function("danger", json!({})).await.unwrap();
379 assert_eq!(out["ok"], true);
380 assert_eq!(runs.load(Ordering::SeqCst), 1);
381 }
382
383 #[tokio::test]
384 async fn no_provider_runs_optin() {
385 let runs = Arc::new(AtomicUsize::new(0));
387 let mut d = ToolDispatcher::new();
388 d.register_function(confirm_tool(runs.clone()));
389
390 let out = d.call_function("danger", json!({})).await.unwrap();
391 assert_eq!(out["ok"], true);
392 assert_eq!(runs.load(Ordering::SeqCst), 1);
393 }
394
395 #[tokio::test]
396 async fn provider_sees_request_and_ignores_non_gated_tools() {
397 let mut d = ToolDispatcher::new();
399 d.register(SimpleTool::new(
400 "plain",
401 "no confirmation",
402 None,
403 |_| async move { Ok(json!({ "ran": true })) },
404 ));
405 d.set_confirmation_provider(StaticConfirmation::deny_all("should not be consulted"));
406
407 let out = d.call_function("plain", json!({})).await.unwrap();
408 assert_eq!(out["ran"], true);
409 }
410
411 #[tokio::test]
412 async fn nested_policy_wrapper_does_not_bypass_confirmation() {
413 let runs = Arc::new(AtomicUsize::new(0));
416 let inner_confirm = confirm_tool(runs.clone()); let outer_cached: Arc<dyn ToolFunction> = Arc::new(PolicyTool::new(
418 inner_confirm,
419 ToolPolicy::new().with_cache(),
420 ));
421 assert!(
422 outer_cached.requires_confirmation(),
423 "must propagate through nesting"
424 );
425
426 let mut d = ToolDispatcher::new();
427 d.register_function(outer_cached);
428 d.set_confirmation_provider(StaticConfirmation::deny_all("blocked"));
429
430 let result = d.call_function("danger", json!({})).await;
431 assert!(matches!(result, Err(ToolError::Cancelled)));
432 assert_eq!(
433 runs.load(Ordering::SeqCst),
434 0,
435 "nested confirm must not run when denied"
436 );
437 }
438
439 #[tokio::test]
440 async fn closure_provider_can_gate_by_name() {
441 let runs = Arc::new(AtomicUsize::new(0));
442 let mut d = ToolDispatcher::new();
443 d.register_function(confirm_tool(runs.clone()));
444 d.set_confirmation_provider(Arc::new(
445 |req: crate::confirmation::ConfirmationRequest| async move {
446 if req.tool_name == "danger" {
447 crate::confirmation::ToolConfirmation::denied("name-gated")
448 } else {
449 crate::confirmation::ToolConfirmation::confirmed()
450 }
451 },
452 ));
453
454 assert!(d.call_function("danger", json!({})).await.is_err());
455 assert_eq!(runs.load(Ordering::SeqCst), 0);
456 }
457}