gemini_adk_fluent_rs/compose/
tools.rs

1//! T — Tool composition.
2//!
3//! Compose tools in any order with `|`.
4
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use gemini_adk_rs::text::TextAgent;
10use gemini_adk_rs::tool::{PolicyTool, SimpleTool, ToolFunction, ToolPolicy};
11use gemini_genai_rs::prelude::{FunctionDeclaration, Tool};
12
13/// A tool composite — one or more tool entries.
14#[derive(Clone)]
15pub struct ToolComposite {
16    /// The tool entries in this composite.
17    pub entries: Vec<ToolCompositeEntry>,
18}
19
20/// Async transformer applied to a tool result value.
21pub type TransformFn = Arc<
22    dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
23        + Send
24        + Sync,
25>;
26
27/// An entry in a tool composite.
28#[derive(Clone)]
29pub enum ToolCompositeEntry {
30    /// A runtime tool function.
31    Function(Arc<dyn ToolFunction>),
32    /// A built-in Gemini tool declaration.
33    BuiltIn(Tool),
34    /// A text agent wrapped as a tool.
35    Agent {
36        /// Tool name exposed to the model.
37        name: String,
38        /// Tool description exposed to the model.
39        description: String,
40        /// The text agent to invoke.
41        agent: Arc<dyn TextAgent>,
42    },
43    /// An MCP (Model Context Protocol) toolset connection.
44    Mcp {
45        /// Connection params (e.g. URL or command string).
46        params: String,
47    },
48    /// A remote agent-to-agent tool.
49    A2a {
50        /// URL of the remote agent.
51        url: String,
52        /// Skill to invoke on the remote agent.
53        skill: String,
54    },
55    /// A mock tool that returns a fixed response (useful for testing).
56    Mock {
57        /// Tool name.
58        name: String,
59        /// Tool description.
60        description: String,
61        /// Fixed response to return.
62        response: serde_json::Value,
63    },
64    /// An OpenAPI spec-driven tool (placeholder/marker).
65    OpenApi {
66        /// Tool name.
67        name: String,
68        /// URL to the OpenAPI spec.
69        spec_url: String,
70    },
71    /// A BM25 search tool (placeholder/marker).
72    Search {
73        /// Tool name.
74        name: String,
75        /// Tool description.
76        description: String,
77    },
78    /// A schema-defined tool (placeholder/marker).
79    Schema {
80        /// Tool name.
81        name: String,
82        /// JSON Schema defining the tool's parameters.
83        schema: serde_json::Value,
84    },
85    /// A tool wrapped with a result transformer.
86    Transform {
87        /// The inner tool entry.
88        inner: Box<ToolCompositeEntry>,
89        /// Transformer function applied to the tool result.
90        transformer: TransformFn,
91    },
92}
93
94impl ToolComposite {
95    /// Create a composite containing a single runtime tool function.
96    pub fn from_function(f: Arc<dyn ToolFunction>) -> Self {
97        Self {
98            entries: vec![ToolCompositeEntry::Function(f)],
99        }
100    }
101
102    /// Create a composite containing a single built-in tool declaration.
103    pub fn from_built_in(tool: Tool) -> Self {
104        Self {
105            entries: vec![ToolCompositeEntry::BuiltIn(tool)],
106        }
107    }
108
109    /// Number of tool entries.
110    pub fn len(&self) -> usize {
111        self.entries.len()
112    }
113
114    /// Whether empty.
115    pub fn is_empty(&self) -> bool {
116        self.entries.is_empty()
117    }
118
119    /// Apply a per-tool [`ToolPolicy`] transform to every function entry.
120    ///
121    /// Each [`ToolCompositeEntry::Function`] is wrapped in a [`PolicyTool`]
122    /// carrying the policy. Successive modifiers nest (e.g. `T::cached(T::timeout(..))`
123    /// applies both timeout and cache), since a `PolicyTool` is itself a
124    /// [`ToolFunction`]. Other entry kinds are left untouched.
125    fn map_function_policy(
126        mut self,
127        f: impl Fn(ToolPolicy) -> ToolPolicy + Send + Sync + 'static,
128    ) -> Self {
129        self.entries = self
130            .entries
131            .into_iter()
132            .map(|entry| match entry {
133                ToolCompositeEntry::Function(func) => {
134                    let policy = f(ToolPolicy::new());
135                    ToolCompositeEntry::Function(PolicyTool::wrap(func, policy))
136                }
137                other => other,
138            })
139            .collect();
140        self
141    }
142}
143
144/// Compose two tool composites with `|`.
145impl std::ops::BitOr for ToolComposite {
146    type Output = ToolComposite;
147
148    fn bitor(mut self, rhs: ToolComposite) -> Self::Output {
149        self.entries.extend(rhs.entries);
150        self
151    }
152}
153
154/// The `T` namespace — static factory methods for tool composition.
155pub struct T;
156
157impl T {
158    /// Register a function tool.
159    pub fn function(f: Arc<dyn ToolFunction>) -> ToolComposite {
160        ToolComposite::from_function(f)
161    }
162
163    /// Add Google Search built-in tool.
164    pub fn google_search() -> ToolComposite {
165        ToolComposite::from_built_in(Tool::google_search())
166    }
167
168    /// Add URL context built-in tool.
169    pub fn url_context() -> ToolComposite {
170        ToolComposite::from_built_in(Tool::url_context())
171    }
172
173    /// Add code execution built-in tool.
174    pub fn code_execution() -> ToolComposite {
175        ToolComposite::from_built_in(Tool::code_execution())
176    }
177
178    /// Create a simple tool from a name, description, and async closure.
179    pub fn simple<F, Fut>(
180        name: impl Into<String>,
181        description: impl Into<String>,
182        f: F,
183    ) -> ToolComposite
184    where
185        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
186        Fut: Future<Output = Result<serde_json::Value, gemini_adk_rs::ToolError>> + Send + 'static,
187    {
188        let tool = SimpleTool::new(name, description, None, f);
189        ToolComposite::from_function(Arc::new(tool))
190    }
191
192    /// Alias for [`simple`](Self::simple) — matches upstream Python `T.fn()`.
193    ///
194    /// Named `fn_tool` because `fn` is a reserved keyword in Rust.
195    pub fn fn_tool<F, Fut>(
196        name: impl Into<String>,
197        description: impl Into<String>,
198        f: F,
199    ) -> ToolComposite
200    where
201        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
202        Fut: Future<Output = Result<serde_json::Value, gemini_adk_rs::ToolError>> + Send + 'static,
203    {
204        Self::simple(name, description, f)
205    }
206
207    /// Require user confirmation before each function tool in the composite runs.
208    ///
209    /// The confirmation flag is recorded on the tool's [`ToolPolicy`] and surfaced
210    /// to the runtime via [`PolicyTool::requires_confirmation`] — it is never
211    /// silently dropped. The `message` becomes the confirmation hint. Built-in and
212    /// placeholder entries are left unchanged.
213    pub fn confirm(tool: ToolComposite, message: &str) -> ToolComposite {
214        let msg = if message.is_empty() {
215            None
216        } else {
217            Some(message.to_string())
218        };
219        tool.map_function_policy(move |p| p.with_confirm(msg.clone()))
220    }
221
222    /// Bound each function tool in the composite by a timeout.
223    ///
224    /// At dispatch the tool's future is raced against the duration; on elapse the
225    /// call returns [`ToolError::Timeout`](gemini_adk_rs::ToolError::Timeout).
226    /// Built-in and placeholder entries are left unchanged.
227    pub fn timeout(tool: ToolComposite, duration: std::time::Duration) -> ToolComposite {
228        tool.map_function_policy(move |p| p.with_timeout(duration))
229    }
230
231    /// Memoize each function tool's successful results.
232    ///
233    /// Results are cached by `(tool name, canonical-JSON args)`; repeat calls with
234    /// identical arguments return the cached value without re-invoking the tool.
235    /// Errors are not cached. Built-in/placeholder entries are left unchanged.
236    pub fn cached(tool: ToolComposite) -> ToolComposite {
237        tool.map_function_policy(|p| p.with_cache())
238    }
239
240    /// Combine multiple tool functions into a single composite.
241    pub fn toolset(tools: Vec<Arc<dyn ToolFunction>>) -> ToolComposite {
242        ToolComposite {
243            entries: tools
244                .into_iter()
245                .map(ToolCompositeEntry::Function)
246                .collect(),
247        }
248    }
249
250    /// Wrap a [`TextAgent`] as a tool (shorthand for creating an agent tool entry).
251    ///
252    /// When invoked, the agent runs via `BaseLlm::generate()` and returns its
253    /// text output as the tool result. State is shared with the parent session.
254    pub fn agent(
255        name: impl Into<String>,
256        description: impl Into<String>,
257        agent: impl TextAgent + 'static,
258    ) -> ToolComposite {
259        ToolComposite {
260            entries: vec![ToolCompositeEntry::Agent {
261                name: name.into(),
262                description: description.into(),
263                agent: Arc::new(agent),
264            }],
265        }
266    }
267
268    /// Create an MCP (Model Context Protocol) toolset entry.
269    ///
270    /// `params` is the connection string (e.g. a URL or command) used to
271    /// establish the MCP session at runtime.
272    pub fn mcp(params: impl Into<String>) -> ToolComposite {
273        ToolComposite {
274            entries: vec![ToolCompositeEntry::Mcp {
275                params: params.into(),
276            }],
277        }
278    }
279
280    /// Create a remote agent-to-agent tool.
281    ///
282    /// Routes tool calls to a remote agent at `url`, invoking the given `skill`.
283    pub fn a2a(url: impl Into<String>, skill: impl Into<String>) -> ToolComposite {
284        ToolComposite {
285            entries: vec![ToolCompositeEntry::A2a {
286                url: url.into(),
287                skill: skill.into(),
288            }],
289        }
290    }
291
292    /// Create a mock tool that returns a fixed response.
293    ///
294    /// Useful for testing and prototyping without real tool implementations.
295    pub fn mock(
296        name: impl Into<String>,
297        description: impl Into<String>,
298        response: serde_json::Value,
299    ) -> ToolComposite {
300        ToolComposite {
301            entries: vec![ToolCompositeEntry::Mock {
302                name: name.into(),
303                description: description.into(),
304                response,
305            }],
306        }
307    }
308
309    /// Create an OpenAPI spec-driven tool (placeholder/marker).
310    ///
311    /// At runtime, the spec at `spec_url` is fetched and used to generate
312    /// tool declarations and HTTP call routing.
313    pub fn openapi(name: impl Into<String>, spec_url: impl Into<String>) -> ToolComposite {
314        ToolComposite {
315            entries: vec![ToolCompositeEntry::OpenApi {
316                name: name.into(),
317                spec_url: spec_url.into(),
318            }],
319        }
320    }
321
322    /// Create a BM25 search tool (placeholder/marker).
323    ///
324    /// Declares a search tool that performs BM25 retrieval at runtime.
325    pub fn search(name: impl Into<String>, description: impl Into<String>) -> ToolComposite {
326        ToolComposite {
327            entries: vec![ToolCompositeEntry::Search {
328                name: name.into(),
329                description: description.into(),
330            }],
331        }
332    }
333
334    /// Create a schema-defined tool (placeholder/marker).
335    ///
336    /// The tool's parameters are defined by the given JSON Schema value.
337    pub fn schema(name: impl Into<String>, schema: serde_json::Value) -> ToolComposite {
338        ToolComposite {
339            entries: vec![ToolCompositeEntry::Schema {
340                name: name.into(),
341                schema,
342            }],
343        }
344    }
345
346    /// Wrap each tool entry in a composite with a result transformer.
347    ///
348    /// The transformer function is applied to the tool's output value before
349    /// it is returned to the model.
350    pub fn transform<F, Fut>(tool: ToolComposite, f: F) -> ToolComposite
351    where
352        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
353        Fut: Future<Output = serde_json::Value> + Send + 'static,
354    {
355        let f: TransformFn = Arc::new(
356            move |v: serde_json::Value| -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>> {
357                Box::pin(f(v))
358            },
359        );
360        ToolComposite {
361            entries: tool
362                .entries
363                .into_iter()
364                .map(|entry| ToolCompositeEntry::Transform {
365                    inner: Box::new(entry),
366                    transformer: Arc::clone(&f),
367                })
368                .collect(),
369        }
370    }
371}
372
373// ── Resolution ────────────────────────────────────────────────────────────
374
375/// A tool entry that needs asynchronous I/O (network or subprocess) to resolve,
376/// and is therefore resolved at connect time rather than when the composite is
377/// built. See [`crate::live::Live`] connection methods.
378#[derive(Clone, Debug)]
379pub enum DeferredTool {
380    /// MCP server connection — a stdio command line or an SSE/HTTP URL.
381    Mcp {
382        /// Connection string: an `http(s)://` URL (SSE) or a command line (stdio).
383        params: String,
384    },
385    /// Remote agent-to-agent skill invocation.
386    A2a {
387        /// URL of the remote agent.
388        url: String,
389        /// Skill to invoke on the remote agent.
390        skill: String,
391    },
392    /// OpenAPI spec-driven toolset — one tool per operation in the spec.
393    OpenApi {
394        /// Toolset name.
395        name: String,
396        /// URL of the OpenAPI document.
397        spec_url: String,
398    },
399    /// Search/retrieval tool.
400    Search {
401        /// Tool name.
402        name: String,
403        /// Tool description.
404        description: String,
405    },
406}
407
408/// The concrete outcome of classifying a single [`ToolCompositeEntry`].
409///
410/// This is the *single* exhaustive mapping from the composable tool algebra to
411/// the runtime; both [`crate::builder::AgentBuilder`] and [`crate::live::Live`]
412/// resolve through it, so no entry can be silently dropped.
413pub(crate) enum ToolResolution {
414    /// A runtime-executable tool function (register with a dispatcher).
415    Runtime(Arc<dyn ToolFunction>),
416    /// A built-in / declaration-only Gemini tool (add to the session config).
417    BuiltIn(Tool),
418    /// A text agent to expose as a tool (needs a shared session `State`).
419    Agent {
420        /// Tool name exposed to the model.
421        name: String,
422        /// Tool description exposed to the model.
423        description: String,
424        /// The text agent to invoke.
425        agent: Arc<dyn TextAgent>,
426    },
427    /// A tool that can only be resolved with async I/O at connect time.
428    Deferred(DeferredTool),
429}
430
431impl ToolCompositeEntry {
432    /// Classify this entry into its concrete [`ToolResolution`]. Exhaustive by
433    /// construction — adding a variant forces every consumer to handle it.
434    pub(crate) fn classify(self) -> ToolResolution {
435        match self {
436            ToolCompositeEntry::Function(f) => ToolResolution::Runtime(f),
437            ToolCompositeEntry::BuiltIn(t) => ToolResolution::BuiltIn(t),
438            ToolCompositeEntry::Agent {
439                name,
440                description,
441                agent,
442            } => ToolResolution::Agent {
443                name,
444                description,
445                agent,
446            },
447            ToolCompositeEntry::Mock {
448                name,
449                description,
450                response,
451            } => ToolResolution::Runtime(Arc::new(SimpleTool::new(
452                name,
453                description,
454                None,
455                move |_args| {
456                    let r = response.clone();
457                    async move { Ok(r) }
458                },
459            ))),
460            ToolCompositeEntry::Transform { inner, transformer } => match inner.classify() {
461                ToolResolution::Runtime(f) => ToolResolution::Runtime(Arc::new(TransformTool {
462                    inner: f,
463                    transformer,
464                })),
465                // A transformer only applies to a runtime function; for any other
466                // inner kind the transform is a no-op and the inner resolution
467                // passes through unchanged.
468                other => other,
469            },
470            ToolCompositeEntry::Schema { name, schema } => {
471                // A declaration-only tool: the model is told the function exists
472                // and the application services the call (e.g. via on_tool_call).
473                ToolResolution::BuiltIn(Tool::functions(vec![FunctionDeclaration {
474                    name,
475                    description: String::new(),
476                    parameters: Some(schema),
477                    behavior: None,
478                }]))
479            }
480            ToolCompositeEntry::Mcp { params } => {
481                ToolResolution::Deferred(DeferredTool::Mcp { params })
482            }
483            ToolCompositeEntry::A2a { url, skill } => {
484                ToolResolution::Deferred(DeferredTool::A2a { url, skill })
485            }
486            ToolCompositeEntry::OpenApi { name, spec_url } => {
487                ToolResolution::Deferred(DeferredTool::OpenApi { name, spec_url })
488            }
489            ToolCompositeEntry::Search { name, description } => {
490                ToolResolution::Deferred(DeferredTool::Search { name, description })
491            }
492        }
493    }
494}
495
496/// A [`ToolFunction`] that applies an async transformer to another tool's result.
497struct TransformTool {
498    inner: Arc<dyn ToolFunction>,
499    #[allow(clippy::type_complexity)]
500    transformer: Arc<
501        dyn Fn(serde_json::Value) -> Pin<Box<dyn Future<Output = serde_json::Value> + Send>>
502            + Send
503            + Sync,
504    >,
505}
506
507#[async_trait::async_trait]
508impl ToolFunction for TransformTool {
509    fn name(&self) -> &str {
510        self.inner.name()
511    }
512
513    fn description(&self) -> &str {
514        self.inner.description()
515    }
516
517    fn parameters(&self) -> Option<serde_json::Value> {
518        self.inner.parameters()
519    }
520
521    async fn call(
522        &self,
523        args: serde_json::Value,
524    ) -> Result<serde_json::Value, gemini_adk_rs::error::ToolError> {
525        let result = self.inner.call(args).await?;
526        Ok((self.transformer)(result).await)
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    /// Classify the single entry of a one-element composite.
535    fn classify_one(c: ToolComposite) -> ToolResolution {
536        c.entries.into_iter().next().unwrap().classify()
537    }
538
539    #[test]
540    fn classify_maps_every_variant() {
541        // Synchronous, runtime-executable.
542        assert!(matches!(
543            classify_one(T::mock("m", "d", serde_json::json!({"ok": true}))),
544            ToolResolution::Runtime(_)
545        ));
546        assert!(matches!(
547            classify_one(T::simple("s", "d", |a| async move { Ok(a) })),
548            ToolResolution::Runtime(_)
549        ));
550        // Built-in / declaration-only.
551        assert!(matches!(
552            classify_one(T::google_search()),
553            ToolResolution::BuiltIn(_)
554        ));
555        assert!(matches!(
556            classify_one(T::schema("s", serde_json::json!({"type": "object"}))),
557            ToolResolution::BuiltIn(_)
558        ));
559        // Async, connect-time.
560        assert!(matches!(
561            classify_one(T::mcp("node ./server.js")),
562            ToolResolution::Deferred(DeferredTool::Mcp { .. })
563        ));
564        assert!(matches!(
565            classify_one(T::a2a("http://x", "skill")),
566            ToolResolution::Deferred(DeferredTool::A2a { .. })
567        ));
568        assert!(matches!(
569            classify_one(T::openapi("o", "http://x/openapi.json")),
570            ToolResolution::Deferred(DeferredTool::OpenApi { .. })
571        ));
572        assert!(matches!(
573            classify_one(T::search("s", "d")),
574            ToolResolution::Deferred(DeferredTool::Search { .. })
575        ));
576    }
577
578    #[tokio::test]
579    async fn mock_resolves_to_callable_runtime_tool() {
580        let resolution = classify_one(T::mock(
581            "weather",
582            "Mock weather",
583            serde_json::json!({"temp": 22}),
584        ));
585        let ToolResolution::Runtime(tool) = resolution else {
586            panic!("mock should resolve to a runtime tool");
587        };
588        assert_eq!(tool.name(), "weather");
589        let out = tool.call(serde_json::json!({})).await.unwrap();
590        assert_eq!(out, serde_json::json!({"temp": 22}));
591    }
592
593    #[tokio::test]
594    async fn transform_wraps_inner_runtime_result() {
595        let composite = T::transform(
596            T::mock("base", "d", serde_json::json!({"n": 1})),
597            |mut v| async move {
598                v["doubled"] = serde_json::json!(true);
599                v
600            },
601        );
602        let ToolResolution::Runtime(tool) = classify_one(composite) else {
603            panic!("transform over a mock should resolve to a runtime tool");
604        };
605        assert_eq!(tool.name(), "base");
606        let out = tool.call(serde_json::json!({})).await.unwrap();
607        assert_eq!(out, serde_json::json!({"n": 1, "doubled": true}));
608    }
609
610    #[test]
611    fn google_search_creates_composite() {
612        let t = T::google_search();
613        assert_eq!(t.len(), 1);
614    }
615
616    #[test]
617    fn url_context_creates_composite() {
618        let t = T::url_context();
619        assert_eq!(t.len(), 1);
620    }
621
622    #[test]
623    fn code_execution_creates_composite() {
624        let t = T::code_execution();
625        assert_eq!(t.len(), 1);
626    }
627
628    #[test]
629    fn compose_with_bitor() {
630        let t = T::google_search() | T::url_context() | T::code_execution();
631        assert_eq!(t.len(), 3);
632    }
633
634    #[test]
635    fn simple_creates_tool() {
636        let t = T::simple("greet", "Greets the user", |_args| async {
637            Ok(serde_json::json!({"message": "hello"}))
638        });
639        assert_eq!(t.len(), 1);
640        match &t.entries[0] {
641            ToolCompositeEntry::Function(f) => assert_eq!(f.name(), "greet"),
642            _ => panic!("expected Function entry"),
643        }
644    }
645
646    #[tokio::test]
647    async fn timeout_modifier_enforces_timeout() {
648        use gemini_adk_rs::ToolError;
649        use std::time::Duration;
650
651        let t = T::timeout(
652            T::simple("slow", "slow tool", |_| async move {
653                tokio::time::sleep(Duration::from_secs(3600)).await;
654                Ok(serde_json::json!({"ok": true}))
655            }),
656            Duration::from_millis(50),
657        );
658        match &t.entries[0] {
659            ToolCompositeEntry::Function(f) => match f.call(serde_json::json!({})).await {
660                Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
661                other => panic!("expected Timeout, got {other:?}"),
662            },
663            _ => panic!("expected Function entry"),
664        }
665    }
666
667    #[tokio::test]
668    async fn cached_modifier_memoizes_results() {
669        use std::sync::atomic::{AtomicU32, Ordering};
670
671        let counter = Arc::new(AtomicU32::new(0));
672        let c = counter.clone();
673        let t = T::cached(T::simple("count", "counts calls", move |_| {
674            let c = c.clone();
675            async move {
676                let n = c.fetch_add(1, Ordering::SeqCst) + 1;
677                Ok(serde_json::json!({"n": n}))
678            }
679        }));
680        match &t.entries[0] {
681            ToolCompositeEntry::Function(f) => {
682                let first = f.call(serde_json::json!({"x": 1})).await.unwrap();
683                let second = f.call(serde_json::json!({"x": 1})).await.unwrap();
684                assert_eq!(first, second);
685                assert_eq!(first["n"], 1);
686                assert_eq!(counter.load(Ordering::SeqCst), 1);
687            }
688            _ => panic!("expected Function entry"),
689        }
690    }
691
692    #[test]
693    fn confirm_modifier_wraps_function() {
694        // confirm() wraps the function (preserving its name) so the policy flag
695        // travels to the runtime rather than being silently dropped.
696        let t = T::confirm(
697            T::simple("danger", "dangerous", |_| async move {
698                Ok(serde_json::json!({}))
699            }),
700            "are you sure?",
701        );
702        match &t.entries[0] {
703            ToolCompositeEntry::Function(f) => assert_eq!(f.name(), "danger"),
704            _ => panic!("expected Function entry"),
705        }
706    }
707
708    #[test]
709    fn toolset_combines_functions() {
710        let tool_a: Arc<dyn ToolFunction> =
711            Arc::new(SimpleTool::new("a", "tool a", None, |_| async {
712                Ok(serde_json::json!(null))
713            }));
714        let tool_b: Arc<dyn ToolFunction> =
715            Arc::new(SimpleTool::new("b", "tool b", None, |_| async {
716                Ok(serde_json::json!(null))
717            }));
718        let t = T::toolset(vec![tool_a, tool_b]);
719        assert_eq!(t.len(), 2);
720    }
721}