1pub mod dispatcher;
4pub mod simple;
5pub mod typed;
6
7pub use dispatcher::*;
8pub use simple::*;
9pub use typed::*;
10
11use std::sync::Arc;
12use std::time::Duration;
13
14use async_trait::async_trait;
15use tokio::sync::{broadcast, mpsc};
16use tokio::task::JoinHandle;
17use tokio_util::sync::CancellationToken;
18
19use crate::agent_session::InputEvent;
20use crate::error::ToolError;
21
22#[async_trait]
44pub trait ToolFunction: Send + Sync + 'static {
45 fn name(&self) -> &str;
47 fn description(&self) -> &str;
49 fn parameters(&self) -> Option<serde_json::Value>;
51 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError>;
53}
54
55#[async_trait]
57pub trait StreamingTool: Send + Sync + 'static {
58 fn name(&self) -> &str;
60 fn description(&self) -> &str;
62 fn parameters(&self) -> Option<serde_json::Value>;
64 async fn run(
66 &self,
67 args: serde_json::Value,
68 yield_tx: mpsc::Sender<serde_json::Value>,
69 ) -> Result<(), ToolError>;
70}
71
72#[async_trait]
74pub trait InputStreamingTool: Send + Sync + 'static {
75 fn name(&self) -> &str;
77 fn description(&self) -> &str;
79 fn parameters(&self) -> Option<serde_json::Value>;
81 async fn run(
83 &self,
84 args: serde_json::Value,
85 input_rx: broadcast::Receiver<InputEvent>,
86 yield_tx: mpsc::Sender<serde_json::Value>,
87 ) -> Result<(), ToolError>;
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum ToolClass {
93 Regular,
95 Streaming,
97 InputStream,
99}
100
101pub enum ToolKind {
103 Function(Arc<dyn ToolFunction>),
105 Streaming(Arc<dyn StreamingTool>),
107 InputStream(Arc<dyn InputStreamingTool>),
109}
110
111pub struct ActiveStreamingTool {
113 pub task: JoinHandle<()>,
115 pub cancel: CancellationToken,
117}
118
119pub(crate) const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(30);
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use gemini_genai_rs::prelude::FunctionCall;
126 use serde_json::json;
127
128 struct MockTool;
129
130 #[async_trait]
131 impl ToolFunction for MockTool {
132 fn name(&self) -> &str {
133 "mock_tool"
134 }
135 fn description(&self) -> &str {
136 "A mock tool"
137 }
138 fn parameters(&self) -> Option<serde_json::Value> {
139 None
140 }
141 async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
142 Ok(json!({"result": "ok"}))
143 }
144 }
145
146 #[tokio::test]
147 async fn register_and_call_function_tool() {
148 let mut dispatcher = ToolDispatcher::new();
149 dispatcher.register_function(Arc::new(MockTool));
150 let result = dispatcher
151 .call_function("mock_tool", json!({}))
152 .await
153 .unwrap();
154 assert_eq!(result["result"], "ok");
155 }
156
157 #[tokio::test]
158 async fn call_unknown_tool_returns_error() {
159 let dispatcher = ToolDispatcher::new();
160 let result = dispatcher.call_function("nonexistent", json!({})).await;
161 assert!(result.is_err());
162 }
163
164 #[test]
165 fn to_tool_declarations() {
166 let mut dispatcher = ToolDispatcher::new();
167 dispatcher.register_function(Arc::new(MockTool));
168 let decls = dispatcher.to_tool_declarations();
169 assert_eq!(decls.len(), 1);
170 }
171
172 #[test]
173 fn classify_tool() {
174 let mut dispatcher = ToolDispatcher::new();
175 dispatcher.register_function(Arc::new(MockTool));
176 assert_eq!(dispatcher.classify("mock_tool"), Some(ToolClass::Regular));
177 assert_eq!(dispatcher.classify("nonexistent"), None);
178 }
179
180 #[test]
181 fn empty_dispatcher() {
182 let dispatcher = ToolDispatcher::new();
183 assert!(dispatcher.is_empty());
184 assert_eq!(dispatcher.len(), 0);
185 assert!(dispatcher.to_tool_declarations().is_empty());
186 }
187
188 #[test]
189 fn build_response_success() {
190 let call = FunctionCall {
191 name: "test".to_string(),
192 args: json!({}),
193 id: Some("call-1".to_string()),
194 };
195 let resp = ToolDispatcher::build_response(&call, Ok(json!({"ok": true})));
196 assert_eq!(resp.name, "test");
197 assert_eq!(resp.response["ok"], true);
198 }
199
200 #[test]
201 fn build_response_error() {
202 let call = FunctionCall {
203 name: "test".to_string(),
204 args: json!({}),
205 id: Some("call-1".to_string()),
206 };
207 let resp = ToolDispatcher::build_response(
208 &call,
209 Err(ToolError::ExecutionFailed("boom".to_string())),
210 );
211 assert!(resp.response["error"].as_str().unwrap().contains("boom"));
212 }
213
214 #[test]
215 fn tool_dispatcher_implements_tool_provider() {
216 use gemini_genai_rs::prelude::ToolProvider;
217 let mut dispatcher = ToolDispatcher::new();
218 dispatcher.register_function(Arc::new(MockTool));
219 let decls = dispatcher.declarations();
220 assert_eq!(decls.len(), 1);
221 }
222
223 #[tokio::test]
224 async fn simple_tool_closure() {
225 let tool = SimpleTool::new(
226 "add",
227 "Add two numbers",
228 Some(
229 json!({"type": "object", "properties": {"a": {"type": "number"}, "b": {"type": "number"}}}),
230 ),
231 |args| async move {
232 let a = args["a"].as_f64().unwrap_or(0.0);
233 let b = args["b"].as_f64().unwrap_or(0.0);
234 Ok(json!({"sum": a + b}))
235 },
236 );
237
238 let mut dispatcher = ToolDispatcher::new();
239 dispatcher.register_function(Arc::new(tool));
240 let result = dispatcher
241 .call_function("add", json!({"a": 3, "b": 4}))
242 .await
243 .unwrap();
244 assert_eq!(result["sum"], 7.0);
245 }
246
247 #[derive(serde::Deserialize, schemars::JsonSchema)]
250 struct WeatherArgs {
251 city: String,
253 #[serde(default = "default_units")]
255 units: String,
256 }
257
258 fn default_units() -> String {
259 "celsius".to_string()
260 }
261
262 #[test]
263 fn typed_tool_auto_generates_schema() {
264 let tool = TypedTool::new(
265 "get_weather",
266 "Get current weather for a city",
267 |_args: WeatherArgs| async move { Ok(json!({})) },
268 );
269
270 let params = tool.parameters().expect("should have parameters");
271
272 let props = ¶ms["properties"];
274 assert!(
275 props.get("city").is_some(),
276 "schema should contain 'city' property"
277 );
278 assert!(
279 props.get("units").is_some(),
280 "schema should contain 'units' property"
281 );
282
283 let required = params["required"]
285 .as_array()
286 .expect("should have required array");
287 let required_names: Vec<&str> = required.iter().filter_map(|v| v.as_str()).collect();
288 assert!(required_names.contains(&"city"), "city should be required");
289 }
290
291 #[tokio::test]
292 async fn typed_tool_deserializes_args() {
293 let tool = TypedTool::new(
294 "get_weather",
295 "Get current weather for a city",
296 |args: WeatherArgs| async move {
297 Ok(json!({
298 "temp": 22,
299 "city": args.city,
300 "units": args.units,
301 }))
302 },
303 );
304
305 let result = tool
306 .call(json!({"city": "London", "units": "fahrenheit"}))
307 .await
308 .unwrap();
309 assert_eq!(result["city"], "London");
310 assert_eq!(result["units"], "fahrenheit");
311 assert_eq!(result["temp"], 22);
312 }
313
314 #[tokio::test]
315 async fn typed_tool_invalid_args_returns_error() {
316 let tool = TypedTool::new(
317 "get_weather",
318 "Get current weather for a city",
319 |_args: WeatherArgs| async move { Ok(json!({})) },
320 );
321
322 let result = tool.call(json!({"units": "celsius"})).await;
324 assert!(result.is_err(), "should fail with missing required field");
325 let err = result.unwrap_err();
326 match &err {
327 ToolError::InvalidArgs(msg) => {
328 assert!(
329 msg.contains("city"),
330 "error message should mention the missing field: {msg}"
331 );
332 }
333 other => panic!("expected ToolError::InvalidArgs, got: {other:?}"),
334 }
335
336 let result = tool.call(json!({"city": 12345})).await;
338 assert!(result.is_err(), "should fail with wrong type");
339 }
340
341 #[tokio::test]
342 async fn typed_tool_registers_in_dispatcher() {
343 let tool = TypedTool::new(
344 "get_weather",
345 "Get current weather for a city",
346 |args: WeatherArgs| async move { Ok(json!({"city": args.city})) },
347 );
348
349 let mut dispatcher = ToolDispatcher::new();
350 dispatcher.register_function(Arc::new(tool));
351
352 assert_eq!(dispatcher.classify("get_weather"), Some(ToolClass::Regular));
353 assert_eq!(dispatcher.len(), 1);
354
355 let result = dispatcher
356 .call_function("get_weather", json!({"city": "Paris"}))
357 .await
358 .unwrap();
359 assert_eq!(result["city"], "Paris");
360
361 let decls = dispatcher.to_tool_declarations();
363 assert_eq!(decls.len(), 1);
364 }
365
366 struct SlowTool;
370
371 #[async_trait]
372 impl ToolFunction for SlowTool {
373 fn name(&self) -> &str {
374 "slow_tool"
375 }
376 fn description(&self) -> &str {
377 "A tool that never completes"
378 }
379 fn parameters(&self) -> Option<serde_json::Value> {
380 None
381 }
382 async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
383 tokio::time::sleep(Duration::from_secs(3600)).await;
385 Ok(json!({"result": "should never reach here"}))
386 }
387 }
388
389 #[tokio::test]
390 async fn tool_timeout_returns_error() {
391 let mut dispatcher = ToolDispatcher::new();
392 dispatcher.register_function(Arc::new(SlowTool));
393
394 let timeout = Duration::from_millis(50);
395 let result = dispatcher
396 .call_function_with_timeout("slow_tool", json!({}), timeout)
397 .await;
398
399 match result {
400 Err(ToolError::Timeout(d)) => assert_eq!(d, timeout),
401 other => panic!("expected ToolError::Timeout, got: {other:?}"),
402 }
403 }
404
405 #[tokio::test]
406 async fn tool_completes_before_timeout() {
407 let mut dispatcher = ToolDispatcher::new();
408 dispatcher.register_function(Arc::new(MockTool));
409
410 let result = dispatcher
411 .call_function_with_timeout("mock_tool", json!({}), Duration::from_secs(5))
412 .await
413 .unwrap();
414 assert_eq!(result["result"], "ok");
415 }
416
417 #[tokio::test]
418 async fn tool_cancelled_returns_error() {
419 let mut dispatcher = ToolDispatcher::new();
420 dispatcher.register_function(Arc::new(SlowTool));
421
422 let cancel = CancellationToken::new();
423 let cancel_clone = cancel.clone();
424
425 tokio::spawn(async move {
427 tokio::time::sleep(Duration::from_millis(50)).await;
428 cancel_clone.cancel();
429 });
430
431 let result = dispatcher
432 .call_function_with_cancel("slow_tool", json!({}), cancel)
433 .await;
434
435 match result {
436 Err(ToolError::Cancelled) => {} other => panic!("expected ToolError::Cancelled, got: {other:?}"),
438 }
439 }
440
441 #[test]
442 fn default_timeout_is_30s() {
443 let dispatcher = ToolDispatcher::new();
444 assert_eq!(dispatcher.default_timeout(), Duration::from_secs(30));
445 }
446
447 #[test]
448 fn with_timeout_overrides_default() {
449 let dispatcher = ToolDispatcher::new().with_timeout(Duration::from_secs(10));
450 assert_eq!(dispatcher.default_timeout(), Duration::from_secs(10));
451 }
452
453 #[tokio::test]
454 async fn call_function_uses_default_timeout() {
455 let mut dispatcher = ToolDispatcher::new().with_timeout(Duration::from_millis(50));
457 dispatcher.register_function(Arc::new(SlowTool));
458
459 let result = dispatcher.call_function("slow_tool", json!({})).await;
460
461 match result {
462 Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
463 other => panic!("expected ToolError::Timeout, got: {other:?}"),
464 }
465 }
466}