gemini_adk_rs/tool/
typed.rs1use std::marker::PhantomData;
4
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::de::DeserializeOwned;
8
9use crate::error::ToolError;
10
11use super::ToolFunction;
12
13pub struct TypedTool<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> {
41 name: String,
42 description: String,
43 schema: serde_json::Value,
44 #[allow(clippy::type_complexity)]
45 handler: Box<
46 dyn Fn(
47 T,
48 ) -> std::pin::Pin<
49 Box<dyn std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send>,
50 > + Send
51 + Sync,
52 >,
53 _phantom: PhantomData<T>,
54}
55
56impl<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> TypedTool<T> {
57 pub fn new<F, Fut>(name: impl Into<String>, description: impl Into<String>, handler: F) -> Self
62 where
63 F: Fn(T) -> Fut + Send + Sync + 'static,
64 Fut: std::future::Future<Output = Result<serde_json::Value, ToolError>> + Send + 'static,
65 {
66 let root_schema = schemars::schema_for!(T);
67 let schema =
68 serde_json::to_value(root_schema).expect("schemars schema should serialize to JSON");
69
70 Self {
71 name: name.into(),
72 description: description.into(),
73 schema,
74 handler: Box::new(move |args| Box::pin(handler(args))),
75 _phantom: PhantomData,
76 }
77 }
78}
79
80#[async_trait]
81impl<T: DeserializeOwned + JsonSchema + Send + Sync + 'static> ToolFunction for TypedTool<T> {
82 fn name(&self) -> &str {
83 &self.name
84 }
85
86 fn description(&self) -> &str {
87 &self.description
88 }
89
90 fn parameters(&self) -> Option<serde_json::Value> {
91 Some(self.schema.clone())
92 }
93
94 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
95 let typed_args: T = serde_json::from_value(args)
96 .map_err(|e| ToolError::InvalidArgs(format!("Failed to deserialize arguments: {e}")))?;
97 (self.handler)(typed_args).await
98 }
99}