gemini_adk_rs/text/
fallback.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::context::AgentEvent;
7use crate::error::AgentError;
8use crate::middleware::MiddlewareChain;
9use crate::state::State;
10
11/// Tries each child agent in sequence. Returns the first successful result.
12/// If all fail, returns the last error.
13pub struct FallbackTextAgent {
14    name: String,
15    candidates: Vec<Arc<dyn TextAgent>>,
16    middleware: MiddlewareChain,
17}
18
19impl FallbackTextAgent {
20    /// Create a new fallback agent that tries candidates in order.
21    pub fn new(name: impl Into<String>, candidates: Vec<Arc<dyn TextAgent>>) -> Self {
22        Self {
23            name: name.into(),
24            candidates,
25            middleware: MiddlewareChain::new(),
26        }
27    }
28
29    /// Attach a middleware chain. `AgentEvent::FallbackActivated` is emitted
30    /// through it when a fallback branch (any candidate after the first) is
31    /// tried, so `on_event` observers (`M::on_fallback`) fire.
32    pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
33        self.middleware = chain;
34        self
35    }
36}
37
38#[async_trait]
39impl TextAgent for FallbackTextAgent {
40    fn name(&self) -> &str {
41        &self.name
42    }
43
44    async fn run(&self, state: &State) -> Result<String, AgentError> {
45        let mut last_err = AgentError::Other("No candidates in fallback".into());
46
47        for (i, candidate) in self.candidates.iter().enumerate() {
48            // The first candidate is the primary; subsequent ones are fallbacks.
49            if i > 0 {
50                let _ = self
51                    .middleware
52                    .run_on_event(&AgentEvent::FallbackActivated {
53                        agent_name: candidate.name().to_string(),
54                    })
55                    .await;
56            }
57            match candidate.run(state).await {
58                Ok(result) => return Ok(result),
59                Err(e) => last_err = e,
60            }
61        }
62
63        Err(last_err)
64    }
65}