1use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use gemini_adk_rs::text::TextAgent;
10use gemini_adk_rs::tool::{PolicyTool, SimpleTool, ToolFunction, ToolPolicy};
11use gemini_genai_rs::prelude::{FunctionDeclaration, Tool};
12
13#[derive(Clone)]
15pub struct ToolComposite {
16 pub entries: Vec<ToolCompositeEntry>,
18}
19
20pub type TransformFn = Arc<
22 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
23 + Send
24 + Sync,
25>;
26
27#[derive(Clone)]
29pub enum ToolCompositeEntry {
30 Function(Arc<dyn ToolFunction>),
32 BuiltIn(Tool),
34 Agent {
36 name: String,
38 description: String,
40 agent: Arc<dyn TextAgent>,
42 },
43 Mcp {
45 params: String,
47 },
48 A2a {
50 url: String,
52 skill: String,
54 },
55 Mock {
57 name: String,
59 description: String,
61 response: serde_json::Value,
63 },
64 OpenApi {
66 name: String,
68 spec_url: String,
70 },
71 Search {
73 name: String,
75 description: String,
77 },
78 Schema {
80 name: String,
82 schema: serde_json::Value,
84 },
85 Transform {
87 inner: Box<ToolCompositeEntry>,
89 transformer: TransformFn,
91 },
92}
93
94impl ToolComposite {
95 pub fn from_function(f: Arc<dyn ToolFunction>) -> Self {
97 Self {
98 entries: vec![ToolCompositeEntry::Function(f)],
99 }
100 }
101
102 pub fn from_built_in(tool: Tool) -> Self {
104 Self {
105 entries: vec![ToolCompositeEntry::BuiltIn(tool)],
106 }
107 }
108
109 pub fn len(&self) -> usize {
111 self.entries.len()
112 }
113
114 pub fn is_empty(&self) -> bool {
116 self.entries.is_empty()
117 }
118
119 fn map_function_policy(
126 mut self,
127 f: impl Fn(ToolPolicy) -> ToolPolicy + Send + Sync + 'static,
128 ) -> Self {
129 self.entries = self
130 .entries
131 .into_iter()
132 .map(|entry| match entry {
133 ToolCompositeEntry::Function(func) => {
134 let policy = f(ToolPolicy::new());
135 ToolCompositeEntry::Function(PolicyTool::wrap(func, policy))
136 }
137 other => other,
138 })
139 .collect();
140 self
141 }
142}
143
144impl std::ops::BitOr for ToolComposite {
146 type Output = ToolComposite;
147
148 fn bitor(mut self, rhs: ToolComposite) -> Self::Output {
149 self.entries.extend(rhs.entries);
150 self
151 }
152}
153
154pub struct T;
156
157impl T {
158 pub fn function(f: Arc<dyn ToolFunction>) -> ToolComposite {
160 ToolComposite::from_function(f)
161 }
162
163 pub fn google_search() -> ToolComposite {
165 ToolComposite::from_built_in(Tool::google_search())
166 }
167
168 pub fn url_context() -> ToolComposite {
170 ToolComposite::from_built_in(Tool::url_context())
171 }
172
173 pub fn code_execution() -> ToolComposite {
175 ToolComposite::from_built_in(Tool::code_execution())
176 }
177
178 pub fn simple<F, Fut>(
180 name: impl Into<String>,
181 description: impl Into<String>,
182 f: F,
183 ) -> ToolComposite
184 where
185 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
186 Fut: Future<Output = Result<serde_json::Value, gemini_adk_rs::ToolError>> + Send + 'static,
187 {
188 let tool = SimpleTool::new(name, description, None, f);
189 ToolComposite::from_function(Arc::new(tool))
190 }
191
192 pub fn fn_tool<F, Fut>(
196 name: impl Into<String>,
197 description: impl Into<String>,
198 f: F,
199 ) -> ToolComposite
200 where
201 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
202 Fut: Future<Output = Result<serde_json::Value, gemini_adk_rs::ToolError>> + Send + 'static,
203 {
204 Self::simple(name, description, f)
205 }
206
207 pub fn confirm(tool: ToolComposite, message: &str) -> ToolComposite {
214 let msg = if message.is_empty() {
215 None
216 } else {
217 Some(message.to_string())
218 };
219 tool.map_function_policy(move |p| p.with_confirm(msg.clone()))
220 }
221
222 pub fn timeout(tool: ToolComposite, duration: std::time::Duration) -> ToolComposite {
228 tool.map_function_policy(move |p| p.with_timeout(duration))
229 }
230
231 pub fn cached(tool: ToolComposite) -> ToolComposite {
237 tool.map_function_policy(|p| p.with_cache())
238 }
239
240 pub fn toolset(tools: Vec<Arc<dyn ToolFunction>>) -> ToolComposite {
242 ToolComposite {
243 entries: tools
244 .into_iter()
245 .map(ToolCompositeEntry::Function)
246 .collect(),
247 }
248 }
249
250 pub fn agent(
255 name: impl Into<String>,
256 description: impl Into<String>,
257 agent: impl TextAgent + 'static,
258 ) -> ToolComposite {
259 ToolComposite {
260 entries: vec![ToolCompositeEntry::Agent {
261 name: name.into(),
262 description: description.into(),
263 agent: Arc::new(agent),
264 }],
265 }
266 }
267
268 pub fn mcp(params: impl Into<String>) -> ToolComposite {
273 ToolComposite {
274 entries: vec![ToolCompositeEntry::Mcp {
275 params: params.into(),
276 }],
277 }
278 }
279
280 pub fn a2a(url: impl Into<String>, skill: impl Into<String>) -> ToolComposite {
284 ToolComposite {
285 entries: vec![ToolCompositeEntry::A2a {
286 url: url.into(),
287 skill: skill.into(),
288 }],
289 }
290 }
291
292 pub fn mock(
296 name: impl Into<String>,
297 description: impl Into<String>,
298 response: serde_json::Value,
299 ) -> ToolComposite {
300 ToolComposite {
301 entries: vec![ToolCompositeEntry::Mock {
302 name: name.into(),
303 description: description.into(),
304 response,
305 }],
306 }
307 }
308
309 pub fn openapi(name: impl Into<String>, spec_url: impl Into<String>) -> ToolComposite {
314 ToolComposite {
315 entries: vec![ToolCompositeEntry::OpenApi {
316 name: name.into(),
317 spec_url: spec_url.into(),
318 }],
319 }
320 }
321
322 pub fn search(name: impl Into<String>, description: impl Into<String>) -> ToolComposite {
326 ToolComposite {
327 entries: vec![ToolCompositeEntry::Search {
328 name: name.into(),
329 description: description.into(),
330 }],
331 }
332 }
333
334 pub fn schema(name: impl Into<String>, schema: serde_json::Value) -> ToolComposite {
338 ToolComposite {
339 entries: vec![ToolCompositeEntry::Schema {
340 name: name.into(),
341 schema,
342 }],
343 }
344 }
345
346 pub fn transform<F, Fut>(tool: ToolComposite, f: F) -> ToolComposite
351 where
352 F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
353 Fut: Future<Output = serde_json::Value> + Send + 'static,
354 {
355 let f: TransformFn = Arc::new(
356 move |v: serde_json::Value| -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> {
357 Box::pin(f(v))
358 },
359 );
360 ToolComposite {
361 entries: tool
362 .entries
363 .into_iter()
364 .map(|entry| ToolCompositeEntry::Transform {
365 inner: Box::new(entry),
366 transformer: Arc::clone(&f),
367 })
368 .collect(),
369 }
370 }
371}
372
373#[derive(Clone, Debug)]
379pub enum DeferredTool {
380 Mcp {
382 params: String,
384 },
385 A2a {
387 url: String,
389 skill: String,
391 },
392 OpenApi {
394 name: String,
396 spec_url: String,
398 },
399 Search {
401 name: String,
403 description: String,
405 },
406}
407
408pub(crate) enum ToolResolution {
414 Runtime(Arc<dyn ToolFunction>),
416 BuiltIn(Tool),
418 Agent {
420 name: String,
422 description: String,
424 agent: Arc<dyn TextAgent>,
426 },
427 Deferred(DeferredTool),
429}
430
431impl ToolCompositeEntry {
432 pub(crate) fn classify(self) -> ToolResolution {
435 match self {
436 ToolCompositeEntry::Function(f) => ToolResolution::Runtime(f),
437 ToolCompositeEntry::BuiltIn(t) => ToolResolution::BuiltIn(t),
438 ToolCompositeEntry::Agent {
439 name,
440 description,
441 agent,
442 } => ToolResolution::Agent {
443 name,
444 description,
445 agent,
446 },
447 ToolCompositeEntry::Mock {
448 name,
449 description,
450 response,
451 } => ToolResolution::Runtime(Arc::new(SimpleTool::new(
452 name,
453 description,
454 None,
455 move |_args| {
456 let r = response.clone();
457 async move { Ok(r) }
458 },
459 ))),
460 ToolCompositeEntry::Transform { inner, transformer } => match inner.classify() {
461 ToolResolution::Runtime(f) => ToolResolution::Runtime(Arc::new(TransformTool {
462 inner: f,
463 transformer,
464 })),
465 other => other,
469 },
470 ToolCompositeEntry::Schema { name, schema } => {
471 ToolResolution::BuiltIn(Tool::functions(vec![FunctionDeclaration {
474 name,
475 description: String::new(),
476 parameters: Some(schema),
477 behavior: None,
478 }]))
479 }
480 ToolCompositeEntry::Mcp { params } => {
481 ToolResolution::Deferred(DeferredTool::Mcp { params })
482 }
483 ToolCompositeEntry::A2a { url, skill } => {
484 ToolResolution::Deferred(DeferredTool::A2a { url, skill })
485 }
486 ToolCompositeEntry::OpenApi { name, spec_url } => {
487 ToolResolution::Deferred(DeferredTool::OpenApi { name, spec_url })
488 }
489 ToolCompositeEntry::Search { name, description } => {
490 ToolResolution::Deferred(DeferredTool::Search { name, description })
491 }
492 }
493 }
494}
495
496struct TransformTool {
498 inner: Arc<dyn ToolFunction>,
499 #[allow(clippy::type_complexity)]
500 transformer: Arc<
501 dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
502 + Send
503 + Sync,
504 >,
505}
506
507#[async_trait::async_trait]
508impl ToolFunction for TransformTool {
509 fn name(&self) -> &str {
510 self.inner.name()
511 }
512
513 fn description(&self) -> &str {
514 self.inner.description()
515 }
516
517 fn parameters(&self) -> Option<serde_json::Value> {
518 self.inner.parameters()
519 }
520
521 async fn call(
522 &self,
523 args: serde_json::Value,
524 ) -> Result<serde_json::Value, gemini_adk_rs::error::ToolError> {
525 let result = self.inner.call(args).await?;
526 Ok((self.transformer)(result).await)
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 fn classify_one(c: ToolComposite) -> ToolResolution {
536 c.entries.into_iter().next().unwrap().classify()
537 }
538
539 #[test]
540 fn classify_maps_every_variant() {
541 assert!(matches!(
543 classify_one(T::mock("m", "d", serde_json::json!({"ok": true}))),
544 ToolResolution::Runtime(_)
545 ));
546 assert!(matches!(
547 classify_one(T::simple("s", "d", |a| async move { Ok(a) })),
548 ToolResolution::Runtime(_)
549 ));
550 assert!(matches!(
552 classify_one(T::google_search()),
553 ToolResolution::BuiltIn(_)
554 ));
555 assert!(matches!(
556 classify_one(T::schema("s", serde_json::json!({"type": "object"}))),
557 ToolResolution::BuiltIn(_)
558 ));
559 assert!(matches!(
561 classify_one(T::mcp("node ./server.js")),
562 ToolResolution::Deferred(DeferredTool::Mcp { .. })
563 ));
564 assert!(matches!(
565 classify_one(T::a2a("http://x", "skill")),
566 ToolResolution::Deferred(DeferredTool::A2a { .. })
567 ));
568 assert!(matches!(
569 classify_one(T::openapi("o", "http://x/openapi.json")),
570 ToolResolution::Deferred(DeferredTool::OpenApi { .. })
571 ));
572 assert!(matches!(
573 classify_one(T::search("s", "d")),
574 ToolResolution::Deferred(DeferredTool::Search { .. })
575 ));
576 }
577
578 #[tokio::test]
579 async fn mock_resolves_to_callable_runtime_tool() {
580 let resolution = classify_one(T::mock(
581 "weather",
582 "Mock weather",
583 serde_json::json!({"temp": 22}),
584 ));
585 let ToolResolution::Runtime(tool) = resolution else {
586 panic!("mock should resolve to a runtime tool");
587 };
588 assert_eq!(tool.name(), "weather");
589 let out = tool.call(serde_json::json!({})).await.unwrap();
590 assert_eq!(out, serde_json::json!({"temp": 22}));
591 }
592
593 #[tokio::test]
594 async fn transform_wraps_inner_runtime_result() {
595 let composite = T::transform(
596 T::mock("base", "d", serde_json::json!({"n": 1})),
597 |mut v| async move {
598 v["doubled"] = serde_json::json!(true);
599 v
600 },
601 );
602 let ToolResolution::Runtime(tool) = classify_one(composite) else {
603 panic!("transform over a mock should resolve to a runtime tool");
604 };
605 assert_eq!(tool.name(), "base");
606 let out = tool.call(serde_json::json!({})).await.unwrap();
607 assert_eq!(out, serde_json::json!({"n": 1, "doubled": true}));
608 }
609
610 #[test]
611 fn google_search_creates_composite() {
612 let t = T::google_search();
613 assert_eq!(t.len(), 1);
614 }
615
616 #[test]
617 fn url_context_creates_composite() {
618 let t = T::url_context();
619 assert_eq!(t.len(), 1);
620 }
621
622 #[test]
623 fn code_execution_creates_composite() {
624 let t = T::code_execution();
625 assert_eq!(t.len(), 1);
626 }
627
628 #[test]
629 fn compose_with_bitor() {
630 let t = T::google_search() | T::url_context() | T::code_execution();
631 assert_eq!(t.len(), 3);
632 }
633
634 #[test]
635 fn simple_creates_tool() {
636 let t = T::simple("greet", "Greets the user", |_args| async {
637 Ok(serde_json::json!({"message": "hello"}))
638 });
639 assert_eq!(t.len(), 1);
640 match &t.entries[0] {
641 ToolCompositeEntry::Function(f) => assert_eq!(f.name(), "greet"),
642 _ => panic!("expected Function entry"),
643 }
644 }
645
646 #[tokio::test]
647 async fn timeout_modifier_enforces_timeout() {
648 use gemini_adk_rs::ToolError;
649 use std::time::Duration;
650
651 let t = T::timeout(
652 T::simple("slow", "slow tool", |_| async move {
653 tokio::time::sleep(Duration::from_secs(3600)).await;
654 Ok(serde_json::json!({"ok": true}))
655 }),
656 Duration::from_millis(50),
657 );
658 match &t.entries[0] {
659 ToolCompositeEntry::Function(f) => match f.call(serde_json::json!({})).await {
660 Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
661 other => panic!("expected Timeout, got {other:?}"),
662 },
663 _ => panic!("expected Function entry"),
664 }
665 }
666
667 #[tokio::test]
668 async fn cached_modifier_memoizes_results() {
669 use std::sync::atomic::{AtomicU32, Ordering};
670
671 let counter = Arc::new(AtomicU32::new(0));
672 let c = counter.clone();
673 let t = T::cached(T::simple("count", "counts calls", move |_| {
674 let c = c.clone();
675 async move {
676 let n = c.fetch_add(1, Ordering::SeqCst) + 1;
677 Ok(serde_json::json!({"n": n}))
678 }
679 }));
680 match &t.entries[0] {
681 ToolCompositeEntry::Function(f) => {
682 let first = f.call(serde_json::json!({"x": 1})).await.unwrap();
683 let second = f.call(serde_json::json!({"x": 1})).await.unwrap();
684 assert_eq!(first, second);
685 assert_eq!(first["n"], 1);
686 assert_eq!(counter.load(Ordering::SeqCst), 1);
687 }
688 _ => panic!("expected Function entry"),
689 }
690 }
691
692 #[test]
693 fn confirm_modifier_wraps_function() {
694 let t = T::confirm(
697 T::simple("danger", "dangerous", |_| async move {
698 Ok(serde_json::json!({}))
699 }),
700 "are you sure?",
701 );
702 match &t.entries[0] {
703 ToolCompositeEntry::Function(f) => assert_eq!(f.name(), "danger"),
704 _ => panic!("expected Function entry"),
705 }
706 }
707
708 #[test]
709 fn toolset_combines_functions() {
710 let tool_a: Arc<dyn ToolFunction> =
711 Arc::new(SimpleTool::new("a", "tool a", None, |_| async {
712 Ok(serde_json::json!(null))
713 }));
714 let tool_b: Arc<dyn ToolFunction> =
715 Arc::new(SimpleTool::new("b", "tool b", None, |_| async {
716 Ok(serde_json::json!(null))
717 }));
718 let t = T::toolset(vec![tool_a, tool_b]);
719 assert_eq!(t.len(), 2);
720 }
721}