gemini_adk_rs/text/
llm.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4use gemini_genai_rs::prelude::{Content, FunctionCall, FunctionResponse, Part, Role};
5
6use super::TextAgent;
7use crate::context::AgentEvent;
8use crate::error::AgentError;
9use crate::llm::{BaseLlm, LlmRequest};
10use crate::middleware::MiddlewareChain;
11use crate::state::State;
12use crate::tool::ToolDispatcher;
13
14const MAX_TOOL_ROUNDS: usize = 10;
16
17pub struct LlmTextAgent {
32 name: String,
33 llm: Arc<dyn BaseLlm>,
34 instruction: Option<String>,
35 dispatcher: Option<Arc<ToolDispatcher>>,
36 temperature: Option<f32>,
37 max_output_tokens: Option<u32>,
38 middleware: MiddlewareChain,
39}
40
41impl LlmTextAgent {
42 pub fn new(name: impl Into<String>, llm: Arc<dyn BaseLlm>) -> Self {
44 Self {
45 name: name.into(),
46 llm,
47 instruction: None,
48 dispatcher: None,
49 temperature: None,
50 max_output_tokens: None,
51 middleware: MiddlewareChain::new(),
52 }
53 }
54
55 pub fn instruction(mut self, inst: impl Into<String>) -> Self {
57 self.instruction = Some(inst.into());
58 self
59 }
60
61 pub fn tools(mut self, dispatcher: Arc<ToolDispatcher>) -> Self {
63 self.dispatcher = Some(dispatcher);
64 self
65 }
66
67 pub fn temperature(mut self, t: f32) -> Self {
69 self.temperature = Some(t);
70 self
71 }
72
73 pub fn max_output_tokens(mut self, n: u32) -> Self {
75 self.max_output_tokens = Some(n);
76 self
77 }
78
79 pub fn add_middleware(mut self, mw: Arc<dyn crate::middleware::Middleware>) -> Self {
84 self.middleware.add(mw);
85 self
86 }
87
88 pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
90 self.middleware = chain;
91 self
92 }
93
94 fn build_request(&self, contents: Vec<Content>) -> LlmRequest {
96 let mut req = LlmRequest::from_contents(contents);
97 req.system_instruction = self.instruction.clone();
98 req.temperature = self.temperature;
99 req.max_output_tokens = self.max_output_tokens;
100
101 if let Some(dispatcher) = &self.dispatcher {
102 req.tools = dispatcher.to_tool_declarations();
103 }
104
105 req
106 }
107
108 async fn dispatch_tools(&self, calls: &[FunctionCall]) -> Vec<FunctionResponse> {
110 let dispatcher = match &self.dispatcher {
111 Some(d) => d,
112 None => return Vec::new(),
113 };
114
115 let mut responses = Vec::with_capacity(calls.len());
116 for call in calls {
117 if let Err(e) = self.middleware.run_before_tool(call).await {
119 let _ = self
121 .middleware
122 .run_on_tool_error(
123 call,
124 &crate::error::ToolError::ExecutionFailed(e.to_string()),
125 )
126 .await;
127 responses.push(ToolDispatcher::build_response(
128 call,
129 Err(crate::error::ToolError::ExecutionFailed(e.to_string())),
130 ));
131 continue;
132 }
133
134 let result = dispatcher
135 .call_function(&call.name, call.args.clone())
136 .await;
137
138 match &result {
139 Ok(value) => {
140 let _ = self.middleware.run_after_tool(call, value).await;
141 }
142 Err(e) => {
143 let _ = self.middleware.run_on_tool_error(call, e).await;
144 }
145 }
146
147 responses.push(ToolDispatcher::build_response(call, result));
148 }
149 responses
150 }
151}
152
153#[async_trait]
154impl TextAgent for LlmTextAgent {
155 fn name(&self) -> &str {
156 &self.name
157 }
158
159 async fn run(&self, state: &State) -> Result<String, AgentError> {
160 let input = state.get::<String>("input").unwrap_or_default();
162
163 let mut contents = vec![Content::user(&input)];
164
165 let _ = self
167 .middleware
168 .run_on_event(&AgentEvent::AgentStarted {
169 name: self.name.clone(),
170 })
171 .await;
172
173 let result = match self.middleware.timeout() {
175 Some(limit) => match tokio::time::timeout(limit, self.run_inner(&mut contents)).await {
176 Ok(r) => r,
177 Err(_) => {
178 let _ = self.middleware.run_on_event(&AgentEvent::Timeout).await;
179 Err(AgentError::Other(format!(
180 "agent '{}' timed out after {:?}",
181 self.name, limit
182 )))
183 }
184 },
185 None => self.run_inner(&mut contents).await,
186 };
187
188 if let Err(ref e) = result {
189 let _ = self.middleware.run_on_error(e).await;
190 } else if let Ok(ref text) = result {
191 let _ = state.set("output", text);
192 let _ = self
193 .middleware
194 .run_on_event(&AgentEvent::AgentCompleted {
195 name: self.name.clone(),
196 })
197 .await;
198 }
199
200 result
201 }
202}
203
204impl LlmTextAgent {
205 async fn run_inner(&self, contents: &mut Vec<Content>) -> Result<String, AgentError> {
207 for _round in 0..MAX_TOOL_ROUNDS {
208 let mut request = self.build_request(contents.clone());
209
210 self.middleware.run_transform_request(&mut request).await?;
213
214 let response = match self.middleware.run_before_model(&request).await? {
216 Some(cached) => cached,
217 None => {
218 let llm_response = self
219 .llm
220 .generate(request.clone())
221 .await
222 .map_err(|e| AgentError::Other(format!("LLM error: {e}")))?;
223
224 match self
226 .middleware
227 .run_after_model(&request, &llm_response)
228 .await?
229 {
230 Some(replaced) => replaced,
231 None => llm_response,
232 }
233 }
234 };
235
236 let calls: Vec<FunctionCall> = response.function_calls().into_iter().cloned().collect();
237
238 if calls.is_empty() {
239 return Ok(response.text());
241 }
242
243 contents.push(response.content);
245
246 let tool_responses = self.dispatch_tools(&calls).await;
248 let response_parts: Vec<Part> = tool_responses
249 .into_iter()
250 .map(|fr| Part::FunctionResponse {
251 function_response: fr,
252 })
253 .collect();
254
255 contents.push(Content {
256 role: Some(Role::User),
257 parts: response_parts,
258 });
259 }
260
261 Err(AgentError::Other(format!(
262 "Agent '{}' exceeded max tool rounds ({})",
263 self.name, MAX_TOOL_ROUNDS
264 )))
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use crate::context::AgentEvent;
272 use crate::llm::{LlmError, LlmResponse};
273 use crate::middleware::Middleware;
274 use gemini_genai_rs::prelude::{Content, Part, Role};
275 use std::sync::atomic::{AtomicBool, Ordering};
276 use std::time::Duration;
277
278 fn text_response(t: &str) -> LlmResponse {
279 LlmResponse {
280 content: Content {
281 role: Some(Role::Model),
282 parts: vec![Part::Text { text: t.into() }],
283 },
284 finish_reason: Some("STOP".into()),
285 usage: None,
286 }
287 }
288
289 struct SlowLlm;
290 #[async_trait]
291 impl BaseLlm for SlowLlm {
292 fn model_id(&self) -> &str {
293 "slow"
294 }
295 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
296 tokio::time::sleep(Duration::from_millis(500)).await;
297 Ok(text_response("done"))
298 }
299 }
300
301 struct FastLlm;
302 #[async_trait]
303 impl BaseLlm for FastLlm {
304 fn model_id(&self) -> &str {
305 "fast"
306 }
307 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
308 Ok(text_response("hi"))
309 }
310 }
311
312 struct ShortTimeout;
313 #[async_trait]
314 impl Middleware for ShortTimeout {
315 fn name(&self) -> &str {
316 "short-timeout"
317 }
318 fn timeout(&self) -> Option<Duration> {
319 Some(Duration::from_millis(20))
320 }
321 }
322
323 struct EventFlag(Arc<AtomicBool>);
324 #[async_trait]
325 impl Middleware for EventFlag {
326 fn name(&self) -> &str {
327 "event-flag"
328 }
329 async fn on_event(&self, event: &AgentEvent) -> Result<(), AgentError> {
330 if matches!(event, AgentEvent::AgentStarted { .. }) {
331 self.0.store(true, Ordering::SeqCst);
332 }
333 Ok(())
334 }
335 }
336
337 #[tokio::test]
338 async fn timeout_aborts_slow_run() {
339 let agent =
340 LlmTextAgent::new("slowpoke", Arc::new(SlowLlm)).add_middleware(Arc::new(ShortTimeout));
341 let state = State::new();
342 let _ = state.set("input", "hi");
343 let err = agent.run(&state).await.expect_err("expected timeout");
344 assert!(format!("{err:?}").contains("timed out"), "got: {err:?}");
345 }
346
347 #[tokio::test]
348 async fn on_event_fires_for_agent_lifecycle() {
349 let flag = Arc::new(AtomicBool::new(false));
350 let agent = LlmTextAgent::new("a", Arc::new(FastLlm))
351 .add_middleware(Arc::new(EventFlag(flag.clone())));
352 let state = State::new();
353 let _ = state.set("input", "hi");
354 let _ = agent.run(&state).await;
355 assert!(
356 flag.load(Ordering::SeqCst),
357 "on_event(AgentStarted) should fire"
358 );
359 }
360}