1use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio_util::sync::CancellationToken;
7
8use gemini_genai_rs::prelude::{ConnectBuilder, SessionConfig, SessionPhase};
9use gemini_genai_rs::session::{SessionHandle, SessionWriter};
10
11use crate::error::AgentError;
12use crate::state::State;
13use crate::tool::ToolDispatcher;
14
15use super::background_tool::{BackgroundToolTracker, ToolExecutionMode};
16use super::callbacks::EventCallbacks;
17use super::computed::ComputedRegistry;
18use super::context_writer::{DeferredWriter, PendingContext};
19use super::extractor::TurnExtractor;
20use super::handle::LiveHandle;
21use super::needs::{NeedsFulfillment, RepairConfig};
22use super::persistence::SessionPersistence;
23use super::phase::PhaseMachine;
24use super::processor::{spawn_event_processor, spawn_telemetry_lane, ControlPlaneConfig};
25use super::session_signals::SessionSignals;
26use super::soft_turn::SoftTurnDetector;
27use super::steering::{ContextDelivery, SteeringMode};
28use super::telemetry::SessionTelemetry;
29use super::temporal::TemporalRegistry;
30use super::watcher::WatcherRegistry;
31
32pub struct LiveSessionBuilder {
42 config: SessionConfig,
43 callbacks: EventCallbacks,
44 dispatcher: Option<Arc<ToolDispatcher>>,
45 extractors: Vec<Arc<dyn TurnExtractor>>,
46 computed: Option<ComputedRegistry>,
47 phase_machine: Option<PhaseMachine>,
48 watchers: Option<WatcherRegistry>,
49 temporal: Option<TemporalRegistry>,
50 greeting: Option<String>,
51 state: Option<State>,
52 execution_modes: HashMap<String, ToolExecutionMode>,
53 soft_turn_timeout: Option<std::time::Duration>,
55 steering_mode: SteeringMode,
56 context_delivery: ContextDelivery,
57 delivery: super::processor::DeliveryConfig,
58 repair_config: Option<RepairConfig>,
59 persistence: Option<Arc<dyn SessionPersistence>>,
60 session_id: Option<String>,
61 tool_advisory: bool,
62 telemetry_interval: Option<std::time::Duration>,
63 middleware: Vec<Arc<dyn crate::middleware::Middleware>>,
64 flow: Option<crate::flow::FlowMonitor>,
65}
66
67impl LiveSessionBuilder {
68 pub fn new(config: SessionConfig) -> Self {
70 Self {
71 config,
72 callbacks: EventCallbacks::default(),
73 dispatcher: None,
74 extractors: Vec::new(),
75 computed: None,
76 phase_machine: None,
77 watchers: None,
78 temporal: None,
79 greeting: None,
80 state: None,
81 execution_modes: HashMap::new(),
82 soft_turn_timeout: None,
83 steering_mode: SteeringMode::default(),
84 context_delivery: ContextDelivery::default(),
85 delivery: super::processor::DeliveryConfig::default(),
86 repair_config: None,
87 persistence: None,
88 session_id: None,
89 tool_advisory: true,
90 telemetry_interval: None,
91 middleware: Vec::new(),
92 flow: None,
93 }
94 }
95
96 pub fn middleware(mut self, layer: Arc<dyn crate::middleware::Middleware>) -> Self {
102 self.middleware.push(layer);
103 self
104 }
105
106 pub fn flow_monitor(mut self, monitor: crate::flow::FlowMonitor) -> Self {
108 self.flow = Some(monitor);
109 self
110 }
111
112 pub fn with_state(mut self, state: State) -> Self {
118 self.state = Some(state);
119 self
120 }
121
122 pub fn greeting(mut self, prompt: impl Into<String>) -> Self {
124 self.greeting = Some(prompt.into());
125 self
126 }
127
128 pub fn dispatcher(mut self, dispatcher: ToolDispatcher) -> Self {
130 for tool in dispatcher.to_tool_declarations() {
132 self.config = self.config.add_tool(tool);
133 }
134 self.dispatcher = Some(Arc::new(dispatcher));
135 self
136 }
137
138 pub fn callbacks(mut self, callbacks: EventCallbacks) -> Self {
140 self.callbacks = callbacks;
141 self
142 }
143
144 pub fn extractor(mut self, extractor: Arc<dyn TurnExtractor>) -> Self {
146 self.extractors.push(extractor);
147 self
148 }
149
150 pub fn computed(mut self, registry: ComputedRegistry) -> Self {
152 self.computed = Some(registry);
153 self
154 }
155
156 pub fn phase_machine(mut self, machine: PhaseMachine) -> Self {
158 self.phase_machine = Some(machine);
159 self
160 }
161
162 pub fn watchers(mut self, registry: WatcherRegistry) -> Self {
164 self.watchers = Some(registry);
165 self
166 }
167
168 pub fn temporal(mut self, registry: TemporalRegistry) -> Self {
170 self.temporal = Some(registry);
171 self
172 }
173
174 pub fn tool_execution_mode(
179 mut self,
180 tool_name: impl Into<String>,
181 mode: ToolExecutionMode,
182 ) -> Self {
183 self.execution_modes.insert(tool_name.into(), mode);
184 self
185 }
186
187 pub fn soft_turn_timeout(mut self, timeout: std::time::Duration) -> Self {
194 self.soft_turn_timeout = Some(timeout);
195 self
196 }
197
198 pub fn steering_mode(mut self, mode: SteeringMode) -> Self {
200 self.steering_mode = mode;
201 self
202 }
203
204 pub fn context_delivery(mut self, mode: ContextDelivery) -> Self {
209 self.context_delivery = mode;
210 self
211 }
212
213 pub fn delivery(mut self, delivery: super::processor::DeliveryConfig) -> Self {
220 self.delivery = delivery;
221 self
222 }
223
224 pub fn repair(mut self, config: RepairConfig) -> Self {
229 self.repair_config = Some(config);
230 self
231 }
232
233 pub fn persistence(mut self, backend: Arc<dyn SessionPersistence>) -> Self {
235 self.persistence = Some(backend);
236 self
237 }
238
239 pub fn session_id(mut self, id: impl Into<String>) -> Self {
241 self.session_id = Some(id.into());
242 self
243 }
244
245 pub fn tool_advisory(mut self, enabled: bool) -> Self {
247 self.tool_advisory = enabled;
248 self
249 }
250
251 pub fn telemetry_interval(mut self, interval: std::time::Duration) -> Self {
256 self.telemetry_interval = Some(interval);
257 self
258 }
259
260 pub async fn connect(self) -> Result<LiveHandle, AgentError> {
273 let mut plan = self.into_plan()?;
274
275 let config = plan.config.take().expect("plan always carries a config");
278 let session = ConnectBuilder::new(config)
279 .build()
280 .await
281 .map_err(AgentError::Session)?;
282
283 session.wait_for_phase(SessionPhase::Active).await;
285
286 let runtime = build_runtime(plan, session);
287 spawn_lanes(runtime).await
288 }
289
290 pub(crate) fn into_plan(self) -> Result<SessionPlan, AgentError> {
297 if let Some(ref pm) = self.phase_machine {
299 pm.validate().map_err(AgentError::Config)?;
300 }
301 if let Some(ref computed) = self.computed {
302 computed.validate().map_err(AgentError::Config)?;
303 }
304
305 let mut config = self.config;
307 for (tool_name, mode) in &self.execution_modes {
308 if matches!(
309 mode,
310 super::background_tool::ToolExecutionMode::Background { .. }
311 ) {
312 for tool in &mut config.tools {
313 if let Some(ref mut decls) = tool.function_declarations {
314 for decl in decls {
315 if decl.name == *tool_name {
316 decl.behavior = Some(
317 gemini_genai_rs::prelude::FunctionCallingBehavior::NonBlocking,
318 );
319 }
320 }
321 }
322 }
323 }
324 }
325
326 Ok(SessionPlan {
327 config: Some(config),
328 callbacks: self.callbacks,
329 dispatcher: self.dispatcher,
330 extractors: self.extractors,
331 computed: self.computed,
332 phase_machine: self.phase_machine,
333 watchers: self.watchers,
334 temporal: self.temporal,
335 greeting: self.greeting,
336 state: self.state,
337 execution_modes: self.execution_modes,
338 soft_turn_timeout: self.soft_turn_timeout,
339 steering_mode: self.steering_mode,
340 context_delivery: self.context_delivery,
341 delivery: self.delivery,
342 repair_config: self.repair_config,
343 persistence: self.persistence,
344 session_id: self.session_id,
345 tool_advisory: self.tool_advisory,
346 telemetry_interval: self.telemetry_interval,
347 middleware: self.middleware,
348 flow: self.flow,
349 })
350 }
351}
352
353pub(crate) struct SessionPlan {
361 config: Option<SessionConfig>,
364 callbacks: EventCallbacks,
365 dispatcher: Option<Arc<ToolDispatcher>>,
366 extractors: Vec<Arc<dyn TurnExtractor>>,
367 computed: Option<ComputedRegistry>,
368 phase_machine: Option<PhaseMachine>,
369 watchers: Option<WatcherRegistry>,
370 temporal: Option<TemporalRegistry>,
371 greeting: Option<String>,
372 state: Option<State>,
373 execution_modes: HashMap<String, ToolExecutionMode>,
374 soft_turn_timeout: Option<std::time::Duration>,
375 steering_mode: SteeringMode,
376 context_delivery: ContextDelivery,
377 delivery: super::processor::DeliveryConfig,
378 repair_config: Option<RepairConfig>,
379 persistence: Option<Arc<dyn SessionPersistence>>,
380 session_id: Option<String>,
381 tool_advisory: bool,
382 telemetry_interval: Option<std::time::Duration>,
383 middleware: Vec<Arc<dyn crate::middleware::Middleware>>,
384 flow: Option<crate::flow::FlowMonitor>,
385}
386
387pub(crate) struct SessionRuntime {
395 session: SessionHandle,
396 callbacks: Arc<EventCallbacks>,
397 dispatcher: Option<Arc<ToolDispatcher>>,
398 extractors: Vec<Arc<dyn TurnExtractor>>,
399 computed: Option<ComputedRegistry>,
400 phase_machine: Option<tokio::sync::Mutex<PhaseMachine>>,
401 watchers: Option<WatcherRegistry>,
402 temporal: Option<Arc<TemporalRegistry>>,
403 greeting: Option<String>,
404 state: State,
405 execution_modes: HashMap<String, ToolExecutionMode>,
406 background_tracker: Arc<BackgroundToolTracker>,
407 telemetry: Arc<SessionTelemetry>,
408 telemetry_interval: Option<std::time::Duration>,
409 control_plane: ControlPlaneConfig,
410 pending_context: Option<Arc<PendingContext>>,
411 writer: Arc<dyn SessionWriter>,
413 user_writer: Arc<dyn SessionWriter>,
415 event_rx: tokio::sync::broadcast::Receiver<gemini_genai_rs::prelude::SessionEvent>,
416 telem_rx: tokio::sync::broadcast::Receiver<gemini_genai_rs::prelude::SessionEvent>,
417 on_usage_cb: Option<super::callbacks::UsageCallback>,
418 live_event_tx: tokio::sync::broadcast::Sender<super::events::LiveEvent>,
419 telem_cancel: CancellationToken,
420 flow_monitor: Option<crate::flow::SharedFlowMonitor>,
421}
422
423pub(crate) fn build_runtime(plan: SessionPlan, session: SessionHandle) -> SessionRuntime {
428 let flow_monitor = plan.flow.map(crate::flow::FlowMonitor::into_shared);
431 let mut callbacks = plan.callbacks;
432 let on_usage_cb = callbacks.on_usage.take();
433 let callbacks = Arc::new(callbacks);
434 let raw_writer: Arc<dyn SessionWriter> = Arc::new(session.clone());
435 let state = plan.state.unwrap_or_default();
436
437 let event_rx = session.subscribe();
439 let telem_rx = session.subscribe();
440
441 if let Some(ref pm) = plan.phase_machine {
443 let _ = state.session().set("phase", pm.current());
444 if let Some(phase) = pm.current_phase() {
445 if !phase.needs.is_empty() {
446 let _ = state.set("session:phase_needs", phase.needs.clone());
447 }
448 }
449 }
450
451 let phase_machine_mutex = plan.phase_machine.map(tokio::sync::Mutex::new);
452 let temporal_arc = plan.temporal.map(Arc::new);
453 let background_tracker = Arc::new(BackgroundToolTracker::new());
454
455 let telemetry = Arc::new(SessionTelemetry::new());
457 let telem_cancel = CancellationToken::new();
458
459 let mut control_plane = ControlPlaneConfig {
461 soft_turn: plan.soft_turn_timeout.map(SoftTurnDetector::new),
462 steering_mode: plan.steering_mode,
463 context_delivery: plan.context_delivery,
464 delivery: plan.delivery,
465 needs_fulfillment: plan.repair_config.map(NeedsFulfillment::new),
466 persistence: plan.persistence,
467 session_id: plan.session_id,
468 tool_advisory: plan.tool_advisory,
469 pending_context: None, middleware: {
471 let mut chain = crate::middleware::MiddlewareChain::new();
472 for layer in plan.middleware {
473 chain.add(layer);
474 }
475 Arc::new(chain)
476 },
477 flow: flow_monitor.clone(),
478 };
479
480 let pending_context = if plan.context_delivery == ContextDelivery::Deferred {
485 Some(Arc::new(PendingContext::new()))
486 } else {
487 None
488 };
489
490 let (writer, user_writer) = if let Some(ref pending) = pending_context {
492 let deferred: Arc<dyn SessionWriter> =
493 Arc::new(DeferredWriter::new(raw_writer.clone(), pending.clone()));
494 (raw_writer, deferred)
498 } else {
499 (raw_writer.clone(), raw_writer)
500 };
501
502 control_plane.pending_context = pending_context.clone();
504
505 use super::events::LiveEvent;
507 use tokio::sync::broadcast;
508 let (live_event_tx, _) = broadcast::channel::<LiveEvent>(4096);
509
510 SessionRuntime {
511 session,
512 callbacks,
513 dispatcher: plan.dispatcher,
514 extractors: plan.extractors,
515 computed: plan.computed,
516 phase_machine: phase_machine_mutex,
517 watchers: plan.watchers,
518 temporal: temporal_arc,
519 greeting: plan.greeting,
520 state,
521 execution_modes: plan.execution_modes,
522 background_tracker,
523 telemetry,
524 telemetry_interval: plan.telemetry_interval,
525 control_plane,
526 pending_context,
527 writer,
528 user_writer,
529 event_rx,
530 telem_rx,
531 on_usage_cb,
532 live_event_tx,
533 telem_cancel,
534 flow_monitor,
535 }
536}
537
538pub(crate) async fn spawn_lanes(rt: SessionRuntime) -> Result<LiveHandle, AgentError> {
541 use super::events::LiveEvent;
542
543 let session_signals = SessionSignals::new(rt.state.clone());
545 let _telem_handle = spawn_telemetry_lane(
546 rt.telem_rx,
547 session_signals,
548 rt.telemetry.clone(),
549 rt.telem_cancel.clone(),
550 rt.on_usage_cb,
551 );
552
553 let greeting_writer = rt.user_writer.clone();
555 let (fast_handle, ctrl_handle) = spawn_event_processor(
556 rt.event_rx,
557 rt.callbacks,
558 rt.dispatcher,
559 rt.writer,
560 rt.extractors,
561 rt.state.clone(),
562 rt.computed,
563 rt.phase_machine,
564 rt.watchers,
565 rt.temporal,
566 Some(rt.background_tracker.clone()),
567 rt.execution_modes,
568 rt.control_plane,
569 rt.live_event_tx.clone(),
570 );
571
572 if let Some(interval) = rt.telemetry_interval {
574 let telem_tx = rt.live_event_tx.clone();
575 let telem_ref = rt.telemetry.clone();
576 tokio::spawn(async move {
577 let mut tick = tokio::time::interval(interval);
578 let mut prev_turns = 0u64;
579 loop {
580 tick.tick().await;
581 let snap = telem_ref.snapshot();
582 if let Some(obj) = snap.as_object() {
583 let tc = obj
584 .get("turn_count")
585 .or_else(|| obj.get("response_count"))
586 .and_then(|v| v.as_u64())
587 .unwrap_or(0);
588 if tc > prev_turns {
589 let latency = obj
590 .get("last_response_latency_ms")
591 .and_then(|v| v.as_u64())
592 .unwrap_or(0) as u32;
593 let prompt = obj
594 .get("prompt_token_count")
595 .and_then(|v| v.as_u64())
596 .unwrap_or(0) as u32;
597 let response = obj
598 .get("response_token_count")
599 .and_then(|v| v.as_u64())
600 .unwrap_or(0) as u32;
601 let _ = telem_tx.send(LiveEvent::TurnMetrics {
602 turn: tc as u32,
603 latency_ms: latency,
604 prompt_tokens: prompt,
605 response_tokens: response,
606 });
607 prev_turns = tc;
608 }
609 }
610 if telem_tx.send(LiveEvent::Telemetry(snap)).is_err() {
611 break;
612 }
613 }
614 });
615 }
616
617 if let Some(greeting) = rt.greeting {
619 greeting_writer
620 .send_text(greeting)
621 .await
622 .map_err(AgentError::Session)?;
623 }
624
625 Ok(LiveHandle::new(
626 rt.session,
627 rt.user_writer,
628 fast_handle,
629 ctrl_handle,
630 rt.state,
631 rt.telemetry,
632 rt.live_event_tx,
633 rt.pending_context,
634 rt.flow_monitor,
635 rt.background_tracker,
636 rt.telem_cancel,
637 ))
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
645 fn builder_creates_with_defaults() {
646 let config = SessionConfig::new("test-key");
647 let builder = LiveSessionBuilder::new(config);
648 assert!(builder.dispatcher.is_none());
649 assert!(builder.computed.is_none());
650 assert!(builder.phase_machine.is_none());
651 assert!(builder.watchers.is_none());
652 assert!(builder.temporal.is_none());
653 }
654
655 #[test]
656 fn into_plan_derives_defaults() {
657 let config = SessionConfig::new("test-key");
658 let plan = LiveSessionBuilder::new(config)
659 .into_plan()
660 .expect("default builder should produce a plan");
661
662 assert!(plan.config.is_some());
664 assert!(plan.dispatcher.is_none());
666 assert!(plan.phase_machine.is_none());
667 assert!(plan.persistence.is_none());
668 assert!(plan.session_id.is_none());
669 assert!(plan.greeting.is_none());
670 assert!(plan.soft_turn_timeout.is_none());
671 assert!(plan.telemetry_interval.is_none());
672 assert!(plan.repair_config.is_none());
673 assert!(plan.flow.is_none());
674 assert!(plan.execution_modes.is_empty());
675 assert!(plan.middleware.is_empty());
676 assert_eq!(plan.steering_mode, SteeringMode::default());
677 assert_eq!(plan.context_delivery, ContextDelivery::default());
678 assert!(plan.tool_advisory);
680 }
681
682 #[test]
683 fn into_plan_carries_persistence_and_session_id() {
684 let config = SessionConfig::new("test-key");
685 let plan = LiveSessionBuilder::new(config)
686 .session_id("user-123-session-456")
687 .into_plan()
688 .expect("plan derivation should succeed");
689
690 assert_eq!(plan.session_id.as_deref(), Some("user-123-session-456"));
691 }
692
693 #[test]
694 fn into_plan_carries_steering_and_context_delivery() {
695 let config = SessionConfig::new("test-key");
696 let plan = LiveSessionBuilder::new(config)
697 .steering_mode(SteeringMode::ContextInjection)
698 .context_delivery(ContextDelivery::Deferred)
699 .tool_advisory(false)
700 .into_plan()
701 .expect("plan derivation should succeed");
702
703 assert_eq!(plan.steering_mode, SteeringMode::ContextInjection);
704 assert_eq!(plan.context_delivery, ContextDelivery::Deferred);
705 assert!(!plan.tool_advisory);
706 }
707
708 #[test]
709 fn into_plan_carries_greeting_and_telemetry_interval() {
710 let config = SessionConfig::new("test-key");
711 let plan = LiveSessionBuilder::new(config)
712 .greeting("Hello there")
713 .telemetry_interval(std::time::Duration::from_secs(5))
714 .soft_turn_timeout(std::time::Duration::from_secs(2))
715 .into_plan()
716 .expect("plan derivation should succeed");
717
718 assert_eq!(plan.greeting.as_deref(), Some("Hello there"));
719 assert_eq!(
720 plan.telemetry_interval,
721 Some(std::time::Duration::from_secs(5))
722 );
723 assert_eq!(
724 plan.soft_turn_timeout,
725 Some(std::time::Duration::from_secs(2))
726 );
727 }
728
729 #[test]
730 fn into_plan_validates_phase_machine() {
731 let config = SessionConfig::new("test-key");
734 let pm = PhaseMachine::new("nonexistent");
735 let result = LiveSessionBuilder::new(config)
736 .phase_machine(pm)
737 .into_plan();
738 assert!(result.is_err(), "invalid phase machine should fail to plan");
739 }
740
741 #[test]
742 fn into_plan_carries_valid_phase_machine_and_seeds_nothing() {
743 let config = SessionConfig::new("test-key");
746 let mut pm = PhaseMachine::new("start");
747 pm.add_phase(crate::live::phase::Phase::new("start", "Start phase"));
748 let plan = LiveSessionBuilder::new(config)
749 .phase_machine(pm)
750 .into_plan()
751 .expect("valid phase machine should plan");
752
753 assert!(plan.phase_machine.is_some());
754 }
755
756 #[test]
757 fn into_plan_applies_non_blocking_to_background_tools() {
758 use gemini_genai_rs::prelude::{FunctionCallingBehavior, FunctionDeclaration, Tool};
759
760 let decl = FunctionDeclaration {
761 name: "search_kb".into(),
762 description: "Search".into(),
763 parameters: None,
764 behavior: None,
765 };
766 let config = SessionConfig::new("test-key").add_tool(Tool::functions(vec![decl]));
767
768 let plan = LiveSessionBuilder::new(config)
769 .tool_execution_mode(
770 "search_kb",
771 ToolExecutionMode::Background {
772 formatter: None,
773 scheduling: None,
774 },
775 )
776 .into_plan()
777 .expect("plan derivation should succeed");
778
779 let cfg = plan.config.expect("config carried");
780 let decl = cfg.tools[0]
781 .function_declarations
782 .as_ref()
783 .unwrap()
784 .iter()
785 .find(|d| d.name == "search_kb")
786 .unwrap();
787 assert_eq!(decl.behavior, Some(FunctionCallingBehavior::NonBlocking));
788 }
789}