1pub mod dispatcher;
4pub mod policy;
5pub mod simple;
6pub mod typed;
7
8pub use dispatcher::*;
9pub use policy::*;
10pub use simple::*;
11pub use typed::*;
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use async_trait::async_trait;
17use tokio::sync::{broadcast, mpsc};
18use tokio::task::JoinHandle;
19use tokio_util::sync::CancellationToken;
20
21use crate::agent_session::InputEvent;
22use crate::error::ToolError;
23
24#[async_trait]
46pub trait ToolFunction: Send + Sync + 'static {
47 fn name(&self) -> &str;
49 fn description(&self) -> &str;
51 fn parameters(&self) -> Option<serde_json::Value>;
53 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError>;
55
56 fn requires_confirmation(&self) -> bool {
62 false
63 }
64
65 fn confirmation_message(&self) -> Option<&str> {
67 None
68 }
69}
70
71#[async_trait]
73pub trait StreamingTool: Send + Sync + 'static {
74 fn name(&self) -> &str;
76 fn description(&self) -> &str;
78 fn parameters(&self) -> Option<serde_json::Value>;
80 async fn run(
82 &self,
83 args: serde_json::Value,
84 yield_tx: mpsc::Sender<serde_json::Value>,
85 ) -> Result<(), ToolError>;
86}
87
88#[async_trait]
90pub trait InputStreamingTool: Send + Sync + 'static {
91 fn name(&self) -> &str;
93 fn description(&self) -> &str;
95 fn parameters(&self) -> Option<serde_json::Value>;
97 async fn run(
99 &self,
100 args: serde_json::Value,
101 input_rx: broadcast::Receiver<InputEvent>,
102 yield_tx: mpsc::Sender<serde_json::Value>,
103 ) -> Result<(), ToolError>;
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
108pub enum ToolClass {
109 Regular,
111 Streaming,
113 InputStream,
115}
116
117pub enum ToolKind {
119 Function(Arc<dyn ToolFunction>),
121 Streaming(Arc<dyn StreamingTool>),
123 InputStream(Arc<dyn InputStreamingTool>),
125}
126
127pub struct ActiveStreamingTool {
129 pub task: JoinHandle<()>,
131 pub cancel: CancellationToken,
133}
134
135pub(crate) const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(30);
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use gemini_genai_rs::prelude::FunctionCall;
142 use serde_json::json;
143
144 struct MockTool;
145
146 #[async_trait]
147 impl ToolFunction for MockTool {
148 fn name(&self) -> &str {
149 "mock_tool"
150 }
151 fn description(&self) -> &str {
152 "A mock tool"
153 }
154 fn parameters(&self) -> Option<serde_json::Value> {
155 None
156 }
157 async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
158 Ok(json!({"result": "ok"}))
159 }
160 }
161
162 #[tokio::test]
163 async fn register_and_call_function_tool() {
164 let mut dispatcher = ToolDispatcher::new();
165 dispatcher.register_function(Arc::new(MockTool));
166 let result = dispatcher
167 .call_function("mock_tool", json!({}))
168 .await
169 .unwrap();
170 assert_eq!(result["result"], "ok");
171 }
172
173 #[tokio::test]
174 async fn call_unknown_tool_returns_error() {
175 let dispatcher = ToolDispatcher::new();
176 let result = dispatcher.call_function("nonexistent", json!({})).await;
177 assert!(result.is_err());
178 }
179
180 #[test]
181 fn to_tool_declarations() {
182 let mut dispatcher = ToolDispatcher::new();
183 dispatcher.register_function(Arc::new(MockTool));
184 let decls = dispatcher.to_tool_declarations();
185 assert_eq!(decls.len(), 1);
186 }
187
188 #[test]
189 fn classify_tool() {
190 let mut dispatcher = ToolDispatcher::new();
191 dispatcher.register_function(Arc::new(MockTool));
192 assert_eq!(dispatcher.classify("mock_tool"), Some(ToolClass::Regular));
193 assert_eq!(dispatcher.classify("nonexistent"), None);
194 }
195
196 #[test]
197 fn empty_dispatcher() {
198 let dispatcher = ToolDispatcher::new();
199 assert!(dispatcher.is_empty());
200 assert_eq!(dispatcher.len(), 0);
201 assert!(dispatcher.to_tool_declarations().is_empty());
202 }
203
204 #[test]
205 fn build_response_success() {
206 let call = FunctionCall {
207 name: "test".to_string(),
208 args: json!({}),
209 id: Some("call-1".to_string()),
210 };
211 let resp = ToolDispatcher::build_response(&call, Ok(json!({"ok": true})));
212 assert_eq!(resp.name, "test");
213 assert_eq!(resp.response["ok"], true);
214 }
215
216 #[test]
217 fn build_response_error() {
218 let call = FunctionCall {
219 name: "test".to_string(),
220 args: json!({}),
221 id: Some("call-1".to_string()),
222 };
223 let resp = ToolDispatcher::build_response(
224 &call,
225 Err(ToolError::ExecutionFailed("boom".to_string())),
226 );
227 assert!(resp.response["error"].as_str().unwrap().contains("boom"));
228 }
229
230 #[test]
231 fn tool_dispatcher_implements_tool_provider() {
232 use gemini_genai_rs::prelude::ToolProvider;
233 let mut dispatcher = ToolDispatcher::new();
234 dispatcher.register_function(Arc::new(MockTool));
235 let decls = dispatcher.declarations();
236 assert_eq!(decls.len(), 1);
237 }
238
239 #[tokio::test]
240 async fn simple_tool_closure() {
241 let tool = SimpleTool::new(
242 "add",
243 "Add two numbers",
244 Some(
245 json!({"type": "object", "properties": {"a": {"type": "number"}, "b": {"type": "number"}}}),
246 ),
247 |args| async move {
248 let a = args["a"].as_f64().unwrap_or(0.0);
249 let b = args["b"].as_f64().unwrap_or(0.0);
250 Ok(json!({"sum": a + b}))
251 },
252 );
253
254 let mut dispatcher = ToolDispatcher::new();
255 dispatcher.register_function(Arc::new(tool));
256 let result = dispatcher
257 .call_function("add", json!({"a": 3, "b": 4}))
258 .await
259 .unwrap();
260 assert_eq!(result["sum"], 7.0);
261 }
262
263 #[derive(serde::Deserialize, schemars::JsonSchema)]
266 struct WeatherArgs {
267 city: String,
269 #[serde(default = "default_units")]
271 units: String,
272 }
273
274 fn default_units() -> String {
275 "celsius".to_string()
276 }
277
278 #[test]
279 fn typed_tool_auto_generates_schema() {
280 let tool = TypedTool::new(
281 "get_weather",
282 "Get current weather for a city",
283 |_args: WeatherArgs| async move { Ok(json!({})) },
284 );
285
286 let params = tool.parameters().expect("should have parameters");
287
288 let props = ¶ms["properties"];
290 assert!(
291 props.get("city").is_some(),
292 "schema should contain 'city' property"
293 );
294 assert!(
295 props.get("units").is_some(),
296 "schema should contain 'units' property"
297 );
298
299 let required = params["required"]
301 .as_array()
302 .expect("should have required array");
303 let required_names: Vec<&str> = required.iter().filter_map(|v| v.as_str()).collect();
304 assert!(required_names.contains(&"city"), "city should be required");
305 }
306
307 #[tokio::test]
308 async fn typed_tool_deserializes_args() {
309 let tool = TypedTool::new(
310 "get_weather",
311 "Get current weather for a city",
312 |args: WeatherArgs| async move {
313 Ok(json!({
314 "temp": 22,
315 "city": args.city,
316 "units": args.units,
317 }))
318 },
319 );
320
321 let result = tool
322 .call(json!({"city": "London", "units": "fahrenheit"}))
323 .await
324 .unwrap();
325 assert_eq!(result["city"], "London");
326 assert_eq!(result["units"], "fahrenheit");
327 assert_eq!(result["temp"], 22);
328 }
329
330 #[tokio::test]
331 async fn typed_tool_invalid_args_returns_error() {
332 let tool = TypedTool::new(
333 "get_weather",
334 "Get current weather for a city",
335 |_args: WeatherArgs| async move { Ok(json!({})) },
336 );
337
338 let result = tool.call(json!({"units": "celsius"})).await;
340 assert!(result.is_err(), "should fail with missing required field");
341 let err = result.unwrap_err();
342 match &err {
343 ToolError::InvalidArgs(msg) => {
344 assert!(
345 msg.contains("city"),
346 "error message should mention the missing field: {msg}"
347 );
348 }
349 other => panic!("expected ToolError::InvalidArgs, got: {other:?}"),
350 }
351
352 let result = tool.call(json!({"city": 12345})).await;
354 assert!(result.is_err(), "should fail with wrong type");
355 }
356
357 #[tokio::test]
358 async fn typed_tool_registers_in_dispatcher() {
359 let tool = TypedTool::new(
360 "get_weather",
361 "Get current weather for a city",
362 |args: WeatherArgs| async move { Ok(json!({"city": args.city})) },
363 );
364
365 let mut dispatcher = ToolDispatcher::new();
366 dispatcher.register_function(Arc::new(tool));
367
368 assert_eq!(dispatcher.classify("get_weather"), Some(ToolClass::Regular));
369 assert_eq!(dispatcher.len(), 1);
370
371 let result = dispatcher
372 .call_function("get_weather", json!({"city": "Paris"}))
373 .await
374 .unwrap();
375 assert_eq!(result["city"], "Paris");
376
377 let decls = dispatcher.to_tool_declarations();
379 assert_eq!(decls.len(), 1);
380 }
381
382 struct SlowTool;
386
387 #[async_trait]
388 impl ToolFunction for SlowTool {
389 fn name(&self) -> &str {
390 "slow_tool"
391 }
392 fn description(&self) -> &str {
393 "A tool that never completes"
394 }
395 fn parameters(&self) -> Option<serde_json::Value> {
396 None
397 }
398 async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
399 tokio::time::sleep(Duration::from_secs(3600)).await;
401 Ok(json!({"result": "should never reach here"}))
402 }
403 }
404
405 #[tokio::test]
406 async fn tool_timeout_returns_error() {
407 let mut dispatcher = ToolDispatcher::new();
408 dispatcher.register_function(Arc::new(SlowTool));
409
410 let timeout = Duration::from_millis(50);
411 let result = dispatcher
412 .call_function_with_timeout("slow_tool", json!({}), timeout)
413 .await;
414
415 match result {
416 Err(ToolError::Timeout(d)) => assert_eq!(d, timeout),
417 other => panic!("expected ToolError::Timeout, got: {other:?}"),
418 }
419 }
420
421 #[tokio::test]
422 async fn tool_completes_before_timeout() {
423 let mut dispatcher = ToolDispatcher::new();
424 dispatcher.register_function(Arc::new(MockTool));
425
426 let result = dispatcher
427 .call_function_with_timeout("mock_tool", json!({}), Duration::from_secs(5))
428 .await
429 .unwrap();
430 assert_eq!(result["result"], "ok");
431 }
432
433 #[tokio::test]
434 async fn tool_cancelled_returns_error() {
435 let mut dispatcher = ToolDispatcher::new();
436 dispatcher.register_function(Arc::new(SlowTool));
437
438 let cancel = CancellationToken::new();
439 let cancel_clone = cancel.clone();
440
441 tokio::spawn(async move {
443 tokio::time::sleep(Duration::from_millis(50)).await;
444 cancel_clone.cancel();
445 });
446
447 let result = dispatcher
448 .call_function_with_cancel("slow_tool", json!({}), cancel)
449 .await;
450
451 match result {
452 Err(ToolError::Cancelled) => {} other => panic!("expected ToolError::Cancelled, got: {other:?}"),
454 }
455 }
456
457 #[test]
458 fn default_timeout_is_30s() {
459 let dispatcher = ToolDispatcher::new();
460 assert_eq!(dispatcher.default_timeout(), Duration::from_secs(30));
461 }
462
463 #[test]
464 fn with_timeout_overrides_default() {
465 let dispatcher = ToolDispatcher::new().with_timeout(Duration::from_secs(10));
466 assert_eq!(dispatcher.default_timeout(), Duration::from_secs(10));
467 }
468
469 #[tokio::test]
470 async fn call_function_uses_default_timeout() {
471 let mut dispatcher = ToolDispatcher::new().with_timeout(Duration::from_millis(50));
473 dispatcher.register_function(Arc::new(SlowTool));
474
475 let result = dispatcher.call_function("slow_tool", json!({})).await;
476
477 match result {
478 Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
479 other => panic!("expected ToolError::Timeout, got: {other:?}"),
480 }
481 }
482}