gemini_adk_fluent_rs/compose/
middleware.rs

1//! M — Middleware composition.
2//!
3//! Compose middleware in any order with `|`.
4//!
5//! ## Wiring status
6//!
7//! **TextAgent pipelines** (via `AgentBuilder::middleware` + `AgentBuilder::build`) —
8//! **fully wired**.  Every factory in this module produces a `MiddlewareComposite`
9//! whose layers are installed into the `LlmTextAgent` middleware chain at compile
10//! time.  Hooks fire in this order per `run()` call:
11//!
12//! 1. `before_model` (forward order) — may short-circuit with a cached response.
13//! 2. LLM call (skipped if `before_model` returned `Some`).
14//! 3. `after_model` (reverse order) — may replace the LLM response.
15//! 4. `before_tool` (forward) / `after_tool` (reverse) / `on_tool_error` (forward)
16//!    — called for each tool dispatch round.
17//! 5. `on_error` (forward) — called once if `run()` returns an error.
18//!
19//! **Live sessions** (via `Live::middleware`) — the **tool-lifecycle hooks**
20//! are wired: `before_tool` (a returned error vetoes the call), `after_tool`,
21//! and `on_tool_error` fire around every tool dispatch in the control lane,
22//! including background tools. Model-level hooks (`before_model`/`after_model`)
23//! do **not** apply to Live — a Live session streams over the wire and has no
24//! discrete `LlmRequest`/`LlmResponse` to intercept; use them on TextAgent
25//! pipelines instead.
26
27use std::sync::Arc;
28use std::time::Duration;
29
30use async_trait::async_trait;
31use gemini_genai_rs::prelude::FunctionCall;
32
33use gemini_adk_rs::context::AgentEvent;
34use gemini_adk_rs::error::{AgentError, ToolError};
35use gemini_adk_rs::middleware::{LatencyMiddleware, LogMiddleware, Middleware};
36
37/// A middleware composite — one or more middleware layers.
38#[derive(Clone)]
39pub struct MiddlewareComposite {
40    /// The ordered list of middleware layers.
41    pub layers: Vec<Arc<dyn Middleware>>,
42}
43
44impl MiddlewareComposite {
45    /// Create a composite containing a single middleware layer.
46    pub fn new(layer: Arc<dyn Middleware>) -> Self {
47        Self {
48            layers: vec![layer],
49        }
50    }
51
52    /// Number of layers.
53    pub fn len(&self) -> usize {
54        self.layers.len()
55    }
56
57    /// Whether empty.
58    pub fn is_empty(&self) -> bool {
59        self.layers.is_empty()
60    }
61}
62
63/// Compose two middleware composites with `|`.
64impl std::ops::BitOr for MiddlewareComposite {
65    type Output = MiddlewareComposite;
66
67    fn bitor(mut self, rhs: MiddlewareComposite) -> Self::Output {
68        self.layers.extend(rhs.layers);
69        self
70    }
71}
72
73/// The `M` namespace — static factory methods for middleware.
74pub struct M;
75
76impl M {
77    /// Add logging middleware.
78    pub fn log() -> MiddlewareComposite {
79        MiddlewareComposite::new(Arc::new(LogMiddleware::new()))
80    }
81
82    /// Add latency tracking middleware.
83    pub fn latency() -> MiddlewareComposite {
84        MiddlewareComposite::new(Arc::new(LatencyMiddleware::new()))
85    }
86
87    /// Bound the agent run to `duration`. The text agent enforces the tightest
88    /// timeout across its middleware chain by wrapping the whole run; on elapse
89    /// it emits `AgentEvent::Timeout` and returns an error.
90    pub fn timeout(duration: Duration) -> MiddlewareComposite {
91        MiddlewareComposite::new(Arc::new(TimeoutMiddleware {
92            name: "timeout".to_string(),
93            duration,
94        }))
95    }
96
97    /// Add retry middleware — tracks errors and advises on retry.
98    pub fn retry(max_retries: u32) -> MiddlewareComposite {
99        MiddlewareComposite::new(Arc::new(gemini_adk_rs::middleware::RetryMiddleware::new(
100            max_retries,
101        )))
102    }
103
104    /// Add a custom event observer — called on every agent event.
105    pub fn tap(f: impl Fn(&AgentEvent) + Send + Sync + 'static) -> MiddlewareComposite {
106        MiddlewareComposite::new(Arc::new(TapMiddleware {
107            handler: Arc::new(f),
108        }))
109    }
110
111    /// Add a custom before-tool filter — called before every tool invocation.
112    pub fn before_tool(
113        f: impl Fn(&FunctionCall) -> Result<(), String> + Send + Sync + 'static,
114    ) -> MiddlewareComposite {
115        MiddlewareComposite::new(Arc::new(BeforeToolMiddleware {
116            handler: Arc::new(f),
117        }))
118    }
119
120    /// Add a custom after-tool hook — called after every successful tool invocation.
121    pub fn after_tool(
122        f: impl Fn(&FunctionCall, &serde_json::Value) -> Result<(), String> + Send + Sync + 'static,
123    ) -> MiddlewareComposite {
124        MiddlewareComposite::new(Arc::new(AfterToolMiddleware {
125            handler: Arc::new(f),
126        }))
127    }
128
129    /// Add a custom error observer — called when an agent-level error occurs.
130    pub fn on_error(
131        f: impl Fn(&AgentError) -> Result<(), String> + Send + Sync + 'static,
132    ) -> MiddlewareComposite {
133        MiddlewareComposite::new(Arc::new(OnErrorMiddleware {
134            handler: Arc::new(f),
135        }))
136    }
137
138    /// Add cost tracking middleware — records token usage estimates.
139    pub fn cost() -> MiddlewareComposite {
140        MiddlewareComposite::new(Arc::new(CostMiddleware {
141            tool_calls: std::sync::atomic::AtomicU64::new(0),
142        }))
143    }
144
145    /// Add rate-limiting middleware — spaces tool calls to at most `rps` per
146    /// second by delaying `before_tool` (concurrent calls queue rather than
147    /// burst).
148    pub fn rate_limit(rps: u32) -> MiddlewareComposite {
149        MiddlewareComposite::new(Arc::new(RateLimitMiddleware::new(rps)))
150    }
151
152    /// Add circuit breaker middleware — opens after consecutive failures.
153    pub fn circuit_breaker(threshold: u32) -> MiddlewareComposite {
154        MiddlewareComposite::new(Arc::new(CircuitBreakerMiddleware {
155            threshold,
156            consecutive_failures: std::sync::atomic::AtomicU32::new(0),
157        }))
158    }
159
160    /// Add tracing span middleware — creates spans for distributed tracing.
161    pub fn trace() -> MiddlewareComposite {
162        MiddlewareComposite::new(Arc::new(TraceMiddleware))
163    }
164
165    /// Add audit middleware — records all tool calls for review.
166    pub fn audit() -> MiddlewareComposite {
167        MiddlewareComposite::new(Arc::new(AuditMiddleware {
168            log: parking_lot::Mutex::new(Vec::new()),
169        }))
170    }
171
172    /// Scope middleware to specific agent names.
173    ///
174    /// Not yet enforced — agent-name routing requires dispatch-time filtering
175    /// the middleware chain doesn't expose, so this currently returns `inner`
176    /// unchanged. Hidden until real scoping lands to avoid implying behavior.
177    #[doc(hidden)]
178    pub fn scope(_names: &[&str], inner: MiddlewareComposite) -> MiddlewareComposite {
179        inner
180    }
181
182    /// Structured logging middleware — logs agent events as structured JSON.
183    pub fn structured_log() -> MiddlewareComposite {
184        MiddlewareComposite::new(Arc::new(StructuredLogMiddleware))
185    }
186
187    /// Dispatch logging middleware — logs dispatch/join events.
188    pub fn dispatch_log() -> MiddlewareComposite {
189        MiddlewareComposite::new(Arc::new(DispatchLogMiddleware))
190    }
191
192    /// Topology logging middleware — logs agent topology events.
193    pub fn topology_log() -> MiddlewareComposite {
194        MiddlewareComposite::new(Arc::new(TopologyLogMiddleware))
195    }
196
197    /// Add a tool input validator middleware.
198    pub fn validate(
199        f: impl Fn(&FunctionCall) -> Result<(), String> + Send + Sync + 'static,
200    ) -> MiddlewareComposite {
201        MiddlewareComposite::new(Arc::new(ValidateMiddleware {
202            validator: Arc::new(f),
203        }))
204    }
205
206    /// Fallback to an alternative model on error.
207    ///
208    /// Not yet enforced — swapping the model and retrying requires re-issuing
209    /// the LLM call, which the current `after_model`/`on_error` hooks can't do.
210    /// Hidden until real fallback lands to avoid implying behavior.
211    #[doc(hidden)]
212    pub fn fallback_model(model: &str) -> MiddlewareComposite {
213        MiddlewareComposite::new(Arc::new(FallbackModelMiddleware {
214            model: model.to_string(),
215        }))
216    }
217
218    /// Response caching middleware — caches model responses to avoid redundant calls.
219    pub fn cache() -> MiddlewareComposite {
220        MiddlewareComposite::new(Arc::new(CacheMiddleware {
221            cache: parking_lot::Mutex::new(std::collections::HashMap::new()),
222        }))
223    }
224
225    /// Deduplicate consecutive identical requests.
226    pub fn dedup() -> MiddlewareComposite {
227        MiddlewareComposite::new(Arc::new(DedupMiddleware {
228            last_request_hash: parking_lot::Mutex::new(None),
229        }))
230    }
231
232    /// Sample/pass-through a fraction of requests (0.0–1.0).
233    pub fn sample(rate: f64) -> MiddlewareComposite {
234        MiddlewareComposite::new(Arc::new(SampleMiddleware {
235            rate: rate.clamp(0.0, 1.0),
236        }))
237    }
238
239    /// Metrics collection middleware — tracks request counts, error counts, and latencies.
240    pub fn metrics() -> MiddlewareComposite {
241        MiddlewareComposite::new(Arc::new(MetricsMiddleware {
242            request_count: std::sync::atomic::AtomicU64::new(0),
243            error_count: std::sync::atomic::AtomicU64::new(0),
244        }))
245    }
246
247    /// Shortcut for a before-agent hook.
248    pub fn before_agent(
249        f: impl Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String>
250            + Send
251            + Sync
252            + 'static,
253    ) -> MiddlewareComposite {
254        MiddlewareComposite::new(Arc::new(BeforeAgentMiddleware {
255            handler: Arc::new(f),
256        }))
257    }
258
259    /// Shortcut for an after-agent hook.
260    pub fn after_agent(
261        f: impl Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String>
262            + Send
263            + Sync
264            + 'static,
265    ) -> MiddlewareComposite {
266        MiddlewareComposite::new(Arc::new(AfterAgentMiddleware {
267            handler: Arc::new(f),
268        }))
269    }
270
271    /// Shortcut for a before-model hook.
272    pub fn before_model(
273        f: impl Fn(&gemini_adk_rs::llm::LlmRequest) -> Result<(), String> + Send + Sync + 'static,
274    ) -> MiddlewareComposite {
275        MiddlewareComposite::new(Arc::new(BeforeModelMiddleware {
276            handler: Arc::new(f),
277        }))
278    }
279
280    /// Shortcut for an after-model hook.
281    pub fn after_model(
282        f: impl Fn(
283                &gemini_adk_rs::llm::LlmRequest,
284                &gemini_adk_rs::llm::LlmResponse,
285            ) -> Result<(), String>
286            + Send
287            + Sync
288            + 'static,
289    ) -> MiddlewareComposite {
290        MiddlewareComposite::new(Arc::new(AfterModelMiddleware {
291            handler: Arc::new(f),
292        }))
293    }
294
295    /// Loop iteration event hook — called on each iteration of a loop agent.
296    pub fn on_loop(f: impl Fn(u32) + Send + Sync + 'static) -> MiddlewareComposite {
297        MiddlewareComposite::new(Arc::new(OnLoopMiddleware {
298            handler: Arc::new(f),
299        }))
300    }
301
302    /// Timeout event hook — called when an agent times out.
303    pub fn on_timeout(f: impl Fn() + Send + Sync + 'static) -> MiddlewareComposite {
304        MiddlewareComposite::new(Arc::new(OnTimeoutMiddleware {
305            handler: Arc::new(f),
306        }))
307    }
308
309    /// Route decision event hook — called when a route agent selects a branch.
310    pub fn on_route(f: impl Fn(&str) + Send + Sync + 'static) -> MiddlewareComposite {
311        MiddlewareComposite::new(Arc::new(OnRouteMiddleware {
312            handler: Arc::new(f),
313        }))
314    }
315
316    /// Fallback event hook — called when a fallback agent activates.
317    pub fn on_fallback(f: impl Fn(&str) + Send + Sync + 'static) -> MiddlewareComposite {
318        MiddlewareComposite::new(Arc::new(OnFallbackMiddleware {
319            handler: Arc::new(f),
320        }))
321    }
322}
323
324/// Timeout middleware — stores the configured duration for runtime enforcement.
325#[allow(dead_code)]
326struct TimeoutMiddleware {
327    name: String,
328    duration: Duration,
329}
330
331#[async_trait::async_trait]
332impl Middleware for TimeoutMiddleware {
333    fn name(&self) -> &str {
334        &self.name
335    }
336
337    fn timeout(&self) -> Option<Duration> {
338        Some(self.duration)
339    }
340}
341
342// ── Tap Middleware ──────────────────────────────────────────────────────────
343
344struct TapMiddleware {
345    #[allow(clippy::type_complexity)]
346    handler: Arc<dyn Fn(&AgentEvent) + Send + Sync>,
347}
348
349#[async_trait]
350impl Middleware for TapMiddleware {
351    fn name(&self) -> &str {
352        "tap"
353    }
354
355    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
356        (self.handler)(event);
357        Ok(())
358    }
359}
360
361// ── BeforeTool Middleware ───────────────────────────────────────────────────
362
363struct BeforeToolMiddleware {
364    #[allow(clippy::type_complexity)]
365    handler: Arc<dyn Fn(&FunctionCall) -> Result<(), String> + Send + Sync>,
366}
367
368#[async_trait]
369impl Middleware for BeforeToolMiddleware {
370    fn name(&self) -> &str {
371        "before_tool"
372    }
373
374    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
375        (self.handler)(call).map_err(AgentError::Other)
376    }
377}
378
379// ── AfterTool Middleware ────────────────────────────────────────────────────
380
381struct AfterToolMiddleware {
382    #[allow(clippy::type_complexity)]
383    handler: Arc<dyn Fn(&FunctionCall, &serde_json::Value) -> Result<(), String> + Send + Sync>,
384}
385
386#[async_trait]
387impl Middleware for AfterToolMiddleware {
388    fn name(&self) -> &str {
389        "after_tool"
390    }
391
392    async fn after_tool(
393        &self,
394        call: &FunctionCall,
395        result: &serde_json::Value,
396    ) -> Result<(), AgentError> {
397        (self.handler)(call, result).map_err(AgentError::Other)
398    }
399}
400
401// ── OnError Middleware ──────────────────────────────────────────────────────
402
403struct OnErrorMiddleware {
404    #[allow(clippy::type_complexity)]
405    handler: Arc<dyn Fn(&AgentError) -> Result<(), String> + Send + Sync>,
406}
407
408#[async_trait]
409impl Middleware for OnErrorMiddleware {
410    fn name(&self) -> &str {
411        "on_error"
412    }
413
414    async fn on_error(&self, err: &AgentError) -> Result<(), AgentError> {
415        (self.handler)(err).map_err(AgentError::Other)
416    }
417}
418
419// ── Cost Middleware ────────────────────────────────────────────────────────
420
421/// Tracks the number of tool calls as a proxy for cost.
422pub struct CostMiddleware {
423    tool_calls: std::sync::atomic::AtomicU64,
424}
425
426impl CostMiddleware {
427    /// Returns the total number of tool calls recorded.
428    pub fn tool_call_count(&self) -> u64 {
429        self.tool_calls.load(std::sync::atomic::Ordering::SeqCst)
430    }
431}
432
433#[async_trait]
434impl Middleware for CostMiddleware {
435    fn name(&self) -> &str {
436        "cost"
437    }
438
439    async fn after_tool(
440        &self,
441        _call: &FunctionCall,
442        _result: &serde_json::Value,
443    ) -> Result<(), AgentError> {
444        self.tool_calls
445            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
446        Ok(())
447    }
448}
449
450// ── RateLimit Middleware ───────────────────────────────────────────────────
451
452#[allow(dead_code)]
453struct RateLimitMiddleware {
454    /// Minimum spacing between successive tool calls.
455    min_interval: Duration,
456    /// Reserved start time of the most recent call (advances as calls queue).
457    last: parking_lot::Mutex<Option<std::time::Instant>>,
458}
459
460impl RateLimitMiddleware {
461    fn new(rps: u32) -> Self {
462        let rps = rps.max(1);
463        Self {
464            min_interval: Duration::from_secs_f64(1.0 / rps as f64),
465            last: parking_lot::Mutex::new(None),
466        }
467    }
468}
469
470#[async_trait]
471impl Middleware for RateLimitMiddleware {
472    fn name(&self) -> &str {
473        "rate_limit"
474    }
475
476    async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
477        // Compute the wait needed to honor min_interval, reserving this call's
478        // slot so concurrent callers queue instead of bursting. The lock is
479        // released before the await (never held across it).
480        let wait = {
481            let mut last = self.last.lock();
482            let now = std::time::Instant::now();
483            let scheduled = match *last {
484                Some(prev) if prev + self.min_interval > now => prev + self.min_interval,
485                _ => now,
486            };
487            *last = Some(scheduled);
488            scheduled.saturating_duration_since(now)
489        };
490        if !wait.is_zero() {
491            tokio::time::sleep(wait).await;
492        }
493        Ok(())
494    }
495}
496
497// ── CircuitBreaker Middleware ──────────────────────────────────────────────
498
499struct CircuitBreakerMiddleware {
500    threshold: u32,
501    consecutive_failures: std::sync::atomic::AtomicU32,
502}
503
504#[async_trait]
505impl Middleware for CircuitBreakerMiddleware {
506    fn name(&self) -> &str {
507        "circuit_breaker"
508    }
509
510    async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
511        let failures = self
512            .consecutive_failures
513            .load(std::sync::atomic::Ordering::SeqCst);
514        if failures >= self.threshold {
515            return Err(AgentError::Other(format!(
516                "Circuit breaker open: {} consecutive failures (threshold: {})",
517                failures, self.threshold
518            )));
519        }
520        Ok(())
521    }
522
523    async fn after_tool(
524        &self,
525        _call: &FunctionCall,
526        _result: &serde_json::Value,
527    ) -> Result<(), AgentError> {
528        self.consecutive_failures
529            .store(0, std::sync::atomic::Ordering::SeqCst);
530        Ok(())
531    }
532
533    async fn on_tool_error(
534        &self,
535        _call: &FunctionCall,
536        _err: &ToolError,
537    ) -> Result<(), AgentError> {
538        self.consecutive_failures
539            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
540        Ok(())
541    }
542}
543
544// ── Trace Middleware ──────────────────────────────────────────────────────
545
546/// Middleware that creates tracing spans for agent and tool lifecycle events.
547/// When `tracing-support` is enabled, these spans are picked up by
548/// `tracing-opentelemetry` and exported as OTel spans.
549struct TraceMiddleware;
550
551#[async_trait]
552impl Middleware for TraceMiddleware {
553    fn name(&self) -> &str {
554        "trace"
555    }
556
557    async fn before_agent(
558        &self,
559        ctx: &gemini_adk_rs::context::InvocationContext,
560    ) -> Result<(), AgentError> {
561        let sid = ctx.session_id.as_deref().unwrap_or("unknown");
562        gemini_adk_rs::telemetry::logging::log_agent_started(sid, 0);
563        Ok(())
564    }
565
566    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
567        gemini_adk_rs::telemetry::logging::log_tool_dispatch("fluent", &call.name, "function");
568        Ok(())
569    }
570
571    async fn after_tool(
572        &self,
573        call: &FunctionCall,
574        _result: &serde_json::Value,
575    ) -> Result<(), AgentError> {
576        gemini_adk_rs::telemetry::logging::log_tool_result("fluent", &call.name, true, 0.0);
577        Ok(())
578    }
579
580    async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
581        gemini_adk_rs::telemetry::logging::log_tool_result("fluent", &call.name, false, 0.0);
582        Ok(())
583    }
584
585    async fn on_error(&self, err: &AgentError) -> Result<(), AgentError> {
586        gemini_adk_rs::telemetry::logging::log_agent_error("fluent", &err.to_string());
587        Ok(())
588    }
589}
590
591// ── Audit Middleware ─────────────────────────────────────────────────────
592
593/// Records all tool calls for audit review.
594pub struct AuditMiddleware {
595    log: parking_lot::Mutex<Vec<AuditEntry>>,
596}
597
598/// An audit log entry.
599#[derive(Debug, Clone)]
600pub struct AuditEntry {
601    /// Tool name.
602    pub tool_name: String,
603    /// Tool arguments.
604    pub args: serde_json::Value,
605    /// Whether the call succeeded.
606    pub success: Option<bool>,
607}
608
609impl AuditMiddleware {
610    /// Returns a snapshot of the audit log.
611    pub fn entries(&self) -> Vec<AuditEntry> {
612        self.log.lock().clone()
613    }
614}
615
616#[async_trait]
617impl Middleware for AuditMiddleware {
618    fn name(&self) -> &str {
619        "audit"
620    }
621
622    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
623        let mut log = self.log.lock();
624        if log.len() >= 10_000 {
625            log.drain(..1_000);
626        }
627        log.push(AuditEntry {
628            tool_name: call.name.clone(),
629            args: call.args.clone(),
630            success: None,
631        });
632        Ok(())
633    }
634
635    async fn after_tool(
636        &self,
637        call: &FunctionCall,
638        _result: &serde_json::Value,
639    ) -> Result<(), AgentError> {
640        let mut log = self.log.lock();
641        if let Some(entry) = log.iter_mut().rev().find(|e| e.tool_name == call.name) {
642            entry.success = Some(true);
643        }
644        Ok(())
645    }
646
647    async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
648        let mut log = self.log.lock();
649        if let Some(entry) = log.iter_mut().rev().find(|e| e.tool_name == call.name) {
650            entry.success = Some(false);
651        }
652        Ok(())
653    }
654}
655
656// ── Validate Middleware ──────────────────────────────────────────────────
657
658struct ValidateMiddleware {
659    #[allow(clippy::type_complexity)]
660    validator: Arc<dyn Fn(&FunctionCall) -> Result<(), String> + Send + Sync>,
661}
662
663#[async_trait]
664impl Middleware for ValidateMiddleware {
665    fn name(&self) -> &str {
666        "validate"
667    }
668
669    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
670        (self.validator)(call).map_err(|e| AgentError::Tool(ToolError::InvalidArgs(e)))
671    }
672}
673
674// ── FallbackModel Middleware ──────────────────────────────────────────
675
676/// Middleware that falls back to an alternative model on error.
677#[allow(dead_code)]
678struct FallbackModelMiddleware {
679    model: String,
680}
681
682#[async_trait]
683impl Middleware for FallbackModelMiddleware {
684    fn name(&self) -> &str {
685        "fallback_model"
686    }
687
688    async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
689        // Runtime inspects the `model` field and retries with the fallback model.
690        Ok(())
691    }
692}
693
694// ── Cache Middleware ──────────────────────────────────────────────────
695
696/// Caches model responses keyed by request hash to avoid redundant LLM calls.
697pub struct CacheMiddleware {
698    cache: parking_lot::Mutex<std::collections::HashMap<u64, gemini_adk_rs::llm::LlmResponse>>,
699}
700
701impl CacheMiddleware {
702    /// Returns the number of cached entries.
703    pub fn len(&self) -> usize {
704        self.cache.lock().len()
705    }
706
707    /// Whether the cache is empty.
708    pub fn is_empty(&self) -> bool {
709        self.cache.lock().is_empty()
710    }
711
712    /// Clear all cached entries.
713    pub fn clear(&self) {
714        self.cache.lock().clear();
715    }
716}
717
718#[async_trait]
719impl Middleware for CacheMiddleware {
720    fn name(&self) -> &str {
721        "cache"
722    }
723
724    async fn before_model(
725        &self,
726        request: &gemini_adk_rs::llm::LlmRequest,
727    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
728        use std::hash::{Hash, Hasher};
729        let mut hasher = std::collections::hash_map::DefaultHasher::new();
730        format!("{:?}", request).hash(&mut hasher);
731        let key = hasher.finish();
732        let cache = self.cache.lock();
733        Ok(cache.get(&key).cloned())
734    }
735
736    async fn after_model(
737        &self,
738        request: &gemini_adk_rs::llm::LlmRequest,
739        response: &gemini_adk_rs::llm::LlmResponse,
740    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
741        use std::hash::{Hash, Hasher};
742        let mut hasher = std::collections::hash_map::DefaultHasher::new();
743        format!("{:?}", request).hash(&mut hasher);
744        let key = hasher.finish();
745        self.cache.lock().insert(key, response.clone());
746        Ok(None) // don't replace the response
747    }
748}
749
750// ── Dedup Middleware ─────────────────────────────────────────────────
751
752/// Deduplicates consecutive identical requests by hashing.
753#[allow(dead_code)]
754struct DedupMiddleware {
755    last_request_hash: parking_lot::Mutex<Option<u64>>,
756}
757
758#[async_trait]
759impl Middleware for DedupMiddleware {
760    fn name(&self) -> &str {
761        "dedup"
762    }
763
764    async fn before_model(
765        &self,
766        request: &gemini_adk_rs::llm::LlmRequest,
767    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
768        use std::hash::{Hash, Hasher};
769        let mut hasher = std::collections::hash_map::DefaultHasher::new();
770        format!("{:?}", request).hash(&mut hasher);
771        let hash = hasher.finish();
772        let mut last = self.last_request_hash.lock();
773        if *last == Some(hash) {
774            // Duplicate consecutive request — signal skip by returning empty response.
775            return Err(AgentError::Other(
776                "Duplicate consecutive request".to_string(),
777            ));
778        }
779        *last = Some(hash);
780        Ok(None)
781    }
782}
783
784// ── Sample Middleware ────────────────────────────────────────────────
785
786/// Passes through only a fraction of requests, dropping the rest.
787#[allow(dead_code)]
788struct SampleMiddleware {
789    rate: f64,
790}
791
792#[async_trait]
793impl Middleware for SampleMiddleware {
794    fn name(&self) -> &str {
795        "sample"
796    }
797
798    async fn before_model(
799        &self,
800        _request: &gemini_adk_rs::llm::LlmRequest,
801    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
802        use std::hash::{Hash, Hasher};
803        // Use a fast pseudo-random check based on time.
804        let mut hasher = std::collections::hash_map::DefaultHasher::new();
805        std::time::Instant::now().hash(&mut hasher);
806        let hash = hasher.finish();
807        let normalized = (hash as f64) / (u64::MAX as f64);
808        if normalized > self.rate {
809            return Err(AgentError::Other("Sampled out".to_string()));
810        }
811        Ok(None)
812    }
813}
814
815// ── Metrics Middleware ──────────────────────────────────────────────
816
817/// Collects request and error counts.
818pub struct MetricsMiddleware {
819    request_count: std::sync::atomic::AtomicU64,
820    error_count: std::sync::atomic::AtomicU64,
821}
822
823impl MetricsMiddleware {
824    /// Returns the total number of requests observed.
825    pub fn request_count(&self) -> u64 {
826        self.request_count.load(std::sync::atomic::Ordering::SeqCst)
827    }
828
829    /// Returns the total number of errors observed.
830    pub fn error_count(&self) -> u64 {
831        self.error_count.load(std::sync::atomic::Ordering::SeqCst)
832    }
833}
834
835#[async_trait]
836impl Middleware for MetricsMiddleware {
837    fn name(&self) -> &str {
838        "metrics"
839    }
840
841    async fn before_agent(
842        &self,
843        _ctx: &gemini_adk_rs::context::InvocationContext,
844    ) -> Result<(), AgentError> {
845        self.request_count
846            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
847        Ok(())
848    }
849
850    async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
851        self.error_count
852            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
853        Ok(())
854    }
855}
856
857// ── BeforeAgent Middleware ───────────────────────────────────────────
858
859struct BeforeAgentMiddleware {
860    #[allow(clippy::type_complexity)]
861    handler:
862        Arc<dyn Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String> + Send + Sync>,
863}
864
865#[async_trait]
866impl Middleware for BeforeAgentMiddleware {
867    fn name(&self) -> &str {
868        "before_agent"
869    }
870
871    async fn before_agent(
872        &self,
873        ctx: &gemini_adk_rs::context::InvocationContext,
874    ) -> Result<(), AgentError> {
875        (self.handler)(ctx).map_err(AgentError::Other)
876    }
877}
878
879// ── AfterAgent Middleware ───────────────────────────────────────────
880
881struct AfterAgentMiddleware {
882    #[allow(clippy::type_complexity)]
883    handler:
884        Arc<dyn Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String> + Send + Sync>,
885}
886
887#[async_trait]
888impl Middleware for AfterAgentMiddleware {
889    fn name(&self) -> &str {
890        "after_agent"
891    }
892
893    async fn after_agent(
894        &self,
895        ctx: &gemini_adk_rs::context::InvocationContext,
896    ) -> Result<(), AgentError> {
897        (self.handler)(ctx).map_err(AgentError::Other)
898    }
899}
900
901// ── BeforeModel Middleware ──────────────────────────────────────────
902
903struct BeforeModelMiddleware {
904    #[allow(clippy::type_complexity)]
905    handler: Arc<dyn Fn(&gemini_adk_rs::llm::LlmRequest) -> Result<(), String> + Send + Sync>,
906}
907
908#[async_trait]
909impl Middleware for BeforeModelMiddleware {
910    fn name(&self) -> &str {
911        "before_model"
912    }
913
914    async fn before_model(
915        &self,
916        request: &gemini_adk_rs::llm::LlmRequest,
917    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
918        (self.handler)(request).map_err(AgentError::Other)?;
919        Ok(None)
920    }
921}
922
923// ── AfterModel Middleware ──────────────────────────────────────────
924
925struct AfterModelMiddleware {
926    #[allow(clippy::type_complexity)]
927    handler: Arc<
928        dyn Fn(
929                &gemini_adk_rs::llm::LlmRequest,
930                &gemini_adk_rs::llm::LlmResponse,
931            ) -> Result<(), String>
932            + Send
933            + Sync,
934    >,
935}
936
937#[async_trait]
938impl Middleware for AfterModelMiddleware {
939    fn name(&self) -> &str {
940        "after_model"
941    }
942
943    async fn after_model(
944        &self,
945        request: &gemini_adk_rs::llm::LlmRequest,
946        response: &gemini_adk_rs::llm::LlmResponse,
947    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
948        (self.handler)(request, response).map_err(AgentError::Other)?;
949        Ok(None)
950    }
951}
952
953// ── OnLoop Middleware ───────────────────────────────────────────────
954
955struct OnLoopMiddleware {
956    handler: Arc<dyn Fn(u32) + Send + Sync>,
957}
958
959#[async_trait]
960impl Middleware for OnLoopMiddleware {
961    fn name(&self) -> &str {
962        "on_loop"
963    }
964
965    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
966        if let AgentEvent::LoopIteration { iteration } = event {
967            (self.handler)(*iteration);
968        }
969        Ok(())
970    }
971}
972
973// ── OnTimeout Middleware ────────────────────────────────────────────
974
975struct OnTimeoutMiddleware {
976    handler: Arc<dyn Fn() + Send + Sync>,
977}
978
979#[async_trait]
980impl Middleware for OnTimeoutMiddleware {
981    fn name(&self) -> &str {
982        "on_timeout"
983    }
984
985    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
986        if let AgentEvent::Timeout = event {
987            (self.handler)();
988        }
989        Ok(())
990    }
991}
992
993// ── OnRoute Middleware ──────────────────────────────────────────────
994
995struct OnRouteMiddleware {
996    handler: Arc<dyn Fn(&str) + Send + Sync>,
997}
998
999#[async_trait]
1000impl Middleware for OnRouteMiddleware {
1001    fn name(&self) -> &str {
1002        "on_route"
1003    }
1004
1005    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
1006        if let AgentEvent::RouteSelected { agent_name } = event {
1007            (self.handler)(agent_name);
1008        }
1009        Ok(())
1010    }
1011}
1012
1013// ── OnFallback Middleware ───────────────────────────────────────────
1014
1015struct OnFallbackMiddleware {
1016    handler: Arc<dyn Fn(&str) + Send + Sync>,
1017}
1018
1019#[async_trait]
1020impl Middleware for OnFallbackMiddleware {
1021    fn name(&self) -> &str {
1022        "on_fallback"
1023    }
1024
1025    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
1026        if let AgentEvent::FallbackActivated { agent_name } = event {
1027            (self.handler)(agent_name);
1028        }
1029        Ok(())
1030    }
1031}
1032
1033// ── Structured Log Middleware ────────────────────────────────────────
1034
1035struct StructuredLogMiddleware;
1036
1037#[async_trait]
1038impl Middleware for StructuredLogMiddleware {
1039    fn name(&self) -> &str {
1040        "structured_log"
1041    }
1042
1043    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
1044        // Log events as structured format (uses tracing in production).
1045        let _ = event;
1046        Ok(())
1047    }
1048}
1049
1050// ── Dispatch Log Middleware ──────────────────────────────────────────
1051
1052struct DispatchLogMiddleware;
1053
1054#[async_trait]
1055impl Middleware for DispatchLogMiddleware {
1056    fn name(&self) -> &str {
1057        "dispatch_log"
1058    }
1059}
1060
1061// ── Topology Log Middleware ──────────────────────────────────────────
1062
1063struct TopologyLogMiddleware;
1064
1065#[async_trait]
1066impl Middleware for TopologyLogMiddleware {
1067    fn name(&self) -> &str {
1068        "topology_log"
1069    }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074    use super::*;
1075
1076    #[test]
1077    fn log_creates_composite() {
1078        let m = M::log();
1079        assert_eq!(m.len(), 1);
1080    }
1081
1082    #[test]
1083    fn latency_creates_composite() {
1084        let m = M::latency();
1085        assert_eq!(m.len(), 1);
1086    }
1087
1088    #[test]
1089    fn timeout_creates_composite() {
1090        let m = M::timeout(Duration::from_secs(30));
1091        assert_eq!(m.len(), 1);
1092    }
1093
1094    #[test]
1095    fn compose_with_bitor() {
1096        let m = M::log() | M::latency() | M::timeout(Duration::from_secs(5));
1097        assert_eq!(m.len(), 3);
1098    }
1099
1100    #[test]
1101    fn retry_creates_composite() {
1102        let m = M::retry(3);
1103        assert_eq!(m.len(), 1);
1104    }
1105
1106    #[test]
1107    fn tap_creates_composite() {
1108        let m = M::tap(|_event| {});
1109        assert_eq!(m.len(), 1);
1110    }
1111
1112    #[test]
1113    fn before_tool_creates_composite() {
1114        let m = M::before_tool(|_call| Ok(()));
1115        assert_eq!(m.len(), 1);
1116    }
1117
1118    #[test]
1119    fn cost_creates_composite() {
1120        let m = M::cost();
1121        assert_eq!(m.len(), 1);
1122    }
1123
1124    #[test]
1125    fn rate_limit_creates_composite() {
1126        let m = M::rate_limit(10);
1127        assert_eq!(m.len(), 1);
1128    }
1129
1130    #[test]
1131    fn circuit_breaker_creates_composite() {
1132        let m = M::circuit_breaker(5);
1133        assert_eq!(m.len(), 1);
1134    }
1135
1136    #[test]
1137    fn trace_creates_composite() {
1138        let m = M::trace();
1139        assert_eq!(m.len(), 1);
1140    }
1141
1142    #[test]
1143    fn audit_creates_composite() {
1144        let m = M::audit();
1145        assert_eq!(m.len(), 1);
1146    }
1147
1148    #[test]
1149    fn validate_creates_composite() {
1150        let m = M::validate(|_call| Ok(()));
1151        assert_eq!(m.len(), 1);
1152    }
1153
1154    #[test]
1155    fn fallback_model_creates_composite() {
1156        let m = M::fallback_model("gemini-1.5-flash");
1157        assert_eq!(m.len(), 1);
1158    }
1159
1160    #[test]
1161    fn cache_creates_composite() {
1162        let m = M::cache();
1163        assert_eq!(m.len(), 1);
1164    }
1165
1166    #[test]
1167    fn dedup_creates_composite() {
1168        let m = M::dedup();
1169        assert_eq!(m.len(), 1);
1170    }
1171
1172    #[test]
1173    fn sample_creates_composite() {
1174        let m = M::sample(0.5);
1175        assert_eq!(m.len(), 1);
1176    }
1177
1178    #[test]
1179    fn sample_clamps_rate() {
1180        let m = M::sample(2.0);
1181        assert_eq!(m.len(), 1);
1182        let m = M::sample(-1.0);
1183        assert_eq!(m.len(), 1);
1184    }
1185
1186    #[test]
1187    fn metrics_creates_composite() {
1188        let m = M::metrics();
1189        assert_eq!(m.len(), 1);
1190    }
1191
1192    #[test]
1193    fn before_agent_creates_composite() {
1194        let m = M::before_agent(|_ctx| Ok(()));
1195        assert_eq!(m.len(), 1);
1196    }
1197
1198    #[test]
1199    fn after_agent_creates_composite() {
1200        let m = M::after_agent(|_ctx| Ok(()));
1201        assert_eq!(m.len(), 1);
1202    }
1203
1204    #[test]
1205    fn before_model_creates_composite() {
1206        let m = M::before_model(|_req| Ok(()));
1207        assert_eq!(m.len(), 1);
1208    }
1209
1210    #[test]
1211    fn after_model_creates_composite() {
1212        let m = M::after_model(|_req, _resp| Ok(()));
1213        assert_eq!(m.len(), 1);
1214    }
1215
1216    #[test]
1217    fn on_loop_creates_composite() {
1218        let m = M::on_loop(|_iteration| {});
1219        assert_eq!(m.len(), 1);
1220    }
1221
1222    #[test]
1223    fn on_timeout_creates_composite() {
1224        let m = M::on_timeout(|| {});
1225        assert_eq!(m.len(), 1);
1226    }
1227
1228    #[test]
1229    fn on_route_creates_composite() {
1230        let m = M::on_route(|_name| {});
1231        assert_eq!(m.len(), 1);
1232    }
1233
1234    #[test]
1235    fn on_fallback_creates_composite() {
1236        let m = M::on_fallback(|_name| {});
1237        assert_eq!(m.len(), 1);
1238    }
1239
1240    #[test]
1241    fn compose_all_middleware() {
1242        let m = M::log()
1243            | M::latency()
1244            | M::timeout(Duration::from_secs(30))
1245            | M::retry(3)
1246            | M::cost()
1247            | M::rate_limit(10)
1248            | M::circuit_breaker(5)
1249            | M::trace()
1250            | M::audit()
1251            | M::validate(|_| Ok(()))
1252            | M::fallback_model("gemini-1.5-flash")
1253            | M::cache()
1254            | M::dedup()
1255            | M::sample(0.5)
1256            | M::metrics()
1257            | M::before_agent(|_| Ok(()))
1258            | M::after_agent(|_| Ok(()))
1259            | M::before_model(|_| Ok(()))
1260            | M::after_model(|_, _| Ok(()))
1261            | M::on_loop(|_| {})
1262            | M::on_timeout(|| {})
1263            | M::on_route(|_| {})
1264            | M::on_fallback(|_| {});
1265        assert_eq!(m.len(), 23);
1266    }
1267}