gemini_adk_rs/tool/
mod.rs

1//! Tool dispatch — regular, streaming, and input-streaming tools.
2
3pub mod dispatcher;
4pub mod policy;
5pub mod simple;
6pub mod typed;
7
8pub use dispatcher::*;
9pub use policy::*;
10pub use simple::*;
11pub use typed::*;
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use async_trait::async_trait;
17use tokio::sync::{broadcast, mpsc};
18use tokio::task::JoinHandle;
19use tokio_util::sync::CancellationToken;
20
21use crate::agent_session::InputEvent;
22use crate::error::ToolError;
23
24/// A regular tool — called once, returns a result.
25///
26/// # Examples
27///
28/// ```rust,ignore
29/// use async_trait::async_trait;
30/// use gemini_adk_rs::tool::ToolFunction;
31/// use gemini_adk_rs::error::ToolError;
32///
33/// struct MyTool;
34///
35/// #[async_trait]
36/// impl ToolFunction for MyTool {
37///     fn name(&self) -> &str { "my_tool" }
38///     fn description(&self) -> &str { "Does something useful" }
39///     fn parameters(&self) -> Option<serde_json::Value> { None }
40///     async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
41///         Ok(serde_json::json!({"status": "ok"}))
42///     }
43/// }
44/// ```
45#[async_trait]
46pub trait ToolFunction: Send + Sync + 'static {
47    /// The unique name of this tool.
48    fn name(&self) -> &str;
49    /// Human-readable description of what this tool does.
50    fn description(&self) -> &str;
51    /// JSON Schema for the tool's input parameters, or `None` if parameterless.
52    fn parameters(&self) -> Option<serde_json::Value>;
53    /// Execute the tool with the given arguments and return the result.
54    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError>;
55
56    /// Whether this tool must be confirmed before it runs. Defaults to `false`.
57    ///
58    /// Tools built with `T::confirm(..)` (a [`PolicyTool`] with a confirm
59    /// policy) return `true`, prompting the [`ToolDispatcher`] to consult its
60    /// configured `ConfirmationProvider` before executing.
61    fn requires_confirmation(&self) -> bool {
62        false
63    }
64
65    /// Optional hint shown when confirmation is requested.
66    fn confirmation_message(&self) -> Option<&str> {
67        None
68    }
69}
70
71/// A streaming tool — runs in background, yields multiple results.
72#[async_trait]
73pub trait StreamingTool: Send + Sync + 'static {
74    /// The unique name of this tool.
75    fn name(&self) -> &str;
76    /// Human-readable description of what this tool does.
77    fn description(&self) -> &str;
78    /// JSON Schema for the tool's input parameters, or `None` if parameterless.
79    fn parameters(&self) -> Option<serde_json::Value>;
80    /// Execute the tool, sending intermediate results via `yield_tx`.
81    async fn run(
82        &self,
83        args: serde_json::Value,
84        yield_tx: mpsc::Sender<serde_json::Value>,
85    ) -> Result<(), ToolError>;
86}
87
88/// An input-streaming tool — receives duplicated live input while running.
89#[async_trait]
90pub trait InputStreamingTool: Send + Sync + 'static {
91    /// The unique name of this tool.
92    fn name(&self) -> &str;
93    /// Human-readable description of what this tool does.
94    fn description(&self) -> &str;
95    /// JSON Schema for the tool's input parameters, or `None` if parameterless.
96    fn parameters(&self) -> Option<serde_json::Value>;
97    /// Execute the tool, receiving live input via `input_rx` and sending results via `yield_tx`.
98    async fn run(
99        &self,
100        args: serde_json::Value,
101        input_rx: broadcast::Receiver<InputEvent>,
102        yield_tx: mpsc::Sender<serde_json::Value>,
103    ) -> Result<(), ToolError>;
104}
105
106/// Classification of a registered tool.
107#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum ToolClass {
109    /// A one-shot tool that returns a single result.
110    Regular,
111    /// A tool that yields multiple results over time.
112    Streaming,
113    /// A tool that receives live input while producing results.
114    InputStream,
115}
116
117/// Unified tool storage.
118pub enum ToolKind {
119    /// A regular one-shot function tool.
120    Function(Arc<dyn ToolFunction>),
121    /// A streaming tool that yields multiple results.
122    Streaming(Arc<dyn StreamingTool>),
123    /// An input-streaming tool that receives live input.
124    InputStream(Arc<dyn InputStreamingTool>),
125}
126
127/// Handle to a running streaming tool.
128pub struct ActiveStreamingTool {
129    /// The spawned task handle.
130    pub task: JoinHandle<()>,
131    /// Token to cancel this streaming tool.
132    pub cancel: CancellationToken,
133}
134
135/// Default timeout for tool execution (30 seconds).
136pub(crate) const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(30);
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use gemini_genai_rs::prelude::FunctionCall;
142    use serde_json::json;
143
144    struct MockTool;
145
146    #[async_trait]
147    impl ToolFunction for MockTool {
148        fn name(&self) -> &str {
149            "mock_tool"
150        }
151        fn description(&self) -> &str {
152            "A mock tool"
153        }
154        fn parameters(&self) -> Option<serde_json::Value> {
155            None
156        }
157        async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
158            Ok(json!({"result": "ok"}))
159        }
160    }
161
162    #[tokio::test]
163    async fn register_and_call_function_tool() {
164        let mut dispatcher = ToolDispatcher::new();
165        dispatcher.register_function(Arc::new(MockTool));
166        let result = dispatcher
167            .call_function("mock_tool", json!({}))
168            .await
169            .unwrap();
170        assert_eq!(result["result"], "ok");
171    }
172
173    #[tokio::test]
174    async fn call_unknown_tool_returns_error() {
175        let dispatcher = ToolDispatcher::new();
176        let result = dispatcher.call_function("nonexistent", json!({})).await;
177        assert!(result.is_err());
178    }
179
180    #[test]
181    fn to_tool_declarations() {
182        let mut dispatcher = ToolDispatcher::new();
183        dispatcher.register_function(Arc::new(MockTool));
184        let decls = dispatcher.to_tool_declarations();
185        assert_eq!(decls.len(), 1);
186    }
187
188    #[test]
189    fn classify_tool() {
190        let mut dispatcher = ToolDispatcher::new();
191        dispatcher.register_function(Arc::new(MockTool));
192        assert_eq!(dispatcher.classify("mock_tool"), Some(ToolClass::Regular));
193        assert_eq!(dispatcher.classify("nonexistent"), None);
194    }
195
196    #[test]
197    fn empty_dispatcher() {
198        let dispatcher = ToolDispatcher::new();
199        assert!(dispatcher.is_empty());
200        assert_eq!(dispatcher.len(), 0);
201        assert!(dispatcher.to_tool_declarations().is_empty());
202    }
203
204    #[test]
205    fn build_response_success() {
206        let call = FunctionCall {
207            name: "test".to_string(),
208            args: json!({}),
209            id: Some("call-1".to_string()),
210        };
211        let resp = ToolDispatcher::build_response(&call, Ok(json!({"ok": true})));
212        assert_eq!(resp.name, "test");
213        assert_eq!(resp.response["ok"], true);
214    }
215
216    #[test]
217    fn build_response_error() {
218        let call = FunctionCall {
219            name: "test".to_string(),
220            args: json!({}),
221            id: Some("call-1".to_string()),
222        };
223        let resp = ToolDispatcher::build_response(
224            &call,
225            Err(ToolError::ExecutionFailed("boom".to_string())),
226        );
227        assert!(resp.response["error"].as_str().unwrap().contains("boom"));
228    }
229
230    #[test]
231    fn tool_dispatcher_implements_tool_provider() {
232        use gemini_genai_rs::prelude::ToolProvider;
233        let mut dispatcher = ToolDispatcher::new();
234        dispatcher.register_function(Arc::new(MockTool));
235        let decls = dispatcher.declarations();
236        assert_eq!(decls.len(), 1);
237    }
238
239    #[tokio::test]
240    async fn simple_tool_closure() {
241        let tool = SimpleTool::new(
242            "add",
243            "Add two numbers",
244            Some(
245                json!({"type": "object", "properties": {"a": {"type": "number"}, "b": {"type": "number"}}}),
246            ),
247            |args| async move {
248                let a = args["a"].as_f64().unwrap_or(0.0);
249                let b = args["b"].as_f64().unwrap_or(0.0);
250                Ok(json!({"sum": a + b}))
251            },
252        );
253
254        let mut dispatcher = ToolDispatcher::new();
255        dispatcher.register_function(Arc::new(tool));
256        let result = dispatcher
257            .call_function("add", json!({"a": 3, "b": 4}))
258            .await
259            .unwrap();
260        assert_eq!(result["sum"], 7.0);
261    }
262
263    // --- TypedTool tests ---
264
265    #[derive(serde::Deserialize, schemars::JsonSchema)]
266    struct WeatherArgs {
267        /// The city to get weather for
268        city: String,
269        /// Temperature units (celsius or fahrenheit)
270        #[serde(default = "default_units")]
271        units: String,
272    }
273
274    fn default_units() -> String {
275        "celsius".to_string()
276    }
277
278    #[test]
279    fn typed_tool_auto_generates_schema() {
280        let tool = TypedTool::new(
281            "get_weather",
282            "Get current weather for a city",
283            |_args: WeatherArgs| async move { Ok(json!({})) },
284        );
285
286        let params = tool.parameters().expect("should have parameters");
287
288        // The schema should be an object type with "city" and "units" properties
289        let props = &params["properties"];
290        assert!(
291            props.get("city").is_some(),
292            "schema should contain 'city' property"
293        );
294        assert!(
295            props.get("units").is_some(),
296            "schema should contain 'units' property"
297        );
298
299        // "city" should be required (no default), "units" has a default so may not be
300        let required = params["required"]
301            .as_array()
302            .expect("should have required array");
303        let required_names: Vec<&str> = required.iter().filter_map(|v| v.as_str()).collect();
304        assert!(required_names.contains(&"city"), "city should be required");
305    }
306
307    #[tokio::test]
308    async fn typed_tool_deserializes_args() {
309        let tool = TypedTool::new(
310            "get_weather",
311            "Get current weather for a city",
312            |args: WeatherArgs| async move {
313                Ok(json!({
314                    "temp": 22,
315                    "city": args.city,
316                    "units": args.units,
317                }))
318            },
319        );
320
321        let result = tool
322            .call(json!({"city": "London", "units": "fahrenheit"}))
323            .await
324            .unwrap();
325        assert_eq!(result["city"], "London");
326        assert_eq!(result["units"], "fahrenheit");
327        assert_eq!(result["temp"], 22);
328    }
329
330    #[tokio::test]
331    async fn typed_tool_invalid_args_returns_error() {
332        let tool = TypedTool::new(
333            "get_weather",
334            "Get current weather for a city",
335            |_args: WeatherArgs| async move { Ok(json!({})) },
336        );
337
338        // Missing required field "city"
339        let result = tool.call(json!({"units": "celsius"})).await;
340        assert!(result.is_err(), "should fail with missing required field");
341        let err = result.unwrap_err();
342        match &err {
343            ToolError::InvalidArgs(msg) => {
344                assert!(
345                    msg.contains("city"),
346                    "error message should mention the missing field: {msg}"
347                );
348            }
349            other => panic!("expected ToolError::InvalidArgs, got: {other:?}"),
350        }
351
352        // Wrong type for "city" (number instead of string)
353        let result = tool.call(json!({"city": 12345})).await;
354        assert!(result.is_err(), "should fail with wrong type");
355    }
356
357    #[tokio::test]
358    async fn typed_tool_registers_in_dispatcher() {
359        let tool = TypedTool::new(
360            "get_weather",
361            "Get current weather for a city",
362            |args: WeatherArgs| async move { Ok(json!({"city": args.city})) },
363        );
364
365        let mut dispatcher = ToolDispatcher::new();
366        dispatcher.register_function(Arc::new(tool));
367
368        assert_eq!(dispatcher.classify("get_weather"), Some(ToolClass::Regular));
369        assert_eq!(dispatcher.len(), 1);
370
371        let result = dispatcher
372            .call_function("get_weather", json!({"city": "Paris"}))
373            .await
374            .unwrap();
375        assert_eq!(result["city"], "Paris");
376
377        // Verify it appears in tool declarations
378        let decls = dispatcher.to_tool_declarations();
379        assert_eq!(decls.len(), 1);
380    }
381
382    // --- Timeout and cancellation tests ---
383
384    /// A tool that sleeps forever (until cancelled/timed out).
385    struct SlowTool;
386
387    #[async_trait]
388    impl ToolFunction for SlowTool {
389        fn name(&self) -> &str {
390            "slow_tool"
391        }
392        fn description(&self) -> &str {
393            "A tool that never completes"
394        }
395        fn parameters(&self) -> Option<serde_json::Value> {
396            None
397        }
398        async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
399            // Sleep effectively forever
400            tokio::time::sleep(Duration::from_secs(3600)).await;
401            Ok(json!({"result": "should never reach here"}))
402        }
403    }
404
405    #[tokio::test]
406    async fn tool_timeout_returns_error() {
407        let mut dispatcher = ToolDispatcher::new();
408        dispatcher.register_function(Arc::new(SlowTool));
409
410        let timeout = Duration::from_millis(50);
411        let result = dispatcher
412            .call_function_with_timeout("slow_tool", json!({}), timeout)
413            .await;
414
415        match result {
416            Err(ToolError::Timeout(d)) => assert_eq!(d, timeout),
417            other => panic!("expected ToolError::Timeout, got: {other:?}"),
418        }
419    }
420
421    #[tokio::test]
422    async fn tool_completes_before_timeout() {
423        let mut dispatcher = ToolDispatcher::new();
424        dispatcher.register_function(Arc::new(MockTool));
425
426        let result = dispatcher
427            .call_function_with_timeout("mock_tool", json!({}), Duration::from_secs(5))
428            .await
429            .unwrap();
430        assert_eq!(result["result"], "ok");
431    }
432
433    #[tokio::test]
434    async fn tool_cancelled_returns_error() {
435        let mut dispatcher = ToolDispatcher::new();
436        dispatcher.register_function(Arc::new(SlowTool));
437
438        let cancel = CancellationToken::new();
439        let cancel_clone = cancel.clone();
440
441        // Cancel after a short delay
442        tokio::spawn(async move {
443            tokio::time::sleep(Duration::from_millis(50)).await;
444            cancel_clone.cancel();
445        });
446
447        let result = dispatcher
448            .call_function_with_cancel("slow_tool", json!({}), cancel)
449            .await;
450
451        match result {
452            Err(ToolError::Cancelled) => {} // expected
453            other => panic!("expected ToolError::Cancelled, got: {other:?}"),
454        }
455    }
456
457    #[test]
458    fn default_timeout_is_30s() {
459        let dispatcher = ToolDispatcher::new();
460        assert_eq!(dispatcher.default_timeout(), Duration::from_secs(30));
461    }
462
463    #[test]
464    fn with_timeout_overrides_default() {
465        let dispatcher = ToolDispatcher::new().with_timeout(Duration::from_secs(10));
466        assert_eq!(dispatcher.default_timeout(), Duration::from_secs(10));
467    }
468
469    #[tokio::test]
470    async fn call_function_uses_default_timeout() {
471        // Set a very short default timeout so the slow tool times out
472        let mut dispatcher = ToolDispatcher::new().with_timeout(Duration::from_millis(50));
473        dispatcher.register_function(Arc::new(SlowTool));
474
475        let result = dispatcher.call_function("slow_tool", json!({})).await;
476
477        match result {
478            Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
479            other => panic!("expected ToolError::Timeout, got: {other:?}"),
480        }
481    }
482}