1use std::sync::Arc;
9use std::time::Duration;
10
11use async_trait::async_trait;
12use gemini_genai_rs::prelude::FunctionCall;
13
14use gemini_adk_rs::context::AgentEvent;
15use gemini_adk_rs::error::{AgentError, ToolError};
16use gemini_adk_rs::middleware::{LatencyMiddleware, LogMiddleware, Middleware};
17
18#[derive(Clone)]
20pub struct MiddlewareComposite {
21 pub layers: Vec<Arc<dyn Middleware>>,
23}
24
25impl MiddlewareComposite {
26 pub fn new(layer: Arc<dyn Middleware>) -> Self {
28 Self {
29 layers: vec![layer],
30 }
31 }
32
33 pub fn len(&self) -> usize {
35 self.layers.len()
36 }
37
38 pub fn is_empty(&self) -> bool {
40 self.layers.is_empty()
41 }
42}
43
44impl std::ops::BitOr for MiddlewareComposite {
46 type Output = MiddlewareComposite;
47
48 fn bitor(mut self, rhs: MiddlewareComposite) -> Self::Output {
49 self.layers.extend(rhs.layers);
50 self
51 }
52}
53
54pub struct M;
56
57impl M {
58 pub fn log() -> MiddlewareComposite {
60 MiddlewareComposite::new(Arc::new(LogMiddleware::new()))
61 }
62
63 pub fn latency() -> MiddlewareComposite {
65 MiddlewareComposite::new(Arc::new(LatencyMiddleware::new()))
66 }
67
68 pub fn timeout(duration: Duration) -> MiddlewareComposite {
70 MiddlewareComposite::new(Arc::new(TimeoutMiddleware {
71 name: "timeout".to_string(),
72 duration,
73 }))
74 }
75
76 pub fn retry(max_retries: u32) -> MiddlewareComposite {
78 MiddlewareComposite::new(Arc::new(gemini_adk_rs::middleware::RetryMiddleware::new(
79 max_retries,
80 )))
81 }
82
83 pub fn tap(f: impl Fn(&AgentEvent) + Send + Sync + 'static) -> MiddlewareComposite {
85 MiddlewareComposite::new(Arc::new(TapMiddleware {
86 handler: Arc::new(f),
87 }))
88 }
89
90 pub fn before_tool(
92 f: impl Fn(&FunctionCall) -> Result<(), String> + Send + Sync + 'static,
93 ) -> MiddlewareComposite {
94 MiddlewareComposite::new(Arc::new(BeforeToolMiddleware {
95 handler: Arc::new(f),
96 }))
97 }
98
99 pub fn cost() -> MiddlewareComposite {
101 MiddlewareComposite::new(Arc::new(CostMiddleware {
102 tool_calls: std::sync::atomic::AtomicU64::new(0),
103 }))
104 }
105
106 pub fn rate_limit(rps: u32) -> MiddlewareComposite {
108 MiddlewareComposite::new(Arc::new(RateLimitMiddleware { rps }))
109 }
110
111 pub fn circuit_breaker(threshold: u32) -> MiddlewareComposite {
113 MiddlewareComposite::new(Arc::new(CircuitBreakerMiddleware {
114 threshold,
115 consecutive_failures: std::sync::atomic::AtomicU32::new(0),
116 }))
117 }
118
119 pub fn trace() -> MiddlewareComposite {
121 MiddlewareComposite::new(Arc::new(TraceMiddleware))
122 }
123
124 pub fn audit() -> MiddlewareComposite {
126 MiddlewareComposite::new(Arc::new(AuditMiddleware {
127 log: parking_lot::Mutex::new(Vec::new()),
128 }))
129 }
130
131 pub fn scope(names: &[&str], inner: MiddlewareComposite) -> MiddlewareComposite {
133 let _names: Vec<String> = names.iter().map(|n| n.to_string()).collect();
134 inner
137 }
138
139 pub fn structured_log() -> MiddlewareComposite {
141 MiddlewareComposite::new(Arc::new(StructuredLogMiddleware))
142 }
143
144 pub fn dispatch_log() -> MiddlewareComposite {
146 MiddlewareComposite::new(Arc::new(DispatchLogMiddleware))
147 }
148
149 pub fn topology_log() -> MiddlewareComposite {
151 MiddlewareComposite::new(Arc::new(TopologyLogMiddleware))
152 }
153
154 pub fn validate(
156 f: impl Fn(&FunctionCall) -> Result<(), String> + Send + Sync + 'static,
157 ) -> MiddlewareComposite {
158 MiddlewareComposite::new(Arc::new(ValidateMiddleware {
159 validator: Arc::new(f),
160 }))
161 }
162
163 pub fn fallback_model(model: &str) -> MiddlewareComposite {
165 MiddlewareComposite::new(Arc::new(FallbackModelMiddleware {
166 model: model.to_string(),
167 }))
168 }
169
170 pub fn cache() -> MiddlewareComposite {
172 MiddlewareComposite::new(Arc::new(CacheMiddleware {
173 cache: parking_lot::Mutex::new(std::collections::HashMap::new()),
174 }))
175 }
176
177 pub fn dedup() -> MiddlewareComposite {
179 MiddlewareComposite::new(Arc::new(DedupMiddleware {
180 last_request_hash: parking_lot::Mutex::new(None),
181 }))
182 }
183
184 pub fn sample(rate: f64) -> MiddlewareComposite {
186 MiddlewareComposite::new(Arc::new(SampleMiddleware {
187 rate: rate.clamp(0.0, 1.0),
188 }))
189 }
190
191 pub fn metrics() -> MiddlewareComposite {
193 MiddlewareComposite::new(Arc::new(MetricsMiddleware {
194 request_count: std::sync::atomic::AtomicU64::new(0),
195 error_count: std::sync::atomic::AtomicU64::new(0),
196 }))
197 }
198
199 pub fn before_agent(
201 f: impl Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String>
202 + Send
203 + Sync
204 + 'static,
205 ) -> MiddlewareComposite {
206 MiddlewareComposite::new(Arc::new(BeforeAgentMiddleware {
207 handler: Arc::new(f),
208 }))
209 }
210
211 pub fn after_agent(
213 f: impl Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String>
214 + Send
215 + Sync
216 + 'static,
217 ) -> MiddlewareComposite {
218 MiddlewareComposite::new(Arc::new(AfterAgentMiddleware {
219 handler: Arc::new(f),
220 }))
221 }
222
223 pub fn before_model(
225 f: impl Fn(&gemini_adk_rs::llm::LlmRequest) -> Result<(), String> + Send + Sync + 'static,
226 ) -> MiddlewareComposite {
227 MiddlewareComposite::new(Arc::new(BeforeModelMiddleware {
228 handler: Arc::new(f),
229 }))
230 }
231
232 pub fn after_model(
234 f: impl Fn(
235 &gemini_adk_rs::llm::LlmRequest,
236 &gemini_adk_rs::llm::LlmResponse,
237 ) -> Result<(), String>
238 + Send
239 + Sync
240 + 'static,
241 ) -> MiddlewareComposite {
242 MiddlewareComposite::new(Arc::new(AfterModelMiddleware {
243 handler: Arc::new(f),
244 }))
245 }
246
247 pub fn on_loop(f: impl Fn(u32) + Send + Sync + 'static) -> MiddlewareComposite {
249 MiddlewareComposite::new(Arc::new(OnLoopMiddleware {
250 handler: Arc::new(f),
251 }))
252 }
253
254 pub fn on_timeout(f: impl Fn() + Send + Sync + 'static) -> MiddlewareComposite {
256 MiddlewareComposite::new(Arc::new(OnTimeoutMiddleware {
257 handler: Arc::new(f),
258 }))
259 }
260
261 pub fn on_route(f: impl Fn(&str) + Send + Sync + 'static) -> MiddlewareComposite {
263 MiddlewareComposite::new(Arc::new(OnRouteMiddleware {
264 handler: Arc::new(f),
265 }))
266 }
267
268 pub fn on_fallback(f: impl Fn(&str) + Send + Sync + 'static) -> MiddlewareComposite {
270 MiddlewareComposite::new(Arc::new(OnFallbackMiddleware {
271 handler: Arc::new(f),
272 }))
273 }
274}
275
276#[allow(dead_code)]
278struct TimeoutMiddleware {
279 name: String,
280 duration: Duration,
281}
282
283#[async_trait::async_trait]
284impl Middleware for TimeoutMiddleware {
285 fn name(&self) -> &str {
286 &self.name
287 }
288}
289
290struct TapMiddleware {
293 #[allow(clippy::type_complexity)]
294 handler: Arc<dyn Fn(&AgentEvent) + Send + Sync>,
295}
296
297#[async_trait]
298impl Middleware for TapMiddleware {
299 fn name(&self) -> &str {
300 "tap"
301 }
302
303 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
304 (self.handler)(event);
305 Ok(())
306 }
307}
308
309struct BeforeToolMiddleware {
312 #[allow(clippy::type_complexity)]
313 handler: Arc<dyn Fn(&FunctionCall) -> Result<(), String> + Send + Sync>,
314}
315
316#[async_trait]
317impl Middleware for BeforeToolMiddleware {
318 fn name(&self) -> &str {
319 "before_tool"
320 }
321
322 async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
323 (self.handler)(call).map_err(AgentError::Other)
324 }
325}
326
327pub struct CostMiddleware {
331 tool_calls: std::sync::atomic::AtomicU64,
332}
333
334impl CostMiddleware {
335 pub fn tool_call_count(&self) -> u64 {
337 self.tool_calls.load(std::sync::atomic::Ordering::SeqCst)
338 }
339}
340
341#[async_trait]
342impl Middleware for CostMiddleware {
343 fn name(&self) -> &str {
344 "cost"
345 }
346
347 async fn after_tool(
348 &self,
349 _call: &FunctionCall,
350 _result: &serde_json::Value,
351 ) -> Result<(), AgentError> {
352 self.tool_calls
353 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
354 Ok(())
355 }
356}
357
358#[allow(dead_code)]
361struct RateLimitMiddleware {
362 rps: u32,
363}
364
365#[async_trait]
366impl Middleware for RateLimitMiddleware {
367 fn name(&self) -> &str {
368 "rate_limit"
369 }
370}
371
372struct CircuitBreakerMiddleware {
375 threshold: u32,
376 consecutive_failures: std::sync::atomic::AtomicU32,
377}
378
379#[async_trait]
380impl Middleware for CircuitBreakerMiddleware {
381 fn name(&self) -> &str {
382 "circuit_breaker"
383 }
384
385 async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
386 let failures = self
387 .consecutive_failures
388 .load(std::sync::atomic::Ordering::SeqCst);
389 if failures >= self.threshold {
390 return Err(AgentError::Other(format!(
391 "Circuit breaker open: {} consecutive failures (threshold: {})",
392 failures, self.threshold
393 )));
394 }
395 Ok(())
396 }
397
398 async fn after_tool(
399 &self,
400 _call: &FunctionCall,
401 _result: &serde_json::Value,
402 ) -> Result<(), AgentError> {
403 self.consecutive_failures
404 .store(0, std::sync::atomic::Ordering::SeqCst);
405 Ok(())
406 }
407
408 async fn on_tool_error(
409 &self,
410 _call: &FunctionCall,
411 _err: &ToolError,
412 ) -> Result<(), AgentError> {
413 self.consecutive_failures
414 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
415 Ok(())
416 }
417}
418
419struct TraceMiddleware;
425
426#[async_trait]
427impl Middleware for TraceMiddleware {
428 fn name(&self) -> &str {
429 "trace"
430 }
431
432 async fn before_agent(
433 &self,
434 ctx: &gemini_adk_rs::context::InvocationContext,
435 ) -> Result<(), AgentError> {
436 let sid = ctx.session_id.as_deref().unwrap_or("unknown");
437 gemini_adk_rs::telemetry::logging::log_agent_started(sid, 0);
438 Ok(())
439 }
440
441 async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
442 gemini_adk_rs::telemetry::logging::log_tool_dispatch("fluent", &call.name, "function");
443 Ok(())
444 }
445
446 async fn after_tool(
447 &self,
448 call: &FunctionCall,
449 _result: &serde_json::Value,
450 ) -> Result<(), AgentError> {
451 gemini_adk_rs::telemetry::logging::log_tool_result("fluent", &call.name, true, 0.0);
452 Ok(())
453 }
454
455 async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
456 gemini_adk_rs::telemetry::logging::log_tool_result("fluent", &call.name, false, 0.0);
457 Ok(())
458 }
459
460 async fn on_error(&self, err: &AgentError) -> Result<(), AgentError> {
461 gemini_adk_rs::telemetry::logging::log_agent_error("fluent", &err.to_string());
462 Ok(())
463 }
464}
465
466pub struct AuditMiddleware {
470 log: parking_lot::Mutex<Vec<AuditEntry>>,
471}
472
473#[derive(Debug, Clone)]
475pub struct AuditEntry {
476 pub tool_name: String,
478 pub args: serde_json::Value,
480 pub success: Option<bool>,
482}
483
484impl AuditMiddleware {
485 pub fn entries(&self) -> Vec<AuditEntry> {
487 self.log.lock().clone()
488 }
489}
490
491#[async_trait]
492impl Middleware for AuditMiddleware {
493 fn name(&self) -> &str {
494 "audit"
495 }
496
497 async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
498 let mut log = self.log.lock();
499 if log.len() >= 10_000 {
500 log.drain(..1_000);
501 }
502 log.push(AuditEntry {
503 tool_name: call.name.clone(),
504 args: call.args.clone(),
505 success: None,
506 });
507 Ok(())
508 }
509
510 async fn after_tool(
511 &self,
512 call: &FunctionCall,
513 _result: &serde_json::Value,
514 ) -> Result<(), AgentError> {
515 let mut log = self.log.lock();
516 if let Some(entry) = log.iter_mut().rev().find(|e| e.tool_name == call.name) {
517 entry.success = Some(true);
518 }
519 Ok(())
520 }
521
522 async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
523 let mut log = self.log.lock();
524 if let Some(entry) = log.iter_mut().rev().find(|e| e.tool_name == call.name) {
525 entry.success = Some(false);
526 }
527 Ok(())
528 }
529}
530
531struct ValidateMiddleware {
534 #[allow(clippy::type_complexity)]
535 validator: Arc<dyn Fn(&FunctionCall) -> Result<(), String> + Send + Sync>,
536}
537
538#[async_trait]
539impl Middleware for ValidateMiddleware {
540 fn name(&self) -> &str {
541 "validate"
542 }
543
544 async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
545 (self.validator)(call).map_err(|e| AgentError::Tool(ToolError::InvalidArgs(e)))
546 }
547}
548
549#[allow(dead_code)]
553struct FallbackModelMiddleware {
554 model: String,
555}
556
557#[async_trait]
558impl Middleware for FallbackModelMiddleware {
559 fn name(&self) -> &str {
560 "fallback_model"
561 }
562
563 async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
564 Ok(())
566 }
567}
568
569pub struct CacheMiddleware {
573 cache: parking_lot::Mutex<std::collections::HashMap<u64, gemini_adk_rs::llm::LlmResponse>>,
574}
575
576impl CacheMiddleware {
577 pub fn len(&self) -> usize {
579 self.cache.lock().len()
580 }
581
582 pub fn is_empty(&self) -> bool {
584 self.cache.lock().is_empty()
585 }
586
587 pub fn clear(&self) {
589 self.cache.lock().clear();
590 }
591}
592
593#[async_trait]
594impl Middleware for CacheMiddleware {
595 fn name(&self) -> &str {
596 "cache"
597 }
598
599 async fn before_model(
600 &self,
601 request: &gemini_adk_rs::llm::LlmRequest,
602 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
603 use std::hash::{Hash, Hasher};
604 let mut hasher = std::collections::hash_map::DefaultHasher::new();
605 format!("{:?}", request).hash(&mut hasher);
606 let key = hasher.finish();
607 let cache = self.cache.lock();
608 Ok(cache.get(&key).cloned())
609 }
610
611 async fn after_model(
612 &self,
613 request: &gemini_adk_rs::llm::LlmRequest,
614 response: &gemini_adk_rs::llm::LlmResponse,
615 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
616 use std::hash::{Hash, Hasher};
617 let mut hasher = std::collections::hash_map::DefaultHasher::new();
618 format!("{:?}", request).hash(&mut hasher);
619 let key = hasher.finish();
620 self.cache.lock().insert(key, response.clone());
621 Ok(None) }
623}
624
625#[allow(dead_code)]
629struct DedupMiddleware {
630 last_request_hash: parking_lot::Mutex<Option<u64>>,
631}
632
633#[async_trait]
634impl Middleware for DedupMiddleware {
635 fn name(&self) -> &str {
636 "dedup"
637 }
638
639 async fn before_model(
640 &self,
641 request: &gemini_adk_rs::llm::LlmRequest,
642 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
643 use std::hash::{Hash, Hasher};
644 let mut hasher = std::collections::hash_map::DefaultHasher::new();
645 format!("{:?}", request).hash(&mut hasher);
646 let hash = hasher.finish();
647 let mut last = self.last_request_hash.lock();
648 if *last == Some(hash) {
649 return Err(AgentError::Other(
651 "Duplicate consecutive request".to_string(),
652 ));
653 }
654 *last = Some(hash);
655 Ok(None)
656 }
657}
658
659#[allow(dead_code)]
663struct SampleMiddleware {
664 rate: f64,
665}
666
667#[async_trait]
668impl Middleware for SampleMiddleware {
669 fn name(&self) -> &str {
670 "sample"
671 }
672
673 async fn before_model(
674 &self,
675 _request: &gemini_adk_rs::llm::LlmRequest,
676 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
677 use std::hash::{Hash, Hasher};
678 let mut hasher = std::collections::hash_map::DefaultHasher::new();
680 std::time::Instant::now().hash(&mut hasher);
681 let hash = hasher.finish();
682 let normalized = (hash as f64) / (u64::MAX as f64);
683 if normalized > self.rate {
684 return Err(AgentError::Other("Sampled out".to_string()));
685 }
686 Ok(None)
687 }
688}
689
690pub struct MetricsMiddleware {
694 request_count: std::sync::atomic::AtomicU64,
695 error_count: std::sync::atomic::AtomicU64,
696}
697
698impl MetricsMiddleware {
699 pub fn request_count(&self) -> u64 {
701 self.request_count.load(std::sync::atomic::Ordering::SeqCst)
702 }
703
704 pub fn error_count(&self) -> u64 {
706 self.error_count.load(std::sync::atomic::Ordering::SeqCst)
707 }
708}
709
710#[async_trait]
711impl Middleware for MetricsMiddleware {
712 fn name(&self) -> &str {
713 "metrics"
714 }
715
716 async fn before_agent(
717 &self,
718 _ctx: &gemini_adk_rs::context::InvocationContext,
719 ) -> Result<(), AgentError> {
720 self.request_count
721 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
722 Ok(())
723 }
724
725 async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
726 self.error_count
727 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
728 Ok(())
729 }
730}
731
732struct BeforeAgentMiddleware {
735 #[allow(clippy::type_complexity)]
736 handler:
737 Arc<dyn Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String> + Send + Sync>,
738}
739
740#[async_trait]
741impl Middleware for BeforeAgentMiddleware {
742 fn name(&self) -> &str {
743 "before_agent"
744 }
745
746 async fn before_agent(
747 &self,
748 ctx: &gemini_adk_rs::context::InvocationContext,
749 ) -> Result<(), AgentError> {
750 (self.handler)(ctx).map_err(AgentError::Other)
751 }
752}
753
754struct AfterAgentMiddleware {
757 #[allow(clippy::type_complexity)]
758 handler:
759 Arc<dyn Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String> + Send + Sync>,
760}
761
762#[async_trait]
763impl Middleware for AfterAgentMiddleware {
764 fn name(&self) -> &str {
765 "after_agent"
766 }
767
768 async fn after_agent(
769 &self,
770 ctx: &gemini_adk_rs::context::InvocationContext,
771 ) -> Result<(), AgentError> {
772 (self.handler)(ctx).map_err(AgentError::Other)
773 }
774}
775
776struct BeforeModelMiddleware {
779 #[allow(clippy::type_complexity)]
780 handler: Arc<dyn Fn(&gemini_adk_rs::llm::LlmRequest) -> Result<(), String> + Send + Sync>,
781}
782
783#[async_trait]
784impl Middleware for BeforeModelMiddleware {
785 fn name(&self) -> &str {
786 "before_model"
787 }
788
789 async fn before_model(
790 &self,
791 request: &gemini_adk_rs::llm::LlmRequest,
792 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
793 (self.handler)(request).map_err(AgentError::Other)?;
794 Ok(None)
795 }
796}
797
798struct AfterModelMiddleware {
801 #[allow(clippy::type_complexity)]
802 handler: Arc<
803 dyn Fn(
804 &gemini_adk_rs::llm::LlmRequest,
805 &gemini_adk_rs::llm::LlmResponse,
806 ) -> Result<(), String>
807 + Send
808 + Sync,
809 >,
810}
811
812#[async_trait]
813impl Middleware for AfterModelMiddleware {
814 fn name(&self) -> &str {
815 "after_model"
816 }
817
818 async fn after_model(
819 &self,
820 request: &gemini_adk_rs::llm::LlmRequest,
821 response: &gemini_adk_rs::llm::LlmResponse,
822 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
823 (self.handler)(request, response).map_err(AgentError::Other)?;
824 Ok(None)
825 }
826}
827
828struct OnLoopMiddleware {
831 handler: Arc<dyn Fn(u32) + Send + Sync>,
832}
833
834#[async_trait]
835impl Middleware for OnLoopMiddleware {
836 fn name(&self) -> &str {
837 "on_loop"
838 }
839
840 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
841 if let AgentEvent::LoopIteration { iteration } = event {
842 (self.handler)(*iteration);
843 }
844 Ok(())
845 }
846}
847
848struct OnTimeoutMiddleware {
851 handler: Arc<dyn Fn() + Send + Sync>,
852}
853
854#[async_trait]
855impl Middleware for OnTimeoutMiddleware {
856 fn name(&self) -> &str {
857 "on_timeout"
858 }
859
860 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
861 if let AgentEvent::Timeout = event {
862 (self.handler)();
863 }
864 Ok(())
865 }
866}
867
868struct OnRouteMiddleware {
871 handler: Arc<dyn Fn(&str) + Send + Sync>,
872}
873
874#[async_trait]
875impl Middleware for OnRouteMiddleware {
876 fn name(&self) -> &str {
877 "on_route"
878 }
879
880 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
881 if let AgentEvent::RouteSelected { agent_name } = event {
882 (self.handler)(agent_name);
883 }
884 Ok(())
885 }
886}
887
888struct OnFallbackMiddleware {
891 handler: Arc<dyn Fn(&str) + Send + Sync>,
892}
893
894#[async_trait]
895impl Middleware for OnFallbackMiddleware {
896 fn name(&self) -> &str {
897 "on_fallback"
898 }
899
900 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
901 if let AgentEvent::FallbackActivated { agent_name } = event {
902 (self.handler)(agent_name);
903 }
904 Ok(())
905 }
906}
907
908struct StructuredLogMiddleware;
911
912#[async_trait]
913impl Middleware for StructuredLogMiddleware {
914 fn name(&self) -> &str {
915 "structured_log"
916 }
917
918 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
919 let _ = event;
921 Ok(())
922 }
923}
924
925struct DispatchLogMiddleware;
928
929#[async_trait]
930impl Middleware for DispatchLogMiddleware {
931 fn name(&self) -> &str {
932 "dispatch_log"
933 }
934}
935
936struct TopologyLogMiddleware;
939
940#[async_trait]
941impl Middleware for TopologyLogMiddleware {
942 fn name(&self) -> &str {
943 "topology_log"
944 }
945}
946
947#[cfg(test)]
948mod tests {
949 use super::*;
950
951 #[test]
952 fn log_creates_composite() {
953 let m = M::log();
954 assert_eq!(m.len(), 1);
955 }
956
957 #[test]
958 fn latency_creates_composite() {
959 let m = M::latency();
960 assert_eq!(m.len(), 1);
961 }
962
963 #[test]
964 fn timeout_creates_composite() {
965 let m = M::timeout(Duration::from_secs(30));
966 assert_eq!(m.len(), 1);
967 }
968
969 #[test]
970 fn compose_with_bitor() {
971 let m = M::log() | M::latency() | M::timeout(Duration::from_secs(5));
972 assert_eq!(m.len(), 3);
973 }
974
975 #[test]
976 fn retry_creates_composite() {
977 let m = M::retry(3);
978 assert_eq!(m.len(), 1);
979 }
980
981 #[test]
982 fn tap_creates_composite() {
983 let m = M::tap(|_event| {});
984 assert_eq!(m.len(), 1);
985 }
986
987 #[test]
988 fn before_tool_creates_composite() {
989 let m = M::before_tool(|_call| Ok(()));
990 assert_eq!(m.len(), 1);
991 }
992
993 #[test]
994 fn cost_creates_composite() {
995 let m = M::cost();
996 assert_eq!(m.len(), 1);
997 }
998
999 #[test]
1000 fn rate_limit_creates_composite() {
1001 let m = M::rate_limit(10);
1002 assert_eq!(m.len(), 1);
1003 }
1004
1005 #[test]
1006 fn circuit_breaker_creates_composite() {
1007 let m = M::circuit_breaker(5);
1008 assert_eq!(m.len(), 1);
1009 }
1010
1011 #[test]
1012 fn trace_creates_composite() {
1013 let m = M::trace();
1014 assert_eq!(m.len(), 1);
1015 }
1016
1017 #[test]
1018 fn audit_creates_composite() {
1019 let m = M::audit();
1020 assert_eq!(m.len(), 1);
1021 }
1022
1023 #[test]
1024 fn validate_creates_composite() {
1025 let m = M::validate(|_call| Ok(()));
1026 assert_eq!(m.len(), 1);
1027 }
1028
1029 #[test]
1030 fn fallback_model_creates_composite() {
1031 let m = M::fallback_model("gemini-1.5-flash");
1032 assert_eq!(m.len(), 1);
1033 }
1034
1035 #[test]
1036 fn cache_creates_composite() {
1037 let m = M::cache();
1038 assert_eq!(m.len(), 1);
1039 }
1040
1041 #[test]
1042 fn dedup_creates_composite() {
1043 let m = M::dedup();
1044 assert_eq!(m.len(), 1);
1045 }
1046
1047 #[test]
1048 fn sample_creates_composite() {
1049 let m = M::sample(0.5);
1050 assert_eq!(m.len(), 1);
1051 }
1052
1053 #[test]
1054 fn sample_clamps_rate() {
1055 let m = M::sample(2.0);
1056 assert_eq!(m.len(), 1);
1057 let m = M::sample(-1.0);
1058 assert_eq!(m.len(), 1);
1059 }
1060
1061 #[test]
1062 fn metrics_creates_composite() {
1063 let m = M::metrics();
1064 assert_eq!(m.len(), 1);
1065 }
1066
1067 #[test]
1068 fn before_agent_creates_composite() {
1069 let m = M::before_agent(|_ctx| Ok(()));
1070 assert_eq!(m.len(), 1);
1071 }
1072
1073 #[test]
1074 fn after_agent_creates_composite() {
1075 let m = M::after_agent(|_ctx| Ok(()));
1076 assert_eq!(m.len(), 1);
1077 }
1078
1079 #[test]
1080 fn before_model_creates_composite() {
1081 let m = M::before_model(|_req| Ok(()));
1082 assert_eq!(m.len(), 1);
1083 }
1084
1085 #[test]
1086 fn after_model_creates_composite() {
1087 let m = M::after_model(|_req, _resp| Ok(()));
1088 assert_eq!(m.len(), 1);
1089 }
1090
1091 #[test]
1092 fn on_loop_creates_composite() {
1093 let m = M::on_loop(|_iteration| {});
1094 assert_eq!(m.len(), 1);
1095 }
1096
1097 #[test]
1098 fn on_timeout_creates_composite() {
1099 let m = M::on_timeout(|| {});
1100 assert_eq!(m.len(), 1);
1101 }
1102
1103 #[test]
1104 fn on_route_creates_composite() {
1105 let m = M::on_route(|_name| {});
1106 assert_eq!(m.len(), 1);
1107 }
1108
1109 #[test]
1110 fn on_fallback_creates_composite() {
1111 let m = M::on_fallback(|_name| {});
1112 assert_eq!(m.len(), 1);
1113 }
1114
1115 #[test]
1116 fn compose_all_middleware() {
1117 let m = M::log()
1118 | M::latency()
1119 | M::timeout(Duration::from_secs(30))
1120 | M::retry(3)
1121 | M::cost()
1122 | M::rate_limit(10)
1123 | M::circuit_breaker(5)
1124 | M::trace()
1125 | M::audit()
1126 | M::validate(|_| Ok(()))
1127 | M::fallback_model("gemini-1.5-flash")
1128 | M::cache()
1129 | M::dedup()
1130 | M::sample(0.5)
1131 | M::metrics()
1132 | M::before_agent(|_| Ok(()))
1133 | M::after_agent(|_| Ok(()))
1134 | M::before_model(|_| Ok(()))
1135 | M::after_model(|_, _| Ok(()))
1136 | M::on_loop(|_| {})
1137 | M::on_timeout(|| {})
1138 | M::on_route(|_| {})
1139 | M::on_fallback(|_| {});
1140 assert_eq!(m.len(), 23);
1141 }
1142}