gemini_adk_rs/middleware/
mod.rs1pub mod latency;
4pub mod log;
5pub mod retry;
6
7pub use latency::*;
8pub use log::*;
9pub use retry::*;
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use gemini_genai_rs::prelude::FunctionCall;
16
17use crate::context::AgentEvent;
18use crate::context::InvocationContext;
19use crate::error::{AgentError, ToolError};
20use crate::llm::{LlmRequest, LlmResponse};
21
22#[async_trait]
45pub trait Middleware: Send + Sync + 'static {
46 fn name(&self) -> &str;
48
49 async fn before_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
51 Ok(())
52 }
53 async fn after_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
55 Ok(())
56 }
57
58 async fn before_tool(&self, _call: &FunctionCall) -> Result<(), AgentError> {
60 Ok(())
61 }
62 async fn after_tool(
64 &self,
65 _call: &FunctionCall,
66 _result: &serde_json::Value,
67 ) -> Result<(), AgentError> {
68 Ok(())
69 }
70 async fn on_tool_error(
72 &self,
73 _call: &FunctionCall,
74 _err: &ToolError,
75 ) -> Result<(), AgentError> {
76 Ok(())
77 }
78
79 async fn on_event(&self, _event: &AgentEvent) -> Result<(), AgentError> {
81 Ok(())
82 }
83
84 async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
86 Ok(())
87 }
88
89 async fn before_model(&self, _request: &LlmRequest) -> Result<Option<LlmResponse>, AgentError> {
92 Ok(None)
93 }
94
95 async fn after_model(
98 &self,
99 _request: &LlmRequest,
100 _response: &LlmResponse,
101 ) -> Result<Option<LlmResponse>, AgentError> {
102 Ok(None)
103 }
104
105 async fn transform_request(&self, _request: &mut LlmRequest) -> Result<(), AgentError> {
110 Ok(())
111 }
112
113 fn timeout(&self) -> Option<std::time::Duration> {
117 None
118 }
119}
120
121#[derive(Clone, Default)]
123pub struct MiddlewareChain {
124 layers: Vec<Arc<dyn Middleware>>,
125}
126
127impl MiddlewareChain {
128 pub fn new() -> Self {
130 Self::default()
131 }
132
133 pub fn add(&mut self, middleware: Arc<dyn Middleware>) {
135 self.layers.push(middleware);
136 }
137
138 pub fn prepend(&mut self, middleware: Arc<dyn Middleware>) {
140 self.layers.insert(0, middleware);
141 }
142
143 pub async fn run_before_agent(&self, ctx: &InvocationContext) -> Result<(), AgentError> {
145 for m in &self.layers {
146 m.before_agent(ctx).await?;
147 }
148 Ok(())
149 }
150
151 pub async fn run_after_agent(&self, ctx: &InvocationContext) -> Result<(), AgentError> {
153 for m in self.layers.iter().rev() {
154 m.after_agent(ctx).await?;
155 }
156 Ok(())
157 }
158
159 pub async fn run_before_tool(&self, call: &FunctionCall) -> Result<(), AgentError> {
161 for m in &self.layers {
162 m.before_tool(call).await?;
163 }
164 Ok(())
165 }
166
167 pub async fn run_after_tool(
169 &self,
170 call: &FunctionCall,
171 result: &serde_json::Value,
172 ) -> Result<(), AgentError> {
173 for m in self.layers.iter().rev() {
174 m.after_tool(call, result).await?;
175 }
176 Ok(())
177 }
178
179 pub async fn run_on_tool_error(
181 &self,
182 call: &FunctionCall,
183 err: &ToolError,
184 ) -> Result<(), AgentError> {
185 for m in &self.layers {
186 m.on_tool_error(call, err).await?;
187 }
188 Ok(())
189 }
190
191 pub async fn run_on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
193 for m in &self.layers {
194 m.on_event(event).await?;
195 }
196 Ok(())
197 }
198
199 pub async fn run_on_error(&self, err: &AgentError) -> Result<(), AgentError> {
201 for m in &self.layers {
202 m.on_error(err).await?;
203 }
204 Ok(())
205 }
206
207 pub async fn run_transform_request(&self, request: &mut LlmRequest) -> Result<(), AgentError> {
209 for m in &self.layers {
210 m.transform_request(request).await?;
211 }
212 Ok(())
213 }
214
215 pub async fn run_before_model(
217 &self,
218 request: &LlmRequest,
219 ) -> Result<Option<LlmResponse>, AgentError> {
220 for m in &self.layers {
221 if let Some(response) = m.before_model(request).await? {
222 return Ok(Some(response));
223 }
224 }
225 Ok(None)
226 }
227
228 pub async fn run_after_model(
230 &self,
231 request: &LlmRequest,
232 response: &LlmResponse,
233 ) -> Result<Option<LlmResponse>, AgentError> {
234 for m in self.layers.iter().rev() {
235 if let Some(replacement) = m.after_model(request, response).await? {
236 return Ok(Some(replacement));
237 }
238 }
239 Ok(None)
240 }
241
242 pub fn timeout(&self) -> Option<std::time::Duration> {
244 self.layers.iter().filter_map(|m| m.timeout()).min()
245 }
246
247 pub fn is_empty(&self) -> bool {
249 self.layers.is_empty()
250 }
251
252 pub fn len(&self) -> usize {
254 self.layers.len()
255 }
256}
257
258#[cfg(test)]
261mod tests {
262 use super::*;
263 use std::time::Duration;
264
265 fn test_call(name: &str) -> FunctionCall {
267 FunctionCall {
268 name: name.to_string(),
269 args: serde_json::json!({"key": "value"}),
270 id: None,
271 }
272 }
273
274 struct CountingMiddleware {
277 call_count: Arc<std::sync::atomic::AtomicU32>,
278 }
279
280 #[async_trait]
281 impl Middleware for CountingMiddleware {
282 fn name(&self) -> &str {
283 "counter"
284 }
285
286 async fn before_agent(&self, _ctx: &InvocationContext) -> Result<(), AgentError> {
287 self.call_count
288 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
289 Ok(())
290 }
291 }
292
293 #[test]
294 fn middleware_chain_ordering() {
295 let chain = MiddlewareChain::new();
296 assert!(chain.is_empty());
297 assert_eq!(chain.len(), 0);
298 }
299
300 #[test]
301 fn middleware_is_object_safe() {
302 fn _assert(_: &dyn Middleware) {}
303 }
304
305 #[test]
306 fn add_middleware_to_chain() {
307 let mut chain = MiddlewareChain::new();
308 let counter = Arc::new(CountingMiddleware {
309 call_count: Arc::new(std::sync::atomic::AtomicU32::new(0)),
310 });
311 chain.add(counter);
312 assert_eq!(chain.len(), 1);
313 assert!(!chain.is_empty());
314 }
315
316 #[test]
317 fn chain_is_clone() {
318 let mut chain = MiddlewareChain::new();
319 chain.add(Arc::new(LogMiddleware::new()));
320 let chain2 = chain.clone();
321 assert_eq!(chain2.len(), 1);
322 }
323
324 #[test]
325 fn log_middleware_defaults() {
326 let log = LogMiddleware::new();
327 assert_eq!(log.name(), "log");
328 }
329
330 #[test]
331 fn latency_middleware_defaults() {
332 let lat = LatencyMiddleware::new();
333 assert_eq!(lat.name(), "latency");
334 }
335
336 #[tokio::test]
339 async fn logging_middleware_doesnt_panic() {
340 let log = LogMiddleware::new();
341 let call = test_call("my_tool");
342 let result = serde_json::json!({"ok": true});
343 let tool_err = ToolError::ExecutionFailed("boom".to_string());
344 let agent_err = AgentError::Other("oops".to_string());
345
346 assert!(log.before_tool(&call).await.is_ok());
348 assert!(log.after_tool(&call, &result).await.is_ok());
349 assert!(log.on_tool_error(&call, &tool_err).await.is_ok());
350 assert!(log.on_error(&agent_err).await.is_ok());
351 }
352
353 #[tokio::test]
356 async fn latency_middleware_records_timing() {
357 let lat = LatencyMiddleware::new();
358 let call = test_call("slow_tool");
359 let result = serde_json::json!("done");
360
361 lat.before_tool(&call).await.unwrap();
363 tokio::time::sleep(Duration::from_millis(5)).await;
365 lat.after_tool(&call, &result).await.unwrap();
366
367 let records = lat.tool_latencies();
368 assert_eq!(records.len(), 1);
369 assert_eq!(records[0].name, "slow_tool");
370 assert!(records[0].success);
371 assert!(records[0].elapsed >= Duration::from_millis(1));
372 }
373
374 #[tokio::test]
375 async fn latency_middleware_records_failure() {
376 let lat = LatencyMiddleware::new();
377 let call = test_call("failing_tool");
378 let err = ToolError::ExecutionFailed("kaboom".to_string());
379
380 lat.before_tool(&call).await.unwrap();
381 lat.on_tool_error(&call, &err).await.unwrap();
382
383 let records = lat.tool_latencies();
384 assert_eq!(records.len(), 1);
385 assert_eq!(records[0].name, "failing_tool");
386 assert!(!records[0].success);
387 }
388
389 #[tokio::test]
390 async fn latency_middleware_clear() {
391 let lat = LatencyMiddleware::new();
392 let call = test_call("tool_a");
393 let result = serde_json::json!(null);
394
395 lat.before_tool(&call).await.unwrap();
396 lat.after_tool(&call, &result).await.unwrap();
397 assert_eq!(lat.tool_latencies().len(), 1);
398
399 lat.clear();
400 assert!(lat.tool_latencies().is_empty());
401 }
402
403 #[tokio::test]
406 async fn retry_middleware_tracks_retries() {
407 let retry = RetryMiddleware::new(3);
408 assert_eq!(retry.max_retries(), 3);
409 assert_eq!(retry.attempts(), 0);
410 assert!(!retry.should_retry(), "no error yet, should not retry");
411
412 let err = AgentError::Other("transient".to_string());
414 retry.on_error(&err).await.unwrap();
415 assert!(retry.should_retry(), "error recorded, should retry");
416
417 retry.record_attempt();
419 assert_eq!(retry.attempts(), 1);
420 assert!(!retry.should_retry(), "error was cleared by record_attempt");
421
422 retry.on_error(&err).await.unwrap();
424 assert!(retry.should_retry());
425 retry.record_attempt();
426 assert_eq!(retry.attempts(), 2);
427
428 retry.on_error(&err).await.unwrap();
430 assert!(retry.should_retry());
431 retry.record_attempt();
432 assert_eq!(retry.attempts(), 3);
433
434 retry.on_error(&err).await.unwrap();
436 assert!(!retry.should_retry(), "at max retries, should not retry");
437 }
438
439 #[test]
440 fn retry_middleware_reset() {
441 let retry = RetryMiddleware::new(2);
442 retry
443 .error_count
444 .store(1, std::sync::atomic::Ordering::SeqCst);
445 retry.attempt.store(1, std::sync::atomic::Ordering::SeqCst);
446 retry.reset();
447 assert_eq!(retry.attempts(), 0);
448 assert!(!retry.should_retry());
449 }
450
451 #[test]
454 fn chain_with_all_builtin_middleware() {
455 let mut chain = MiddlewareChain::new();
456 chain.add(Arc::new(LogMiddleware::new()));
457 chain.add(Arc::new(LatencyMiddleware::new()));
458 chain.add(Arc::new(RetryMiddleware::new(3)));
459 assert_eq!(chain.len(), 3);
460 }
461}