gemini_adk_rs/middleware/
retry.rs

1//! Advisory retry middleware for agent error tracking.
2
3use async_trait::async_trait;
4
5use super::Middleware;
6use crate::error::AgentError;
7
8/// Advisory middleware that tracks errors and provides retry guidance.
9///
10/// The `Middleware` trait hooks are lifecycle callbacks, not control-flow points.
11/// `RetryMiddleware` counts errors via [`Middleware::on_error`] and exposes a
12/// [`RetryMiddleware::should_retry`] method the caller can query to decide
13/// whether to re-invoke the agent.
14///
15/// # Examples
16///
17/// ```rust,ignore
18/// use std::sync::Arc;
19/// use gemini_adk_rs::middleware::RetryMiddleware;
20///
21/// let retry = Arc::new(RetryMiddleware::new(3));
22/// // ... run agent, on_error is called by the middleware chain ...
23/// while retry.should_retry() {
24///     retry.record_attempt();
25///     // re-run the agent
26/// }
27/// ```
28pub struct RetryMiddleware {
29    max_retries: u32,
30    pub(crate) error_count: std::sync::atomic::AtomicU32,
31    pub(crate) attempt: std::sync::atomic::AtomicU32,
32}
33
34impl RetryMiddleware {
35    /// Create a new retry middleware with the given maximum retry count.
36    pub fn new(max_retries: u32) -> Self {
37        Self {
38            max_retries,
39            error_count: std::sync::atomic::AtomicU32::new(0),
40            attempt: std::sync::atomic::AtomicU32::new(0),
41        }
42    }
43
44    /// Returns `true` if the number of attempts is below `max_retries`
45    /// and at least one error has been recorded since the last reset.
46    pub fn should_retry(&self) -> bool {
47        let attempts = self.attempt.load(std::sync::atomic::Ordering::SeqCst);
48        let errors = self.error_count.load(std::sync::atomic::Ordering::SeqCst);
49        errors > 0 && attempts < self.max_retries
50    }
51
52    /// Record that a retry attempt is being made.
53    /// Call this before re-invoking the agent.
54    pub fn record_attempt(&self) {
55        self.attempt
56            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
57        // Reset the error flag so we wait for a new error before retrying again.
58        self.error_count
59            .store(0, std::sync::atomic::Ordering::SeqCst);
60    }
61
62    /// Returns the current attempt count (0-based).
63    pub fn attempts(&self) -> u32 {
64        self.attempt.load(std::sync::atomic::Ordering::SeqCst)
65    }
66
67    /// Returns the configured maximum number of retries.
68    pub fn max_retries(&self) -> u32 {
69        self.max_retries
70    }
71
72    /// Reset all counters, allowing the middleware to be reused.
73    pub fn reset(&self) {
74        self.error_count
75            .store(0, std::sync::atomic::Ordering::SeqCst);
76        self.attempt.store(0, std::sync::atomic::Ordering::SeqCst);
77    }
78}
79
80#[async_trait]
81impl Middleware for RetryMiddleware {
82    fn name(&self) -> &str {
83        "retry"
84    }
85
86    async fn on_error(&self, _err: &AgentError) -> Result<(), AgentError> {
87        self.error_count
88            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
89        Ok(())
90    }
91}