gemini_adk_rs/middleware/
latency.rs

1//! Latency tracking middleware for tool calls.
2
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6
7use gemini_genai_rs::prelude::FunctionCall;
8
9use super::Middleware;
10use crate::error::{AgentError, ToolError};
11
12/// A recorded tool-call latency measurement.
13#[derive(Debug, Clone)]
14pub struct ToolLatency {
15    /// Tool name.
16    pub name: String,
17    /// Elapsed wall-clock time.
18    pub elapsed: Duration,
19    /// Whether the tool call succeeded.
20    pub success: bool,
21}
22
23/// Middleware that records latency metrics for tool calls.
24///
25/// Stores `ToolLatency` entries that can be retrieved via [`LatencyMiddleware::tool_latencies`].
26/// Thread-safe and suitable for use across async tasks.
27///
28/// # Examples
29///
30/// ```rust,ignore
31/// use gemini_adk_rs::middleware::{LatencyMiddleware, Middleware};
32/// use gemini_genai_rs::prelude::FunctionCall;
33///
34/// let lat = LatencyMiddleware::new();
35/// // In an async context:
36/// lat.before_tool(&call).await.unwrap();
37/// // ... tool executes ...
38/// lat.after_tool(&call, &result).await.unwrap();
39/// let records = lat.tool_latencies();
40/// println!("Tool {} took {:?}", records[0].name, records[0].elapsed);
41/// ```
42pub struct LatencyMiddleware {
43    /// In-flight tool start times, keyed by tool name.
44    /// Multiple concurrent calls to the same tool name will overwrite,
45    /// but this is acceptable for metrics collection.
46    in_flight: parking_lot::Mutex<std::collections::HashMap<String, Instant>>,
47    /// Completed tool latency records.
48    records: parking_lot::Mutex<Vec<ToolLatency>>,
49}
50
51impl LatencyMiddleware {
52    /// Create a new latency middleware with empty records.
53    pub fn new() -> Self {
54        Self {
55            in_flight: parking_lot::Mutex::new(std::collections::HashMap::new()),
56            records: parking_lot::Mutex::new(Vec::new()),
57        }
58    }
59
60    /// Returns a snapshot of all recorded tool latencies.
61    pub fn tool_latencies(&self) -> Vec<ToolLatency> {
62        self.records.lock().clone()
63    }
64
65    /// Clears all recorded latencies and in-flight state.
66    pub fn clear(&self) {
67        self.in_flight.lock().clear();
68        self.records.lock().clear();
69    }
70}
71
72impl Default for LatencyMiddleware {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78#[async_trait]
79impl Middleware for LatencyMiddleware {
80    fn name(&self) -> &str {
81        "latency"
82    }
83
84    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
85        self.in_flight
86            .lock()
87            .insert(call.name.clone(), Instant::now());
88        Ok(())
89    }
90
91    async fn after_tool(
92        &self,
93        call: &FunctionCall,
94        _result: &serde_json::Value,
95    ) -> Result<(), AgentError> {
96        let elapsed = self
97            .in_flight
98            .lock()
99            .remove(&call.name)
100            .map(|start| start.elapsed())
101            .unwrap_or_default();
102        self.records.lock().push(ToolLatency {
103            name: call.name.clone(),
104            elapsed,
105            success: true,
106        });
107        Ok(())
108    }
109
110    async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
111        let elapsed = self
112            .in_flight
113            .lock()
114            .remove(&call.name)
115            .map(|start| start.elapsed())
116            .unwrap_or_default();
117        self.records.lock().push(ToolLatency {
118            name: call.name.clone(),
119            elapsed,
120            success: false,
121        });
122        Ok(())
123    }
124}