gemini_adk_rs/
agent_tool.rs

1//! AgentTool — wraps an Agent as a ToolFunction for "agent as a tool" dispatch.
2//!
3//! When the live model calls this tool, the wrapped agent runs in an isolated
4//! context (no live WebSocket). The agent's text output is collected and returned
5//! as the tool result. State changes propagate back to the parent context.
6//!
7//! This bridges live<->non-live: the wrapped agent can use regular Gemini API,
8//! external services, or pure computation — it doesn't need a WebSocket.
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use serde_json::json;
14use tokio::sync::broadcast;
15
16use gemini_genai_rs::session::SessionEvent;
17
18use crate::agent::Agent;
19use crate::agent_session::{AgentSession, NoOpSessionWriter};
20use crate::context::{AgentEvent, InvocationContext};
21use crate::error::ToolError;
22use crate::tool::ToolFunction;
23
24/// Wraps an Agent as a ToolFunction for "agent as a tool" dispatch.
25///
26/// When the live model calls this tool, the wrapped agent runs in an isolated
27/// context (no live WebSocket). The agent's text output is collected and returned
28/// as the tool result.
29pub struct AgentTool {
30    agent: Arc<dyn Agent>,
31    description: String,
32    parameters: Option<serde_json::Value>,
33}
34
35impl AgentTool {
36    /// Create a new AgentTool wrapping the given agent.
37    pub fn new(agent: impl Agent + 'static) -> Self {
38        let description = format!("Delegate to the {} agent", agent.name());
39        Self {
40            agent: Arc::new(agent),
41            description,
42            parameters: Some(json!({
43                "type": "object",
44                "properties": {
45                    "request": {
46                        "type": "string",
47                        "description": "The request to send to the agent"
48                    }
49                },
50                "required": ["request"]
51            })),
52        }
53    }
54
55    /// Create from an already-Arc'd agent.
56    pub fn from_arc(agent: Arc<dyn Agent>) -> Self {
57        let description = format!("Delegate to the {} agent", agent.name());
58        Self {
59            agent,
60            description,
61            parameters: Some(json!({
62                "type": "object",
63                "properties": {
64                    "request": {
65                        "type": "string",
66                        "description": "The request to send to the agent"
67                    }
68                },
69                "required": ["request"]
70            })),
71        }
72    }
73
74    /// Override the tool description.
75    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
76        self.description = desc.into();
77        self
78    }
79
80    /// Override the tool parameters schema.
81    pub fn with_parameters(mut self, params: serde_json::Value) -> Self {
82        self.parameters = Some(params);
83        self
84    }
85}
86
87#[async_trait]
88impl ToolFunction for AgentTool {
89    fn name(&self) -> &str {
90        self.agent.name()
91    }
92
93    fn description(&self) -> &str {
94        &self.description
95    }
96
97    fn parameters(&self) -> Option<serde_json::Value> {
98        self.parameters.clone()
99    }
100
101    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
102        let start = std::time::Instant::now();
103        let agent_name = self.agent.name().to_string();
104
105        // Telemetry
106        crate::telemetry::logging::log_agent_tool_dispatch("parent", &agent_name);
107
108        // 1. Create isolated context with NoOpSessionWriter
109        let (event_tx, _) = broadcast::channel::<SessionEvent>(64);
110        let noop_writer: Arc<dyn gemini_genai_rs::session::SessionWriter> =
111            Arc::new(NoOpSessionWriter);
112        let isolated_session = AgentSession::from_writer(noop_writer, event_tx);
113
114        // 2. Inject args into state
115        if let Some(request) = args.get("request").and_then(|r| r.as_str()) {
116            isolated_session.state().set("request_text", request);
117        }
118        isolated_session.state().set("request", &args);
119
120        // 3. Create isolated InvocationContext
121        let mut ctx = InvocationContext::new(isolated_session);
122
123        // 4. Subscribe to events before running (to collect text output)
124        let mut events = ctx.subscribe();
125
126        // 5. Run the agent
127        let agent = self.agent.clone();
128        let run_result = tokio::spawn(async move { agent.run_live(&mut ctx).await }).await;
129
130        // 6. Collect text output from events
131        let mut output_parts = Vec::new();
132        while let Ok(event) = events.try_recv() {
133            match event {
134                AgentEvent::Session(SessionEvent::TextDelta(text)) => {
135                    output_parts.push(text);
136                }
137                AgentEvent::Session(SessionEvent::TextComplete(text)) => {
138                    if output_parts.is_empty() {
139                        output_parts.push(text);
140                    }
141                    // If we already have deltas, TextComplete is the full assembled text
142                    // Don't double-count — deltas already captured incrementally
143                }
144                _ => {}
145            }
146        }
147
148        let elapsed = start.elapsed();
149        crate::telemetry::metrics::record_agent_tool_dispatch(
150            "parent",
151            &agent_name,
152            elapsed.as_millis() as f64,
153        );
154
155        // 7. Handle result
156        match run_result {
157            Ok(Ok(())) => {
158                let output = if output_parts.is_empty() {
159                    json!({"status": "completed"})
160                } else {
161                    json!({"result": output_parts.join("")})
162                };
163                Ok(output)
164            }
165            Ok(Err(e)) => Err(ToolError::ExecutionFailed(format!(
166                "Agent '{}' failed: {}",
167                agent_name, e
168            ))),
169            Err(e) => Err(ToolError::ExecutionFailed(format!(
170                "Agent '{}' task panicked: {}",
171                agent_name, e
172            ))),
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::error::AgentError;
181
182    struct EchoAgent {
183        name: String,
184    }
185
186    #[async_trait]
187    impl Agent for EchoAgent {
188        fn name(&self) -> &str {
189            &self.name
190        }
191        async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
192            // Read the request from state and echo it back as a text event
193            let request = ctx
194                .state()
195                .get::<String>("request_text")
196                .unwrap_or_else(|| "no request".to_string());
197            ctx.emit(AgentEvent::Session(SessionEvent::TextDelta(format!(
198                "Echo: {}",
199                request
200            ))));
201            ctx.emit(AgentEvent::Session(SessionEvent::TurnComplete));
202            Ok(())
203        }
204    }
205
206    struct FailingAgent;
207
208    #[async_trait]
209    impl Agent for FailingAgent {
210        fn name(&self) -> &str {
211            "failing"
212        }
213        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
214            Err(AgentError::Other("intentional failure".to_string()))
215        }
216    }
217
218    struct SilentAgent;
219
220    #[async_trait]
221    impl Agent for SilentAgent {
222        fn name(&self) -> &str {
223            "silent"
224        }
225        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
226            Ok(())
227        }
228    }
229
230    #[tokio::test]
231    async fn agent_tool_runs_agent_in_isolation() {
232        let agent = EchoAgent {
233            name: "echo".to_string(),
234        };
235        let tool = AgentTool::new(agent);
236
237        assert_eq!(tool.name(), "echo");
238        assert!(tool.description().contains("echo"));
239    }
240
241    #[tokio::test]
242    async fn agent_tool_collects_text_output() {
243        let agent = EchoAgent {
244            name: "echo".to_string(),
245        };
246        let tool = AgentTool::new(agent);
247
248        let result = tool.call(json!({"request": "hello world"})).await.unwrap();
249        assert_eq!(result["result"], "Echo: hello world");
250    }
251
252    #[tokio::test]
253    async fn agent_tool_propagates_errors() {
254        let tool = AgentTool::new(FailingAgent);
255        let result = tool.call(json!({"request": "test"})).await;
256        assert!(result.is_err());
257        let err = result.unwrap_err();
258        match err {
259            ToolError::ExecutionFailed(msg) => {
260                assert!(msg.contains("intentional failure"));
261            }
262            other => panic!("expected ExecutionFailed, got: {:?}", other),
263        }
264    }
265
266    #[tokio::test]
267    async fn agent_tool_returns_completed_when_no_output() {
268        let tool = AgentTool::new(SilentAgent);
269        let result = tool.call(json!({"request": "test"})).await.unwrap();
270        assert_eq!(result["status"], "completed");
271    }
272
273    #[tokio::test]
274    async fn agent_tool_state_injection() {
275        // Verify that args are injected into state
276        struct StateCheckAgent;
277
278        #[async_trait]
279        impl Agent for StateCheckAgent {
280            fn name(&self) -> &str {
281                "state_check"
282            }
283            async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
284                let request_text = ctx.state().get::<String>("request_text");
285                let request = ctx.state().get::<serde_json::Value>("request");
286
287                assert!(request_text.is_some());
288                assert!(request.is_some());
289                assert_eq!(request_text.unwrap(), "check state");
290
291                ctx.emit(AgentEvent::Session(SessionEvent::TextDelta(
292                    "state ok".to_string(),
293                )));
294                Ok(())
295            }
296        }
297
298        let tool = AgentTool::new(StateCheckAgent);
299        let result = tool.call(json!({"request": "check state"})).await.unwrap();
300        assert_eq!(result["result"], "state ok");
301    }
302
303    #[tokio::test]
304    async fn agent_tool_with_custom_description() {
305        let tool = AgentTool::new(SilentAgent).with_description("Custom description");
306        assert_eq!(tool.description(), "Custom description");
307    }
308
309    #[tokio::test]
310    async fn agent_tool_with_custom_parameters() {
311        let params = json!({
312            "type": "object",
313            "properties": {
314                "query": { "type": "string" }
315            }
316        });
317        let tool = AgentTool::new(SilentAgent).with_parameters(params.clone());
318        assert_eq!(tool.parameters().unwrap(), params);
319    }
320}