1use std::sync::Arc;
28use std::time::Duration;
29
30use async_trait::async_trait;
31use gemini_genai_rs::prelude::FunctionCall;
32
33use gemini_adk_rs::context::AgentEvent;
34use gemini_adk_rs::error::{AgentError, ToolError};
35use gemini_adk_rs::middleware::{LatencyMiddleware, LogMiddleware, Middleware};
36
37#[derive(Clone)]
39pub struct MiddlewareComposite {
40 pub layers: Vec<Arc<dyn Middleware>>,
42}
43
44impl MiddlewareComposite {
45 pub fn new(layer: Arc<dyn Middleware>) -> Self {
47 Self {
48 layers: vec![layer],
49 }
50 }
51
52 pub fn len(&self) -> usize {
54 self.layers.len()
55 }
56
57 pub fn is_empty(&self) -> bool {
59 self.layers.is_empty()
60 }
61}
62
63impl std::ops::BitOr for MiddlewareComposite {
65 type Output = MiddlewareComposite;
66
67 fn bitor(mut self, rhs: MiddlewareComposite) -> Self::Output {
68 self.layers.extend(rhs.layers);
69 self
70 }
71}
72
73pub struct M;
75
76impl M {
77 pub fn log() -> MiddlewareComposite {
79 MiddlewareComposite::new(Arc::new(LogMiddleware::new()))
80 }
81
82 pub fn latency() -> MiddlewareComposite {
84 MiddlewareComposite::new(Arc::new(LatencyMiddleware::new()))
85 }
86
87 pub fn timeout(duration: Duration) -> MiddlewareComposite {
91 MiddlewareComposite::new(Arc::new(TimeoutMiddleware {
92 name: "timeout".to_string(),
93 duration,
94 }))
95 }
96
97 pub fn retry(max_retries: u32) -> MiddlewareComposite {
99 MiddlewareComposite::new(Arc::new(gemini_adk_rs::middleware::RetryMiddleware::new(
100 max_retries,
101 )))
102 }
103
104 pub fn tap(f: impl Fn(&AgentEvent) + Send + Sync + 'static) -> MiddlewareComposite {
106 MiddlewareComposite::new(Arc::new(TapMiddleware {
107 handler: Arc::new(f),
108 }))
109 }
110
111 pub fn before_tool(
113 f: impl Fn(&FunctionCall) -> Result<(), String> + Send + Sync + 'static,
114 ) -> MiddlewareComposite {
115 MiddlewareComposite::new(Arc::new(BeforeToolMiddleware {
116 handler: Arc::new(f),
117 }))
118 }
119
120 pub fn after_tool(
122 f: impl Fn(&FunctionCall, &serde_json::Value) -> Result<(), String> + Send + Sync + 'static,
123 ) -> MiddlewareComposite {
124 MiddlewareComposite::new(Arc::new(AfterToolMiddleware {
125 handler: Arc::new(f),
126 }))
127 }
128
129 pub fn on_error(
131 f: impl Fn(&AgentError) -> Result<(), String> + Send + Sync + 'static,
132 ) -> MiddlewareComposite {
133 MiddlewareComposite::new(Arc::new(OnErrorMiddleware {
134 handler: Arc::new(f),
135 }))
136 }
137
138 pub fn cost() -> MiddlewareComposite {
140 MiddlewareComposite::new(Arc::new(CostMiddleware {
141 tool_calls: std::sync::atomic::AtomicU64::new(0),
142 }))
143 }
144
145 pub fn rate_limit(rps: u32) -> MiddlewareComposite {
149 MiddlewareComposite::new(Arc::new(RateLimitMiddleware::new(rps)))
150 }
151
152 pub fn circuit_breaker(threshold: u32) -> MiddlewareComposite {
154 MiddlewareComposite::new(Arc::new(CircuitBreakerMiddleware {
155 threshold,
156 consecutive_failures: std::sync::atomic::AtomicU32::new(0),
157 }))
158 }
159
160 pub fn trace() -> MiddlewareComposite {
162 MiddlewareComposite::new(Arc::new(TraceMiddleware))
163 }
164
165 pub fn audit() -> MiddlewareComposite {
167 MiddlewareComposite::new(Arc::new(AuditMiddleware {
168 log: parking_lot::Mutex::new(Vec::new()),
169 }))
170 }
171
172 #[doc(hidden)]
178 pub fn scope(_names: &[&str], inner: MiddlewareComposite) -> MiddlewareComposite {
179 inner
180 }
181
182 pub fn structured_log() -> MiddlewareComposite {
184 MiddlewareComposite::new(Arc::new(StructuredLogMiddleware))
185 }
186
187 pub fn dispatch_log() -> MiddlewareComposite {
189 MiddlewareComposite::new(Arc::new(DispatchLogMiddleware))
190 }
191
192 pub fn topology_log() -> MiddlewareComposite {
194 MiddlewareComposite::new(Arc::new(TopologyLogMiddleware))
195 }
196
197 pub fn validate(
199 f: impl Fn(&FunctionCall) -> Result<(), String> + Send + Sync + 'static,
200 ) -> MiddlewareComposite {
201 MiddlewareComposite::new(Arc::new(ValidateMiddleware {
202 validator: Arc::new(f),
203 }))
204 }
205
206 #[doc(hidden)]
212 pub fn fallback_model(model: &str) -> MiddlewareComposite {
213 MiddlewareComposite::new(Arc::new(FallbackModelMiddleware {
214 model: model.to_string(),
215 }))
216 }
217
218 pub fn cache() -> MiddlewareComposite {
220 MiddlewareComposite::new(Arc::new(CacheMiddleware {
221 cache: parking_lot::Mutex::new(std::collections::HashMap::new()),
222 }))
223 }
224
225 pub fn dedup() -> MiddlewareComposite {
227 MiddlewareComposite::new(Arc::new(DedupMiddleware {
228 last_request_hash: parking_lot::Mutex::new(None),
229 }))
230 }
231
232 pub fn sample(rate: f64) -> MiddlewareComposite {
234 MiddlewareComposite::new(Arc::new(SampleMiddleware {
235 rate: rate.clamp(0.0, 1.0),
236 }))
237 }
238
239 pub fn metrics() -> MiddlewareComposite {
241 MiddlewareComposite::new(Arc::new(MetricsMiddleware {
242 request_count: std::sync::atomic::AtomicU64::new(0),
243 error_count: std::sync::atomic::AtomicU64::new(0),
244 }))
245 }
246
247 pub fn before_agent(
249 f: impl Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String>
250 + Send
251 + Sync
252 + 'static,
253 ) -> MiddlewareComposite {
254 MiddlewareComposite::new(Arc::new(BeforeAgentMiddleware {
255 handler: Arc::new(f),
256 }))
257 }
258
259 pub fn after_agent(
261 f: impl Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String>
262 + Send
263 + Sync
264 + 'static,
265 ) -> MiddlewareComposite {
266 MiddlewareComposite::new(Arc::new(AfterAgentMiddleware {
267 handler: Arc::new(f),
268 }))
269 }
270
271 pub fn before_model(
273 f: impl Fn(&gemini_adk_rs::llm::LlmRequest) -> Result<(), String> + Send + Sync + 'static,
274 ) -> MiddlewareComposite {
275 MiddlewareComposite::new(Arc::new(BeforeModelMiddleware {
276 handler: Arc::new(f),
277 }))
278 }
279
280 pub fn after_model(
282 f: impl Fn(
283 &gemini_adk_rs::llm::LlmRequest,
284 &gemini_adk_rs::llm::LlmResponse,
285 ) -> Result<(), String>
286 + Send
287 + Sync
288 + 'static,
289 ) -> MiddlewareComposite {
290 MiddlewareComposite::new(Arc::new(AfterModelMiddleware {
291 handler: Arc::new(f),
292 }))
293 }
294
295 pub fn on_loop(f: impl Fn(u32) + Send + Sync + 'static) -> MiddlewareComposite {
297 MiddlewareComposite::new(Arc::new(OnLoopMiddleware {
298 handler: Arc::new(f),
299 }))
300 }
301
302 pub fn on_timeout(f: impl Fn() + Send + Sync + 'static) -> MiddlewareComposite {
304 MiddlewareComposite::new(Arc::new(OnTimeoutMiddleware {
305 handler: Arc::new(f),
306 }))
307 }
308
309 pub fn on_route(f: impl Fn(&str) + Send + Sync + 'static) -> MiddlewareComposite {
311 MiddlewareComposite::new(Arc::new(OnRouteMiddleware {
312 handler: Arc::new(f),
313 }))
314 }
315
316 pub fn on_fallback(f: impl Fn(&str) + Send + Sync + 'static) -> MiddlewareComposite {
318 MiddlewareComposite::new(Arc::new(OnFallbackMiddleware {
319 handler: Arc::new(f),
320 }))
321 }
322}
323
324#[allow(dead_code)]
326struct TimeoutMiddleware {
327 name: String,
328 duration: Duration,
329}
330
331#[async_trait::async_trait]
332impl Middleware for TimeoutMiddleware {
333 fn name(&self) -> &str {
334 &self.name
335 }
336
337 fn timeout(&self) -> Option<Duration> {
338 Some(self.duration)
339 }
340}
341
342struct TapMiddleware {
345 #[allow(clippy::type_complexity)]
346 handler: Arc<dyn Fn(&AgentEvent) + Send + Sync>,
347}
348
349#[async_trait]
350impl Middleware for TapMiddleware {
351 fn name(&self) -> &str {
352 "tap"
353 }
354
355 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
356 (self.handler)(event);
357 Ok(())
358 }
359}
360
361struct BeforeToolMiddleware {
364 #[allow(clippy::type_complexity)]
365 handler: Arc<dyn Fn(&FunctionCall) -> Result<(), String> + Send + Sync>,
366}
367
368#[async_trait]
369impl Middleware for BeforeToolMiddleware {
370 fn name(&self) -> &str {
371 "before_tool"
372 }
373
374 async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
375 (self.handler)(call).map_err(AgentError::Other)
376 }
377}
378
379struct AfterToolMiddleware {
382 #[allow(clippy::type_complexity)]
383 handler: Arc<dyn Fn(&FunctionCall, &serde_json::Value) -> Result<(), String> + Send + Sync>,
384}
385
386#[async_trait]
387impl Middleware for AfterToolMiddleware {
388 fn name(&self) -> &str {
389 "after_tool"
390 }
391
392 async fn after_tool(
393 &self,
394 call: &FunctionCall,
395 result: &serde_json::Value,
396 ) -> Result<(), AgentError> {
397 (self.handler)(call, result).map_err(AgentError::Other)
398 }
399}
400
401struct OnErrorMiddleware {
404 #[allow(clippy::type_complexity)]
405 handler: Arc<dyn Fn(&AgentError) -> Result<(), String> + Send + Sync>,
406}
407
408#[async_trait]
409impl Middleware for OnErrorMiddleware {
410 fn name(&self) -> &str {
411 "on_error"
412 }
413
414 async fn on_error(&self, err: &AgentError) -> Result<(), AgentError> {
415 (self.handler)(err).map_err(AgentError::Other)
416 }
417}
418
419pub struct CostMiddleware {
423 tool_calls: std::sync::atomic::AtomicU64,
424}
425
426impl CostMiddleware {
427 pub fn tool_call_count(&self) -> u64 {
429 self.tool_calls.load(std::sync::atomic::Ordering::SeqCst)
430 }
431}
432
433#[async_trait]
434impl Middleware for CostMiddleware {
435 fn name(&self) -> &str {
436 "cost"
437 }
438
439 async fn after_tool(
440 &self,
441 _call: &FunctionCall,
442 _result: &serde_json::Value,
443 ) -> Result<(), AgentError> {
444 self.tool_calls
445 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
446 Ok(())
447 }
448}
449
450#[allow(dead_code)]
453struct RateLimitMiddleware {
454 min_interval: Duration,
456 last: parking_lot::Mutex<Option<std::time::Instant>>,
458}
459
460impl RateLimitMiddleware {
461 fn new(rps: u32) -> Self {
462 let rps = rps.max(1);
463 Self {
464 min_interval: Duration::from_secs_f64(1.0 / rps as f64),
465 last: parking_lot::Mutex::new(None),
466 }
467 }
468}
469
470#[async_trait]
471impl Middleware for RateLimitMiddleware {
472 fn name(&self) -> &str {
473 "rate_limit"
474 }
475
476 async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
477 let wait = {
481 let mut last = self.last.lock();
482 let now = std::time::Instant::now();
483 let scheduled = match *last {
484 Some(prev) if prev + self.min_interval > now => prev + self.min_interval,
485 _ => now,
486 };
487 *last = Some(scheduled);
488 scheduled.saturating_duration_since(now)
489 };
490 if !wait.is_zero() {
491 tokio::time::sleep(wait).await;
492 }
493 Ok(())
494 }
495}
496
497struct CircuitBreakerMiddleware {
500 threshold: u32,
501 consecutive_failures: std::sync::atomic::AtomicU32,
502}
503
504#[async_trait]
505impl Middleware for CircuitBreakerMiddleware {
506 fn name(&self) -> &str {
507 "circuit_breaker"
508 }
509
510 async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
511 let failures = self
512 .consecutive_failures
513 .load(std::sync::atomic::Ordering::SeqCst);
514 if failures >= self.threshold {
515 return Err(AgentError::Other(format!(
516 "Circuit breaker open: {} consecutive failures (threshold: {})",
517 failures, self.threshold
518 )));
519 }
520 Ok(())
521 }
522
523 async fn after_tool(
524 &self,
525 _call: &FunctionCall,
526 _result: &serde_json::Value,
527 ) -> Result<(), AgentError> {
528 self.consecutive_failures
529 .store(0, std::sync::atomic::Ordering::SeqCst);
530 Ok(())
531 }
532
533 async fn on_tool_error(
534 &self,
535 _call: &FunctionCall,
536 _err: &ToolError,
537 ) -> Result<(), AgentError> {
538 self.consecutive_failures
539 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
540 Ok(())
541 }
542}
543
544struct TraceMiddleware;
550
551#[async_trait]
552impl Middleware for TraceMiddleware {
553 fn name(&self) -> &str {
554 "trace"
555 }
556
557 async fn before_agent(
558 &self,
559 ctx: &gemini_adk_rs::context::InvocationContext,
560 ) -> Result<(), AgentError> {
561 let sid = ctx.session_id.as_deref().unwrap_or("unknown");
562 gemini_adk_rs::telemetry::logging::log_agent_started(sid, 0);
563 Ok(())
564 }
565
566 async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
567 gemini_adk_rs::telemetry::logging::log_tool_dispatch("fluent", &call.name, "function");
568 Ok(())
569 }
570
571 async fn after_tool(
572 &self,
573 call: &FunctionCall,
574 _result: &serde_json::Value,
575 ) -> Result<(), AgentError> {
576 gemini_adk_rs::telemetry::logging::log_tool_result("fluent", &call.name, true, 0.0);
577 Ok(())
578 }
579
580 async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
581 gemini_adk_rs::telemetry::logging::log_tool_result("fluent", &call.name, false, 0.0);
582 Ok(())
583 }
584
585 async fn on_error(&self, err: &AgentError) -> Result<(), AgentError> {
586 gemini_adk_rs::telemetry::logging::log_agent_error("fluent", &err.to_string());
587 Ok(())
588 }
589}
590
591pub struct AuditMiddleware {
595 log: parking_lot::Mutex<Vec<AuditEntry>>,
596}
597
598#[derive(Debug, Clone)]
600pub struct AuditEntry {
601 pub tool_name: String,
603 pub args: serde_json::Value,
605 pub success: Option<bool>,
607}
608
609impl AuditMiddleware {
610 pub fn entries(&self) -> Vec<AuditEntry> {
612 self.log.lock().clone()
613 }
614}
615
616#[async_trait]
617impl Middleware for AuditMiddleware {
618 fn name(&self) -> &str {
619 "audit"
620 }
621
622 async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
623 let mut log = self.log.lock();
624 if log.len() >= 10_000 {
625 log.drain(..1_000);
626 }
627 log.push(AuditEntry {
628 tool_name: call.name.clone(),
629 args: call.args.clone(),
630 success: None,
631 });
632 Ok(())
633 }
634
635 async fn after_tool(
636 &self,
637 call: &FunctionCall,
638 _result: &serde_json::Value,
639 ) -> Result<(), AgentError> {
640 let mut log = self.log.lock();
641 if let Some(entry) = log.iter_mut().rev().find(|e| e.tool_name == call.name) {
642 entry.success = Some(true);
643 }
644 Ok(())
645 }
646
647 async fn on_tool_error(&self, call: &FunctionCall, _err: &ToolError) -> Result<(), AgentError> {
648 let mut log = self.log.lock();
649 if let Some(entry) = log.iter_mut().rev().find(|e| e.tool_name == call.name) {
650 entry.success = Some(false);
651 }
652 Ok(())
653 }
654}
655
656struct ValidateMiddleware {
659 #[allow(clippy::type_complexity)]
660 validator: Arc<dyn Fn(&FunctionCall) -> Result<(), String> + Send + Sync>,
661}
662
663#[async_trait]
664impl Middleware for ValidateMiddleware {
665 fn name(&self) -> &str {
666 "validate"
667 }
668
669 async fn before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
670 (self.validator)(call).map_err(|e| AgentError::Tool(ToolError::InvalidArgs(e)))
671 }
672}
673
674#[allow(dead_code)]
678struct FallbackModelMiddleware {
679 model: String,
680}
681
682#[async_trait]
683impl Middleware for FallbackModelMiddleware {
684 fn name(&self) -> &str {
685 "fallback_model"
686 }
687
688 async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
689 Ok(())
691 }
692}
693
694pub struct CacheMiddleware {
698 cache: parking_lot::Mutex<std::collections::HashMap<u64, gemini_adk_rs::llm::LlmResponse>>,
699}
700
701impl CacheMiddleware {
702 pub fn len(&self) -> usize {
704 self.cache.lock().len()
705 }
706
707 pub fn is_empty(&self) -> bool {
709 self.cache.lock().is_empty()
710 }
711
712 pub fn clear(&self) {
714 self.cache.lock().clear();
715 }
716}
717
718#[async_trait]
719impl Middleware for CacheMiddleware {
720 fn name(&self) -> &str {
721 "cache"
722 }
723
724 async fn before_model(
725 &self,
726 request: &gemini_adk_rs::llm::LlmRequest,
727 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
728 use std::hash::{Hash, Hasher};
729 let mut hasher = std::collections::hash_map::DefaultHasher::new();
730 format!("{:?}", request).hash(&mut hasher);
731 let key = hasher.finish();
732 let cache = self.cache.lock();
733 Ok(cache.get(&key).cloned())
734 }
735
736 async fn after_model(
737 &self,
738 request: &gemini_adk_rs::llm::LlmRequest,
739 response: &gemini_adk_rs::llm::LlmResponse,
740 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
741 use std::hash::{Hash, Hasher};
742 let mut hasher = std::collections::hash_map::DefaultHasher::new();
743 format!("{:?}", request).hash(&mut hasher);
744 let key = hasher.finish();
745 self.cache.lock().insert(key, response.clone());
746 Ok(None) }
748}
749
750#[allow(dead_code)]
754struct DedupMiddleware {
755 last_request_hash: parking_lot::Mutex<Option<u64>>,
756}
757
758#[async_trait]
759impl Middleware for DedupMiddleware {
760 fn name(&self) -> &str {
761 "dedup"
762 }
763
764 async fn before_model(
765 &self,
766 request: &gemini_adk_rs::llm::LlmRequest,
767 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
768 use std::hash::{Hash, Hasher};
769 let mut hasher = std::collections::hash_map::DefaultHasher::new();
770 format!("{:?}", request).hash(&mut hasher);
771 let hash = hasher.finish();
772 let mut last = self.last_request_hash.lock();
773 if *last == Some(hash) {
774 return Err(AgentError::Other(
776 "Duplicate consecutive request".to_string(),
777 ));
778 }
779 *last = Some(hash);
780 Ok(None)
781 }
782}
783
784#[allow(dead_code)]
788struct SampleMiddleware {
789 rate: f64,
790}
791
792#[async_trait]
793impl Middleware for SampleMiddleware {
794 fn name(&self) -> &str {
795 "sample"
796 }
797
798 async fn before_model(
799 &self,
800 _request: &gemini_adk_rs::llm::LlmRequest,
801 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
802 use std::hash::{Hash, Hasher};
803 let mut hasher = std::collections::hash_map::DefaultHasher::new();
805 std::time::Instant::now().hash(&mut hasher);
806 let hash = hasher.finish();
807 let normalized = (hash as f64) / (u64::MAX as f64);
808 if normalized > self.rate {
809 return Err(AgentError::Other("Sampled out".to_string()));
810 }
811 Ok(None)
812 }
813}
814
815pub struct MetricsMiddleware {
819 request_count: std::sync::atomic::AtomicU64,
820 error_count: std::sync::atomic::AtomicU64,
821}
822
823impl MetricsMiddleware {
824 pub fn request_count(&self) -> u64 {
826 self.request_count.load(std::sync::atomic::Ordering::SeqCst)
827 }
828
829 pub fn error_count(&self) -> u64 {
831 self.error_count.load(std::sync::atomic::Ordering::SeqCst)
832 }
833}
834
835#[async_trait]
836impl Middleware for MetricsMiddleware {
837 fn name(&self) -> &str {
838 "metrics"
839 }
840
841 async fn before_agent(
842 &self,
843 _ctx: &gemini_adk_rs::context::InvocationContext,
844 ) -> Result<(), AgentError> {
845 self.request_count
846 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
847 Ok(())
848 }
849
850 async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
851 self.error_count
852 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
853 Ok(())
854 }
855}
856
857struct BeforeAgentMiddleware {
860 #[allow(clippy::type_complexity)]
861 handler:
862 Arc<dyn Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String> + Send + Sync>,
863}
864
865#[async_trait]
866impl Middleware for BeforeAgentMiddleware {
867 fn name(&self) -> &str {
868 "before_agent"
869 }
870
871 async fn before_agent(
872 &self,
873 ctx: &gemini_adk_rs::context::InvocationContext,
874 ) -> Result<(), AgentError> {
875 (self.handler)(ctx).map_err(AgentError::Other)
876 }
877}
878
879struct AfterAgentMiddleware {
882 #[allow(clippy::type_complexity)]
883 handler:
884 Arc<dyn Fn(&gemini_adk_rs::context::InvocationContext) -> Result<(), String> + Send + Sync>,
885}
886
887#[async_trait]
888impl Middleware for AfterAgentMiddleware {
889 fn name(&self) -> &str {
890 "after_agent"
891 }
892
893 async fn after_agent(
894 &self,
895 ctx: &gemini_adk_rs::context::InvocationContext,
896 ) -> Result<(), AgentError> {
897 (self.handler)(ctx).map_err(AgentError::Other)
898 }
899}
900
901struct BeforeModelMiddleware {
904 #[allow(clippy::type_complexity)]
905 handler: Arc<dyn Fn(&gemini_adk_rs::llm::LlmRequest) -> Result<(), String> + Send + Sync>,
906}
907
908#[async_trait]
909impl Middleware for BeforeModelMiddleware {
910 fn name(&self) -> &str {
911 "before_model"
912 }
913
914 async fn before_model(
915 &self,
916 request: &gemini_adk_rs::llm::LlmRequest,
917 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
918 (self.handler)(request).map_err(AgentError::Other)?;
919 Ok(None)
920 }
921}
922
923struct AfterModelMiddleware {
926 #[allow(clippy::type_complexity)]
927 handler: Arc<
928 dyn Fn(
929 &gemini_adk_rs::llm::LlmRequest,
930 &gemini_adk_rs::llm::LlmResponse,
931 ) -> Result<(), String>
932 + Send
933 + Sync,
934 >,
935}
936
937#[async_trait]
938impl Middleware for AfterModelMiddleware {
939 fn name(&self) -> &str {
940 "after_model"
941 }
942
943 async fn after_model(
944 &self,
945 request: &gemini_adk_rs::llm::LlmRequest,
946 response: &gemini_adk_rs::llm::LlmResponse,
947 ) -> Result<Option<gemini_adk_rs::llm::LlmResponse>, AgentError> {
948 (self.handler)(request, response).map_err(AgentError::Other)?;
949 Ok(None)
950 }
951}
952
953struct OnLoopMiddleware {
956 handler: Arc<dyn Fn(u32) + Send + Sync>,
957}
958
959#[async_trait]
960impl Middleware for OnLoopMiddleware {
961 fn name(&self) -> &str {
962 "on_loop"
963 }
964
965 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
966 if let AgentEvent::LoopIteration { iteration } = event {
967 (self.handler)(*iteration);
968 }
969 Ok(())
970 }
971}
972
973struct OnTimeoutMiddleware {
976 handler: Arc<dyn Fn() + Send + Sync>,
977}
978
979#[async_trait]
980impl Middleware for OnTimeoutMiddleware {
981 fn name(&self) -> &str {
982 "on_timeout"
983 }
984
985 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
986 if let AgentEvent::Timeout = event {
987 (self.handler)();
988 }
989 Ok(())
990 }
991}
992
993struct OnRouteMiddleware {
996 handler: Arc<dyn Fn(&str) + Send + Sync>,
997}
998
999#[async_trait]
1000impl Middleware for OnRouteMiddleware {
1001 fn name(&self) -> &str {
1002 "on_route"
1003 }
1004
1005 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
1006 if let AgentEvent::RouteSelected { agent_name } = event {
1007 (self.handler)(agent_name);
1008 }
1009 Ok(())
1010 }
1011}
1012
1013struct OnFallbackMiddleware {
1016 handler: Arc<dyn Fn(&str) + Send + Sync>,
1017}
1018
1019#[async_trait]
1020impl Middleware for OnFallbackMiddleware {
1021 fn name(&self) -> &str {
1022 "on_fallback"
1023 }
1024
1025 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
1026 if let AgentEvent::FallbackActivated { agent_name } = event {
1027 (self.handler)(agent_name);
1028 }
1029 Ok(())
1030 }
1031}
1032
1033struct StructuredLogMiddleware;
1036
1037#[async_trait]
1038impl Middleware for StructuredLogMiddleware {
1039 fn name(&self) -> &str {
1040 "structured_log"
1041 }
1042
1043 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
1044 let _ = event;
1046 Ok(())
1047 }
1048}
1049
1050struct DispatchLogMiddleware;
1053
1054#[async_trait]
1055impl Middleware for DispatchLogMiddleware {
1056 fn name(&self) -> &str {
1057 "dispatch_log"
1058 }
1059}
1060
1061struct TopologyLogMiddleware;
1064
1065#[async_trait]
1066impl Middleware for TopologyLogMiddleware {
1067 fn name(&self) -> &str {
1068 "topology_log"
1069 }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074 use super::*;
1075
1076 #[test]
1077 fn log_creates_composite() {
1078 let m = M::log();
1079 assert_eq!(m.len(), 1);
1080 }
1081
1082 #[test]
1083 fn latency_creates_composite() {
1084 let m = M::latency();
1085 assert_eq!(m.len(), 1);
1086 }
1087
1088 #[test]
1089 fn timeout_creates_composite() {
1090 let m = M::timeout(Duration::from_secs(30));
1091 assert_eq!(m.len(), 1);
1092 }
1093
1094 #[test]
1095 fn compose_with_bitor() {
1096 let m = M::log() | M::latency() | M::timeout(Duration::from_secs(5));
1097 assert_eq!(m.len(), 3);
1098 }
1099
1100 #[test]
1101 fn retry_creates_composite() {
1102 let m = M::retry(3);
1103 assert_eq!(m.len(), 1);
1104 }
1105
1106 #[test]
1107 fn tap_creates_composite() {
1108 let m = M::tap(|_event| {});
1109 assert_eq!(m.len(), 1);
1110 }
1111
1112 #[test]
1113 fn before_tool_creates_composite() {
1114 let m = M::before_tool(|_call| Ok(()));
1115 assert_eq!(m.len(), 1);
1116 }
1117
1118 #[test]
1119 fn cost_creates_composite() {
1120 let m = M::cost();
1121 assert_eq!(m.len(), 1);
1122 }
1123
1124 #[test]
1125 fn rate_limit_creates_composite() {
1126 let m = M::rate_limit(10);
1127 assert_eq!(m.len(), 1);
1128 }
1129
1130 #[test]
1131 fn circuit_breaker_creates_composite() {
1132 let m = M::circuit_breaker(5);
1133 assert_eq!(m.len(), 1);
1134 }
1135
1136 #[test]
1137 fn trace_creates_composite() {
1138 let m = M::trace();
1139 assert_eq!(m.len(), 1);
1140 }
1141
1142 #[test]
1143 fn audit_creates_composite() {
1144 let m = M::audit();
1145 assert_eq!(m.len(), 1);
1146 }
1147
1148 #[test]
1149 fn validate_creates_composite() {
1150 let m = M::validate(|_call| Ok(()));
1151 assert_eq!(m.len(), 1);
1152 }
1153
1154 #[test]
1155 fn fallback_model_creates_composite() {
1156 let m = M::fallback_model("gemini-1.5-flash");
1157 assert_eq!(m.len(), 1);
1158 }
1159
1160 #[test]
1161 fn cache_creates_composite() {
1162 let m = M::cache();
1163 assert_eq!(m.len(), 1);
1164 }
1165
1166 #[test]
1167 fn dedup_creates_composite() {
1168 let m = M::dedup();
1169 assert_eq!(m.len(), 1);
1170 }
1171
1172 #[test]
1173 fn sample_creates_composite() {
1174 let m = M::sample(0.5);
1175 assert_eq!(m.len(), 1);
1176 }
1177
1178 #[test]
1179 fn sample_clamps_rate() {
1180 let m = M::sample(2.0);
1181 assert_eq!(m.len(), 1);
1182 let m = M::sample(-1.0);
1183 assert_eq!(m.len(), 1);
1184 }
1185
1186 #[test]
1187 fn metrics_creates_composite() {
1188 let m = M::metrics();
1189 assert_eq!(m.len(), 1);
1190 }
1191
1192 #[test]
1193 fn before_agent_creates_composite() {
1194 let m = M::before_agent(|_ctx| Ok(()));
1195 assert_eq!(m.len(), 1);
1196 }
1197
1198 #[test]
1199 fn after_agent_creates_composite() {
1200 let m = M::after_agent(|_ctx| Ok(()));
1201 assert_eq!(m.len(), 1);
1202 }
1203
1204 #[test]
1205 fn before_model_creates_composite() {
1206 let m = M::before_model(|_req| Ok(()));
1207 assert_eq!(m.len(), 1);
1208 }
1209
1210 #[test]
1211 fn after_model_creates_composite() {
1212 let m = M::after_model(|_req, _resp| Ok(()));
1213 assert_eq!(m.len(), 1);
1214 }
1215
1216 #[test]
1217 fn on_loop_creates_composite() {
1218 let m = M::on_loop(|_iteration| {});
1219 assert_eq!(m.len(), 1);
1220 }
1221
1222 #[test]
1223 fn on_timeout_creates_composite() {
1224 let m = M::on_timeout(|| {});
1225 assert_eq!(m.len(), 1);
1226 }
1227
1228 #[test]
1229 fn on_route_creates_composite() {
1230 let m = M::on_route(|_name| {});
1231 assert_eq!(m.len(), 1);
1232 }
1233
1234 #[test]
1235 fn on_fallback_creates_composite() {
1236 let m = M::on_fallback(|_name| {});
1237 assert_eq!(m.len(), 1);
1238 }
1239
1240 #[test]
1241 fn compose_all_middleware() {
1242 let m = M::log()
1243 | M::latency()
1244 | M::timeout(Duration::from_secs(30))
1245 | M::retry(3)
1246 | M::cost()
1247 | M::rate_limit(10)
1248 | M::circuit_breaker(5)
1249 | M::trace()
1250 | M::audit()
1251 | M::validate(|_| Ok(()))
1252 | M::fallback_model("gemini-1.5-flash")
1253 | M::cache()
1254 | M::dedup()
1255 | M::sample(0.5)
1256 | M::metrics()
1257 | M::before_agent(|_| Ok(()))
1258 | M::after_agent(|_| Ok(()))
1259 | M::before_model(|_| Ok(()))
1260 | M::after_model(|_, _| Ok(()))
1261 | M::on_loop(|_| {})
1262 | M::on_timeout(|| {})
1263 | M::on_route(|_| {})
1264 | M::on_fallback(|_| {});
1265 assert_eq!(m.len(), 23);
1266 }
1267}