gemini_adk_rs/text/
llm.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use gemini_genai_rs::prelude::{Content, FunctionCall, FunctionResponse, Part, Role};
5
6use super::TextAgent;
7use crate::context::AgentEvent;
8use crate::error::AgentError;
9use crate::llm::{BaseLlm, LlmRequest};
10use crate::middleware::MiddlewareChain;
11use crate::state::State;
12use crate::tool::ToolDispatcher;
13
14/// Maximum number of tool-dispatch round-trips before giving up.
15const MAX_TOOL_ROUNDS: usize = 10;
16
17/// Core text agent — calls `BaseLlm::generate()`, dispatches tools, loops
18/// until the model produces a final text response.
19///
20/// Middleware hooks fire at each lifecycle point:
21///
22/// - `before_model` / `after_model` — wraps each `BaseLlm::generate()` call;
23///   `before_model` may return a cached response to skip the LLM entirely.
24/// - `before_tool` / `after_tool` / `on_tool_error` — wraps each tool dispatch.
25/// - `on_error` — called when `run()` is about to return an error.
26///
27/// Note: `before_agent`/`after_agent` are Live-session hooks that require an
28/// `InvocationContext` (a Live WebSocket concept) and are therefore not invoked
29/// by `LlmTextAgent`.  Use `before_model` or wrap in a custom `TextAgent` if you
30/// need entry/exit hooks for the text path.
31pub struct LlmTextAgent {
32    name: String,
33    llm: Arc<dyn BaseLlm>,
34    instruction: Option<String>,
35    dispatcher: Option<Arc<ToolDispatcher>>,
36    temperature: Option<f32>,
37    max_output_tokens: Option<u32>,
38    middleware: MiddlewareChain,
39}
40
41impl LlmTextAgent {
42    /// Create a new LLM text agent.
43    pub fn new(name: impl Into<String>, llm: Arc<dyn BaseLlm>) -> Self {
44        Self {
45            name: name.into(),
46            llm,
47            instruction: None,
48            dispatcher: None,
49            temperature: None,
50            max_output_tokens: None,
51            middleware: MiddlewareChain::new(),
52        }
53    }
54
55    /// Set the system instruction.
56    pub fn instruction(mut self, inst: impl Into<String>) -> Self {
57        self.instruction = Some(inst.into());
58        self
59    }
60
61    /// Set the tool dispatcher.
62    pub fn tools(mut self, dispatcher: Arc<ToolDispatcher>) -> Self {
63        self.dispatcher = Some(dispatcher);
64        self
65    }
66
67    /// Set temperature.
68    pub fn temperature(mut self, t: f32) -> Self {
69        self.temperature = Some(t);
70        self
71    }
72
73    /// Set max output tokens.
74    pub fn max_output_tokens(mut self, n: u32) -> Self {
75        self.max_output_tokens = Some(n);
76        self
77    }
78
79    /// Append a middleware layer to the chain.
80    ///
81    /// Layers are run in insertion order for `before_*` / `on_error` hooks
82    /// and in reverse insertion order for `after_*` hooks (outermost last).
83    pub fn add_middleware(mut self, mw: Arc<dyn crate::middleware::Middleware>) -> Self {
84        self.middleware.add(mw);
85        self
86    }
87
88    /// Replace the entire middleware chain (advanced — prefer `add_middleware`).
89    pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
90        self.middleware = chain;
91        self
92    }
93
94    /// Build an LlmRequest, taking ownership of contents to avoid cloning.
95    fn build_request(&self, contents: Vec<Content>) -> LlmRequest {
96        let mut req = LlmRequest::from_contents(contents);
97        req.system_instruction = self.instruction.clone();
98        req.temperature = self.temperature;
99        req.max_output_tokens = self.max_output_tokens;
100
101        if let Some(dispatcher) = &self.dispatcher {
102            req.tools = dispatcher.to_tool_declarations();
103        }
104
105        req
106    }
107
108    /// Dispatch function calls and return function responses, firing middleware hooks.
109    async fn dispatch_tools(&self, calls: &[FunctionCall]) -> Vec<FunctionResponse> {
110        let dispatcher = match &self.dispatcher {
111            Some(d) => d,
112            None => return Vec::new(),
113        };
114
115        let mut responses = Vec::with_capacity(calls.len());
116        for call in calls {
117            // before_tool hook
118            if let Err(e) = self.middleware.run_before_tool(call).await {
119                // Hook error — record it and return an error response.
120                let _ = self
121                    .middleware
122                    .run_on_tool_error(
123                        call,
124                        &crate::error::ToolError::ExecutionFailed(e.to_string()),
125                    )
126                    .await;
127                responses.push(ToolDispatcher::build_response(
128                    call,
129                    Err(crate::error::ToolError::ExecutionFailed(e.to_string())),
130                ));
131                continue;
132            }
133
134            let result = dispatcher
135                .call_function(&call.name, call.args.clone())
136                .await;
137
138            match &result {
139                Ok(value) => {
140                    let _ = self.middleware.run_after_tool(call, value).await;
141                }
142                Err(e) => {
143                    let _ = self.middleware.run_on_tool_error(call, e).await;
144                }
145            }
146
147            responses.push(ToolDispatcher::build_response(call, result));
148        }
149        responses
150    }
151}
152
153#[async_trait]
154impl TextAgent for LlmTextAgent {
155    fn name(&self) -> &str {
156        &self.name
157    }
158
159    async fn run(&self, state: &State) -> Result<String, AgentError> {
160        // Build initial contents from state "input" key, or empty user message.
161        let input = state.get::<String>("input").unwrap_or_default();
162
163        let mut contents = vec![Content::user(&input)];
164
165        // Lifecycle event — makes `on_event` (e.g. M::tap) observe agent start.
166        let _ = self
167            .middleware
168            .run_on_event(&AgentEvent::AgentStarted {
169                name: self.name.clone(),
170            })
171            .await;
172
173        // Enforce the tightest middleware timeout (M::timeout) over the whole run.
174        let result = match self.middleware.timeout() {
175            Some(limit) => match tokio::time::timeout(limit, self.run_inner(&mut contents)).await {
176                Ok(r) => r,
177                Err(_) => {
178                    let _ = self.middleware.run_on_event(&AgentEvent::Timeout).await;
179                    Err(AgentError::Other(format!(
180                        "agent '{}' timed out after {:?}",
181                        self.name, limit
182                    )))
183                }
184            },
185            None => self.run_inner(&mut contents).await,
186        };
187
188        if let Err(ref e) = result {
189            let _ = self.middleware.run_on_error(e).await;
190        } else if let Ok(ref text) = result {
191            let _ = state.set("output", text);
192            let _ = self
193                .middleware
194                .run_on_event(&AgentEvent::AgentCompleted {
195                    name: self.name.clone(),
196                })
197                .await;
198        }
199
200        result
201    }
202}
203
204impl LlmTextAgent {
205    /// Inner execution loop — separated so `on_error` fires exactly once.
206    async fn run_inner(&self, contents: &mut Vec<Content>) -> Result<String, AgentError> {
207        for _round in 0..MAX_TOOL_ROUNDS {
208            let mut request = self.build_request(contents.clone());
209
210            // transform_request hook — may rewrite the request (e.g. context
211            // policies trimming conversation history) before it is sent.
212            self.middleware.run_transform_request(&mut request).await?;
213
214            // before_model hook — may short-circuit with a cached response.
215            let response = match self.middleware.run_before_model(&request).await? {
216                Some(cached) => cached,
217                None => {
218                    let llm_response = self
219                        .llm
220                        .generate(request.clone())
221                        .await
222                        .map_err(|e| AgentError::Other(format!("LLM error: {e}")))?;
223
224                    // after_model hook — may replace the response.
225                    match self
226                        .middleware
227                        .run_after_model(&request, &llm_response)
228                        .await?
229                    {
230                        Some(replaced) => replaced,
231                        None => llm_response,
232                    }
233                }
234            };
235
236            let calls: Vec<FunctionCall> = response.function_calls().into_iter().cloned().collect();
237
238            if calls.is_empty() {
239                // No tool calls — we have a final text response.
240                return Ok(response.text());
241            }
242
243            // Move model response into conversation (no clone needed).
244            contents.push(response.content);
245
246            // Dispatch tools (middleware hooks inside).
247            let tool_responses = self.dispatch_tools(&calls).await;
248            let response_parts: Vec<Part> = tool_responses
249                .into_iter()
250                .map(|fr| Part::FunctionResponse {
251                    function_response: fr,
252                })
253                .collect();
254
255            contents.push(Content {
256                role: Some(Role::User),
257                parts: response_parts,
258            });
259        }
260
261        Err(AgentError::Other(format!(
262            "Agent '{}' exceeded max tool rounds ({})",
263            self.name, MAX_TOOL_ROUNDS
264        )))
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use crate::context::AgentEvent;
272    use crate::llm::{LlmError, LlmResponse};
273    use crate::middleware::Middleware;
274    use gemini_genai_rs::prelude::{Content, Part, Role};
275    use std::sync::atomic::{AtomicBool, Ordering};
276    use std::time::Duration;
277
278    fn text_response(t: &str) -> LlmResponse {
279        LlmResponse {
280            content: Content {
281                role: Some(Role::Model),
282                parts: vec![Part::Text { text: t.into() }],
283            },
284            finish_reason: Some("STOP".into()),
285            usage: None,
286        }
287    }
288
289    struct SlowLlm;
290    #[async_trait]
291    impl BaseLlm for SlowLlm {
292        fn model_id(&self) -> &str {
293            "slow"
294        }
295        async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
296            tokio::time::sleep(Duration::from_millis(500)).await;
297            Ok(text_response("done"))
298        }
299    }
300
301    struct FastLlm;
302    #[async_trait]
303    impl BaseLlm for FastLlm {
304        fn model_id(&self) -> &str {
305            "fast"
306        }
307        async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
308            Ok(text_response("hi"))
309        }
310    }
311
312    struct ShortTimeout;
313    #[async_trait]
314    impl Middleware for ShortTimeout {
315        fn name(&self) -> &str {
316            "short-timeout"
317        }
318        fn timeout(&self) -> Option<Duration> {
319            Some(Duration::from_millis(20))
320        }
321    }
322
323    struct EventFlag(Arc<AtomicBool>);
324    #[async_trait]
325    impl Middleware for EventFlag {
326        fn name(&self) -> &str {
327            "event-flag"
328        }
329        async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
330            if matches!(event, AgentEvent::AgentStarted { .. }) {
331                self.0.store(true, Ordering::SeqCst);
332            }
333            Ok(())
334        }
335    }
336
337    #[tokio::test]
338    async fn timeout_aborts_slow_run() {
339        let agent =
340            LlmTextAgent::new("slowpoke", Arc::new(SlowLlm)).add_middleware(Arc::new(ShortTimeout));
341        let state = State::new();
342        let _ = state.set("input", "hi");
343        let err = agent.run(&state).await.expect_err("expected timeout");
344        assert!(format!("{err:?}").contains("timed out"), "got: {err:?}");
345    }
346
347    #[tokio::test]
348    async fn on_event_fires_for_agent_lifecycle() {
349        let flag = Arc::new(AtomicBool::new(false));
350        let agent = LlmTextAgent::new("a", Arc::new(FastLlm))
351            .add_middleware(Arc::new(EventFlag(flag.clone())));
352        let state = State::new();
353        let _ = state.set("input", "hi");
354        let _ = agent.run(&state).await;
355        assert!(
356            flag.load(Ordering::SeqCst),
357            "on_event(AgentStarted) should fire"
358        );
359    }
360}