gemini_adk_rs/
llm_agent.rs

1//! LlmAgent — concrete Agent implementation with builder pattern.
2//!
3//! The builder freezes tools at `build()` time (respecting Gemini Live's
4//! constraint that tools are fixed at session setup). Auto-registers
5//! `transfer_to_{name}` tools for each sub-agent.
6//!
7//! The event loop subscribes to SessionEvents, auto-dispatches tool calls,
8//! detects transfers via `__transfer_to` signal in tool results, and handles
9//! streaming/input-streaming tools.
10
11use std::sync::Arc;
12use std::time::Duration;
13
14use async_trait::async_trait;
15use serde_json::json;
16use tokio::sync::broadcast;
17
18use gemini_genai_rs::prelude::{recv_event, FunctionResponse, Tool};
19use gemini_genai_rs::session::SessionEvent;
20
21use crate::agent::Agent;
22use crate::context::{AgentEvent, InvocationContext};
23use crate::error::{AgentError, ToolError};
24use crate::middleware::MiddlewareChain;
25use crate::plugin::{PluginManager, PluginResult};
26use crate::tool::{
27    ActiveStreamingTool, InputStreamingTool, SimpleTool, StreamingTool, ToolClass, ToolDispatcher,
28    ToolFunction, ToolKind, TypedTool,
29};
30
31/// Concrete Agent implementation that runs a Gemini Live event loop.
32///
33/// Tools are declared at build time and sent during session setup.
34/// The event loop subscribes to SessionEvents, auto-dispatches tool calls,
35/// detects transfers, and emits AgentEvents.
36pub struct LlmAgent {
37    name: String,
38    dispatcher: ToolDispatcher,
39    middleware: MiddlewareChain,
40    plugins: PluginManager,
41    sub_agents: Vec<Arc<dyn Agent>>,
42}
43
44impl LlmAgent {
45    /// Start building a new LlmAgent.
46    pub fn builder(name: impl Into<String>) -> LlmAgentBuilder {
47        LlmAgentBuilder {
48            name: name.into(),
49            dispatcher: ToolDispatcher::new(),
50            middleware: MiddlewareChain::new(),
51            plugins: PluginManager::new(),
52            sub_agents: Vec::new(),
53        }
54    }
55
56    /// Access the tool dispatcher (for testing/introspection).
57    pub fn dispatcher(&self) -> &ToolDispatcher {
58        &self.dispatcher
59    }
60
61    /// Access the middleware chain.
62    pub fn middleware(&self) -> &MiddlewareChain {
63        &self.middleware
64    }
65
66    /// Access the plugin manager.
67    pub fn plugins(&self) -> &PluginManager {
68        &self.plugins
69    }
70
71    /// Core event loop -- processes SessionEvents, dispatches tools, detects transfers.
72    async fn event_loop(
73        &self,
74        ctx: &mut InvocationContext,
75        events: &mut broadcast::Receiver<SessionEvent>,
76        agent_name: &str,
77    ) -> Result<(), AgentError> {
78        loop {
79            let event = match recv_event(events).await {
80                Some(e) => e,
81                None => break, // channel closed
82            };
83
84            match event {
85                SessionEvent::ToolCall(calls) => {
86                    let mut responses = Vec::new();
87                    let mut transfer_target = None;
88
89                    for call in &calls {
90                        // Emit events + middleware hooks
91                        ctx.emit(AgentEvent::ToolCallStarted {
92                            name: call.name.clone(),
93                            args: call.args.clone(),
94                        });
95                        let _ = ctx.middleware.run_before_tool(call).await;
96
97                        // Plugin before_tool hook — can deny or short-circuit
98                        let plugin_result = self.plugins.run_before_tool(call, ctx).await;
99                        match &plugin_result {
100                            PluginResult::Deny(reason) => {
101                                ctx.emit(AgentEvent::ToolCallFailed {
102                                    name: call.name.clone(),
103                                    error: format!("Denied by plugin: {}", reason),
104                                });
105                                responses.push(ToolDispatcher::build_response(
106                                    call,
107                                    Err(ToolError::ExecutionFailed(format!(
108                                        "Denied by plugin: {}",
109                                        reason
110                                    ))),
111                                ));
112                                continue;
113                            }
114                            PluginResult::ShortCircuit(value) => {
115                                let _ = ctx.middleware.run_after_tool(call, value).await;
116                                ctx.emit(AgentEvent::ToolCallCompleted {
117                                    name: call.name.clone(),
118                                    result: value.clone(),
119                                    duration: std::time::Duration::ZERO,
120                                });
121                                responses
122                                    .push(ToolDispatcher::build_response(call, Ok(value.clone())));
123                                continue;
124                            }
125                            PluginResult::Continue => {}
126                        }
127
128                        let tool_start = std::time::Instant::now();
129                        let tool_class = self.dispatcher.classify(&call.name);
130
131                        match tool_class {
132                            Some(ToolClass::Regular) => {
133                                crate::telemetry::logging::log_tool_dispatch(
134                                    agent_name, &call.name, "function",
135                                );
136                                crate::telemetry::metrics::record_agent_tool_dispatched(
137                                    agent_name, &call.name,
138                                );
139
140                                let result = self
141                                    .dispatcher
142                                    .call_function(&call.name, call.args.clone())
143                                    .await;
144                                let elapsed = tool_start.elapsed();
145
146                                match &result {
147                                    Ok(value) => {
148                                        // Check for transfer signal
149                                        if let Some(target) =
150                                            value.get("__transfer_to").and_then(|v| v.as_str())
151                                        {
152                                            transfer_target = Some(target.to_string());
153                                        }
154
155                                        let _ = ctx.middleware.run_after_tool(call, value).await;
156                                        let _ = self.plugins.run_after_tool(call, value, ctx).await;
157                                        ctx.emit(AgentEvent::ToolCallCompleted {
158                                            name: call.name.clone(),
159                                            result: value.clone(),
160                                            duration: elapsed,
161                                        });
162                                        crate::telemetry::logging::log_tool_result(
163                                            agent_name,
164                                            &call.name,
165                                            true,
166                                            elapsed.as_millis() as f64,
167                                        );
168                                        crate::telemetry::metrics::record_agent_tool_duration(
169                                            agent_name,
170                                            &call.name,
171                                            elapsed.as_millis() as f64,
172                                        );
173                                    }
174                                    Err(e) => {
175                                        let _ = ctx.middleware.run_on_tool_error(call, e).await;
176                                        ctx.emit(AgentEvent::ToolCallFailed {
177                                            name: call.name.clone(),
178                                            error: e.to_string(),
179                                        });
180                                        crate::telemetry::logging::log_tool_result(
181                                            agent_name,
182                                            &call.name,
183                                            false,
184                                            elapsed.as_millis() as f64,
185                                        );
186                                    }
187                                }
188
189                                responses.push(ToolDispatcher::build_response(call, result));
190                            }
191                            Some(ToolClass::Streaming) | Some(ToolClass::InputStream) => {
192                                let class_str = if tool_class == Some(ToolClass::Streaming) {
193                                    "streaming"
194                                } else {
195                                    "input_stream"
196                                };
197                                crate::telemetry::logging::log_tool_dispatch(
198                                    agent_name, &call.name, class_str,
199                                );
200
201                                self.spawn_streaming_tool(call, ctx, agent_name).await;
202
203                                responses.push(FunctionResponse {
204                                    name: call.name.clone(),
205                                    response: json!({"status": "streaming"}),
206                                    id: call.id.clone(),
207                                    scheduling: None,
208                                });
209                            }
210                            None => {
211                                ctx.emit(AgentEvent::ToolCallFailed {
212                                    name: call.name.clone(),
213                                    error: format!("Tool not found: {}", call.name),
214                                });
215                                responses.push(ToolDispatcher::build_response(
216                                    call,
217                                    Err(ToolError::NotFound(call.name.clone())),
218                                ));
219                            }
220                        }
221                    }
222
223                    // Send all responses back to Gemini
224                    ctx.agent_session.send_tool_response(responses).await?;
225
226                    // Handle transfer AFTER sending response
227                    if let Some(target) = transfer_target {
228                        ctx.emit(AgentEvent::AgentTransfer {
229                            from: agent_name.to_string(),
230                            to: target.clone(),
231                        });
232                        crate::telemetry::metrics::record_agent_transfer(agent_name, &target);
233                        crate::telemetry::logging::log_agent_transfer(agent_name, &target);
234                        return Err(AgentError::TransferRequested(target));
235                    }
236                }
237                SessionEvent::ToolCallCancelled(ids) => {
238                    self.dispatcher.cancel_by_ids(&ids).await;
239                }
240                SessionEvent::TurnComplete => {
241                    ctx.emit(AgentEvent::Session(SessionEvent::TurnComplete));
242                    break;
243                }
244                SessionEvent::Disconnected(reason) => {
245                    ctx.emit(AgentEvent::Session(SessionEvent::Disconnected(reason)));
246                    break;
247                }
248                SessionEvent::Error(ref e) => {
249                    ctx.emit(AgentEvent::Session(event.clone()));
250                    crate::telemetry::metrics::record_agent_error(agent_name, "session_error");
251                    crate::telemetry::logging::log_agent_error(agent_name, e);
252                }
253                other => {
254                    // Pass through all other events (TextDelta, AudioData, etc.)
255                    ctx.emit(AgentEvent::Session(other));
256                }
257            }
258        }
259        Ok(())
260    }
261
262    /// Spawn a streaming or input-streaming tool as a background task.
263    async fn spawn_streaming_tool(
264        &self,
265        call: &gemini_genai_rs::prelude::FunctionCall,
266        ctx: &InvocationContext,
267        _agent_name: &str,
268    ) {
269        let tool_kind = match self.dispatcher.get_tool(&call.name) {
270            Some(kind) => kind,
271            None => return,
272        };
273
274        let (yield_tx, mut yield_rx) = tokio::sync::mpsc::channel::<serde_json::Value>(32);
275        let cancel = tokio_util::sync::CancellationToken::new();
276
277        let tool_name = call.name.clone();
278        let call_id = call.id.clone();
279        let args = call.args.clone();
280        let event_tx = ctx.event_tx.clone();
281        let agent_session = ctx.agent_session.clone();
282
283        match tool_kind {
284            ToolKind::Streaming(tool) => {
285                let tool = tool.clone();
286                let cancel_clone = cancel.clone();
287                let tool_name_err = tool_name.clone();
288                let event_tx_err = event_tx.clone();
289
290                let tool_task = tokio::spawn(async move {
291                    tokio::select! {
292                        result = tool.run(args, yield_tx) => {
293                            if let Err(e) = result {
294                                let _ = event_tx_err.send(AgentEvent::ToolCallFailed {
295                                    name: tool_name_err,
296                                    error: e.to_string(),
297                                });
298                            }
299                        }
300                        _ = cancel_clone.cancelled() => {}
301                    }
302                });
303
304                let active = ActiveStreamingTool {
305                    task: tool_task,
306                    cancel,
307                };
308                let id = call_id.clone().unwrap_or_else(|| tool_name.clone());
309                self.dispatcher.store_active(id, active).await;
310            }
311            ToolKind::InputStream(tool) => {
312                let tool = tool.clone();
313                let input_rx = ctx.agent_session.subscribe_input();
314                let cancel_clone = cancel.clone();
315                let tool_name_err = tool_name.clone();
316                let event_tx_err = event_tx.clone();
317
318                let tool_task = tokio::spawn(async move {
319                    tokio::select! {
320                        result = tool.run(args, input_rx, yield_tx) => {
321                            if let Err(e) = result {
322                                let _ = event_tx_err.send(AgentEvent::ToolCallFailed {
323                                    name: tool_name_err,
324                                    error: e.to_string(),
325                                });
326                            }
327                        }
328                        _ = cancel_clone.cancelled() => {}
329                    }
330                });
331
332                let active = ActiveStreamingTool {
333                    task: tool_task,
334                    cancel,
335                };
336                let id = call_id.clone().unwrap_or_else(|| tool_name.clone());
337                self.dispatcher.store_active(id, active).await;
338            }
339            ToolKind::Function(_) => {} // shouldn't reach here
340        }
341
342        // Spawn collector: reads yields and forwards as events + sends final FunctionResponse
343        let yield_tool_name = call.name.clone();
344        let yield_call_id = call.id.clone();
345
346        tokio::spawn(async move {
347            let mut all_yields = Vec::new();
348            while let Some(value) = yield_rx.recv().await {
349                let _ = event_tx.send(AgentEvent::StreamingToolYield {
350                    name: yield_tool_name.clone(),
351                    value: value.clone(),
352                });
353                all_yields.push(value);
354            }
355
356            // Send final response when tool completes
357            let final_response = if all_yields.is_empty() {
358                json!({"status": "completed"})
359            } else if all_yields.len() == 1 {
360                all_yields.into_iter().next().unwrap()
361            } else {
362                json!({"results": all_yields})
363            };
364
365            let resp = FunctionResponse {
366                name: yield_tool_name,
367                response: final_response,
368                id: yield_call_id,
369                scheduling: None,
370            };
371            let _ = agent_session.send_tool_response(vec![resp]).await;
372        });
373    }
374}
375
376/// Builder for LlmAgent -- fluent API for declaring tools, middleware, sub-agents.
377pub struct LlmAgentBuilder {
378    name: String,
379    dispatcher: ToolDispatcher,
380    middleware: MiddlewareChain,
381    plugins: PluginManager,
382    sub_agents: Vec<Arc<dyn Agent>>,
383}
384
385impl LlmAgentBuilder {
386    /// Register a regular function tool.
387    pub fn tool(mut self, tool: impl ToolFunction + 'static) -> Self {
388        self.dispatcher.register_function(Arc::new(tool));
389        self
390    }
391
392    /// Register a typed tool with auto-generated JSON Schema.
393    pub fn typed_tool<T>(mut self, tool: TypedTool<T>) -> Self
394    where
395        T: serde::de::DeserializeOwned + schemars::JsonSchema + Send + Sync + 'static,
396    {
397        self.dispatcher.register_function(Arc::new(tool));
398        self
399    }
400
401    /// Register a streaming tool.
402    pub fn streaming_tool(mut self, tool: impl StreamingTool + 'static) -> Self {
403        self.dispatcher.register_streaming(Arc::new(tool));
404        self
405    }
406
407    /// Register an input-streaming tool.
408    pub fn input_streaming_tool(mut self, tool: impl InputStreamingTool + 'static) -> Self {
409        self.dispatcher.register_input_streaming(Arc::new(tool));
410        self
411    }
412
413    /// Add middleware to the agent.
414    pub fn middleware(mut self, mw: impl crate::middleware::Middleware + 'static) -> Self {
415        self.middleware.add(Arc::new(mw));
416        self
417    }
418
419    /// Add a plugin to the agent.
420    pub fn plugin(mut self, plugin: impl crate::plugin::Plugin + 'static) -> Self {
421        self.plugins.add(Arc::new(plugin));
422        self
423    }
424
425    /// Register a sub-agent (enables transfer_to_{name} tool).
426    pub fn sub_agent(mut self, agent: impl Agent + 'static) -> Self {
427        self.sub_agents.push(Arc::new(agent));
428        self
429    }
430
431    /// Set the default timeout for tool execution.
432    pub fn tool_timeout(mut self, timeout: Duration) -> Self {
433        self.dispatcher = self.dispatcher.with_timeout(timeout);
434        self
435    }
436
437    /// Build the LlmAgent, freezing all tool declarations.
438    ///
439    /// This:
440    /// 1. Auto-registers `transfer_to_{name}` SimpleTool for each sub_agent
441    /// 2. Prepends TelemetryMiddleware
442    /// 3. Returns the frozen LlmAgent
443    pub fn build(mut self) -> LlmAgent {
444        // Auto-register transfer tools for sub-agents
445        for sub in &self.sub_agents {
446            let target_name = sub.name().to_string();
447            let tool_name = format!("transfer_to_{}", target_name);
448            let transfer_tool = SimpleTool::new(
449                tool_name,
450                format!("Transfer conversation to the {} agent", target_name),
451                Some(json!({
452                    "type": "object",
453                    "properties": {},
454                })),
455                move |_args| {
456                    let name = target_name.clone();
457                    async move { Ok(json!({"__transfer_to": name})) }
458                },
459            );
460            self.dispatcher.register_function(Arc::new(transfer_tool));
461        }
462
463        // Prepend TelemetryMiddleware so it runs first
464        self.middleware
465            .prepend(Arc::new(crate::telemetry::TelemetryMiddleware::new(
466                &self.name,
467            )));
468
469        LlmAgent {
470            name: self.name,
471            dispatcher: self.dispatcher,
472            middleware: self.middleware,
473            plugins: self.plugins,
474            sub_agents: self.sub_agents,
475        }
476    }
477}
478
479#[async_trait]
480impl Agent for LlmAgent {
481    fn name(&self) -> &str {
482        &self.name
483    }
484
485    async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
486        let agent_name = self.name.clone();
487        let start = std::time::Instant::now();
488
489        // Telemetry + middleware + plugins
490        crate::telemetry::logging::log_agent_started(&agent_name, self.dispatcher.len());
491        crate::telemetry::metrics::record_agent_started(&agent_name);
492        ctx.middleware.run_before_agent(ctx).await?;
493
494        // Plugin before_agent hook
495        let plugin_result = self.plugins.run_before_agent(ctx).await;
496        if let PluginResult::Deny(reason) = plugin_result {
497            return Err(AgentError::Other(format!(
498                "Agent denied by plugin: {}",
499                reason
500            )));
501        }
502
503        ctx.emit(AgentEvent::AgentStarted {
504            name: agent_name.clone(),
505        });
506
507        let mut events = ctx.agent_session.subscribe_events();
508
509        let result = self.event_loop(ctx, &mut events, &agent_name).await;
510
511        // Cleanup
512        let elapsed = start.elapsed();
513        ctx.middleware.run_after_agent(ctx).await?;
514        let _ = self.plugins.run_after_agent(ctx).await;
515        ctx.emit(AgentEvent::AgentCompleted {
516            name: agent_name.clone(),
517        });
518        crate::telemetry::logging::log_agent_completed(&agent_name, elapsed.as_millis() as f64);
519        crate::telemetry::metrics::record_agent_completed(&agent_name, elapsed.as_millis() as f64);
520
521        result
522    }
523
524    fn tools(&self) -> Vec<Tool> {
525        self.dispatcher.to_tool_declarations()
526    }
527
528    fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
529        self.sub_agents.clone()
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use gemini_genai_rs::prelude::FunctionCall;
537    use gemini_genai_rs::session::{SessionError, SessionWriter};
538    use serde_json::json;
539
540    struct NoopAgent {
541        name: String,
542    }
543
544    #[async_trait]
545    impl Agent for NoopAgent {
546        fn name(&self) -> &str {
547            &self.name
548        }
549        async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
550            Ok(())
551        }
552    }
553
554    /// Mock writer that accepts all commands without error.
555    struct MockWriter;
556
557    #[async_trait]
558    impl SessionWriter for MockWriter {
559        async fn send_audio(&self, _data: Vec<u8>) -> Result<(), SessionError> {
560            Ok(())
561        }
562        async fn send_text(&self, _text: String) -> Result<(), SessionError> {
563            Ok(())
564        }
565        async fn send_tool_response(
566            &self,
567            _responses: Vec<FunctionResponse>,
568        ) -> Result<(), SessionError> {
569            Ok(())
570        }
571        async fn send_client_content(
572            &self,
573            _turns: Vec<gemini_genai_rs::prelude::Content>,
574            _turn_complete: bool,
575        ) -> Result<(), SessionError> {
576            Ok(())
577        }
578        async fn send_video(&self, _jpeg_data: Vec<u8>) -> Result<(), SessionError> {
579            Ok(())
580        }
581        async fn update_instruction(&self, _instruction: String) -> Result<(), SessionError> {
582            Ok(())
583        }
584        async fn signal_activity_start(&self) -> Result<(), SessionError> {
585            Ok(())
586        }
587        async fn signal_activity_end(&self) -> Result<(), SessionError> {
588            Ok(())
589        }
590        async fn disconnect(&self) -> Result<(), SessionError> {
591            Ok(())
592        }
593    }
594
595    /// Create a mock AgentSession backed by MockWriter, returning the session
596    /// and the event sender so tests can inject SessionEvents.
597    fn mock_agent_session() -> (
598        crate::agent_session::AgentSession,
599        broadcast::Sender<SessionEvent>,
600    ) {
601        let (evt_tx, _) = broadcast::channel(64);
602        let writer: Arc<dyn SessionWriter> = Arc::new(MockWriter);
603        let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx.clone());
604        (session, evt_tx)
605    }
606
607    #[test]
608    fn builder_creates_agent_with_name() {
609        let agent = LlmAgent::builder("test_agent").build();
610        assert_eq!(agent.name(), "test_agent");
611    }
612
613    #[test]
614    fn builder_registers_tools() {
615        let tool = SimpleTool::new("my_tool", "desc", None, |_| async { Ok(json!({})) });
616        let agent = LlmAgent::builder("test").tool(tool).build();
617        // my_tool is the only user tool (TelemetryMiddleware doesn't add tools)
618        assert_eq!(agent.dispatcher().len(), 1);
619    }
620
621    #[test]
622    fn builder_auto_registers_transfer_tools() {
623        let sub = NoopAgent {
624            name: "billing".to_string(),
625        };
626        let agent = LlmAgent::builder("root").sub_agent(sub).build();
627
628        // Should have transfer_to_billing auto-registered
629        assert!(agent.dispatcher().classify("transfer_to_billing").is_some());
630    }
631
632    #[test]
633    fn builder_with_multiple_sub_agents() {
634        let sub1 = NoopAgent {
635            name: "billing".to_string(),
636        };
637        let sub2 = NoopAgent {
638            name: "tech".to_string(),
639        };
640        let agent = LlmAgent::builder("root")
641            .sub_agent(sub1)
642            .sub_agent(sub2)
643            .build();
644
645        assert!(agent.dispatcher().classify("transfer_to_billing").is_some());
646        assert!(agent.dispatcher().classify("transfer_to_tech").is_some());
647        assert_eq!(agent.sub_agents().len(), 2);
648    }
649
650    #[test]
651    fn tools_returns_declarations() {
652        let tool = SimpleTool::new("my_tool", "desc", None, |_| async { Ok(json!({})) });
653        let agent = LlmAgent::builder("test").tool(tool).build();
654        let tools = agent.tools();
655        assert!(!tools.is_empty());
656    }
657
658    #[test]
659    fn transfer_requested_error() {
660        let err = AgentError::TransferRequested("billing".to_string());
661        assert!(err.to_string().contains("billing"));
662    }
663
664    #[test]
665    fn builder_prepends_telemetry_middleware() {
666        let agent = LlmAgent::builder("test").build();
667        // TelemetryMiddleware is auto-prepended
668        assert_eq!(agent.middleware().len(), 1);
669    }
670
671    #[test]
672    fn builder_with_user_middleware_and_telemetry() {
673        use crate::middleware::LogMiddleware;
674
675        let agent = LlmAgent::builder("test")
676            .middleware(LogMiddleware::new())
677            .build();
678        // TelemetryMiddleware (prepended) + LogMiddleware (user-added)
679        assert_eq!(agent.middleware().len(), 2);
680    }
681
682    #[test]
683    fn get_tool_returns_tool_kind() {
684        let tool = SimpleTool::new("lookup", "desc", None, |_| async { Ok(json!({})) });
685        let agent = LlmAgent::builder("test").tool(tool).build();
686        assert!(agent.dispatcher().get_tool("lookup").is_some());
687        assert!(agent.dispatcher().get_tool("nonexistent").is_none());
688    }
689
690    // ── Event loop tests ──────────────────────────────────────────────────
691
692    #[tokio::test]
693    async fn event_loop_breaks_on_turn_complete() {
694        let agent = LlmAgent::builder("test").build();
695        let (session, evt_tx) = mock_agent_session();
696        let mut ctx = InvocationContext::new(session);
697
698        // Send TurnComplete after a short delay
699        tokio::spawn(async move {
700            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
701            let _ = evt_tx.send(SessionEvent::TurnComplete);
702        });
703
704        let result = agent.run_live(&mut ctx).await;
705        assert!(result.is_ok());
706    }
707
708    #[tokio::test]
709    async fn event_loop_breaks_on_disconnect() {
710        let agent = LlmAgent::builder("test").build();
711        let (session, evt_tx) = mock_agent_session();
712        let mut ctx = InvocationContext::new(session);
713
714        tokio::spawn(async move {
715            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
716            let _ = evt_tx.send(SessionEvent::Disconnected(Some("bye".to_string())));
717        });
718
719        let result = agent.run_live(&mut ctx).await;
720        assert!(result.is_ok());
721    }
722
723    #[tokio::test]
724    async fn event_loop_dispatches_tool_call() {
725        let tool = SimpleTool::new("get_weather", "Get weather", None, |_| async {
726            Ok(json!({"temp": 22}))
727        });
728        let agent = LlmAgent::builder("test").tool(tool).build();
729        let (session, evt_tx) = mock_agent_session();
730        let mut ctx = InvocationContext::new(session);
731        let mut agent_events = ctx.subscribe();
732
733        tokio::spawn(async move {
734            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
735            let _ = evt_tx.send(SessionEvent::ToolCall(vec![FunctionCall {
736                name: "get_weather".to_string(),
737                args: json!({"city": "London"}),
738                id: Some("call-1".to_string()),
739            }]));
740            // The tool response will be sent back; then end the turn.
741            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
742            let _ = evt_tx.send(SessionEvent::TurnComplete);
743        });
744
745        let result = agent.run_live(&mut ctx).await;
746        assert!(result.is_ok());
747
748        // Check that we got ToolCallStarted and ToolCallCompleted events
749        let mut saw_tool_started = false;
750        let mut saw_tool_completed = false;
751        while let Ok(event) = agent_events.try_recv() {
752            match event {
753                AgentEvent::ToolCallStarted { name, .. } if name == "get_weather" => {
754                    saw_tool_started = true;
755                }
756                AgentEvent::ToolCallCompleted { name, result, .. } if name == "get_weather" => {
757                    assert_eq!(result["temp"], 22);
758                    saw_tool_completed = true;
759                }
760                _ => {}
761            }
762        }
763        assert!(saw_tool_started, "should have emitted ToolCallStarted");
764        assert!(saw_tool_completed, "should have emitted ToolCallCompleted");
765    }
766
767    #[tokio::test]
768    async fn event_loop_handles_unknown_tool() {
769        let agent = LlmAgent::builder("test").build();
770        let (session, evt_tx) = mock_agent_session();
771        let mut ctx = InvocationContext::new(session);
772        let mut agent_events = ctx.subscribe();
773
774        tokio::spawn(async move {
775            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
776            let _ = evt_tx.send(SessionEvent::ToolCall(vec![FunctionCall {
777                name: "nonexistent_tool".to_string(),
778                args: json!({}),
779                id: Some("call-1".to_string()),
780            }]));
781            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
782            let _ = evt_tx.send(SessionEvent::TurnComplete);
783        });
784
785        let result = agent.run_live(&mut ctx).await;
786        assert!(result.is_ok());
787
788        // Check that we got a ToolCallFailed event
789        let mut saw_tool_failed = false;
790        while let Ok(event) = agent_events.try_recv() {
791            if let AgentEvent::ToolCallFailed { name, error } = event {
792                if name == "nonexistent_tool" {
793                    assert!(error.contains("not found") || error.contains("Not found"));
794                    saw_tool_failed = true;
795                }
796            }
797        }
798        assert!(
799            saw_tool_failed,
800            "should have emitted ToolCallFailed for unknown tool"
801        );
802    }
803
804    #[tokio::test]
805    async fn event_loop_detects_transfer() {
806        let sub = NoopAgent {
807            name: "billing".to_string(),
808        };
809        let agent = LlmAgent::builder("root").sub_agent(sub).build();
810
811        let (session, evt_tx) = mock_agent_session();
812        let mut ctx = InvocationContext::new(session);
813        let mut agent_events = ctx.subscribe();
814
815        tokio::spawn(async move {
816            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
817            let _ = evt_tx.send(SessionEvent::ToolCall(vec![FunctionCall {
818                name: "transfer_to_billing".to_string(),
819                args: json!({}),
820                id: Some("call-1".to_string()),
821            }]));
822        });
823
824        let result = agent.run_live(&mut ctx).await;
825        match result {
826            Err(AgentError::TransferRequested(target)) => assert_eq!(target, "billing"),
827            other => panic!("expected TransferRequested, got: {:?}", other),
828        }
829
830        // Check that AgentTransfer event was emitted
831        let mut saw_transfer = false;
832        while let Ok(event) = agent_events.try_recv() {
833            if let AgentEvent::AgentTransfer { from, to } = event {
834                assert_eq!(from, "root");
835                assert_eq!(to, "billing");
836                saw_transfer = true;
837            }
838        }
839        assert!(saw_transfer, "should have emitted AgentTransfer event");
840    }
841
842    #[tokio::test]
843    async fn event_loop_passes_through_events() {
844        let agent = LlmAgent::builder("test").build();
845        let (session, evt_tx) = mock_agent_session();
846        let mut ctx = InvocationContext::new(session);
847        let mut agent_events = ctx.subscribe();
848
849        tokio::spawn(async move {
850            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
851            let _ = evt_tx.send(SessionEvent::TextDelta("hello".to_string()));
852            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
853            let _ = evt_tx.send(SessionEvent::TurnComplete);
854        });
855
856        agent.run_live(&mut ctx).await.unwrap();
857
858        // Check that we got AgentStarted, TextDelta passthrough, TurnComplete, AgentCompleted
859        let mut saw_text_delta = false;
860        let mut saw_started = false;
861        let mut saw_completed = false;
862        while let Ok(event) = agent_events.try_recv() {
863            match event {
864                AgentEvent::AgentStarted { .. } => saw_started = true,
865                AgentEvent::AgentCompleted { .. } => saw_completed = true,
866                AgentEvent::Session(SessionEvent::TextDelta(t)) if t == "hello" => {
867                    saw_text_delta = true;
868                }
869                _ => {}
870            }
871        }
872        assert!(saw_started, "should have emitted AgentStarted");
873        assert!(saw_text_delta, "should have passed through TextDelta");
874        assert!(saw_completed, "should have emitted AgentCompleted");
875    }
876
877    #[tokio::test]
878    async fn event_loop_handles_error_event() {
879        let agent = LlmAgent::builder("test").build();
880        let (session, evt_tx) = mock_agent_session();
881        let mut ctx = InvocationContext::new(session);
882        let mut agent_events = ctx.subscribe();
883
884        tokio::spawn(async move {
885            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
886            let _ = evt_tx.send(SessionEvent::Error("something broke".to_string()));
887            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
888            let _ = evt_tx.send(SessionEvent::TurnComplete);
889        });
890
891        agent.run_live(&mut ctx).await.unwrap();
892
893        // Check that the error event was passed through
894        let mut saw_error = false;
895        while let Ok(event) = agent_events.try_recv() {
896            if let AgentEvent::Session(SessionEvent::Error(e)) = event {
897                assert_eq!(e, "something broke");
898                saw_error = true;
899            }
900        }
901        assert!(saw_error, "should have passed through Error event");
902    }
903
904    #[tokio::test]
905    async fn event_loop_emits_lifecycle_events() {
906        let agent = LlmAgent::builder("lifecycle_test").build();
907        let (session, evt_tx) = mock_agent_session();
908        let mut ctx = InvocationContext::new(session);
909        let mut agent_events = ctx.subscribe();
910
911        tokio::spawn(async move {
912            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
913            let _ = evt_tx.send(SessionEvent::TurnComplete);
914        });
915
916        agent.run_live(&mut ctx).await.unwrap();
917
918        let mut events = Vec::new();
919        while let Ok(event) = agent_events.try_recv() {
920            events.push(event);
921        }
922
923        // First event should be AgentStarted
924        assert!(
925            matches!(&events[0], AgentEvent::AgentStarted { name } if name == "lifecycle_test"),
926            "first event should be AgentStarted, got: {:?}",
927            events[0]
928        );
929
930        // Last event should be AgentCompleted
931        let last = events.last().unwrap();
932        assert!(
933            matches!(last, AgentEvent::AgentCompleted { name } if name == "lifecycle_test"),
934            "last event should be AgentCompleted, got: {:?}",
935            last
936        );
937    }
938
939    #[tokio::test]
940    async fn event_loop_tool_failure_emits_failed_event() {
941        let tool = SimpleTool::new("failing_tool", "Always fails", None, |_| async {
942            Err(ToolError::ExecutionFailed("kaboom".to_string()))
943        });
944        let agent = LlmAgent::builder("test").tool(tool).build();
945        let (session, evt_tx) = mock_agent_session();
946        let mut ctx = InvocationContext::new(session);
947        let mut agent_events = ctx.subscribe();
948
949        tokio::spawn(async move {
950            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
951            let _ = evt_tx.send(SessionEvent::ToolCall(vec![FunctionCall {
952                name: "failing_tool".to_string(),
953                args: json!({}),
954                id: Some("call-1".to_string()),
955            }]));
956            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
957            let _ = evt_tx.send(SessionEvent::TurnComplete);
958        });
959
960        agent.run_live(&mut ctx).await.unwrap();
961
962        let mut saw_tool_failed = false;
963        while let Ok(event) = agent_events.try_recv() {
964            if let AgentEvent::ToolCallFailed { name, error } = event {
965                if name == "failing_tool" {
966                    assert!(error.contains("kaboom"));
967                    saw_tool_failed = true;
968                }
969            }
970        }
971        assert!(saw_tool_failed, "should have emitted ToolCallFailed");
972    }
973}