gemini_adk_fluent_rs/compose/
middleware.rs

1//! M — Middleware composition.
2//!
3//! Compose middleware in any order with `|`.
4//!
5//! **Note:** Not yet wired into Live session dispatch. Available for
6//! `TextAgent` pipelines. Hidden from docs until Live integration lands.
7
8use std::sync::Arc;
9use std::time::Duration;
10
11use async_trait::async_trait;
12use gemini_genai_rs::prelude::FunctionCall;
13
14use gemini_adk_rs::context::AgentEvent;
15use gemini_adk_rs::error::{AgentError, ToolError};
16use gemini_adk_rs::middleware::{LatencyMiddleware, LogMiddleware, Middleware};
17
18/// A middleware composite — one or more middleware layers.
19#[derive(Clone)]
20pub struct MiddlewareComposite {
21    /// The ordered list of middleware layers.
22    pub layers: Vec<Arc<dyn Middleware>>,
23}
24
25impl MiddlewareComposite {
26    /// Create a composite containing a single middleware layer.
27    pub fn new(layer: Arc<dyn Middleware>) -> Self {
28        Self {
29            layers: vec![layer],
30        }
31    }
32
33    /// Number of layers.
34    pub fn len(&self) -> usize {
35        self.layers.len()
36    }
37
38    /// Whether empty.
39    pub fn is_empty(&self) -> bool {
40        self.layers.is_empty()
41    }
42}
43
44/// Compose two middleware composites with `|`.
45impl std::ops::BitOr for MiddlewareComposite {
46    type Output = MiddlewareComposite;
47
48    fn bitor(mut self, rhs: MiddlewareComposite) -> Self::Output {
49        self.layers.extend(rhs.layers);
50        self
51    }
52}
53
54/// The `M` namespace — static factory methods for middleware.
55pub struct M;
56
57impl M {
58    /// Add logging middleware.
59    pub fn log() -> MiddlewareComposite {
60        MiddlewareComposite::new(Arc::new(LogMiddleware::new()))
61    }
62
63    /// Add latency tracking middleware.
64    pub fn latency() -> MiddlewareComposite {
65        MiddlewareComposite::new(Arc::new(LatencyMiddleware::new()))
66    }
67
68    /// Add timeout middleware (placeholder — records the duration for use by the runtime).
69    pub fn timeout(duration: Duration) -> MiddlewareComposite {
70        MiddlewareComposite::new(Arc::new(TimeoutMiddleware {
71            name: "timeout".to_string(),
72            duration,
73        }))
74    }
75
76    /// Add retry middleware — tracks errors and advises on retry.
77    pub fn retry(max_retries: u32) -> MiddlewareComposite {
78        MiddlewareComposite::new(Arc::new(gemini_adk_rs::middleware::RetryMiddleware::new(
79            max_retries,
80        )))
81    }
82
83    /// Add a custom event observer — called on every agent event.
84    pub fn tap(f: impl Fn(&AgentEvent) + Send + Sync + 'static) -> MiddlewareComposite {
85        MiddlewareComposite::new(Arc::new(TapMiddleware {
86            handler: Arc::new(f),
87        }))
88    }
89
90    /// Add a custom before-tool filter — called before every tool invocation.
91    pub fn before_tool(
92        f: impl Fn(&FunctionCall) -> Result<(), String> + Send + Sync + 'static,
93    ) -> MiddlewareComposite {
94        MiddlewareComposite::new(Arc::new(BeforeToolMiddleware {
95            handler: Arc::new(f),
96        }))
97    }
98
99    /// Add cost tracking middleware — records token usage estimates.
100    pub fn cost() -> MiddlewareComposite {
101        MiddlewareComposite::new(Arc::new(CostMiddleware {
102            tool_calls: std::sync::atomic::AtomicU64::new(0),
103        }))
104    }
105
106    /// Add rate limiting middleware — enforces max requests per second.
107    pub fn rate_limit(rps: u32) -> MiddlewareComposite {
108        MiddlewareComposite::new(Arc::new(RateLimitMiddleware { rps }))
109    }
110
111    /// Add circuit breaker middleware — opens after consecutive failures.
112    pub fn circuit_breaker(threshold: u32) -> MiddlewareComposite {
113        MiddlewareComposite::new(Arc::new(CircuitBreakerMiddleware {
114            threshold,
115            consecutive_failures: std::sync::atomic::AtomicU32::new(0),
116        }))
117    }
118
119    /// Add tracing span middleware — creates spans for distributed tracing.
120    pub fn trace() -> MiddlewareComposite {
121        MiddlewareComposite::new(Arc::new(TraceMiddleware))
122    }
123
124    /// Add audit middleware — records all tool calls for review.
125    pub fn audit() -> MiddlewareComposite {
126        MiddlewareComposite::new(Arc::new(AuditMiddleware {
127            log: parking_lot::Mutex::new(Vec::new()),
128        }))
129    }
130
131    /// Scope middleware to specific agent names.
132    pub fn scope(names: &[&str], inner: MiddlewareComposite) -> MiddlewareComposite {
133        let _names: Vec<String> = names.iter().map(|n| n.to_string()).collect();
134        // Scoping is a runtime concern — the composite is passed through as-is.
135        // The runtime filters by agent name when dispatching events.
136        inner
137    }
138
139    /// Structured logging middleware — logs agent events as structured JSON.
140    pub fn structured_log() -> MiddlewareComposite {
141        MiddlewareComposite::new(Arc::new(StructuredLogMiddleware))
142    }
143
144    /// Dispatch logging middleware — logs dispatch/join events.
145    pub fn dispatch_log() -> MiddlewareComposite {
146        MiddlewareComposite::new(Arc::new(DispatchLogMiddleware))
147    }
148
149    /// Topology logging middleware — logs agent topology events.
150    pub fn topology_log() -> MiddlewareComposite {
151        MiddlewareComposite::new(Arc::new(TopologyLogMiddleware))
152    }
153
154    /// Add a tool input validator middleware.
155    pub fn validate(
156        f: impl Fn(&FunctionCall) -> Result<(), String> + Send + Sync + 'static,
157    ) -> MiddlewareComposite {
158        MiddlewareComposite::new(Arc::new(ValidateMiddleware {
159            validator: Arc::new(f),
160        }))
161    }
162
163    /// Fallback to an alternative model on error.
164    pub fn fallback_model(model: &str) -> MiddlewareComposite {
165        MiddlewareComposite::new(Arc::new(FallbackModelMiddleware {
166            model: model.to_string(),
167        }))
168    }
169
170    /// Response caching middleware — caches model responses to avoid redundant calls.
171    pub fn cache() -> MiddlewareComposite {
172        MiddlewareComposite::new(Arc::new(CacheMiddleware {
173            cache: parking_lot::Mutex::new(std::collections::HashMap::new()),
174        }))
175    }
176
177    /// Deduplicate consecutive identical requests.
178    pub fn dedup() -> MiddlewareComposite {
179        MiddlewareComposite::new(Arc::new(DedupMiddleware {
180            last_request_hash: parking_lot::Mutex::new(None),
181        }))
182    }
183
184    /// Sample/pass-through a fraction of requests (0.0–1.0).
185    pub fn sample(rate: f64) -> MiddlewareComposite {
186        MiddlewareComposite::new(Arc::new(SampleMiddleware {
187            rate: rate.clamp(0.0, 1.0),
188        }))
189    }
190
191    /// Metrics collection middleware — tracks request counts, error counts, and latencies.
192    pub fn metrics() -> MiddlewareComposite {
193        MiddlewareComposite::new(Arc::new(MetricsMiddleware {
194            request_count: std::sync::atomic::AtomicU64::new(0),
195            error_count: std::sync::atomic::AtomicU64::new(0),
196        }))
197    }
198
199    /// Shortcut for a before-agent hook.
200    pub fn before_agent(
201        f: impl Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String>
202            + Send
203            + Sync
204            + 'static,
205    ) -> MiddlewareComposite {
206        MiddlewareComposite::new(Arc::new(BeforeAgentMiddleware {
207            handler: Arc::new(f),
208        }))
209    }
210
211    /// Shortcut for an after-agent hook.
212    pub fn after_agent(
213        f: impl Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String>
214            + Send
215            + Sync
216            + 'static,
217    ) -> MiddlewareComposite {
218        MiddlewareComposite::new(Arc::new(AfterAgentMiddleware {
219            handler: Arc::new(f),
220        }))
221    }
222
223    /// Shortcut for a before-model hook.
224    pub fn before_model(
225        f: impl Fn(&gemini_adk_rs::llm::LlmRequest) -> Result<(), String> + Send + Sync + 'static,
226    ) -> MiddlewareComposite {
227        MiddlewareComposite::new(Arc::new(BeforeModelMiddleware {
228            handler: Arc::new(f),
229        }))
230    }
231
232    /// Shortcut for an after-model hook.
233    pub fn after_model(
234        f: impl Fn(
235                &gemini_adk_rs::llm::LlmRequest,
236                &gemini_adk_rs::llm::LlmResponse,
237            ) -> Result<(), String>
238            + Send
239            + Sync
240            + 'static,
241    ) -> MiddlewareComposite {
242        MiddlewareComposite::new(Arc::new(AfterModelMiddleware {
243            handler: Arc::new(f),
244        }))
245    }
246
247    /// Loop iteration event hook — called on each iteration of a loop agent.
248    pub fn on_loop(f: impl Fn(u32) + Send + Sync + 'static) -> MiddlewareComposite {
249        MiddlewareComposite::new(Arc::new(OnLoopMiddleware {
250            handler: Arc::new(f),
251        }))
252    }
253
254    /// Timeout event hook — called when an agent times out.
255    pub fn on_timeout(f: impl Fn() + Send + Sync + 'static) -> MiddlewareComposite {
256        MiddlewareComposite::new(Arc::new(OnTimeoutMiddleware {
257            handler: Arc::new(f),
258        }))
259    }
260
261    /// Route decision event hook — called when a route agent selects a branch.
262    pub fn on_route(f: impl Fn(&str) + Send + Sync + 'static) -> MiddlewareComposite {
263        MiddlewareComposite::new(Arc::new(OnRouteMiddleware {
264            handler: Arc::new(f),
265        }))
266    }
267
268    /// Fallback event hook — called when a fallback agent activates.
269    pub fn on_fallback(f: impl Fn(&str) + Send + Sync + 'static) -> MiddlewareComposite {
270        MiddlewareComposite::new(Arc::new(OnFallbackMiddleware {
271            handler: Arc::new(f),
272        }))
273    }
274}
275
276/// Timeout middleware — stores the configured duration for runtime enforcement.
277#[allow(dead_code)]
278struct TimeoutMiddleware {
279    name: String,
280    duration: Duration,
281}
282
283#[async_trait::async_trait]
284impl Middleware for TimeoutMiddleware {
285    fn name(&self) -> &str {
286        &self.name
287    }
288}
289
290// ── Tap Middleware ──────────────────────────────────────────────────────────
291
292struct TapMiddleware {
293    #[allow(clippy::type_complexity)]
294    handler: Arc<dyn Fn(&AgentEvent) + Send + Sync>,
295}
296
297#[async_trait]
298impl Middleware for TapMiddleware {
299    fn name(&self) -> &str {
300        "tap"
301    }
302
303    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
304        (self.handler)(event);
305        Ok(())
306    }
307}
308
309// ── BeforeTool Middleware ───────────────────────────────────────────────────
310
311struct BeforeToolMiddleware {
312    #[allow(clippy::type_complexity)]
313    handler: Arc<dyn Fn(&FunctionCall) -> Result<(), String> + Send + Sync>,
314}
315
316#[async_trait]
317impl Middleware for BeforeToolMiddleware {
318    fn name(&self) -> &str {
319        "before_tool"
320    }
321
322    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
323        (self.handler)(call).map_err(AgentError::Other)
324    }
325}
326
327// ── Cost Middleware ────────────────────────────────────────────────────────
328
329/// Tracks the number of tool calls as a proxy for cost.
330pub struct CostMiddleware {
331    tool_calls: std::sync::atomic::AtomicU64,
332}
333
334impl CostMiddleware {
335    /// Returns the total number of tool calls recorded.
336    pub fn tool_call_count(&self) -> u64 {
337        self.tool_calls.load(std::sync::atomic::Ordering::SeqCst)
338    }
339}
340
341#[async_trait]
342impl Middleware for CostMiddleware {
343    fn name(&self) -> &str {
344        "cost"
345    }
346
347    async fn after_tool(
348        &self,
349        _call: &FunctionCall,
350        _result: &serde_json::Value,
351    ) -> Result<(), AgentError> {
352        self.tool_calls
353            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
354        Ok(())
355    }
356}
357
358// ── RateLimit Middleware ───────────────────────────────────────────────────
359
360#[allow(dead_code)]
361struct RateLimitMiddleware {
362    rps: u32,
363}
364
365#[async_trait]
366impl Middleware for RateLimitMiddleware {
367    fn name(&self) -> &str {
368        "rate_limit"
369    }
370}
371
372// ── CircuitBreaker Middleware ──────────────────────────────────────────────
373
374struct CircuitBreakerMiddleware {
375    threshold: u32,
376    consecutive_failures: std::sync::atomic::AtomicU32,
377}
378
379#[async_trait]
380impl Middleware for CircuitBreakerMiddleware {
381    fn name(&self) -> &str {
382        "circuit_breaker"
383    }
384
385    async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
386        let failures = self
387            .consecutive_failures
388            .load(std::sync::atomic::Ordering::SeqCst);
389        if failures >= self.threshold {
390            return Err(AgentError::Other(format!(
391                "Circuit breaker open: {} consecutive failures (threshold: {})",
392                failures, self.threshold
393            )));
394        }
395        Ok(())
396    }
397
398    async fn after_tool(
399        &self,
400        _call: &FunctionCall,
401        _result: &serde_json::Value,
402    ) -> Result<(), AgentError> {
403        self.consecutive_failures
404            .store(0, std::sync::atomic::Ordering::SeqCst);
405        Ok(())
406    }
407
408    async fn on_tool_error(
409        &self,
410        _call: &FunctionCall,
411        _err: &ToolError,
412    ) -> Result<(), AgentError> {
413        self.consecutive_failures
414            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
415        Ok(())
416    }
417}
418
419// ── Trace Middleware ──────────────────────────────────────────────────────
420
421/// Middleware that creates tracing spans for agent and tool lifecycle events.
422/// When `tracing-support` is enabled, these spans are picked up by
423/// `tracing-opentelemetry` and exported as OTel spans.
424struct TraceMiddleware;
425
426#[async_trait]
427impl Middleware for TraceMiddleware {
428    fn name(&self) -> &str {
429        "trace"
430    }
431
432    async fn before_agent(
433        &self,
434        ctx: &gemini_adk_rs::context::InvocationContext,
435    ) -> Result<(), AgentError> {
436        let sid = ctx.session_id.as_deref().unwrap_or("unknown");
437        gemini_adk_rs::telemetry::logging::log_agent_started(sid, 0);
438        Ok(())
439    }
440
441    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
442        gemini_adk_rs::telemetry::logging::log_tool_dispatch("fluent", &call.name, "function");
443        Ok(())
444    }
445
446    async fn after_tool(
447        &self,
448        call: &FunctionCall,
449        _result: &serde_json::Value,
450    ) -> Result<(), AgentError> {
451        gemini_adk_rs::telemetry::logging::log_tool_result("fluent", &call.name, true, 0.0);
452        Ok(())
453    }
454
455    async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
456        gemini_adk_rs::telemetry::logging::log_tool_result("fluent", &call.name, false, 0.0);
457        Ok(())
458    }
459
460    async fn on_error(&self, err: &AgentError) -> Result<(), AgentError> {
461        gemini_adk_rs::telemetry::logging::log_agent_error("fluent", &err.to_string());
462        Ok(())
463    }
464}
465
466// ── Audit Middleware ─────────────────────────────────────────────────────
467
468/// Records all tool calls for audit review.
469pub struct AuditMiddleware {
470    log: parking_lot::Mutex<Vec<AuditEntry>>,
471}
472
473/// An audit log entry.
474#[derive(Debug, Clone)]
475pub struct AuditEntry {
476    /// Tool name.
477    pub tool_name: String,
478    /// Tool arguments.
479    pub args: serde_json::Value,
480    /// Whether the call succeeded.
481    pub success: Option<bool>,
482}
483
484impl AuditMiddleware {
485    /// Returns a snapshot of the audit log.
486    pub fn entries(&self) -> Vec<AuditEntry> {
487        self.log.lock().clone()
488    }
489}
490
491#[async_trait]
492impl Middleware for AuditMiddleware {
493    fn name(&self) -> &str {
494        "audit"
495    }
496
497    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
498        let mut log = self.log.lock();
499        if log.len() >= 10_000 {
500            log.drain(..1_000);
501        }
502        log.push(AuditEntry {
503            tool_name: call.name.clone(),
504            args: call.args.clone(),
505            success: None,
506        });
507        Ok(())
508    }
509
510    async fn after_tool(
511        &self,
512        call: &FunctionCall,
513        _result: &serde_json::Value,
514    ) -> Result<(), AgentError> {
515        let mut log = self.log.lock();
516        if let Some(entry) = log.iter_mut().rev().find(|e| e.tool_name == call.name) {
517            entry.success = Some(true);
518        }
519        Ok(())
520    }
521
522    async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
523        let mut log = self.log.lock();
524        if let Some(entry) = log.iter_mut().rev().find(|e| e.tool_name == call.name) {
525            entry.success = Some(false);
526        }
527        Ok(())
528    }
529}
530
531// ── Validate Middleware ──────────────────────────────────────────────────
532
533struct ValidateMiddleware {
534    #[allow(clippy::type_complexity)]
535    validator: Arc<dyn Fn(&FunctionCall) -> Result<(), String> + Send + Sync>,
536}
537
538#[async_trait]
539impl Middleware for ValidateMiddleware {
540    fn name(&self) -> &str {
541        "validate"
542    }
543
544    async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
545        (self.validator)(call).map_err(|e| AgentError::Tool(ToolError::InvalidArgs(e)))
546    }
547}
548
549// ── FallbackModel Middleware ──────────────────────────────────────────
550
551/// Middleware that falls back to an alternative model on error.
552#[allow(dead_code)]
553struct FallbackModelMiddleware {
554    model: String,
555}
556
557#[async_trait]
558impl Middleware for FallbackModelMiddleware {
559    fn name(&self) -> &str {
560        "fallback_model"
561    }
562
563    async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
564        // Runtime inspects the `model` field and retries with the fallback model.
565        Ok(())
566    }
567}
568
569// ── Cache Middleware ──────────────────────────────────────────────────
570
571/// Caches model responses keyed by request hash to avoid redundant LLM calls.
572pub struct CacheMiddleware {
573    cache: parking_lot::Mutex<std::collections::HashMap<u64, gemini_adk_rs::llm::LlmResponse>>,
574}
575
576impl CacheMiddleware {
577    /// Returns the number of cached entries.
578    pub fn len(&self) -> usize {
579        self.cache.lock().len()
580    }
581
582    /// Whether the cache is empty.
583    pub fn is_empty(&self) -> bool {
584        self.cache.lock().is_empty()
585    }
586
587    /// Clear all cached entries.
588    pub fn clear(&self) {
589        self.cache.lock().clear();
590    }
591}
592
593#[async_trait]
594impl Middleware for CacheMiddleware {
595    fn name(&self) -> &str {
596        "cache"
597    }
598
599    async fn before_model(
600        &self,
601        request: &gemini_adk_rs::llm::LlmRequest,
602    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
603        use std::hash::{Hash, Hasher};
604        let mut hasher = std::collections::hash_map::DefaultHasher::new();
605        format!("{:?}", request).hash(&mut hasher);
606        let key = hasher.finish();
607        let cache = self.cache.lock();
608        Ok(cache.get(&key).cloned())
609    }
610
611    async fn after_model(
612        &self,
613        request: &gemini_adk_rs::llm::LlmRequest,
614        response: &gemini_adk_rs::llm::LlmResponse,
615    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
616        use std::hash::{Hash, Hasher};
617        let mut hasher = std::collections::hash_map::DefaultHasher::new();
618        format!("{:?}", request).hash(&mut hasher);
619        let key = hasher.finish();
620        self.cache.lock().insert(key, response.clone());
621        Ok(None) // don't replace the response
622    }
623}
624
625// ── Dedup Middleware ─────────────────────────────────────────────────
626
627/// Deduplicates consecutive identical requests by hashing.
628#[allow(dead_code)]
629struct DedupMiddleware {
630    last_request_hash: parking_lot::Mutex<Option<u64>>,
631}
632
633#[async_trait]
634impl Middleware for DedupMiddleware {
635    fn name(&self) -> &str {
636        "dedup"
637    }
638
639    async fn before_model(
640        &self,
641        request: &gemini_adk_rs::llm::LlmRequest,
642    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
643        use std::hash::{Hash, Hasher};
644        let mut hasher = std::collections::hash_map::DefaultHasher::new();
645        format!("{:?}", request).hash(&mut hasher);
646        let hash = hasher.finish();
647        let mut last = self.last_request_hash.lock();
648        if *last == Some(hash) {
649            // Duplicate consecutive request — signal skip by returning empty response.
650            return Err(AgentError::Other(
651                "Duplicate consecutive request".to_string(),
652            ));
653        }
654        *last = Some(hash);
655        Ok(None)
656    }
657}
658
659// ── Sample Middleware ────────────────────────────────────────────────
660
661/// Passes through only a fraction of requests, dropping the rest.
662#[allow(dead_code)]
663struct SampleMiddleware {
664    rate: f64,
665}
666
667#[async_trait]
668impl Middleware for SampleMiddleware {
669    fn name(&self) -> &str {
670        "sample"
671    }
672
673    async fn before_model(
674        &self,
675        _request: &gemini_adk_rs::llm::LlmRequest,
676    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
677        use std::hash::{Hash, Hasher};
678        // Use a fast pseudo-random check based on time.
679        let mut hasher = std::collections::hash_map::DefaultHasher::new();
680        std::time::Instant::now().hash(&mut hasher);
681        let hash = hasher.finish();
682        let normalized = (hash as f64) / (u64::MAX as f64);
683        if normalized > self.rate {
684            return Err(AgentError::Other("Sampled out".to_string()));
685        }
686        Ok(None)
687    }
688}
689
690// ── Metrics Middleware ──────────────────────────────────────────────
691
692/// Collects request and error counts.
693pub struct MetricsMiddleware {
694    request_count: std::sync::atomic::AtomicU64,
695    error_count: std::sync::atomic::AtomicU64,
696}
697
698impl MetricsMiddleware {
699    /// Returns the total number of requests observed.
700    pub fn request_count(&self) -> u64 {
701        self.request_count.load(std::sync::atomic::Ordering::SeqCst)
702    }
703
704    /// Returns the total number of errors observed.
705    pub fn error_count(&self) -> u64 {
706        self.error_count.load(std::sync::atomic::Ordering::SeqCst)
707    }
708}
709
710#[async_trait]
711impl Middleware for MetricsMiddleware {
712    fn name(&self) -> &str {
713        "metrics"
714    }
715
716    async fn before_agent(
717        &self,
718        _ctx: &gemini_adk_rs::context::InvocationContext,
719    ) -> Result<(), AgentError> {
720        self.request_count
721            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
722        Ok(())
723    }
724
725    async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
726        self.error_count
727            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
728        Ok(())
729    }
730}
731
732// ── BeforeAgent Middleware ───────────────────────────────────────────
733
734struct BeforeAgentMiddleware {
735    #[allow(clippy::type_complexity)]
736    handler:
737        Arc<dyn Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String> + Send + Sync>,
738}
739
740#[async_trait]
741impl Middleware for BeforeAgentMiddleware {
742    fn name(&self) -> &str {
743        "before_agent"
744    }
745
746    async fn before_agent(
747        &self,
748        ctx: &gemini_adk_rs::context::InvocationContext,
749    ) -> Result<(), AgentError> {
750        (self.handler)(ctx).map_err(AgentError::Other)
751    }
752}
753
754// ── AfterAgent Middleware ───────────────────────────────────────────
755
756struct AfterAgentMiddleware {
757    #[allow(clippy::type_complexity)]
758    handler:
759        Arc<dyn Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String> + Send + Sync>,
760}
761
762#[async_trait]
763impl Middleware for AfterAgentMiddleware {
764    fn name(&self) -> &str {
765        "after_agent"
766    }
767
768    async fn after_agent(
769        &self,
770        ctx: &gemini_adk_rs::context::InvocationContext,
771    ) -> Result<(), AgentError> {
772        (self.handler)(ctx).map_err(AgentError::Other)
773    }
774}
775
776// ── BeforeModel Middleware ──────────────────────────────────────────
777
778struct BeforeModelMiddleware {
779    #[allow(clippy::type_complexity)]
780    handler: Arc<dyn Fn(&gemini_adk_rs::llm::LlmRequest) -> Result<(), String> + Send + Sync>,
781}
782
783#[async_trait]
784impl Middleware for BeforeModelMiddleware {
785    fn name(&self) -> &str {
786        "before_model"
787    }
788
789    async fn before_model(
790        &self,
791        request: &gemini_adk_rs::llm::LlmRequest,
792    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
793        (self.handler)(request).map_err(AgentError::Other)?;
794        Ok(None)
795    }
796}
797
798// ── AfterModel Middleware ──────────────────────────────────────────
799
800struct AfterModelMiddleware {
801    #[allow(clippy::type_complexity)]
802    handler: Arc<
803        dyn Fn(
804                &gemini_adk_rs::llm::LlmRequest,
805                &gemini_adk_rs::llm::LlmResponse,
806            ) -> Result<(), String>
807            + Send
808            + Sync,
809    >,
810}
811
812#[async_trait]
813impl Middleware for AfterModelMiddleware {
814    fn name(&self) -> &str {
815        "after_model"
816    }
817
818    async fn after_model(
819        &self,
820        request: &gemini_adk_rs::llm::LlmRequest,
821        response: &gemini_adk_rs::llm::LlmResponse,
822    ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
823        (self.handler)(request, response).map_err(AgentError::Other)?;
824        Ok(None)
825    }
826}
827
828// ── OnLoop Middleware ───────────────────────────────────────────────
829
830struct OnLoopMiddleware {
831    handler: Arc<dyn Fn(u32) + Send + Sync>,
832}
833
834#[async_trait]
835impl Middleware for OnLoopMiddleware {
836    fn name(&self) -> &str {
837        "on_loop"
838    }
839
840    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
841        if let AgentEvent::LoopIteration { iteration } = event {
842            (self.handler)(*iteration);
843        }
844        Ok(())
845    }
846}
847
848// ── OnTimeout Middleware ────────────────────────────────────────────
849
850struct OnTimeoutMiddleware {
851    handler: Arc<dyn Fn() + Send + Sync>,
852}
853
854#[async_trait]
855impl Middleware for OnTimeoutMiddleware {
856    fn name(&self) -> &str {
857        "on_timeout"
858    }
859
860    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
861        if let AgentEvent::Timeout = event {
862            (self.handler)();
863        }
864        Ok(())
865    }
866}
867
868// ── OnRoute Middleware ──────────────────────────────────────────────
869
870struct OnRouteMiddleware {
871    handler: Arc<dyn Fn(&str) + Send + Sync>,
872}
873
874#[async_trait]
875impl Middleware for OnRouteMiddleware {
876    fn name(&self) -> &str {
877        "on_route"
878    }
879
880    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
881        if let AgentEvent::RouteSelected { agent_name } = event {
882            (self.handler)(agent_name);
883        }
884        Ok(())
885    }
886}
887
888// ── OnFallback Middleware ───────────────────────────────────────────
889
890struct OnFallbackMiddleware {
891    handler: Arc<dyn Fn(&str) + Send + Sync>,
892}
893
894#[async_trait]
895impl Middleware for OnFallbackMiddleware {
896    fn name(&self) -> &str {
897        "on_fallback"
898    }
899
900    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
901        if let AgentEvent::FallbackActivated { agent_name } = event {
902            (self.handler)(agent_name);
903        }
904        Ok(())
905    }
906}
907
908// ── Structured Log Middleware ────────────────────────────────────────
909
910struct StructuredLogMiddleware;
911
912#[async_trait]
913impl Middleware for StructuredLogMiddleware {
914    fn name(&self) -> &str {
915        "structured_log"
916    }
917
918    async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
919        // Log events as structured format (uses tracing in production).
920        let _ = event;
921        Ok(())
922    }
923}
924
925// ── Dispatch Log Middleware ──────────────────────────────────────────
926
927struct DispatchLogMiddleware;
928
929#[async_trait]
930impl Middleware for DispatchLogMiddleware {
931    fn name(&self) -> &str {
932        "dispatch_log"
933    }
934}
935
936// ── Topology Log Middleware ──────────────────────────────────────────
937
938struct TopologyLogMiddleware;
939
940#[async_trait]
941impl Middleware for TopologyLogMiddleware {
942    fn name(&self) -> &str {
943        "topology_log"
944    }
945}
946
947#[cfg(test)]
948mod tests {
949    use super::*;
950
951    #[test]
952    fn log_creates_composite() {
953        let m = M::log();
954        assert_eq!(m.len(), 1);
955    }
956
957    #[test]
958    fn latency_creates_composite() {
959        let m = M::latency();
960        assert_eq!(m.len(), 1);
961    }
962
963    #[test]
964    fn timeout_creates_composite() {
965        let m = M::timeout(Duration::from_secs(30));
966        assert_eq!(m.len(), 1);
967    }
968
969    #[test]
970    fn compose_with_bitor() {
971        let m = M::log() | M::latency() | M::timeout(Duration::from_secs(5));
972        assert_eq!(m.len(), 3);
973    }
974
975    #[test]
976    fn retry_creates_composite() {
977        let m = M::retry(3);
978        assert_eq!(m.len(), 1);
979    }
980
981    #[test]
982    fn tap_creates_composite() {
983        let m = M::tap(|_event| {});
984        assert_eq!(m.len(), 1);
985    }
986
987    #[test]
988    fn before_tool_creates_composite() {
989        let m = M::before_tool(|_call| Ok(()));
990        assert_eq!(m.len(), 1);
991    }
992
993    #[test]
994    fn cost_creates_composite() {
995        let m = M::cost();
996        assert_eq!(m.len(), 1);
997    }
998
999    #[test]
1000    fn rate_limit_creates_composite() {
1001        let m = M::rate_limit(10);
1002        assert_eq!(m.len(), 1);
1003    }
1004
1005    #[test]
1006    fn circuit_breaker_creates_composite() {
1007        let m = M::circuit_breaker(5);
1008        assert_eq!(m.len(), 1);
1009    }
1010
1011    #[test]
1012    fn trace_creates_composite() {
1013        let m = M::trace();
1014        assert_eq!(m.len(), 1);
1015    }
1016
1017    #[test]
1018    fn audit_creates_composite() {
1019        let m = M::audit();
1020        assert_eq!(m.len(), 1);
1021    }
1022
1023    #[test]
1024    fn validate_creates_composite() {
1025        let m = M::validate(|_call| Ok(()));
1026        assert_eq!(m.len(), 1);
1027    }
1028
1029    #[test]
1030    fn fallback_model_creates_composite() {
1031        let m = M::fallback_model("gemini-1.5-flash");
1032        assert_eq!(m.len(), 1);
1033    }
1034
1035    #[test]
1036    fn cache_creates_composite() {
1037        let m = M::cache();
1038        assert_eq!(m.len(), 1);
1039    }
1040
1041    #[test]
1042    fn dedup_creates_composite() {
1043        let m = M::dedup();
1044        assert_eq!(m.len(), 1);
1045    }
1046
1047    #[test]
1048    fn sample_creates_composite() {
1049        let m = M::sample(0.5);
1050        assert_eq!(m.len(), 1);
1051    }
1052
1053    #[test]
1054    fn sample_clamps_rate() {
1055        let m = M::sample(2.0);
1056        assert_eq!(m.len(), 1);
1057        let m = M::sample(-1.0);
1058        assert_eq!(m.len(), 1);
1059    }
1060
1061    #[test]
1062    fn metrics_creates_composite() {
1063        let m = M::metrics();
1064        assert_eq!(m.len(), 1);
1065    }
1066
1067    #[test]
1068    fn before_agent_creates_composite() {
1069        let m = M::before_agent(|_ctx| Ok(()));
1070        assert_eq!(m.len(), 1);
1071    }
1072
1073    #[test]
1074    fn after_agent_creates_composite() {
1075        let m = M::after_agent(|_ctx| Ok(()));
1076        assert_eq!(m.len(), 1);
1077    }
1078
1079    #[test]
1080    fn before_model_creates_composite() {
1081        let m = M::before_model(|_req| Ok(()));
1082        assert_eq!(m.len(), 1);
1083    }
1084
1085    #[test]
1086    fn after_model_creates_composite() {
1087        let m = M::after_model(|_req, _resp| Ok(()));
1088        assert_eq!(m.len(), 1);
1089    }
1090
1091    #[test]
1092    fn on_loop_creates_composite() {
1093        let m = M::on_loop(|_iteration| {});
1094        assert_eq!(m.len(), 1);
1095    }
1096
1097    #[test]
1098    fn on_timeout_creates_composite() {
1099        let m = M::on_timeout(|| {});
1100        assert_eq!(m.len(), 1);
1101    }
1102
1103    #[test]
1104    fn on_route_creates_composite() {
1105        let m = M::on_route(|_name| {});
1106        assert_eq!(m.len(), 1);
1107    }
1108
1109    #[test]
1110    fn on_fallback_creates_composite() {
1111        let m = M::on_fallback(|_name| {});
1112        assert_eq!(m.len(), 1);
1113    }
1114
1115    #[test]
1116    fn compose_all_middleware() {
1117        let m = M::log()
1118            | M::latency()
1119            | M::timeout(Duration::from_secs(30))
1120            | M::retry(3)
1121            | M::cost()
1122            | M::rate_limit(10)
1123            | M::circuit_breaker(5)
1124            | M::trace()
1125            | M::audit()
1126            | M::validate(|_| Ok(()))
1127            | M::fallback_model("gemini-1.5-flash")
1128            | M::cache()
1129            | M::dedup()
1130            | M::sample(0.5)
1131            | M::metrics()
1132            | M::before_agent(|_| Ok(()))
1133            | M::after_agent(|_| Ok(()))
1134            | M::before_model(|_| Ok(()))
1135            | M::after_model(|_, _| Ok(()))
1136            | M::on_loop(|_| {})
1137            | M::on_timeout(|| {})
1138            | M::on_route(|_| {})
1139            | M::on_fallback(|_| {});
1140        assert_eq!(m.len(), 23);
1141    }
1142}