1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio_util::sync::CancellationToken;
8
9use gemini_genai_rs::prelude::{FunctionCall, FunctionDeclaration, FunctionResponse, Tool};
10
11use crate::error::ToolError;
12
13use super::{ActiveStreamingTool, ToolClass, ToolFunction, ToolKind, DEFAULT_TOOL_TIMEOUT};
14
15pub struct ToolDispatcher {
17 tools: HashMap<String, ToolKind>,
18 active: Arc<tokio::sync::Mutex<HashMap<String, ActiveStreamingTool>>>,
19 default_timeout: Duration,
20 cached_declarations: std::sync::OnceLock<Vec<Tool>>,
22}
23
24impl ToolDispatcher {
25 pub fn new() -> Self {
40 Self {
41 tools: HashMap::new(),
42 active: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
43 default_timeout: DEFAULT_TOOL_TIMEOUT,
44 cached_declarations: std::sync::OnceLock::new(),
45 }
46 }
47
48 pub fn with_timeout(mut self, timeout: Duration) -> Self {
50 self.default_timeout = timeout;
51 self
52 }
53
54 pub fn default_timeout(&self) -> Duration {
56 self.default_timeout
57 }
58
59 pub fn register(&mut self, tool: impl ToolFunction) {
61 let tool = Arc::new(tool);
62 self.tools
63 .insert(tool.name().to_string(), ToolKind::Function(tool));
64 }
65
66 pub fn register_function(&mut self, tool: Arc<dyn ToolFunction>) {
68 self.tools
69 .insert(tool.name().to_string(), ToolKind::Function(tool));
70 }
71
72 pub fn register_streaming(&mut self, tool: Arc<dyn super::StreamingTool>) {
74 self.tools
75 .insert(tool.name().to_string(), ToolKind::Streaming(tool));
76 }
77
78 pub fn register_input_streaming(&mut self, tool: Arc<dyn super::InputStreamingTool>) {
80 self.tools
81 .insert(tool.name().to_string(), ToolKind::InputStream(tool));
82 }
83
84 pub fn get_tool(&self, name: &str) -> Option<&ToolKind> {
86 self.tools.get(name)
87 }
88
89 pub fn classify(&self, name: &str) -> Option<ToolClass> {
91 self.tools.get(name).map(|t| match t {
92 ToolKind::Function(_) => ToolClass::Regular,
93 ToolKind::Streaming(_) => ToolClass::Streaming,
94 ToolKind::InputStream(_) => ToolClass::InputStream,
95 })
96 }
97
98 pub async fn call_function(
100 &self,
101 name: &str,
102 args: serde_json::Value,
103 ) -> Result<serde_json::Value, ToolError> {
104 self.call_function_with_timeout(name, args, self.default_timeout)
105 .await
106 }
107
108 pub async fn call_function_with_timeout(
113 &self,
114 name: &str,
115 args: serde_json::Value,
116 timeout: Duration,
117 ) -> Result<serde_json::Value, ToolError> {
118 let func = match self.tools.get(name) {
119 Some(ToolKind::Function(f)) => f.clone(),
120 Some(_) => {
121 return Err(ToolError::Other(format!(
122 "{name} is not a regular function tool"
123 )))
124 }
125 None => return Err(ToolError::NotFound(name.to_string())),
126 };
127
128 match tokio::time::timeout(timeout, func.call(args)).await {
129 Ok(result) => result,
130 Err(_elapsed) => Err(ToolError::Timeout(timeout)),
131 }
132 }
133
134 pub async fn call_function_with_cancel(
139 &self,
140 name: &str,
141 args: serde_json::Value,
142 cancel: CancellationToken,
143 ) -> Result<serde_json::Value, ToolError> {
144 let func = match self.tools.get(name) {
145 Some(ToolKind::Function(f)) => f.clone(),
146 Some(_) => {
147 return Err(ToolError::Other(format!(
148 "{name} is not a regular function tool"
149 )))
150 }
151 None => return Err(ToolError::NotFound(name.to_string())),
152 };
153
154 tokio::select! {
155 result = func.call(args) => result,
156 _ = cancel.cancelled() => Err(ToolError::Cancelled),
157 }
158 }
159
160 pub fn build_response(
162 call: &FunctionCall,
163 result: Result<serde_json::Value, ToolError>,
164 ) -> FunctionResponse {
165 match result {
166 Ok(value) => FunctionResponse {
167 name: call.name.clone(),
168 response: value,
169 id: call.id.clone(),
170 scheduling: None,
171 },
172 Err(e) => FunctionResponse {
173 name: call.name.clone(),
174 response: serde_json::json!({"error": e.to_string()}),
175 id: call.id.clone(),
176 scheduling: None,
177 },
178 }
179 }
180
181 pub async fn cancel_streaming(&self, name: &str) {
183 let mut active = self.active.lock().await;
184 if let Some(tool) = active.remove(name) {
185 tool.cancel.cancel();
186 tool.task.abort();
187 }
188 }
189
190 pub(crate) async fn store_active(&self, id: String, tool: ActiveStreamingTool) {
192 self.active.lock().await.insert(id, tool);
193 }
194
195 pub async fn cancel_by_ids(&self, ids: &[String]) {
197 let mut active = self.active.lock().await;
198 for id in ids {
199 if let Some(tool) = active.remove(id.as_str()) {
200 tool.cancel.cancel();
201 tool.task.abort();
202 }
203 }
204 }
205
206 pub fn to_tool_declarations(&self) -> Vec<Tool> {
211 self.cached_declarations
212 .get_or_init(|| {
213 let declarations: Vec<FunctionDeclaration> = self
214 .tools
215 .values()
216 .map(|t| {
217 let (name, desc, params) = match t {
218 ToolKind::Function(f) => (f.name(), f.description(), f.parameters()),
219 ToolKind::Streaming(s) => (s.name(), s.description(), s.parameters()),
220 ToolKind::InputStream(i) => (i.name(), i.description(), i.parameters()),
221 };
222 FunctionDeclaration {
223 name: name.to_string(),
224 description: desc.to_string(),
225 parameters: params,
226 behavior: None,
227 }
228 })
229 .collect();
230
231 if declarations.is_empty() {
232 vec![]
233 } else {
234 vec![Tool::functions(declarations)]
235 }
236 })
237 .clone()
238 }
239
240 pub fn len(&self) -> usize {
242 self.tools.len()
243 }
244
245 pub fn is_empty(&self) -> bool {
247 self.tools.is_empty()
248 }
249}
250
251impl Default for ToolDispatcher {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257impl gemini_genai_rs::prelude::ToolProvider for ToolDispatcher {
258 fn declarations(&self) -> Vec<gemini_genai_rs::prelude::Tool> {
259 self.to_tool_declarations()
260 }
261}