gemini_adk_rs/middleware/retry.rs
1//! Advisory retry middleware for agent error tracking.
2
3use async_trait::async_trait;
4
5use super::Middleware;
6use crate::error::AgentError;
7
8/// Advisory middleware that tracks errors and provides retry guidance.
9///
10/// The `Middleware` trait hooks are lifecycle callbacks, not control-flow points.
11/// `RetryMiddleware` counts errors via [`Middleware::on_error`] and exposes a
12/// [`RetryMiddleware::should_retry`] method the caller can query to decide
13/// whether to re-invoke the agent.
14///
15/// # Examples
16///
17/// ```rust,ignore
18/// use std::sync::Arc;
19/// use gemini_adk_rs::middleware::RetryMiddleware;
20///
21/// let retry = Arc::new(RetryMiddleware::new(3));
22/// // ... run agent, on_error is called by the middleware chain ...
23/// while retry.should_retry() {
24/// retry.record_attempt();
25/// // re-run the agent
26/// }
27/// ```
28pub struct RetryMiddleware {
29 max_retries: u32,
30 pub(crate) error_count: std::sync::atomic::AtomicU32,
31 pub(crate) attempt: std::sync::atomic::AtomicU32,
32}
33
34impl RetryMiddleware {
35 /// Create a new retry middleware with the given maximum retry count.
36 pub fn new(max_retries: u32) -> Self {
37 Self {
38 max_retries,
39 error_count: std::sync::atomic::AtomicU32::new(0),
40 attempt: std::sync::atomic::AtomicU32::new(0),
41 }
42 }
43
44 /// Returns `true` if the number of attempts is below `max_retries`
45 /// and at least one error has been recorded since the last reset.
46 pub fn should_retry(&self) -> bool {
47 let attempts = self.attempt.load(std::sync::atomic::Ordering::SeqCst);
48 let errors = self.error_count.load(std::sync::atomic::Ordering::SeqCst);
49 errors > 0 && attempts < self.max_retries
50 }
51
52 /// Record that a retry attempt is being made.
53 /// Call this before re-invoking the agent.
54 pub fn record_attempt(&self) {
55 self.attempt
56 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
57 // Reset the error flag so we wait for a new error before retrying again.
58 self.error_count
59 .store(0, std::sync::atomic::Ordering::SeqCst);
60 }
61
62 /// Returns the current attempt count (0-based).
63 pub fn attempts(&self) -> u32 {
64 self.attempt.load(std::sync::atomic::Ordering::SeqCst)
65 }
66
67 /// Returns the configured maximum number of retries.
68 pub fn max_retries(&self) -> u32 {
69 self.max_retries
70 }
71
72 /// Reset all counters, allowing the middleware to be reused.
73 pub fn reset(&self) {
74 self.error_count
75 .store(0, std::sync::atomic::Ordering::SeqCst);
76 self.attempt.store(0, std::sync::atomic::Ordering::SeqCst);
77 }
78}
79
80#[async_trait]
81impl Middleware for RetryMiddleware {
82 fn name(&self) -> &str {
83 "retry"
84 }
85
86 async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
87 self.error_count
88 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
89 Ok(())
90 }
91}