1use std::sync::Arc;
12use std::time::Duration;
13
14use async_trait::async_trait;
15use serde_json::json;
16use tokio::sync::broadcast;
17
18use gemini_genai_rs::prelude::{recv_event, FunctionResponse, Tool};
19use gemini_genai_rs::session::SessionEvent;
20
21use crate::agent::Agent;
22use crate::context::{AgentEvent, InvocationContext};
23use crate::error::{AgentError, ToolError};
24use crate::middleware::MiddlewareChain;
25use crate::plugin::{PluginManager, PluginResult};
26use crate::tool::{
27 ActiveStreamingTool, InputStreamingTool, SimpleTool, StreamingTool, ToolClass, ToolDispatcher,
28 ToolFunction, ToolKind, TypedTool,
29};
30
31pub struct LlmAgent {
37 name: String,
38 dispatcher: ToolDispatcher,
39 middleware: MiddlewareChain,
40 plugins: PluginManager,
41 sub_agents: Vec<Arc<dyn Agent>>,
42}
43
44impl LlmAgent {
45 pub fn builder(name: impl Into<String>) -> LlmAgentBuilder {
47 LlmAgentBuilder {
48 name: name.into(),
49 dispatcher: ToolDispatcher::new(),
50 middleware: MiddlewareChain::new(),
51 plugins: PluginManager::new(),
52 sub_agents: Vec::new(),
53 }
54 }
55
56 pub fn dispatcher(&self) -> &ToolDispatcher {
58 &self.dispatcher
59 }
60
61 pub fn middleware(&self) -> &MiddlewareChain {
63 &self.middleware
64 }
65
66 pub fn plugins(&self) -> &PluginManager {
68 &self.plugins
69 }
70
71 async fn event_loop(
73 &self,
74 ctx: &mut InvocationContext,
75 events: &mut broadcast::Receiver<SessionEvent>,
76 agent_name: &str,
77 ) -> Result<(), AgentError> {
78 loop {
79 let event = match recv_event(events).await {
80 Some(e) => e,
81 None => break, };
83
84 match event {
85 SessionEvent::ToolCall(calls) => {
86 let mut responses = Vec::new();
87 let mut transfer_target = None;
88
89 for call in &calls {
90 ctx.emit(AgentEvent::ToolCallStarted {
92 name: call.name.clone(),
93 args: call.args.clone(),
94 });
95 let _ = ctx.middleware.run_before_tool(call).await;
96
97 let plugin_result = self.plugins.run_before_tool(call, ctx).await;
99 match &plugin_result {
100 PluginResult::Deny(reason) => {
101 ctx.emit(AgentEvent::ToolCallFailed {
102 name: call.name.clone(),
103 error: format!("Denied by plugin: {}", reason),
104 });
105 responses.push(ToolDispatcher::build_response(
106 call,
107 Err(ToolError::ExecutionFailed(format!(
108 "Denied by plugin: {}",
109 reason
110 ))),
111 ));
112 continue;
113 }
114 PluginResult::ShortCircuit(value) => {
115 let _ = ctx.middleware.run_after_tool(call, value).await;
116 ctx.emit(AgentEvent::ToolCallCompleted {
117 name: call.name.clone(),
118 result: value.clone(),
119 duration: std::time::Duration::ZERO,
120 });
121 responses
122 .push(ToolDispatcher::build_response(call, Ok(value.clone())));
123 continue;
124 }
125 PluginResult::Continue => {}
126 }
127
128 let tool_start = std::time::Instant::now();
129 let tool_class = self.dispatcher.classify(&call.name);
130
131 match tool_class {
132 Some(ToolClass::Regular) => {
133 crate::telemetry::logging::log_tool_dispatch(
134 agent_name, &call.name, "function",
135 );
136 crate::telemetry::metrics::record_agent_tool_dispatched(
137 agent_name, &call.name,
138 );
139
140 let result = self
141 .dispatcher
142 .call_function(&call.name, call.args.clone())
143 .await;
144 let elapsed = tool_start.elapsed();
145
146 match &result {
147 Ok(value) => {
148 if let Some(target) =
150 value.get("__transfer_to").and_then(|v| v.as_str())
151 {
152 transfer_target = Some(target.to_string());
153 }
154
155 let _ = ctx.middleware.run_after_tool(call, value).await;
156 let _ = self.plugins.run_after_tool(call, value, ctx).await;
157 ctx.emit(AgentEvent::ToolCallCompleted {
158 name: call.name.clone(),
159 result: value.clone(),
160 duration: elapsed,
161 });
162 crate::telemetry::logging::log_tool_result(
163 agent_name,
164 &call.name,
165 true,
166 elapsed.as_millis() as f64,
167 );
168 crate::telemetry::metrics::record_agent_tool_duration(
169 agent_name,
170 &call.name,
171 elapsed.as_millis() as f64,
172 );
173 }
174 Err(e) => {
175 let _ = ctx.middleware.run_on_tool_error(call, e).await;
176 ctx.emit(AgentEvent::ToolCallFailed {
177 name: call.name.clone(),
178 error: e.to_string(),
179 });
180 crate::telemetry::logging::log_tool_result(
181 agent_name,
182 &call.name,
183 false,
184 elapsed.as_millis() as f64,
185 );
186 }
187 }
188
189 responses.push(ToolDispatcher::build_response(call, result));
190 }
191 Some(ToolClass::Streaming) | Some(ToolClass::InputStream) => {
192 let class_str = if tool_class == Some(ToolClass::Streaming) {
193 "streaming"
194 } else {
195 "input_stream"
196 };
197 crate::telemetry::logging::log_tool_dispatch(
198 agent_name, &call.name, class_str,
199 );
200
201 self.spawn_streaming_tool(call, ctx, agent_name).await;
202
203 responses.push(FunctionResponse {
204 name: call.name.clone(),
205 response: json!({"status": "streaming"}),
206 id: call.id.clone(),
207 scheduling: None,
208 });
209 }
210 None => {
211 ctx.emit(AgentEvent::ToolCallFailed {
212 name: call.name.clone(),
213 error: format!("Tool not found: {}", call.name),
214 });
215 responses.push(ToolDispatcher::build_response(
216 call,
217 Err(ToolError::NotFound(call.name.clone())),
218 ));
219 }
220 }
221 }
222
223 ctx.agent_session.send_tool_response(responses).await?;
225
226 if let Some(target) = transfer_target {
228 ctx.emit(AgentEvent::AgentTransfer {
229 from: agent_name.to_string(),
230 to: target.clone(),
231 });
232 crate::telemetry::metrics::record_agent_transfer(agent_name, &target);
233 crate::telemetry::logging::log_agent_transfer(agent_name, &target);
234 return Err(AgentError::TransferRequested(target));
235 }
236 }
237 SessionEvent::ToolCallCancelled(ids) => {
238 self.dispatcher.cancel_by_ids(&ids).await;
239 }
240 SessionEvent::TurnComplete => {
241 ctx.emit(AgentEvent::Session(SessionEvent::TurnComplete));
242 break;
243 }
244 SessionEvent::Disconnected(reason) => {
245 ctx.emit(AgentEvent::Session(SessionEvent::Disconnected(reason)));
246 break;
247 }
248 SessionEvent::Error(ref e) => {
249 ctx.emit(AgentEvent::Session(event.clone()));
250 crate::telemetry::metrics::record_agent_error(agent_name, "session_error");
251 crate::telemetry::logging::log_agent_error(agent_name, e);
252 }
253 other => {
254 ctx.emit(AgentEvent::Session(other));
256 }
257 }
258 }
259 Ok(())
260 }
261
262 async fn spawn_streaming_tool(
264 &self,
265 call: &gemini_genai_rs::prelude::FunctionCall,
266 ctx: &InvocationContext,
267 _agent_name: &str,
268 ) {
269 let tool_kind = match self.dispatcher.get_tool(&call.name) {
270 Some(kind) => kind,
271 None => return,
272 };
273
274 let (yield_tx, mut yield_rx) = tokio::sync::mpsc::channel::<serde_json::Value>(32);
275 let cancel = tokio_util::sync::CancellationToken::new();
276
277 let tool_name = call.name.clone();
278 let call_id = call.id.clone();
279 let args = call.args.clone();
280 let event_tx = ctx.event_tx.clone();
281 let agent_session = ctx.agent_session.clone();
282
283 match tool_kind {
284 ToolKind::Streaming(tool) => {
285 let tool = tool.clone();
286 let cancel_clone = cancel.clone();
287 let tool_name_err = tool_name.clone();
288 let event_tx_err = event_tx.clone();
289
290 let tool_task = tokio::spawn(async move {
291 tokio::select! {
292 result = tool.run(args, yield_tx) => {
293 if let Err(e) = result {
294 let _ = event_tx_err.send(AgentEvent::ToolCallFailed {
295 name: tool_name_err,
296 error: e.to_string(),
297 });
298 }
299 }
300 _ = cancel_clone.cancelled() => {}
301 }
302 });
303
304 let active = ActiveStreamingTool {
305 task: tool_task,
306 cancel,
307 };
308 let id = call_id.clone().unwrap_or_else(|| tool_name.clone());
309 self.dispatcher.store_active(id, active).await;
310 }
311 ToolKind::InputStream(tool) => {
312 let tool = tool.clone();
313 let input_rx = ctx.agent_session.subscribe_input();
314 let cancel_clone = cancel.clone();
315 let tool_name_err = tool_name.clone();
316 let event_tx_err = event_tx.clone();
317
318 let tool_task = tokio::spawn(async move {
319 tokio::select! {
320 result = tool.run(args, input_rx, yield_tx) => {
321 if let Err(e) = result {
322 let _ = event_tx_err.send(AgentEvent::ToolCallFailed {
323 name: tool_name_err,
324 error: e.to_string(),
325 });
326 }
327 }
328 _ = cancel_clone.cancelled() => {}
329 }
330 });
331
332 let active = ActiveStreamingTool {
333 task: tool_task,
334 cancel,
335 };
336 let id = call_id.clone().unwrap_or_else(|| tool_name.clone());
337 self.dispatcher.store_active(id, active).await;
338 }
339 ToolKind::Function(_) => {} }
341
342 let yield_tool_name = call.name.clone();
344 let yield_call_id = call.id.clone();
345
346 tokio::spawn(async move {
347 let mut all_yields = Vec::new();
348 while let Some(value) = yield_rx.recv().await {
349 let _ = event_tx.send(AgentEvent::StreamingToolYield {
350 name: yield_tool_name.clone(),
351 value: value.clone(),
352 });
353 all_yields.push(value);
354 }
355
356 let final_response = if all_yields.is_empty() {
358 json!({"status": "completed"})
359 } else if all_yields.len() == 1 {
360 all_yields.into_iter().next().unwrap()
361 } else {
362 json!({"results": all_yields})
363 };
364
365 let resp = FunctionResponse {
366 name: yield_tool_name,
367 response: final_response,
368 id: yield_call_id,
369 scheduling: None,
370 };
371 let _ = agent_session.send_tool_response(vec![resp]).await;
372 });
373 }
374}
375
376pub struct LlmAgentBuilder {
378 name: String,
379 dispatcher: ToolDispatcher,
380 middleware: MiddlewareChain,
381 plugins: PluginManager,
382 sub_agents: Vec<Arc<dyn Agent>>,
383}
384
385impl LlmAgentBuilder {
386 pub fn tool(mut self, tool: impl ToolFunction + 'static) -> Self {
388 self.dispatcher.register_function(Arc::new(tool));
389 self
390 }
391
392 pub fn typed_tool<T>(mut self, tool: TypedTool<T>) -> Self
394 where
395 T: serde::de::DeserializeOwned + schemars::JsonSchema + Send + Sync + 'static,
396 {
397 self.dispatcher.register_function(Arc::new(tool));
398 self
399 }
400
401 pub fn streaming_tool(mut self, tool: impl StreamingTool + 'static) -> Self {
403 self.dispatcher.register_streaming(Arc::new(tool));
404 self
405 }
406
407 pub fn input_streaming_tool(mut self, tool: impl InputStreamingTool + 'static) -> Self {
409 self.dispatcher.register_input_streaming(Arc::new(tool));
410 self
411 }
412
413 pub fn middleware(mut self, mw: impl crate::middleware::Middleware + 'static) -> Self {
415 self.middleware.add(Arc::new(mw));
416 self
417 }
418
419 pub fn plugin(mut self, plugin: impl crate::plugin::Plugin + 'static) -> Self {
421 self.plugins.add(Arc::new(plugin));
422 self
423 }
424
425 pub fn sub_agent(mut self, agent: impl Agent + 'static) -> Self {
427 self.sub_agents.push(Arc::new(agent));
428 self
429 }
430
431 pub fn tool_timeout(mut self, timeout: Duration) -> Self {
433 self.dispatcher = self.dispatcher.with_timeout(timeout);
434 self
435 }
436
437 pub fn build(mut self) -> LlmAgent {
444 for sub in &self.sub_agents {
446 let target_name = sub.name().to_string();
447 let tool_name = format!("transfer_to_{}", target_name);
448 let transfer_tool = SimpleTool::new(
449 tool_name,
450 format!("Transfer conversation to the {} agent", target_name),
451 Some(json!({
452 "type": "object",
453 "properties": {},
454 })),
455 move |_args| {
456 let name = target_name.clone();
457 async move { Ok(json!({"__transfer_to": name})) }
458 },
459 );
460 self.dispatcher.register_function(Arc::new(transfer_tool));
461 }
462
463 self.middleware
465 .prepend(Arc::new(crate::telemetry::TelemetryMiddleware::new(
466 &self.name,
467 )));
468
469 LlmAgent {
470 name: self.name,
471 dispatcher: self.dispatcher,
472 middleware: self.middleware,
473 plugins: self.plugins,
474 sub_agents: self.sub_agents,
475 }
476 }
477}
478
479#[async_trait]
480impl Agent for LlmAgent {
481 fn name(&self) -> &str {
482 &self.name
483 }
484
485 async fn run_live(&self, ctx: &mut InvocationContext) -> Result<(), AgentError> {
486 let agent_name = self.name.clone();
487 let start = std::time::Instant::now();
488
489 crate::telemetry::logging::log_agent_started(&agent_name, self.dispatcher.len());
491 crate::telemetry::metrics::record_agent_started(&agent_name);
492 ctx.middleware.run_before_agent(ctx).await?;
493
494 let plugin_result = self.plugins.run_before_agent(ctx).await;
496 if let PluginResult::Deny(reason) = plugin_result {
497 return Err(AgentError::Other(format!(
498 "Agent denied by plugin: {}",
499 reason
500 )));
501 }
502
503 ctx.emit(AgentEvent::AgentStarted {
504 name: agent_name.clone(),
505 });
506
507 let mut events = ctx.agent_session.subscribe_events();
508
509 let result = self.event_loop(ctx, &mut events, &agent_name).await;
510
511 let elapsed = start.elapsed();
513 ctx.middleware.run_after_agent(ctx).await?;
514 let _ = self.plugins.run_after_agent(ctx).await;
515 ctx.emit(AgentEvent::AgentCompleted {
516 name: agent_name.clone(),
517 });
518 crate::telemetry::logging::log_agent_completed(&agent_name, elapsed.as_millis() as f64);
519 crate::telemetry::metrics::record_agent_completed(&agent_name, elapsed.as_millis() as f64);
520
521 result
522 }
523
524 fn tools(&self) -> Vec<Tool> {
525 self.dispatcher.to_tool_declarations()
526 }
527
528 fn sub_agents(&self) -> Vec<Arc<dyn Agent>> {
529 self.sub_agents.clone()
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use gemini_genai_rs::prelude::FunctionCall;
537 use gemini_genai_rs::session::{SessionError, SessionWriter};
538 use serde_json::json;
539
540 struct NoopAgent {
541 name: String,
542 }
543
544 #[async_trait]
545 impl Agent for NoopAgent {
546 fn name(&self) -> &str {
547 &self.name
548 }
549 async fn run_live(&self, _ctx: &mut InvocationContext) -> Result<(), AgentError> {
550 Ok(())
551 }
552 }
553
554 struct MockWriter;
556
557 #[async_trait]
558 impl SessionWriter for MockWriter {
559 async fn send_audio(&self, _data: Vec<u8>) -> Result<(), SessionError> {
560 Ok(())
561 }
562 async fn send_text(&self, _text: String) -> Result<(), SessionError> {
563 Ok(())
564 }
565 async fn send_tool_response(
566 &self,
567 _responses: Vec<FunctionResponse>,
568 ) -> Result<(), SessionError> {
569 Ok(())
570 }
571 async fn send_client_content(
572 &self,
573 _turns: Vec<gemini_genai_rs::prelude::Content>,
574 _turn_complete: bool,
575 ) -> Result<(), SessionError> {
576 Ok(())
577 }
578 async fn send_video(&self, _jpeg_data: Vec<u8>) -> Result<(), SessionError> {
579 Ok(())
580 }
581 async fn update_instruction(&self, _instruction: String) -> Result<(), SessionError> {
582 Ok(())
583 }
584 async fn signal_activity_start(&self) -> Result<(), SessionError> {
585 Ok(())
586 }
587 async fn signal_activity_end(&self) -> Result<(), SessionError> {
588 Ok(())
589 }
590 async fn disconnect(&self) -> Result<(), SessionError> {
591 Ok(())
592 }
593 }
594
595 fn mock_agent_session() -> (
598 crate::agent_session::AgentSession,
599 broadcast::Sender<SessionEvent>,
600 ) {
601 let (evt_tx, _) = broadcast::channel(64);
602 let writer: Arc<dyn SessionWriter> = Arc::new(MockWriter);
603 let session = crate::agent_session::AgentSession::from_writer(writer, evt_tx.clone());
604 (session, evt_tx)
605 }
606
607 #[test]
608 fn builder_creates_agent_with_name() {
609 let agent = LlmAgent::builder("test_agent").build();
610 assert_eq!(agent.name(), "test_agent");
611 }
612
613 #[test]
614 fn builder_registers_tools() {
615 let tool = SimpleTool::new("my_tool", "desc", None, |_| async { Ok(json!({})) });
616 let agent = LlmAgent::builder("test").tool(tool).build();
617 assert_eq!(agent.dispatcher().len(), 1);
619 }
620
621 #[test]
622 fn builder_auto_registers_transfer_tools() {
623 let sub = NoopAgent {
624 name: "billing".to_string(),
625 };
626 let agent = LlmAgent::builder("root").sub_agent(sub).build();
627
628 assert!(agent.dispatcher().classify("transfer_to_billing").is_some());
630 }
631
632 #[test]
633 fn builder_with_multiple_sub_agents() {
634 let sub1 = NoopAgent {
635 name: "billing".to_string(),
636 };
637 let sub2 = NoopAgent {
638 name: "tech".to_string(),
639 };
640 let agent = LlmAgent::builder("root")
641 .sub_agent(sub1)
642 .sub_agent(sub2)
643 .build();
644
645 assert!(agent.dispatcher().classify("transfer_to_billing").is_some());
646 assert!(agent.dispatcher().classify("transfer_to_tech").is_some());
647 assert_eq!(agent.sub_agents().len(), 2);
648 }
649
650 #[test]
651 fn tools_returns_declarations() {
652 let tool = SimpleTool::new("my_tool", "desc", None, |_| async { Ok(json!({})) });
653 let agent = LlmAgent::builder("test").tool(tool).build();
654 let tools = agent.tools();
655 assert!(!tools.is_empty());
656 }
657
658 #[test]
659 fn transfer_requested_error() {
660 let err = AgentError::TransferRequested("billing".to_string());
661 assert!(err.to_string().contains("billing"));
662 }
663
664 #[test]
665 fn builder_prepends_telemetry_middleware() {
666 let agent = LlmAgent::builder("test").build();
667 assert_eq!(agent.middleware().len(), 1);
669 }
670
671 #[test]
672 fn builder_with_user_middleware_and_telemetry() {
673 use crate::middleware::LogMiddleware;
674
675 let agent = LlmAgent::builder("test")
676 .middleware(LogMiddleware::new())
677 .build();
678 assert_eq!(agent.middleware().len(), 2);
680 }
681
682 #[test]
683 fn get_tool_returns_tool_kind() {
684 let tool = SimpleTool::new("lookup", "desc", None, |_| async { Ok(json!({})) });
685 let agent = LlmAgent::builder("test").tool(tool).build();
686 assert!(agent.dispatcher().get_tool("lookup").is_some());
687 assert!(agent.dispatcher().get_tool("nonexistent").is_none());
688 }
689
690 #[tokio::test]
693 async fn event_loop_breaks_on_turn_complete() {
694 let agent = LlmAgent::builder("test").build();
695 let (session, evt_tx) = mock_agent_session();
696 let mut ctx = InvocationContext::new(session);
697
698 tokio::spawn(async move {
700 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
701 let _ = evt_tx.send(SessionEvent::TurnComplete);
702 });
703
704 let result = agent.run_live(&mut ctx).await;
705 assert!(result.is_ok());
706 }
707
708 #[tokio::test]
709 async fn event_loop_breaks_on_disconnect() {
710 let agent = LlmAgent::builder("test").build();
711 let (session, evt_tx) = mock_agent_session();
712 let mut ctx = InvocationContext::new(session);
713
714 tokio::spawn(async move {
715 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
716 let _ = evt_tx.send(SessionEvent::Disconnected(Some("bye".to_string())));
717 });
718
719 let result = agent.run_live(&mut ctx).await;
720 assert!(result.is_ok());
721 }
722
723 #[tokio::test]
724 async fn event_loop_dispatches_tool_call() {
725 let tool = SimpleTool::new("get_weather", "Get weather", None, |_| async {
726 Ok(json!({"temp": 22}))
727 });
728 let agent = LlmAgent::builder("test").tool(tool).build();
729 let (session, evt_tx) = mock_agent_session();
730 let mut ctx = InvocationContext::new(session);
731 let mut agent_events = ctx.subscribe();
732
733 tokio::spawn(async move {
734 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
735 let _ = evt_tx.send(SessionEvent::ToolCall(vec![FunctionCall {
736 name: "get_weather".to_string(),
737 args: json!({"city": "London"}),
738 id: Some("call-1".to_string()),
739 }]));
740 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
742 let _ = evt_tx.send(SessionEvent::TurnComplete);
743 });
744
745 let result = agent.run_live(&mut ctx).await;
746 assert!(result.is_ok());
747
748 let mut saw_tool_started = false;
750 let mut saw_tool_completed = false;
751 while let Ok(event) = agent_events.try_recv() {
752 match event {
753 AgentEvent::ToolCallStarted { name, .. } if name == "get_weather" => {
754 saw_tool_started = true;
755 }
756 AgentEvent::ToolCallCompleted { name, result, .. } if name == "get_weather" => {
757 assert_eq!(result["temp"], 22);
758 saw_tool_completed = true;
759 }
760 _ => {}
761 }
762 }
763 assert!(saw_tool_started, "should have emitted ToolCallStarted");
764 assert!(saw_tool_completed, "should have emitted ToolCallCompleted");
765 }
766
767 #[tokio::test]
768 async fn event_loop_handles_unknown_tool() {
769 let agent = LlmAgent::builder("test").build();
770 let (session, evt_tx) = mock_agent_session();
771 let mut ctx = InvocationContext::new(session);
772 let mut agent_events = ctx.subscribe();
773
774 tokio::spawn(async move {
775 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
776 let _ = evt_tx.send(SessionEvent::ToolCall(vec![FunctionCall {
777 name: "nonexistent_tool".to_string(),
778 args: json!({}),
779 id: Some("call-1".to_string()),
780 }]));
781 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
782 let _ = evt_tx.send(SessionEvent::TurnComplete);
783 });
784
785 let result = agent.run_live(&mut ctx).await;
786 assert!(result.is_ok());
787
788 let mut saw_tool_failed = false;
790 while let Ok(event) = agent_events.try_recv() {
791 if let AgentEvent::ToolCallFailed { name, error } = event {
792 if name == "nonexistent_tool" {
793 assert!(error.contains("not found") || error.contains("Not found"));
794 saw_tool_failed = true;
795 }
796 }
797 }
798 assert!(
799 saw_tool_failed,
800 "should have emitted ToolCallFailed for unknown tool"
801 );
802 }
803
804 #[tokio::test]
805 async fn event_loop_detects_transfer() {
806 let sub = NoopAgent {
807 name: "billing".to_string(),
808 };
809 let agent = LlmAgent::builder("root").sub_agent(sub).build();
810
811 let (session, evt_tx) = mock_agent_session();
812 let mut ctx = InvocationContext::new(session);
813 let mut agent_events = ctx.subscribe();
814
815 tokio::spawn(async move {
816 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
817 let _ = evt_tx.send(SessionEvent::ToolCall(vec![FunctionCall {
818 name: "transfer_to_billing".to_string(),
819 args: json!({}),
820 id: Some("call-1".to_string()),
821 }]));
822 });
823
824 let result = agent.run_live(&mut ctx).await;
825 match result {
826 Err(AgentError::TransferRequested(target)) => assert_eq!(target, "billing"),
827 other => panic!("expected TransferRequested, got: {:?}", other),
828 }
829
830 let mut saw_transfer = false;
832 while let Ok(event) = agent_events.try_recv() {
833 if let AgentEvent::AgentTransfer { from, to } = event {
834 assert_eq!(from, "root");
835 assert_eq!(to, "billing");
836 saw_transfer = true;
837 }
838 }
839 assert!(saw_transfer, "should have emitted AgentTransfer event");
840 }
841
842 #[tokio::test]
843 async fn event_loop_passes_through_events() {
844 let agent = LlmAgent::builder("test").build();
845 let (session, evt_tx) = mock_agent_session();
846 let mut ctx = InvocationContext::new(session);
847 let mut agent_events = ctx.subscribe();
848
849 tokio::spawn(async move {
850 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
851 let _ = evt_tx.send(SessionEvent::TextDelta("hello".to_string()));
852 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
853 let _ = evt_tx.send(SessionEvent::TurnComplete);
854 });
855
856 agent.run_live(&mut ctx).await.unwrap();
857
858 let mut saw_text_delta = false;
860 let mut saw_started = false;
861 let mut saw_completed = false;
862 while let Ok(event) = agent_events.try_recv() {
863 match event {
864 AgentEvent::AgentStarted { .. } => saw_started = true,
865 AgentEvent::AgentCompleted { .. } => saw_completed = true,
866 AgentEvent::Session(SessionEvent::TextDelta(t)) if t == "hello" => {
867 saw_text_delta = true;
868 }
869 _ => {}
870 }
871 }
872 assert!(saw_started, "should have emitted AgentStarted");
873 assert!(saw_text_delta, "should have passed through TextDelta");
874 assert!(saw_completed, "should have emitted AgentCompleted");
875 }
876
877 #[tokio::test]
878 async fn event_loop_handles_error_event() {
879 let agent = LlmAgent::builder("test").build();
880 let (session, evt_tx) = mock_agent_session();
881 let mut ctx = InvocationContext::new(session);
882 let mut agent_events = ctx.subscribe();
883
884 tokio::spawn(async move {
885 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
886 let _ = evt_tx.send(SessionEvent::Error("something broke".to_string()));
887 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
888 let _ = evt_tx.send(SessionEvent::TurnComplete);
889 });
890
891 agent.run_live(&mut ctx).await.unwrap();
892
893 let mut saw_error = false;
895 while let Ok(event) = agent_events.try_recv() {
896 if let AgentEvent::Session(SessionEvent::Error(e)) = event {
897 assert_eq!(e, "something broke");
898 saw_error = true;
899 }
900 }
901 assert!(saw_error, "should have passed through Error event");
902 }
903
904 #[tokio::test]
905 async fn event_loop_emits_lifecycle_events() {
906 let agent = LlmAgent::builder("lifecycle_test").build();
907 let (session, evt_tx) = mock_agent_session();
908 let mut ctx = InvocationContext::new(session);
909 let mut agent_events = ctx.subscribe();
910
911 tokio::spawn(async move {
912 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
913 let _ = evt_tx.send(SessionEvent::TurnComplete);
914 });
915
916 agent.run_live(&mut ctx).await.unwrap();
917
918 let mut events = Vec::new();
919 while let Ok(event) = agent_events.try_recv() {
920 events.push(event);
921 }
922
923 assert!(
925 matches!(&events[0], AgentEvent::AgentStarted { name } if name == "lifecycle_test"),
926 "first event should be AgentStarted, got: {:?}",
927 events[0]
928 );
929
930 let last = events.last().unwrap();
932 assert!(
933 matches!(last, AgentEvent::AgentCompleted { name } if name == "lifecycle_test"),
934 "last event should be AgentCompleted, got: {:?}",
935 last
936 );
937 }
938
939 #[tokio::test]
940 async fn event_loop_tool_failure_emits_failed_event() {
941 let tool = SimpleTool::new("failing_tool", "Always fails", None, |_| async {
942 Err(ToolError::ExecutionFailed("kaboom".to_string()))
943 });
944 let agent = LlmAgent::builder("test").tool(tool).build();
945 let (session, evt_tx) = mock_agent_session();
946 let mut ctx = InvocationContext::new(session);
947 let mut agent_events = ctx.subscribe();
948
949 tokio::spawn(async move {
950 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
951 let _ = evt_tx.send(SessionEvent::ToolCall(vec![FunctionCall {
952 name: "failing_tool".to_string(),
953 args: json!({}),
954 id: Some("call-1".to_string()),
955 }]));
956 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
957 let _ = evt_tx.send(SessionEvent::TurnComplete);
958 });
959
960 agent.run_live(&mut ctx).await.unwrap();
961
962 let mut saw_tool_failed = false;
963 while let Ok(event) = agent_events.try_recv() {
964 if let AgentEvent::ToolCallFailed { name, error } = event {
965 if name == "failing_tool" {
966 assert!(error.contains("kaboom"));
967 saw_tool_failed = true;
968 }
969 }
970 }
971 assert!(saw_tool_failed, "should have emitted ToolCallFailed");
972 }
973}