gemini_adk_rs/text/
loop_agent.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::context::AgentEvent;
7use crate::error::AgentError;
8use crate::middleware::MiddlewareChain;
9use crate::state::State;
10
11type UntilFn = Arc<dyn Fn(&State) -> bool + Send + Sync>;
13
14pub struct LoopTextAgent {
16 name: String,
17 body: Arc<dyn TextAgent>,
18 max: u32,
19 until: Option<UntilFn>,
20 middleware: MiddlewareChain,
21}
22
23impl LoopTextAgent {
24 pub fn new(name: impl Into<String>, body: Arc<dyn TextAgent>, max: u32) -> Self {
26 Self {
27 name: name.into(),
28 body,
29 max,
30 until: None,
31 middleware: MiddlewareChain::new(),
32 }
33 }
34
35 pub fn until(mut self, pred: impl Fn(&State) -> bool + Send + Sync + 'static) -> Self {
37 self.until = Some(Arc::new(pred));
38 self
39 }
40
41 pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
44 self.middleware = chain;
45 self
46 }
47}
48
49#[async_trait]
50impl TextAgent for LoopTextAgent {
51 fn name(&self) -> &str {
52 &self.name
53 }
54
55 async fn run(&self, state: &State) -> Result<String, AgentError> {
56 let mut last_output = String::new();
57
58 for iter in 0..self.max {
59 let _ = self
60 .middleware
61 .run_on_event(&AgentEvent::LoopIteration { iteration: iter })
62 .await;
63
64 last_output = self.body.run(state).await?;
65
66 if let Some(pred) = &self.until {
67 if pred(state) {
68 break;
69 }
70 }
71 }
72
73 Ok(last_output)
74 }
75}