gemini_adk_rs/text/
loop_agent.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::context::AgentEvent;
7use crate::error::AgentError;
8use crate::middleware::MiddlewareChain;
9use crate::state::State;
10
11/// Loop-termination predicate over shared state.
12type UntilFn = Arc<dyn Fn(&State) -> bool + Send + Sync>;
13
14/// Runs a text agent repeatedly until max iterations or a state predicate.
15pub struct LoopTextAgent {
16    name: String,
17    body: Arc<dyn TextAgent>,
18    max: u32,
19    until: Option<UntilFn>,
20    middleware: MiddlewareChain,
21}
22
23impl LoopTextAgent {
24    /// Create a new loop agent that repeats up to `max` iterations.
25    pub fn new(name: impl Into<String>, body: Arc<dyn TextAgent>, max: u32) -> Self {
26        Self {
27            name: name.into(),
28            body,
29            max,
30            until: None,
31            middleware: MiddlewareChain::new(),
32        }
33    }
34
35    /// Add a predicate — loop breaks when predicate returns true.
36    pub fn until(mut self, pred: impl Fn(&State) -> bool + Send + Sync + 'static) -> Self {
37        self.until = Some(Arc::new(pred));
38        self
39    }
40
41    /// Attach a middleware chain. `AgentEvent::LoopIteration` is emitted through
42    /// it on every iteration, so `on_event` observers (e.g. `M::on_loop`) fire.
43    pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
44        self.middleware = chain;
45        self
46    }
47}
48
49#[async_trait]
50impl TextAgent for LoopTextAgent {
51    fn name(&self) -> &str {
52        &self.name
53    }
54
55    async fn run(&self, state: &State) -> Result<String, AgentError> {
56        let mut last_output = String::new();
57
58        for iter in 0..self.max {
59            let _ = self
60                .middleware
61                .run_on_event(&AgentEvent::LoopIteration { iteration: iter })
62                .await;
63
64            last_output = self.body.run(state).await?;
65
66            if let Some(pred) = &self.until {
67                if pred(state) {
68                    break;
69                }
70            }
71        }
72
73        Ok(last_output)
74    }
75}