gemini_adk_rs/live/
builder.rs

1//! LiveSessionBuilder — combines SessionConfig + callbacks + tools into one setup.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio_util::sync::CancellationToken;
7
8use gemini_genai_rs::prelude::{ConnectBuilder, SessionConfig, SessionPhase};
9use gemini_genai_rs::session::SessionWriter;
10
11use crate::error::AgentError;
12use crate::state::State;
13use crate::tool::ToolDispatcher;
14
15use super::background_tool::{BackgroundToolTracker, ToolExecutionMode};
16use super::callbacks::EventCallbacks;
17use super::computed::ComputedRegistry;
18use super::context_writer::{DeferredWriter, PendingContext};
19use super::extractor::TurnExtractor;
20use super::handle::LiveHandle;
21use super::needs::{NeedsFulfillment, RepairConfig};
22use super::persistence::SessionPersistence;
23use super::phase::PhaseMachine;
24use super::processor::{spawn_event_processor, spawn_telemetry_lane, ControlPlaneConfig};
25use super::session_signals::SessionSignals;
26use super::soft_turn::SoftTurnDetector;
27use super::steering::{ContextDelivery, SteeringMode};
28use super::telemetry::SessionTelemetry;
29use super::temporal::TemporalRegistry;
30use super::watcher::WatcherRegistry;
31
32/// Builder for a callback-driven Live session.
33///
34/// Combines [`SessionConfig`], [`EventCallbacks`], tool dispatching, extractors,
35/// computed state, phase machines, watchers, and temporal patterns into a
36/// single connection setup. Call [`connect()`](Self::connect) to establish
37/// the WebSocket connection and start the three-lane event processor.
38///
39/// For ergonomic usage, prefer the L2 `Live` builder from `gemini-adk-fluent-rs`
40/// which wraps this with a fluent API.
41pub struct LiveSessionBuilder {
42    config: SessionConfig,
43    callbacks: EventCallbacks,
44    dispatcher: Option<Arc<ToolDispatcher>>,
45    extractors: Vec<Arc<dyn TurnExtractor>>,
46    computed: Option<ComputedRegistry>,
47    phase_machine: Option<PhaseMachine>,
48    watchers: Option<WatcherRegistry>,
49    temporal: Option<TemporalRegistry>,
50    greeting: Option<String>,
51    state: Option<State>,
52    execution_modes: HashMap<String, ToolExecutionMode>,
53    // Control plane configuration
54    soft_turn_timeout: Option<std::time::Duration>,
55    steering_mode: SteeringMode,
56    context_delivery: ContextDelivery,
57    repair_config: Option<RepairConfig>,
58    persistence: Option<Arc<dyn SessionPersistence>>,
59    session_id: Option<String>,
60    tool_advisory: bool,
61    telemetry_interval: Option<std::time::Duration>,
62}
63
64impl LiveSessionBuilder {
65    /// Create a new builder with the given session config.
66    pub fn new(config: SessionConfig) -> Self {
67        Self {
68            config,
69            callbacks: EventCallbacks::default(),
70            dispatcher: None,
71            extractors: Vec::new(),
72            computed: None,
73            phase_machine: None,
74            watchers: None,
75            temporal: None,
76            greeting: None,
77            state: None,
78            execution_modes: HashMap::new(),
79            soft_turn_timeout: None,
80            steering_mode: SteeringMode::default(),
81            context_delivery: ContextDelivery::default(),
82            repair_config: None,
83            persistence: None,
84            session_id: None,
85            tool_advisory: true,
86            telemetry_interval: None,
87        }
88    }
89
90    /// Provide a pre-created State to use for this session.
91    ///
92    /// If not set, a new State is created at connect time. Use this when
93    /// the State needs to be shared with tools or other components before
94    /// the session connects.
95    pub fn with_state(mut self, state: State) -> Self {
96        self.state = Some(state);
97        self
98    }
99
100    /// Set a greeting prompt sent on connect to trigger the model to speak first.
101    pub fn greeting(mut self, prompt: impl Into<String>) -> Self {
102        self.greeting = Some(prompt.into());
103        self
104    }
105
106    /// Set the tool dispatcher for auto-dispatch of tool calls.
107    pub fn dispatcher(mut self, dispatcher: ToolDispatcher) -> Self {
108        // Add tool declarations to session config
109        for tool in dispatcher.to_tool_declarations() {
110            self.config = self.config.add_tool(tool);
111        }
112        self.dispatcher = Some(Arc::new(dispatcher));
113        self
114    }
115
116    /// Set the event callbacks.
117    pub fn callbacks(mut self, callbacks: EventCallbacks) -> Self {
118        self.callbacks = callbacks;
119        self
120    }
121
122    /// Add a turn extractor that runs between turns.
123    pub fn extractor(mut self, extractor: Arc<dyn TurnExtractor>) -> Self {
124        self.extractors.push(extractor);
125        self
126    }
127
128    /// Set the computed variable registry for derived state.
129    pub fn computed(mut self, registry: ComputedRegistry) -> Self {
130        self.computed = Some(registry);
131        self
132    }
133
134    /// Set the phase machine for declarative conversation phase management.
135    pub fn phase_machine(mut self, machine: PhaseMachine) -> Self {
136        self.phase_machine = Some(machine);
137        self
138    }
139
140    /// Set the watcher registry for state change watchers.
141    pub fn watchers(mut self, registry: WatcherRegistry) -> Self {
142        self.watchers = Some(registry);
143        self
144    }
145
146    /// Set the temporal pattern registry.
147    pub fn temporal(mut self, registry: TemporalRegistry) -> Self {
148        self.temporal = Some(registry);
149        self
150    }
151
152    /// Set the execution mode for a named tool.
153    ///
154    /// Tools default to [`ToolExecutionMode::Standard`]. Set to
155    /// [`ToolExecutionMode::Background`] for zero-dead-air execution.
156    pub fn tool_execution_mode(
157        mut self,
158        tool_name: impl Into<String>,
159        mode: ToolExecutionMode,
160    ) -> Self {
161        self.execution_modes.insert(tool_name.into(), mode);
162        self
163    }
164
165    /// Enable soft turn detection for proactive silence awareness.
166    ///
167    /// When `proactiveAudio` is enabled, the model may choose not to respond.
168    /// This sets a timeout after VAD end — if the model stays silent, a
169    /// lightweight "soft turn" fires to keep state updated without forcing
170    /// the model to speak.
171    pub fn soft_turn_timeout(mut self, timeout: std::time::Duration) -> Self {
172        self.soft_turn_timeout = Some(timeout);
173        self
174    }
175
176    /// Set the steering mode for how the phase machine delivers instructions.
177    pub fn steering_mode(mut self, mode: SteeringMode) -> Self {
178        self.steering_mode = mode;
179        self
180    }
181
182    /// Set the context delivery timing.
183    ///
184    /// - `Immediate` (default): send batched context during TurnComplete.
185    /// - `Deferred`: queue context and flush with next user send.
186    pub fn context_delivery(mut self, mode: ContextDelivery) -> Self {
187        self.context_delivery = mode;
188        self
189    }
190
191    /// Enable the conversation repair protocol.
192    ///
193    /// Tracks need fulfillment per phase and nudges the model when the
194    /// conversation stalls on gathering required information.
195    pub fn repair(mut self, config: RepairConfig) -> Self {
196        self.repair_config = Some(config);
197        self
198    }
199
200    /// Set a session persistence backend for surviving process restarts.
201    pub fn persistence(mut self, backend: Arc<dyn SessionPersistence>) -> Self {
202        self.persistence = Some(backend);
203        self
204    }
205
206    /// Set the session ID for persistence.
207    pub fn session_id(mut self, id: impl Into<String>) -> Self {
208        self.session_id = Some(id.into());
209        self
210    }
211
212    /// Enable or disable tool availability advisory on phase transitions.
213    pub fn tool_advisory(mut self, enabled: bool) -> Self {
214        self.tool_advisory = enabled;
215        self
216    }
217
218    /// Set the periodic telemetry emission interval.
219    ///
220    /// When set, the processor periodically emits `LiveEvent::Telemetry`
221    /// and `LiveEvent::TurnMetrics` to the event stream.
222    pub fn telemetry_interval(mut self, interval: std::time::Duration) -> Self {
223        self.telemetry_interval = Some(interval);
224        self
225    }
226
227    /// Connect to Gemini and start the three-lane event processor.
228    pub async fn connect(self) -> Result<LiveHandle, AgentError> {
229        // Build-time validations
230        if let Some(ref pm) = self.phase_machine {
231            pm.validate().map_err(AgentError::Config)?;
232        }
233        if let Some(ref computed) = self.computed {
234            computed.validate().map_err(AgentError::Config)?;
235        }
236
237        // Apply NON_BLOCKING behavior to tool declarations for background tools
238        let mut config = self.config;
239        for (tool_name, mode) in &self.execution_modes {
240            if matches!(
241                mode,
242                super::background_tool::ToolExecutionMode::Background { .. }
243            ) {
244                for tool in &mut config.tools {
245                    if let Some(ref mut decls) = tool.function_declarations {
246                        for decl in decls {
247                            if decl.name == *tool_name {
248                                decl.behavior = Some(
249                                    gemini_genai_rs::prelude::FunctionCallingBehavior::NonBlocking,
250                                );
251                            }
252                        }
253                    }
254                }
255            }
256        }
257
258        // Connect via L0
259        let session = ConnectBuilder::new(config)
260            .build()
261            .await
262            .map_err(AgentError::Session)?;
263
264        // Wait for Active phase
265        session.wait_for_phase(SessionPhase::Active).await;
266
267        let mut callbacks = self.callbacks;
268        let on_usage_cb = callbacks.on_usage.take();
269        let callbacks = Arc::new(callbacks);
270        let raw_writer: Arc<dyn SessionWriter> = Arc::new(session.clone());
271        let state = self.state.unwrap_or_default();
272
273        // Subscribe twice: one for router → fast/ctrl, one for telemetry lane
274        let event_rx = session.subscribe();
275        let telem_rx = session.subscribe();
276
277        // Store initial phase's `needs` metadata for ContextBuilder.
278        if let Some(ref pm) = self.phase_machine {
279            state.session().set("phase", pm.current());
280            if let Some(phase) = pm.current_phase() {
281                if !phase.needs.is_empty() {
282                    state.set("session:phase_needs", phase.needs.clone());
283                }
284            }
285        }
286
287        let phase_machine_mutex = self.phase_machine.map(tokio::sync::Mutex::new);
288        let temporal_arc = self.temporal.map(Arc::new);
289        let background_tracker = Arc::new(BackgroundToolTracker::new());
290
291        // Create telemetry (auto-collected by the telemetry lane)
292        let telemetry = Arc::new(SessionTelemetry::new());
293        let telem_cancel = CancellationToken::new();
294
295        // Spawn telemetry lane (SessionSignals + SessionTelemetry on own broadcast rx)
296        let session_signals = SessionSignals::new(state.clone());
297        let _telem_handle = spawn_telemetry_lane(
298            telem_rx,
299            session_signals,
300            telemetry.clone(),
301            telem_cancel.clone(),
302            on_usage_cb,
303        );
304
305        // Build control plane config
306        let mut control_plane = ControlPlaneConfig {
307            soft_turn: self.soft_turn_timeout.map(SoftTurnDetector::new),
308            steering_mode: self.steering_mode,
309            context_delivery: self.context_delivery,
310            needs_fulfillment: self.repair_config.map(NeedsFulfillment::new),
311            persistence: self.persistence,
312            session_id: self.session_id,
313            tool_advisory: self.tool_advisory,
314            pending_context: None, // set after PendingContext is created below
315        };
316
317        // Create shared PendingContext for deferred delivery.
318        // The SAME Arc is given to both the DeferredWriter (which drains it before
319        // user sends) and the ControlPlaneConfig (which the processor uses to push
320        // context turns from the control lane).
321        let pending_context = if self.context_delivery == ContextDelivery::Deferred {
322            Some(Arc::new(PendingContext::new()))
323        } else {
324            None
325        };
326
327        // Wrap writer in DeferredWriter if deferred context delivery is enabled.
328        let (writer, user_writer) = if let Some(ref pending) = pending_context {
329            let deferred: Arc<dyn SessionWriter> =
330                Arc::new(DeferredWriter::new(raw_writer.clone(), pending.clone()));
331            // Processor uses raw_writer for internal sends (lifecycle context
332            // goes through PendingContext, not through the writer directly).
333            // User-facing LiveHandle uses the DeferredWriter.
334            (raw_writer, deferred)
335        } else {
336            (raw_writer.clone(), raw_writer)
337        };
338
339        // Pass shared pending context to control plane config
340        control_plane.pending_context = pending_context.clone();
341
342        // Create LiveEvent broadcast channel
343        use super::events::LiveEvent;
344        use tokio::sync::broadcast;
345        let (live_event_tx, _) = broadcast::channel::<LiveEvent>(4096);
346
347        // Spawn fast + control lanes (no session_signals, no transcript mutex)
348        let greeting_writer = user_writer.clone();
349        let (fast_handle, ctrl_handle) = spawn_event_processor(
350            event_rx,
351            callbacks,
352            self.dispatcher,
353            writer,
354            self.extractors,
355            state.clone(),
356            self.computed,
357            phase_machine_mutex,
358            self.watchers,
359            temporal_arc,
360            Some(background_tracker),
361            self.execution_modes,
362            control_plane,
363            live_event_tx.clone(),
364        );
365
366        // Spawn periodic telemetry emitter if interval is set
367        if let Some(interval) = self.telemetry_interval {
368            let telem_tx = live_event_tx.clone();
369            let telem_ref = telemetry.clone();
370            tokio::spawn(async move {
371                let mut tick = tokio::time::interval(interval);
372                let mut prev_turns = 0u64;
373                loop {
374                    tick.tick().await;
375                    let snap = telem_ref.snapshot();
376                    if let Some(obj) = snap.as_object() {
377                        let tc = obj
378                            .get("turn_count")
379                            .or_else(|| obj.get("response_count"))
380                            .and_then(|v| v.as_u64())
381                            .unwrap_or(0);
382                        if tc > prev_turns {
383                            let latency = obj
384                                .get("last_response_latency_ms")
385                                .and_then(|v| v.as_u64())
386                                .unwrap_or(0) as u32;
387                            let prompt = obj
388                                .get("prompt_token_count")
389                                .and_then(|v| v.as_u64())
390                                .unwrap_or(0) as u32;
391                            let response = obj
392                                .get("response_token_count")
393                                .and_then(|v| v.as_u64())
394                                .unwrap_or(0) as u32;
395                            let _ = telem_tx.send(LiveEvent::TurnMetrics {
396                                turn: tc as u32,
397                                latency_ms: latency,
398                                prompt_tokens: prompt,
399                                response_tokens: response,
400                            });
401                            prev_turns = tc;
402                        }
403                    }
404                    if telem_tx.send(LiveEvent::Telemetry(snap)).is_err() {
405                        break;
406                    }
407                }
408            });
409        }
410
411        // Send greeting prompt to trigger model-initiated conversation
412        if let Some(greeting) = self.greeting {
413            greeting_writer
414                .send_text(greeting)
415                .await
416                .map_err(AgentError::Session)?;
417        }
418
419        Ok(LiveHandle::new(
420            session,
421            user_writer,
422            fast_handle,
423            ctrl_handle,
424            state,
425            telemetry,
426            live_event_tx,
427            pending_context,
428        ))
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn builder_creates_with_defaults() {
438        let config = SessionConfig::new("test-key");
439        let builder = LiveSessionBuilder::new(config);
440        assert!(builder.dispatcher.is_none());
441        assert!(builder.computed.is_none());
442        assert!(builder.phase_machine.is_none());
443        assert!(builder.watchers.is_none());
444        assert!(builder.temporal.is_none());
445    }
446}