1use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio_util::sync::CancellationToken;
7
8use gemini_genai_rs::prelude::{ConnectBuilder, SessionConfig, SessionPhase};
9use gemini_genai_rs::session::SessionWriter;
10
11use crate::error::AgentError;
12use crate::state::State;
13use crate::tool::ToolDispatcher;
14
15use super::background_tool::{BackgroundToolTracker, ToolExecutionMode};
16use super::callbacks::EventCallbacks;
17use super::computed::ComputedRegistry;
18use super::context_writer::{DeferredWriter, PendingContext};
19use super::extractor::TurnExtractor;
20use super::handle::LiveHandle;
21use super::needs::{NeedsFulfillment, RepairConfig};
22use super::persistence::SessionPersistence;
23use super::phase::PhaseMachine;
24use super::processor::{spawn_event_processor, spawn_telemetry_lane, ControlPlaneConfig};
25use super::session_signals::SessionSignals;
26use super::soft_turn::SoftTurnDetector;
27use super::steering::{ContextDelivery, SteeringMode};
28use super::telemetry::SessionTelemetry;
29use super::temporal::TemporalRegistry;
30use super::watcher::WatcherRegistry;
31
32pub struct LiveSessionBuilder {
42 config: SessionConfig,
43 callbacks: EventCallbacks,
44 dispatcher: Option<Arc<ToolDispatcher>>,
45 extractors: Vec<Arc<dyn TurnExtractor>>,
46 computed: Option<ComputedRegistry>,
47 phase_machine: Option<PhaseMachine>,
48 watchers: Option<WatcherRegistry>,
49 temporal: Option<TemporalRegistry>,
50 greeting: Option<String>,
51 state: Option<State>,
52 execution_modes: HashMap<String, ToolExecutionMode>,
53 soft_turn_timeout: Option<std::time::Duration>,
55 steering_mode: SteeringMode,
56 context_delivery: ContextDelivery,
57 repair_config: Option<RepairConfig>,
58 persistence: Option<Arc<dyn SessionPersistence>>,
59 session_id: Option<String>,
60 tool_advisory: bool,
61 telemetry_interval: Option<std::time::Duration>,
62}
63
64impl LiveSessionBuilder {
65 pub fn new(config: SessionConfig) -> Self {
67 Self {
68 config,
69 callbacks: EventCallbacks::default(),
70 dispatcher: None,
71 extractors: Vec::new(),
72 computed: None,
73 phase_machine: None,
74 watchers: None,
75 temporal: None,
76 greeting: None,
77 state: None,
78 execution_modes: HashMap::new(),
79 soft_turn_timeout: None,
80 steering_mode: SteeringMode::default(),
81 context_delivery: ContextDelivery::default(),
82 repair_config: None,
83 persistence: None,
84 session_id: None,
85 tool_advisory: true,
86 telemetry_interval: None,
87 }
88 }
89
90 pub fn with_state(mut self, state: State) -> Self {
96 self.state = Some(state);
97 self
98 }
99
100 pub fn greeting(mut self, prompt: impl Into<String>) -> Self {
102 self.greeting = Some(prompt.into());
103 self
104 }
105
106 pub fn dispatcher(mut self, dispatcher: ToolDispatcher) -> Self {
108 for tool in dispatcher.to_tool_declarations() {
110 self.config = self.config.add_tool(tool);
111 }
112 self.dispatcher = Some(Arc::new(dispatcher));
113 self
114 }
115
116 pub fn callbacks(mut self, callbacks: EventCallbacks) -> Self {
118 self.callbacks = callbacks;
119 self
120 }
121
122 pub fn extractor(mut self, extractor: Arc<dyn TurnExtractor>) -> Self {
124 self.extractors.push(extractor);
125 self
126 }
127
128 pub fn computed(mut self, registry: ComputedRegistry) -> Self {
130 self.computed = Some(registry);
131 self
132 }
133
134 pub fn phase_machine(mut self, machine: PhaseMachine) -> Self {
136 self.phase_machine = Some(machine);
137 self
138 }
139
140 pub fn watchers(mut self, registry: WatcherRegistry) -> Self {
142 self.watchers = Some(registry);
143 self
144 }
145
146 pub fn temporal(mut self, registry: TemporalRegistry) -> Self {
148 self.temporal = Some(registry);
149 self
150 }
151
152 pub fn tool_execution_mode(
157 mut self,
158 tool_name: impl Into<String>,
159 mode: ToolExecutionMode,
160 ) -> Self {
161 self.execution_modes.insert(tool_name.into(), mode);
162 self
163 }
164
165 pub fn soft_turn_timeout(mut self, timeout: std::time::Duration) -> Self {
172 self.soft_turn_timeout = Some(timeout);
173 self
174 }
175
176 pub fn steering_mode(mut self, mode: SteeringMode) -> Self {
178 self.steering_mode = mode;
179 self
180 }
181
182 pub fn context_delivery(mut self, mode: ContextDelivery) -> Self {
187 self.context_delivery = mode;
188 self
189 }
190
191 pub fn repair(mut self, config: RepairConfig) -> Self {
196 self.repair_config = Some(config);
197 self
198 }
199
200 pub fn persistence(mut self, backend: Arc<dyn SessionPersistence>) -> Self {
202 self.persistence = Some(backend);
203 self
204 }
205
206 pub fn session_id(mut self, id: impl Into<String>) -> Self {
208 self.session_id = Some(id.into());
209 self
210 }
211
212 pub fn tool_advisory(mut self, enabled: bool) -> Self {
214 self.tool_advisory = enabled;
215 self
216 }
217
218 pub fn telemetry_interval(mut self, interval: std::time::Duration) -> Self {
223 self.telemetry_interval = Some(interval);
224 self
225 }
226
227 pub async fn connect(self) -> Result<LiveHandle, AgentError> {
229 if let Some(ref pm) = self.phase_machine {
231 pm.validate().map_err(AgentError::Config)?;
232 }
233 if let Some(ref computed) = self.computed {
234 computed.validate().map_err(AgentError::Config)?;
235 }
236
237 let mut config = self.config;
239 for (tool_name, mode) in &self.execution_modes {
240 if matches!(
241 mode,
242 super::background_tool::ToolExecutionMode::Background { .. }
243 ) {
244 for tool in &mut config.tools {
245 if let Some(ref mut decls) = tool.function_declarations {
246 for decl in decls {
247 if decl.name == *tool_name {
248 decl.behavior = Some(
249 gemini_genai_rs::prelude::FunctionCallingBehavior::NonBlocking,
250 );
251 }
252 }
253 }
254 }
255 }
256 }
257
258 let session = ConnectBuilder::new(config)
260 .build()
261 .await
262 .map_err(AgentError::Session)?;
263
264 session.wait_for_phase(SessionPhase::Active).await;
266
267 let mut callbacks = self.callbacks;
268 let on_usage_cb = callbacks.on_usage.take();
269 let callbacks = Arc::new(callbacks);
270 let raw_writer: Arc<dyn SessionWriter> = Arc::new(session.clone());
271 let state = self.state.unwrap_or_default();
272
273 let event_rx = session.subscribe();
275 let telem_rx = session.subscribe();
276
277 if let Some(ref pm) = self.phase_machine {
279 state.session().set("phase", pm.current());
280 if let Some(phase) = pm.current_phase() {
281 if !phase.needs.is_empty() {
282 state.set("session:phase_needs", phase.needs.clone());
283 }
284 }
285 }
286
287 let phase_machine_mutex = self.phase_machine.map(tokio::sync::Mutex::new);
288 let temporal_arc = self.temporal.map(Arc::new);
289 let background_tracker = Arc::new(BackgroundToolTracker::new());
290
291 let telemetry = Arc::new(SessionTelemetry::new());
293 let telem_cancel = CancellationToken::new();
294
295 let session_signals = SessionSignals::new(state.clone());
297 let _telem_handle = spawn_telemetry_lane(
298 telem_rx,
299 session_signals,
300 telemetry.clone(),
301 telem_cancel.clone(),
302 on_usage_cb,
303 );
304
305 let mut control_plane = ControlPlaneConfig {
307 soft_turn: self.soft_turn_timeout.map(SoftTurnDetector::new),
308 steering_mode: self.steering_mode,
309 context_delivery: self.context_delivery,
310 needs_fulfillment: self.repair_config.map(NeedsFulfillment::new),
311 persistence: self.persistence,
312 session_id: self.session_id,
313 tool_advisory: self.tool_advisory,
314 pending_context: None, };
316
317 let pending_context = if self.context_delivery == ContextDelivery::Deferred {
322 Some(Arc::new(PendingContext::new()))
323 } else {
324 None
325 };
326
327 let (writer, user_writer) = if let Some(ref pending) = pending_context {
329 let deferred: Arc<dyn SessionWriter> =
330 Arc::new(DeferredWriter::new(raw_writer.clone(), pending.clone()));
331 (raw_writer, deferred)
335 } else {
336 (raw_writer.clone(), raw_writer)
337 };
338
339 control_plane.pending_context = pending_context.clone();
341
342 use super::events::LiveEvent;
344 use tokio::sync::broadcast;
345 let (live_event_tx, _) = broadcast::channel::<LiveEvent>(4096);
346
347 let greeting_writer = user_writer.clone();
349 let (fast_handle, ctrl_handle) = spawn_event_processor(
350 event_rx,
351 callbacks,
352 self.dispatcher,
353 writer,
354 self.extractors,
355 state.clone(),
356 self.computed,
357 phase_machine_mutex,
358 self.watchers,
359 temporal_arc,
360 Some(background_tracker),
361 self.execution_modes,
362 control_plane,
363 live_event_tx.clone(),
364 );
365
366 if let Some(interval) = self.telemetry_interval {
368 let telem_tx = live_event_tx.clone();
369 let telem_ref = telemetry.clone();
370 tokio::spawn(async move {
371 let mut tick = tokio::time::interval(interval);
372 let mut prev_turns = 0u64;
373 loop {
374 tick.tick().await;
375 let snap = telem_ref.snapshot();
376 if let Some(obj) = snap.as_object() {
377 let tc = obj
378 .get("turn_count")
379 .or_else(|| obj.get("response_count"))
380 .and_then(|v| v.as_u64())
381 .unwrap_or(0);
382 if tc > prev_turns {
383 let latency = obj
384 .get("last_response_latency_ms")
385 .and_then(|v| v.as_u64())
386 .unwrap_or(0) as u32;
387 let prompt = obj
388 .get("prompt_token_count")
389 .and_then(|v| v.as_u64())
390 .unwrap_or(0) as u32;
391 let response = obj
392 .get("response_token_count")
393 .and_then(|v| v.as_u64())
394 .unwrap_or(0) as u32;
395 let _ = telem_tx.send(LiveEvent::TurnMetrics {
396 turn: tc as u32,
397 latency_ms: latency,
398 prompt_tokens: prompt,
399 response_tokens: response,
400 });
401 prev_turns = tc;
402 }
403 }
404 if telem_tx.send(LiveEvent::Telemetry(snap)).is_err() {
405 break;
406 }
407 }
408 });
409 }
410
411 if let Some(greeting) = self.greeting {
413 greeting_writer
414 .send_text(greeting)
415 .await
416 .map_err(AgentError::Session)?;
417 }
418
419 Ok(LiveHandle::new(
420 session,
421 user_writer,
422 fast_handle,
423 ctrl_handle,
424 state,
425 telemetry,
426 live_event_tx,
427 pending_context,
428 ))
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 #[test]
437 fn builder_creates_with_defaults() {
438 let config = SessionConfig::new("test-key");
439 let builder = LiveSessionBuilder::new(config);
440 assert!(builder.dispatcher.is_none());
441 assert!(builder.computed.is_none());
442 assert!(builder.phase_machine.is_none());
443 assert!(builder.watchers.is_none());
444 assert!(builder.temporal.is_none());
445 }
446}