1use std::sync::Arc;
9
10use dashmap::DashMap;
11use serde_json::Value;
12use tokio::task::JoinHandle;
13use tokio_util::sync::CancellationToken;
14
15use gemini_genai_rs::prelude::FunctionCall;
16
17use crate::error::ToolError;
18
19pub trait ResultFormatter: Send + Sync + 'static {
29 fn format_running(&self, call: &FunctionCall) -> Value;
31
32 fn format_result(&self, call: &FunctionCall, result: Result<Value, ToolError>) -> Value;
34
35 fn format_cancelled(&self, call_id: &str) -> Value;
37}
38
39pub struct DefaultResultFormatter;
53
54impl ResultFormatter for DefaultResultFormatter {
55 fn format_running(&self, call: &FunctionCall) -> Value {
56 serde_json::json!({
57 "status": "running",
58 "tool": call.name,
59 })
60 }
61
62 fn format_result(&self, call: &FunctionCall, result: Result<Value, ToolError>) -> Value {
63 match result {
64 Ok(value) => serde_json::json!({
65 "status": "completed",
66 "tool": call.name,
67 "result": value,
68 }),
69 Err(e) => serde_json::json!({
70 "status": "error",
71 "tool": call.name,
72 "error": e.to_string(),
73 }),
74 }
75 }
76
77 fn format_cancelled(&self, call_id: &str) -> Value {
78 serde_json::json!({
79 "status": "cancelled",
80 "call_id": call_id,
81 })
82 }
83}
84
85#[derive(Clone, Default)]
124pub enum ToolExecutionMode {
125 #[default]
127 Standard,
128
129 Background {
140 formatter: Option<Arc<dyn ResultFormatter>>,
142 scheduling: Option<gemini_genai_rs::prelude::FunctionResponseScheduling>,
144 },
145}
146
147impl std::fmt::Debug for ToolExecutionMode {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Self::Standard => write!(f, "Standard"),
151 Self::Background {
152 formatter,
153 scheduling,
154 } => {
155 write!(
156 f,
157 "Background(formatter={}, scheduling={:?})",
158 formatter.is_some(),
159 scheduling
160 )
161 }
162 }
163 }
164}
165
166pub struct BackgroundToolTracker {
176 tasks: DashMap<String, (JoinHandle<()>, CancellationToken)>,
177}
178
179impl BackgroundToolTracker {
180 pub fn new() -> Self {
182 Self {
183 tasks: DashMap::new(),
184 }
185 }
186
187 pub fn spawn(&self, call_id: String, task: JoinHandle<()>, cancel: CancellationToken) {
194 self.tasks.insert(call_id, (task, cancel));
195 }
196
197 pub fn cancel(&self, call_ids: &[String]) {
203 for id in call_ids {
204 if let Some((_, (handle, token))) = self.tasks.remove(id) {
205 token.cancel();
206 handle.abort();
207 }
208 }
209 }
210
211 pub fn cancel_all(&self) {
215 let keys: Vec<String> = self.tasks.iter().map(|r| r.key().clone()).collect();
216 for key in keys {
217 if let Some((_, (handle, token))) = self.tasks.remove(&key) {
218 token.cancel();
219 handle.abort();
220 }
221 }
222 }
223
224 pub fn active_ids(&self) -> Vec<String> {
226 self.tasks.iter().map(|r| r.key().clone()).collect()
227 }
228
229 pub fn remove(&self, call_id: &str) {
234 self.tasks.remove(call_id);
235 }
236
237 pub fn active_count(&self) -> usize {
239 self.tasks.len()
240 }
241}
242
243impl Default for BackgroundToolTracker {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
262 fn tracker_new_is_empty() {
263 let tracker = BackgroundToolTracker::new();
264 assert_eq!(tracker.active_count(), 0);
265 assert!(tracker.active_ids().is_empty());
266 }
267
268 #[tokio::test]
269 async fn spawn_shows_active_id() {
270 let tracker = BackgroundToolTracker::new();
271 let token = CancellationToken::new();
272 let t = token.clone();
273 let handle = tokio::spawn(async move {
274 t.cancelled().await;
275 });
276 tracker.spawn("call1".into(), handle, token.clone());
277
278 let ids = tracker.active_ids();
279 assert_eq!(ids, vec!["call1".to_string()]);
280
281 token.cancel();
283 }
284
285 #[tokio::test]
286 async fn spawn_increments_active_count() {
287 let tracker = BackgroundToolTracker::new();
288
289 let token1 = CancellationToken::new();
290 let t1 = token1.clone();
291 let h1 = tokio::spawn(async move {
292 t1.cancelled().await;
293 });
294 tracker.spawn("call1".into(), h1, token1.clone());
295
296 let token2 = CancellationToken::new();
297 let t2 = token2.clone();
298 let h2 = tokio::spawn(async move {
299 t2.cancelled().await;
300 });
301 tracker.spawn("call2".into(), h2, token2.clone());
302
303 assert_eq!(tracker.active_count(), 2);
304
305 token1.cancel();
307 token2.cancel();
308 }
309
310 #[tokio::test]
311 async fn cancel_removes_task_and_cancels_token() {
312 let tracker = BackgroundToolTracker::new();
313 let token = CancellationToken::new();
314 let t = token.clone();
315 let handle = tokio::spawn(async move {
316 t.cancelled().await;
317 });
318 tracker.spawn("call1".into(), handle, token.clone());
319
320 assert_eq!(tracker.active_count(), 1);
321
322 tracker.cancel(&["call1".into()]);
323
324 assert_eq!(tracker.active_count(), 0);
325 assert!(token.is_cancelled());
326 }
327
328 #[tokio::test]
329 async fn cancel_all_clears_all_tasks() {
330 let tracker = BackgroundToolTracker::new();
331
332 let token1 = CancellationToken::new();
333 let t1 = token1.clone();
334 let h1 = tokio::spawn(async move {
335 t1.cancelled().await;
336 });
337 tracker.spawn("call1".into(), h1, token1.clone());
338
339 let token2 = CancellationToken::new();
340 let t2 = token2.clone();
341 let h2 = tokio::spawn(async move {
342 t2.cancelled().await;
343 });
344 tracker.spawn("call2".into(), h2, token2.clone());
345
346 let token3 = CancellationToken::new();
347 let t3 = token3.clone();
348 let h3 = tokio::spawn(async move {
349 t3.cancelled().await;
350 });
351 tracker.spawn("call3".into(), h3, token3.clone());
352
353 assert_eq!(tracker.active_count(), 3);
354
355 tracker.cancel_all();
356
357 assert_eq!(tracker.active_count(), 0);
358 assert!(token1.is_cancelled());
359 assert!(token2.is_cancelled());
360 assert!(token3.is_cancelled());
361 }
362
363 #[tokio::test]
364 async fn remove_cleans_up_completed_task() {
365 let tracker = BackgroundToolTracker::new();
366 let token = CancellationToken::new();
367 let t = token.clone();
368 let handle = tokio::spawn(async move {
369 t.cancelled().await;
370 });
371 tracker.spawn("call1".into(), handle, token.clone());
372
373 assert_eq!(tracker.active_count(), 1);
374
375 tracker.remove("call1");
376
377 assert_eq!(tracker.active_count(), 0);
378 assert!(tracker.active_ids().is_empty());
379
380 token.cancel();
382 }
383
384 #[test]
385 fn cancel_nonexistent_id_is_noop() {
386 let tracker = BackgroundToolTracker::new();
387 tracker.cancel(&["nonexistent".into()]);
389 assert_eq!(tracker.active_count(), 0);
390 }
391
392 fn make_call(name: &str) -> FunctionCall {
397 FunctionCall {
398 name: name.to_string(),
399 args: serde_json::json!({"query": "test"}),
400 id: Some("fc_123".to_string()),
401 }
402 }
403
404 #[test]
405 fn format_running_output() {
406 let fmt = DefaultResultFormatter;
407 let call = make_call("search");
408 let result = fmt.format_running(&call);
409
410 assert_eq!(result["status"], "running");
411 assert_eq!(result["tool"], "search");
412 }
413
414 #[test]
415 fn format_result_ok() {
416 let fmt = DefaultResultFormatter;
417 let call = make_call("search");
418 let value = serde_json::json!({"items": [1, 2, 3]});
419 let result = fmt.format_result(&call, Ok(value.clone()));
420
421 assert_eq!(result["status"], "completed");
422 assert_eq!(result["tool"], "search");
423 assert_eq!(result["result"], value);
424 }
425
426 #[test]
427 fn format_result_err() {
428 let fmt = DefaultResultFormatter;
429 let call = make_call("search");
430 let err = ToolError::ExecutionFailed("connection timeout".into());
431 let result = fmt.format_result(&call, Err(err));
432
433 assert_eq!(result["status"], "error");
434 assert_eq!(result["tool"], "search");
435 assert!(result["error"]
436 .as_str()
437 .unwrap()
438 .contains("connection timeout"));
439 }
440
441 #[test]
442 fn format_cancelled_output() {
443 let fmt = DefaultResultFormatter;
444 let result = fmt.format_cancelled("fc_456");
445
446 assert_eq!(result["status"], "cancelled");
447 assert_eq!(result["call_id"], "fc_456");
448 }
449
450 #[test]
455 fn tool_execution_mode_default_is_standard() {
456 let mode = ToolExecutionMode::default();
457 assert!(matches!(mode, ToolExecutionMode::Standard));
458 }
459
460 #[test]
461 fn tool_execution_mode_debug_standard() {
462 let mode = ToolExecutionMode::Standard;
463 assert_eq!(format!("{:?}", mode), "Standard");
464 }
465
466 #[test]
467 fn tool_execution_mode_debug_background_none() {
468 let mode = ToolExecutionMode::Background {
469 formatter: None,
470 scheduling: None,
471 };
472 assert_eq!(
473 format!("{:?}", mode),
474 "Background(formatter=false, scheduling=None)"
475 );
476 }
477
478 #[test]
479 fn tool_execution_mode_debug_background_some() {
480 let mode = ToolExecutionMode::Background {
481 formatter: Some(Arc::new(DefaultResultFormatter)),
482 scheduling: None,
483 };
484 assert_eq!(
485 format!("{:?}", mode),
486 "Background(formatter=true, scheduling=None)"
487 );
488 }
489}