gemini_adk_rs/tool/
typed.rs

1//! Type-safe function tool with auto-generated JSON Schema.
2
3use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::de::DeserializeOwned;
8
9use crate::error::ToolError;
10
11use super::ToolFunction;
12
13/// Type-safe function tool with auto-generated JSON Schema.
14///
15/// Unlike [`super::SimpleTool`] which takes raw `serde_json::Value` arguments and
16/// requires a manually written schema, `TypedTool` auto-generates the JSON
17/// Schema from a struct that derives [`schemars::JsonSchema`] and deserializes
18/// the arguments into that struct before calling the handler.
19///
20/// # Example
21///
22/// ```ignore
23/// use schemars::JsonSchema;
24/// use serde::Deserialize;
25///
26/// #[derive(Deserialize, JsonSchema)]
27/// struct WeatherArgs {
28///     /// The city to get weather for
29///     city: String,
30/// }
31///
32/// let tool = TypedTool::new::<WeatherArgs>(
33///     "get_weather",
34///     "Get current weather for a city",
35///     |args: WeatherArgs| async move {
36///         Ok(serde_json::json!({ "temp": 22, "city": args.city }))
37///     },
38/// );
39/// ```
40pub struct TypedTool<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> {
41    name: String,
42    description: String,
43    schema: serde_json::Value,
44    #[allow(clippy::type_complexity)]
45    handler: Box<
46        dyn Fn(
47                T,
48            ) -> std::pin::Pin<
49                Box<dyn std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send>,
50            > + Send
51            + Sync,
52    >,
53    _phantom: PhantomData<T>,
54}
55
56impl<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> TypedTool<T> {
57    /// Create a new typed function tool with auto-generated schema.
58    ///
59    /// The JSON Schema is derived from `T`'s [`JsonSchema`] implementation,
60    /// including any doc-comment descriptions on fields.
61    pub fn new<F, Fut>(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self
62    where
63        F: Fn(T) -> Fut + Send + Sync + 'static,
64        Fut: std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
65    {
66        let root_schema = schemars::schema_for!(T);
67        let schema =
68            serde_json::to_value(root_schema).expect("schemars schema should serialize to JSON");
69
70        Self {
71            name: name.into(),
72            description: description.into(),
73            schema,
74            handler: Box::new(move |args| Box::pin(handler(args))),
75            _phantom: PhantomData,
76        }
77    }
78}
79
80#[async_trait]
81impl<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> ToolFunction for TypedTool<T> {
82    fn name(&self) -> &str {
83        &self.name
84    }
85
86    fn description(&self) -> &str {
87        &self.description
88    }
89
90    fn parameters(&self) -> Option<serde_json::Value> {
91        Some(self.schema.clone())
92    }
93
94    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
95        let typed_args: T = serde_json::from_value(args)
96            .map_err(|e| ToolError::InvalidArgs(format!("Failed to deserialize arguments: {e}")))?;
97        (self.handler)(typed_args).await
98    }
99}