gemini_adk_rs/tool/
mod.rs

1//! Tool dispatch — regular, streaming, and input-streaming tools.
2
3pub mod dispatcher;
4pub mod simple;
5pub mod typed;
6
7pub use dispatcher::*;
8pub use simple::*;
9pub use typed::*;
10
11use std::sync::Arc;
12use std::time::Duration;
13
14use async_trait::async_trait;
15use tokio::sync::{broadcast, mpsc};
16use tokio::task::JoinHandle;
17use tokio_util::sync::CancellationToken;
18
19use crate::agent_session::InputEvent;
20use crate::error::ToolError;
21
22/// A regular tool — called once, returns a result.
23///
24/// # Examples
25///
26/// ```rust,ignore
27/// use async_trait::async_trait;
28/// use gemini_adk_rs::tool::ToolFunction;
29/// use gemini_adk_rs::error::ToolError;
30///
31/// struct MyTool;
32///
33/// #[async_trait]
34/// impl ToolFunction for MyTool {
35///     fn name(&self) -> &str { "my_tool" }
36///     fn description(&self) -> &str { "Does something useful" }
37///     fn parameters(&self) -> Option<serde_json::Value> { None }
38///     async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
39///         Ok(serde_json::json!({"status": "ok"}))
40///     }
41/// }
42/// ```
43#[async_trait]
44pub trait ToolFunction: Send + Sync + 'static {
45    /// The unique name of this tool.
46    fn name(&self) -> &str;
47    /// Human-readable description of what this tool does.
48    fn description(&self) -> &str;
49    /// JSON Schema for the tool's input parameters, or `None` if parameterless.
50    fn parameters(&self) -> Option<serde_json::Value>;
51    /// Execute the tool with the given arguments and return the result.
52    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError>;
53}
54
55/// A streaming tool — runs in background, yields multiple results.
56#[async_trait]
57pub trait StreamingTool: Send + Sync + 'static {
58    /// The unique name of this tool.
59    fn name(&self) -> &str;
60    /// Human-readable description of what this tool does.
61    fn description(&self) -> &str;
62    /// JSON Schema for the tool's input parameters, or `None` if parameterless.
63    fn parameters(&self) -> Option<serde_json::Value>;
64    /// Execute the tool, sending intermediate results via `yield_tx`.
65    async fn run(
66        &self,
67        args: serde_json::Value,
68        yield_tx: mpsc::Sender<serde_json::Value>,
69    ) -> Result<(), ToolError>;
70}
71
72/// An input-streaming tool — receives duplicated live input while running.
73#[async_trait]
74pub trait InputStreamingTool: Send + Sync + 'static {
75    /// The unique name of this tool.
76    fn name(&self) -> &str;
77    /// Human-readable description of what this tool does.
78    fn description(&self) -> &str;
79    /// JSON Schema for the tool's input parameters, or `None` if parameterless.
80    fn parameters(&self) -> Option<serde_json::Value>;
81    /// Execute the tool, receiving live input via `input_rx` and sending results via `yield_tx`.
82    async fn run(
83        &self,
84        args: serde_json::Value,
85        input_rx: broadcast::Receiver<InputEvent>,
86        yield_tx: mpsc::Sender<serde_json::Value>,
87    ) -> Result<(), ToolError>;
88}
89
90/// Classification of a registered tool.
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum ToolClass {
93    /// A one-shot tool that returns a single result.
94    Regular,
95    /// A tool that yields multiple results over time.
96    Streaming,
97    /// A tool that receives live input while producing results.
98    InputStream,
99}
100
101/// Unified tool storage.
102pub enum ToolKind {
103    /// A regular one-shot function tool.
104    Function(Arc<dyn ToolFunction>),
105    /// A streaming tool that yields multiple results.
106    Streaming(Arc<dyn StreamingTool>),
107    /// An input-streaming tool that receives live input.
108    InputStream(Arc<dyn InputStreamingTool>),
109}
110
111/// Handle to a running streaming tool.
112pub struct ActiveStreamingTool {
113    /// The spawned task handle.
114    pub task: JoinHandle<()>,
115    /// Token to cancel this streaming tool.
116    pub cancel: CancellationToken,
117}
118
119/// Default timeout for tool execution (30 seconds).
120pub(crate) const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(30);
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use gemini_genai_rs::prelude::FunctionCall;
126    use serde_json::json;
127
128    struct MockTool;
129
130    #[async_trait]
131    impl ToolFunction for MockTool {
132        fn name(&self) -> &str {
133            "mock_tool"
134        }
135        fn description(&self) -> &str {
136            "A mock tool"
137        }
138        fn parameters(&self) -> Option<serde_json::Value> {
139            None
140        }
141        async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
142            Ok(json!({"result": "ok"}))
143        }
144    }
145
146    #[tokio::test]
147    async fn register_and_call_function_tool() {
148        let mut dispatcher = ToolDispatcher::new();
149        dispatcher.register_function(Arc::new(MockTool));
150        let result = dispatcher
151            .call_function("mock_tool", json!({}))
152            .await
153            .unwrap();
154        assert_eq!(result["result"], "ok");
155    }
156
157    #[tokio::test]
158    async fn call_unknown_tool_returns_error() {
159        let dispatcher = ToolDispatcher::new();
160        let result = dispatcher.call_function("nonexistent", json!({})).await;
161        assert!(result.is_err());
162    }
163
164    #[test]
165    fn to_tool_declarations() {
166        let mut dispatcher = ToolDispatcher::new();
167        dispatcher.register_function(Arc::new(MockTool));
168        let decls = dispatcher.to_tool_declarations();
169        assert_eq!(decls.len(), 1);
170    }
171
172    #[test]
173    fn classify_tool() {
174        let mut dispatcher = ToolDispatcher::new();
175        dispatcher.register_function(Arc::new(MockTool));
176        assert_eq!(dispatcher.classify("mock_tool"), Some(ToolClass::Regular));
177        assert_eq!(dispatcher.classify("nonexistent"), None);
178    }
179
180    #[test]
181    fn empty_dispatcher() {
182        let dispatcher = ToolDispatcher::new();
183        assert!(dispatcher.is_empty());
184        assert_eq!(dispatcher.len(), 0);
185        assert!(dispatcher.to_tool_declarations().is_empty());
186    }
187
188    #[test]
189    fn build_response_success() {
190        let call = FunctionCall {
191            name: "test".to_string(),
192            args: json!({}),
193            id: Some("call-1".to_string()),
194        };
195        let resp = ToolDispatcher::build_response(&call, Ok(json!({"ok": true})));
196        assert_eq!(resp.name, "test");
197        assert_eq!(resp.response["ok"], true);
198    }
199
200    #[test]
201    fn build_response_error() {
202        let call = FunctionCall {
203            name: "test".to_string(),
204            args: json!({}),
205            id: Some("call-1".to_string()),
206        };
207        let resp = ToolDispatcher::build_response(
208            &call,
209            Err(ToolError::ExecutionFailed("boom".to_string())),
210        );
211        assert!(resp.response["error"].as_str().unwrap().contains("boom"));
212    }
213
214    #[test]
215    fn tool_dispatcher_implements_tool_provider() {
216        use gemini_genai_rs::prelude::ToolProvider;
217        let mut dispatcher = ToolDispatcher::new();
218        dispatcher.register_function(Arc::new(MockTool));
219        let decls = dispatcher.declarations();
220        assert_eq!(decls.len(), 1);
221    }
222
223    #[tokio::test]
224    async fn simple_tool_closure() {
225        let tool = SimpleTool::new(
226            "add",
227            "Add two numbers",
228            Some(
229                json!({"type": "object", "properties": {"a": {"type": "number"}, "b": {"type": "number"}}}),
230            ),
231            |args| async move {
232                let a = args["a"].as_f64().unwrap_or(0.0);
233                let b = args["b"].as_f64().unwrap_or(0.0);
234                Ok(json!({"sum": a + b}))
235            },
236        );
237
238        let mut dispatcher = ToolDispatcher::new();
239        dispatcher.register_function(Arc::new(tool));
240        let result = dispatcher
241            .call_function("add", json!({"a": 3, "b": 4}))
242            .await
243            .unwrap();
244        assert_eq!(result["sum"], 7.0);
245    }
246
247    // --- TypedTool tests ---
248
249    #[derive(serde::Deserialize, schemars::JsonSchema)]
250    struct WeatherArgs {
251        /// The city to get weather for
252        city: String,
253        /// Temperature units (celsius or fahrenheit)
254        #[serde(default = "default_units")]
255        units: String,
256    }
257
258    fn default_units() -> String {
259        "celsius".to_string()
260    }
261
262    #[test]
263    fn typed_tool_auto_generates_schema() {
264        let tool = TypedTool::new(
265            "get_weather",
266            "Get current weather for a city",
267            |_args: WeatherArgs| async move { Ok(json!({})) },
268        );
269
270        let params = tool.parameters().expect("should have parameters");
271
272        // The schema should be an object type with "city" and "units" properties
273        let props = &params["properties"];
274        assert!(
275            props.get("city").is_some(),
276            "schema should contain 'city' property"
277        );
278        assert!(
279            props.get("units").is_some(),
280            "schema should contain 'units' property"
281        );
282
283        // "city" should be required (no default), "units" has a default so may not be
284        let required = params["required"]
285            .as_array()
286            .expect("should have required array");
287        let required_names: Vec<&str> = required.iter().filter_map(|v| v.as_str()).collect();
288        assert!(required_names.contains(&"city"), "city should be required");
289    }
290
291    #[tokio::test]
292    async fn typed_tool_deserializes_args() {
293        let tool = TypedTool::new(
294            "get_weather",
295            "Get current weather for a city",
296            |args: WeatherArgs| async move {
297                Ok(json!({
298                    "temp": 22,
299                    "city": args.city,
300                    "units": args.units,
301                }))
302            },
303        );
304
305        let result = tool
306            .call(json!({"city": "London", "units": "fahrenheit"}))
307            .await
308            .unwrap();
309        assert_eq!(result["city"], "London");
310        assert_eq!(result["units"], "fahrenheit");
311        assert_eq!(result["temp"], 22);
312    }
313
314    #[tokio::test]
315    async fn typed_tool_invalid_args_returns_error() {
316        let tool = TypedTool::new(
317            "get_weather",
318            "Get current weather for a city",
319            |_args: WeatherArgs| async move { Ok(json!({})) },
320        );
321
322        // Missing required field "city"
323        let result = tool.call(json!({"units": "celsius"})).await;
324        assert!(result.is_err(), "should fail with missing required field");
325        let err = result.unwrap_err();
326        match &err {
327            ToolError::InvalidArgs(msg) => {
328                assert!(
329                    msg.contains("city"),
330                    "error message should mention the missing field: {msg}"
331                );
332            }
333            other => panic!("expected ToolError::InvalidArgs, got: {other:?}"),
334        }
335
336        // Wrong type for "city" (number instead of string)
337        let result = tool.call(json!({"city": 12345})).await;
338        assert!(result.is_err(), "should fail with wrong type");
339    }
340
341    #[tokio::test]
342    async fn typed_tool_registers_in_dispatcher() {
343        let tool = TypedTool::new(
344            "get_weather",
345            "Get current weather for a city",
346            |args: WeatherArgs| async move { Ok(json!({"city": args.city})) },
347        );
348
349        let mut dispatcher = ToolDispatcher::new();
350        dispatcher.register_function(Arc::new(tool));
351
352        assert_eq!(dispatcher.classify("get_weather"), Some(ToolClass::Regular));
353        assert_eq!(dispatcher.len(), 1);
354
355        let result = dispatcher
356            .call_function("get_weather", json!({"city": "Paris"}))
357            .await
358            .unwrap();
359        assert_eq!(result["city"], "Paris");
360
361        // Verify it appears in tool declarations
362        let decls = dispatcher.to_tool_declarations();
363        assert_eq!(decls.len(), 1);
364    }
365
366    // --- Timeout and cancellation tests ---
367
368    /// A tool that sleeps forever (until cancelled/timed out).
369    struct SlowTool;
370
371    #[async_trait]
372    impl ToolFunction for SlowTool {
373        fn name(&self) -> &str {
374            "slow_tool"
375        }
376        fn description(&self) -> &str {
377            "A tool that never completes"
378        }
379        fn parameters(&self) -> Option<serde_json::Value> {
380            None
381        }
382        async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
383            // Sleep effectively forever
384            tokio::time::sleep(Duration::from_secs(3600)).await;
385            Ok(json!({"result": "should never reach here"}))
386        }
387    }
388
389    #[tokio::test]
390    async fn tool_timeout_returns_error() {
391        let mut dispatcher = ToolDispatcher::new();
392        dispatcher.register_function(Arc::new(SlowTool));
393
394        let timeout = Duration::from_millis(50);
395        let result = dispatcher
396            .call_function_with_timeout("slow_tool", json!({}), timeout)
397            .await;
398
399        match result {
400            Err(ToolError::Timeout(d)) => assert_eq!(d, timeout),
401            other => panic!("expected ToolError::Timeout, got: {other:?}"),
402        }
403    }
404
405    #[tokio::test]
406    async fn tool_completes_before_timeout() {
407        let mut dispatcher = ToolDispatcher::new();
408        dispatcher.register_function(Arc::new(MockTool));
409
410        let result = dispatcher
411            .call_function_with_timeout("mock_tool", json!({}), Duration::from_secs(5))
412            .await
413            .unwrap();
414        assert_eq!(result["result"], "ok");
415    }
416
417    #[tokio::test]
418    async fn tool_cancelled_returns_error() {
419        let mut dispatcher = ToolDispatcher::new();
420        dispatcher.register_function(Arc::new(SlowTool));
421
422        let cancel = CancellationToken::new();
423        let cancel_clone = cancel.clone();
424
425        // Cancel after a short delay
426        tokio::spawn(async move {
427            tokio::time::sleep(Duration::from_millis(50)).await;
428            cancel_clone.cancel();
429        });
430
431        let result = dispatcher
432            .call_function_with_cancel("slow_tool", json!({}), cancel)
433            .await;
434
435        match result {
436            Err(ToolError::Cancelled) => {} // expected
437            other => panic!("expected ToolError::Cancelled, got: {other:?}"),
438        }
439    }
440
441    #[test]
442    fn default_timeout_is_30s() {
443        let dispatcher = ToolDispatcher::new();
444        assert_eq!(dispatcher.default_timeout(), Duration::from_secs(30));
445    }
446
447    #[test]
448    fn with_timeout_overrides_default() {
449        let dispatcher = ToolDispatcher::new().with_timeout(Duration::from_secs(10));
450        assert_eq!(dispatcher.default_timeout(), Duration::from_secs(10));
451    }
452
453    #[tokio::test]
454    async fn call_function_uses_default_timeout() {
455        // Set a very short default timeout so the slow tool times out
456        let mut dispatcher = ToolDispatcher::new().with_timeout(Duration::from_millis(50));
457        dispatcher.register_function(Arc::new(SlowTool));
458
459        let result = dispatcher.call_function("slow_tool", json!({})).await;
460
461        match result {
462            Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
463            other => panic!("expected ToolError::Timeout, got: {other:?}"),
464        }
465    }
466}