gemini_adk_rs/middleware/
mod.rs

1//! Middleware trait and chain — wraps agent execution at lifecycle points.
2
3pub mod latency;
4pub mod log;
5pub mod retry;
6
7pub use latency::*;
8pub use log::*;
9pub use retry::*;
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use gemini_genai_rs::prelude::FunctionCall;
16
17use crate::context::AgentEvent;
18use crate::context::InvocationContext;
19use crate::error::{AgentError, ToolError};
20use crate::llm::{LlmRequest, LlmResponse};
21
22/// Middleware hooks — all optional, implement only what you need.
23///
24/// # Examples
25///
26/// ```rust,ignore
27/// use async_trait::async_trait;
28/// use gemini_adk_rs::middleware::Middleware;
29/// use gemini_adk_rs::error::AgentError;
30/// use gemini_genai_rs::prelude::FunctionCall;
31///
32/// struct AuditMiddleware;
33///
34/// #[async_trait]
35/// impl Middleware for AuditMiddleware {
36///     fn name(&self) -> &str { "audit" }
37///
38///     async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
39///         println!("Calling tool: {}", call.name);
40///         Ok(())
41///     }
42/// }
43/// ```
44#[async_trait]
45pub trait Middleware: Send + Sync + 'static {
46    /// Unique name for this middleware (used in logging/debugging).
47    fn name(&self) -> &str;
48
49    /// Called before an agent begins execution.
50    async fn before_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
51        Ok(())
52    }
53    /// Called after an agent completes execution.
54    async fn after_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
55        Ok(())
56    }
57
58    /// Called before a tool is invoked.
59    async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
60        Ok(())
61    }
62    /// Called after a tool completes successfully.
63    async fn after_tool(
64        &self,
65        _call: &FunctionCall,
66        _result: &serde_json::Value,
67    ) -> Result<(), AgentError> {
68        Ok(())
69    }
70    /// Called when a tool execution fails.
71    async fn on_tool_error(
72        &self,
73        _call: &FunctionCall,
74        _err: &ToolError,
75    ) -> Result<(), AgentError> {
76        Ok(())
77    }
78
79    /// Called when an agent event is emitted.
80    async fn on_event(&self, _event: &AgentEvent) -> Result<(), AgentError> {
81        Ok(())
82    }
83
84    /// Called when an agent error occurs.
85    async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
86        Ok(())
87    }
88
89    /// Called before an LLM model call is made. Return `Some(LlmResponse)` to skip the LLM call
90    /// and use the returned response instead (e.g., for caching, guardrails). Return `None` to proceed.
91    async fn before_model(&self, _request: &LlmRequest) -> Result<Option<LlmResponse>, AgentError> {
92        Ok(None)
93    }
94
95    /// Called after an LLM model call completes. Return `Some(LlmResponse)` to replace the model's
96    /// response (e.g., for output filtering, safety). Return `None` to use the original response.
97    async fn after_model(
98        &self,
99        _request: &LlmRequest,
100        _response: &LlmResponse,
101    ) -> Result<Option<LlmResponse>, AgentError> {
102        Ok(None)
103    }
104
105    /// Called with the fully-built request *before* it is sent to the model,
106    /// allowing in-place mutation (e.g. trimming or rewriting conversation
107    /// history). Runs ahead of [`Middleware::before_model`]. This mirrors the
108    /// mutable `before_model_callback` request hook in the ADK Python SDK.
109    async fn transform_request(&self, _request: &mut LlmRequest) -> Result<(), AgentError> {
110        Ok(())
111    }
112
113    /// Maximum wall-clock duration this middleware imposes on the agent run.
114    /// `None` (the default) means no limit. The agent enforces the *tightest*
115    /// timeout across its middleware chain by bounding the whole run.
116    fn timeout(&self) -> Option<std::time::Duration> {
117        None
118    }
119}
120
121/// Ordered chain of middleware.
122#[derive(Clone, Default)]
123pub struct MiddlewareChain {
124    layers: Vec<Arc<dyn Middleware>>,
125}
126
127impl MiddlewareChain {
128    /// Create a new empty middleware chain.
129    pub fn new() -> Self {
130        Self::default()
131    }
132
133    /// Append a middleware to the end of the chain.
134    pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
135        self.layers.push(middleware);
136    }
137
138    /// Prepend a middleware to the front of the chain.
139    pub fn prepend(&mut self, middleware: Arc<dyn Middleware>) {
140        self.layers.insert(0, middleware);
141    }
142
143    /// Run all `before_agent` hooks in order.
144    pub async fn run_before_agent(&self, ctx: &InvocationContext) -> Result<(), AgentError> {
145        for m in &self.layers {
146            m.before_agent(ctx).await?;
147        }
148        Ok(())
149    }
150
151    /// Run all `after_agent` hooks in reverse order.
152    pub async fn run_after_agent(&self, ctx: &InvocationContext) -> Result<(), AgentError> {
153        for m in self.layers.iter().rev() {
154            m.after_agent(ctx).await?;
155        }
156        Ok(())
157    }
158
159    /// Run all `before_tool` hooks in order.
160    pub async fn run_before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
161        for m in &self.layers {
162            m.before_tool(call).await?;
163        }
164        Ok(())
165    }
166
167    /// Run all `after_tool` hooks in reverse order.
168    pub async fn run_after_tool(
169        &self,
170        call: &FunctionCall,
171        result: &serde_json::Value,
172    ) -> Result<(), AgentError> {
173        for m in self.layers.iter().rev() {
174            m.after_tool(call, result).await?;
175        }
176        Ok(())
177    }
178
179    /// Run all `on_tool_error` hooks in order.
180    pub async fn run_on_tool_error(
181        &self,
182        call: &FunctionCall,
183        err: &ToolError,
184    ) -> Result<(), AgentError> {
185        for m in &self.layers {
186            m.on_tool_error(call, err).await?;
187        }
188        Ok(())
189    }
190
191    /// Run all `on_event` hooks in order.
192    pub async fn run_on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
193        for m in &self.layers {
194            m.on_event(event).await?;
195        }
196        Ok(())
197    }
198
199    /// Run all `on_error` hooks in order.
200    pub async fn run_on_error(&self, err: &AgentError) -> Result<(), AgentError> {
201        for m in &self.layers {
202            m.on_error(err).await?;
203        }
204        Ok(())
205    }
206
207    /// Run all `transform_request` hooks in order, mutating the request in place.
208    pub async fn run_transform_request(&self, request: &mut LlmRequest) -> Result<(), AgentError> {
209        for m in &self.layers {
210            m.transform_request(request).await?;
211        }
212        Ok(())
213    }
214
215    /// Run all `before_model` hooks in order. Returns the first non-None override response.
216    pub async fn run_before_model(
217        &self,
218        request: &LlmRequest,
219    ) -> Result<Option<LlmResponse>, AgentError> {
220        for m in &self.layers {
221            if let Some(response) = m.before_model(request).await? {
222                return Ok(Some(response));
223            }
224        }
225        Ok(None)
226    }
227
228    /// Run all `after_model` hooks in reverse order. Returns the first non-None override response.
229    pub async fn run_after_model(
230        &self,
231        request: &LlmRequest,
232        response: &LlmResponse,
233    ) -> Result<Option<LlmResponse>, AgentError> {
234        for m in self.layers.iter().rev() {
235            if let Some(replacement) = m.after_model(request, response).await? {
236                return Ok(Some(replacement));
237            }
238        }
239        Ok(None)
240    }
241
242    /// The tightest timeout imposed by any middleware in the chain, if any.
243    pub fn timeout(&self) -> Option<std::time::Duration> {
244        self.layers.iter().filter_map(|m| m.timeout()).min()
245    }
246
247    /// Whether the chain has no middleware layers.
248    pub fn is_empty(&self) -> bool {
249        self.layers.is_empty()
250    }
251
252    /// Number of middleware layers in the chain.
253    pub fn len(&self) -> usize {
254        self.layers.len()
255    }
256}
257
258// ── Tests ────────────────────────────────────────────────────────────────────
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use std::time::Duration;
264
265    // Helper: create a FunctionCall for testing.
266    fn test_call(name: &str) -> FunctionCall {
267        FunctionCall {
268            name: name.to_string(),
269            args: serde_json::json!({"key": "value"}),
270            id: None,
271        }
272    }
273
274    // ── Existing tests ──
275
276    struct CountingMiddleware {
277        call_count: Arc<std::sync::atomic::AtomicU32>,
278    }
279
280    #[async_trait]
281    impl Middleware for CountingMiddleware {
282        fn name(&self) -> &str {
283            "counter"
284        }
285
286        async fn before_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
287            self.call_count
288                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
289            Ok(())
290        }
291    }
292
293    #[test]
294    fn middleware_chain_ordering() {
295        let chain = MiddlewareChain::new();
296        assert!(chain.is_empty());
297        assert_eq!(chain.len(), 0);
298    }
299
300    #[test]
301    fn middleware_is_object_safe() {
302        fn _assert(_: &dyn Middleware) {}
303    }
304
305    #[test]
306    fn add_middleware_to_chain() {
307        let mut chain = MiddlewareChain::new();
308        let counter = Arc::new(CountingMiddleware {
309            call_count: Arc::new(std::sync::atomic::AtomicU32::new(0)),
310        });
311        chain.add(counter);
312        assert_eq!(chain.len(), 1);
313        assert!(!chain.is_empty());
314    }
315
316    #[test]
317    fn chain_is_clone() {
318        let mut chain = MiddlewareChain::new();
319        chain.add(Arc::new(LogMiddleware::new()));
320        let chain2 = chain.clone();
321        assert_eq!(chain2.len(), 1);
322    }
323
324    #[test]
325    fn log_middleware_defaults() {
326        let log = LogMiddleware::new();
327        assert_eq!(log.name(), "log");
328    }
329
330    #[test]
331    fn latency_middleware_defaults() {
332        let lat = LatencyMiddleware::new();
333        assert_eq!(lat.name(), "latency");
334    }
335
336    // ── LogMiddleware tests ──
337
338    #[tokio::test]
339    async fn logging_middleware_doesnt_panic() {
340        let log = LogMiddleware::new();
341        let call = test_call("my_tool");
342        let result = serde_json::json!({"ok": true});
343        let tool_err = ToolError::ExecutionFailed("boom".to_string());
344        let agent_err = AgentError::Other("oops".to_string());
345
346        // All hooks should complete without panic.
347        assert!(log.before_tool(&call).await.is_ok());
348        assert!(log.after_tool(&call, &result).await.is_ok());
349        assert!(log.on_tool_error(&call, &tool_err).await.is_ok());
350        assert!(log.on_error(&agent_err).await.is_ok());
351    }
352
353    // ── LatencyMiddleware tests ──
354
355    #[tokio::test]
356    async fn latency_middleware_records_timing() {
357        let lat = LatencyMiddleware::new();
358        let call = test_call("slow_tool");
359        let result = serde_json::json!("done");
360
361        // Simulate a tool call.
362        lat.before_tool(&call).await.unwrap();
363        // Small delay to ensure non-zero elapsed time.
364        tokio::time::sleep(Duration::from_millis(5)).await;
365        lat.after_tool(&call, &result).await.unwrap();
366
367        let records = lat.tool_latencies();
368        assert_eq!(records.len(), 1);
369        assert_eq!(records[0].name, "slow_tool");
370        assert!(records[0].success);
371        assert!(records[0].elapsed >= Duration::from_millis(1));
372    }
373
374    #[tokio::test]
375    async fn latency_middleware_records_failure() {
376        let lat = LatencyMiddleware::new();
377        let call = test_call("failing_tool");
378        let err = ToolError::ExecutionFailed("kaboom".to_string());
379
380        lat.before_tool(&call).await.unwrap();
381        lat.on_tool_error(&call, &err).await.unwrap();
382
383        let records = lat.tool_latencies();
384        assert_eq!(records.len(), 1);
385        assert_eq!(records[0].name, "failing_tool");
386        assert!(!records[0].success);
387    }
388
389    #[tokio::test]
390    async fn latency_middleware_clear() {
391        let lat = LatencyMiddleware::new();
392        let call = test_call("tool_a");
393        let result = serde_json::json!(null);
394
395        lat.before_tool(&call).await.unwrap();
396        lat.after_tool(&call, &result).await.unwrap();
397        assert_eq!(lat.tool_latencies().len(), 1);
398
399        lat.clear();
400        assert!(lat.tool_latencies().is_empty());
401    }
402
403    // ── RetryMiddleware tests ──
404
405    #[tokio::test]
406    async fn retry_middleware_tracks_retries() {
407        let retry = RetryMiddleware::new(3);
408        assert_eq!(retry.max_retries(), 3);
409        assert_eq!(retry.attempts(), 0);
410        assert!(!retry.should_retry(), "no error yet, should not retry");
411
412        // Simulate an error.
413        let err = AgentError::Other("transient".to_string());
414        retry.on_error(&err).await.unwrap();
415        assert!(retry.should_retry(), "error recorded, should retry");
416
417        // Record first attempt.
418        retry.record_attempt();
419        assert_eq!(retry.attempts(), 1);
420        assert!(!retry.should_retry(), "error was cleared by record_attempt");
421
422        // Another error + attempt cycle.
423        retry.on_error(&err).await.unwrap();
424        assert!(retry.should_retry());
425        retry.record_attempt();
426        assert_eq!(retry.attempts(), 2);
427
428        // Third error + attempt.
429        retry.on_error(&err).await.unwrap();
430        assert!(retry.should_retry());
431        retry.record_attempt();
432        assert_eq!(retry.attempts(), 3);
433
434        // Now at max — should not retry even with new error.
435        retry.on_error(&err).await.unwrap();
436        assert!(!retry.should_retry(), "at max retries, should not retry");
437    }
438
439    #[test]
440    fn retry_middleware_reset() {
441        let retry = RetryMiddleware::new(2);
442        retry
443            .error_count
444            .store(1, std::sync::atomic::Ordering::SeqCst);
445        retry.attempt.store(1, std::sync::atomic::Ordering::SeqCst);
446        retry.reset();
447        assert_eq!(retry.attempts(), 0);
448        assert!(!retry.should_retry());
449    }
450
451    // ── Chain integration test ──
452
453    #[test]
454    fn chain_with_all_builtin_middleware() {
455        let mut chain = MiddlewareChain::new();
456        chain.add(Arc::new(LogMiddleware::new()));
457        chain.add(Arc::new(LatencyMiddleware::new()));
458        chain.add(Arc::new(RetryMiddleware::new(3)));
459        assert_eq!(chain.len(), 3);
460    }
461}