gemini_adk_rs/
agent_tool.rs1use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::json;
14use tokio::sync::broadcast;
15
16use gemini_genai_rs::session::SessionEvent;
17
18use crate::agent::Agent;
19use crate::agent_session::{AgentSession, NoOpSessionWriter};
20use crate::context::{AgentEvent, InvocationContext};
21use crate::error::ToolError;
22use crate::tool::ToolFunction;
23
24pub struct AgentTool {
30 agent: Arc<dyn Agent>,
31 description: String,
32 parameters: Option<serde_json::Value>,
33}
34
35impl AgentTool {
36 pub fn new(agent: impl Agent + 'static) -> Self {
38 let description = format!("Delegate to the {} agent", agent.name());
39 Self {
40 agent: Arc::new(agent),
41 description,
42 parameters: Some(json!({
43 "type": "object",
44 "properties": {
45 "request": {
46 "type": "string",
47 "description": "The request to send to the agent"
48 }
49 },
50 "required": ["request"]
51 })),
52 }
53 }
54
55 pub fn from_arc(agent: Arc<dyn Agent>) -> Self {
57 let description = format!("Delegate to the {} agent", agent.name());
58 Self {
59 agent,
60 description,
61 parameters: Some(json!({
62 "type": "object",
63 "properties": {
64 "request": {
65 "type": "string",
66 "description": "The request to send to the agent"
67 }
68 },
69 "required": ["request"]
70 })),
71 }
72 }
73
74 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
76 self.description = desc.into();
77 self
78 }
79
80 pub fn with_parameters(mut self, params: serde_json::Value) -> Self {
82 self.parameters = Some(params);
83 self
84 }
85}
86
87#[async_trait]
88impl ToolFunction for AgentTool {
89 fn name(&self) -> &str {
90 self.agent.name()
91 }
92
93 fn description(&self) -> &str {
94 &self.description
95 }
96
97 fn parameters(&self) -> Option<serde_json::Value> {
98 self.parameters.clone()
99 }
100
101 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
102 let start = std::time::Instant::now();
103 let agent_name = self.agent.name().to_string();
104
105 crate::telemetry::logging::log_agent_tool_dispatch("parent", &agent_name);
107
108 let (event_tx, _) = broadcast::channel::<SessionEvent>(64);
110 let noop_writer: Arc<dyn gemini_genai_rs::session::SessionWriter> =
111 Arc::new(NoOpSessionWriter);
112 let isolated_session = AgentSession::from_writer(noop_writer, event_tx);
113
114 if let Some(request) = args.get("request").and_then(|r| r.as_str()) {
116 isolated_session.state().set("request_text", request);
117 }
118 isolated_session.state().set("request", &args);
119
120 let mut ctx = InvocationContext::new(isolated_session);
122
123 let mut events = ctx.subscribe();
125
126 let agent = self.agent.clone();
128 let run_result = tokio::spawn(async move { agent.run_live(&mut ctx).await }).await;
129
130 let mut output_parts = Vec::new();
132 while let Ok(event) = events.try_recv() {
133 match event {
134 AgentEvent::Session(SessionEvent::TextDelta(text)) => {
135 output_parts.push(text);
136 }
137 AgentEvent::Session(SessionEvent::TextComplete(text)) => {
138 if output_parts.is_empty() {
139 output_parts.push(text);
140 }
141 }
144 _ => {}
145 }
146 }
147
148 let elapsed = start.elapsed();
149 crate::telemetry::metrics::record_agent_tool_dispatch(
150 "parent",
151 &agent_name,
152 elapsed.as_millis() as f64,
153 );
154
155 match run_result {
157 Ok(Ok(())) => {
158 let output = if output_parts.is_empty() {
159 json!({"status": "completed"})
160 } else {
161 json!({"result": output_parts.join("")})
162 };
163 Ok(output)
164 }
165 Ok(Err(e)) => Err(ToolError::ExecutionFailed(format!(
166 "Agent '{}' failed: {}",
167 agent_name, e
168 ))),
169 Err(e) => Err(ToolError::ExecutionFailed(format!(
170 "Agent '{}' task panicked: {}",
171 agent_name, e
172 ))),
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::error::AgentError;
181
182 struct EchoAgent {
183 name: String,
184 }
185
186 #[async_trait]
187 impl Agent for EchoAgent {
188 fn name(&self) -> &str {
189 &self.name
190 }
191 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
192 let request = ctx
194 .state()
195 .get::<String>("request_text")
196 .unwrap_or_else(|| "no request".to_string());
197 ctx.emit(AgentEvent::Session(SessionEvent::TextDelta(format!(
198 "Echo: {}",
199 request
200 ))));
201 ctx.emit(AgentEvent::Session(SessionEvent::TurnComplete));
202 Ok(())
203 }
204 }
205
206 struct FailingAgent;
207
208 #[async_trait]
209 impl Agent for FailingAgent {
210 fn name(&self) -> &str {
211 "failing"
212 }
213 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
214 Err(AgentError::Other("intentional failure".to_string()))
215 }
216 }
217
218 struct SilentAgent;
219
220 #[async_trait]
221 impl Agent for SilentAgent {
222 fn name(&self) -> &str {
223 "silent"
224 }
225 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
226 Ok(())
227 }
228 }
229
230 #[tokio::test]
231 async fn agent_tool_runs_agent_in_isolation() {
232 let agent = EchoAgent {
233 name: "echo".to_string(),
234 };
235 let tool = AgentTool::new(agent);
236
237 assert_eq!(tool.name(), "echo");
238 assert!(tool.description().contains("echo"));
239 }
240
241 #[tokio::test]
242 async fn agent_tool_collects_text_output() {
243 let agent = EchoAgent {
244 name: "echo".to_string(),
245 };
246 let tool = AgentTool::new(agent);
247
248 let result = tool.call(json!({"request": "hello world"})).await.unwrap();
249 assert_eq!(result["result"], "Echo: hello world");
250 }
251
252 #[tokio::test]
253 async fn agent_tool_propagates_errors() {
254 let tool = AgentTool::new(FailingAgent);
255 let result = tool.call(json!({"request": "test"})).await;
256 assert!(result.is_err());
257 let err = result.unwrap_err();
258 match err {
259 ToolError::ExecutionFailed(msg) => {
260 assert!(msg.contains("intentional failure"));
261 }
262 other => panic!("expected ExecutionFailed, got: {:?}", other),
263 }
264 }
265
266 #[tokio::test]
267 async fn agent_tool_returns_completed_when_no_output() {
268 let tool = AgentTool::new(SilentAgent);
269 let result = tool.call(json!({"request": "test"})).await.unwrap();
270 assert_eq!(result["status"], "completed");
271 }
272
273 #[tokio::test]
274 async fn agent_tool_state_injection() {
275 struct StateCheckAgent;
277
278 #[async_trait]
279 impl Agent for StateCheckAgent {
280 fn name(&self) -> &str {
281 "state_check"
282 }
283 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
284 let request_text = ctx.state().get::<String>("request_text");
285 let request = ctx.state().get::<serde_json::Value>("request");
286
287 assert!(request_text.is_some());
288 assert!(request.is_some());
289 assert_eq!(request_text.unwrap(), "check state");
290
291 ctx.emit(AgentEvent::Session(SessionEvent::TextDelta(
292 "state ok".to_string(),
293 )));
294 Ok(())
295 }
296 }
297
298 let tool = AgentTool::new(StateCheckAgent);
299 let result = tool.call(json!({"request": "check state"})).await.unwrap();
300 assert_eq!(result["result"], "state ok");
301 }
302
303 #[tokio::test]
304 async fn agent_tool_with_custom_description() {
305 let tool = AgentTool::new(SilentAgent).with_description("Custom description");
306 assert_eq!(tool.description(), "Custom description");
307 }
308
309 #[tokio::test]
310 async fn agent_tool_with_custom_parameters() {
311 let params = json!({
312 "type": "object",
313 "properties": {
314 "query": { "type": "string" }
315 }
316 });
317 let tool = AgentTool::new(SilentAgent).with_parameters(params.clone());
318 assert_eq!(tool.parameters().unwrap(), params);
319 }
320}