gemini_adk_fluent_rs/live/
connect.rs

1//! Connection methods for `Live`.
2
3use gemini_adk_rs::live::{LiveHandle, LiveSessionBuilder, PhaseMachine};
4use gemini_adk_rs::State;
5use gemini_genai_rs::prelude::*;
6
7use super::Live;
8
9impl Live {
10    /// Connect using a Google AI API key.
11    pub async fn connect_google_ai(
12        mut self,
13        api_key: impl Into<String>,
14    ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
15        self.config.endpoint = ApiEndpoint::google_ai(api_key);
16        self.build_and_connect().await
17    }
18
19    /// Connect using Vertex AI credentials.
20    pub async fn connect_vertex(
21        mut self,
22        project: impl Into<String>,
23        location: impl Into<String>,
24        access_token: impl Into<String>,
25    ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
26        self.config.endpoint = ApiEndpoint::vertex(project, location, access_token);
27        self.build_and_connect().await
28    }
29
30    /// Connect using a pre-configured SessionConfig for auth and model.
31    ///
32    /// Merges the provided config's `endpoint` and `model` into the builder's
33    /// config, preserving system instruction, tools, voice, transcription, and
34    /// all other settings configured via the fluent API.
35    pub async fn connect(
36        mut self,
37        config: SessionConfig,
38    ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
39        // Merge auth/model from external config, keep everything else from builder.
40        self.config.endpoint = config.endpoint;
41        self.config.model = config.model;
42        self.build_and_connect().await
43    }
44
45    async fn build_and_connect(mut self) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
46        if uses_audio_output(&self.config) {
47            self.config = self.config.voice_realtime_defaults();
48        }
49
50        let mut builder = LiveSessionBuilder::new(self.config);
51
52        // Resolve deferred agent tools: create shared State, register TextAgentTools
53        let mut dispatcher = self.dispatcher;
54        if !self.deferred_agent_tools.is_empty() {
55            let state = State::new();
56            let d = dispatcher.get_or_insert_with(gemini_adk_rs::tool::ToolDispatcher::new);
57            for deferred in self.deferred_agent_tools {
58                d.register(gemini_adk_rs::TextAgentTool::from_arc(
59                    deferred.name,
60                    deferred.description,
61                    deferred.agent,
62                    state.clone(),
63                ));
64            }
65            builder = builder.with_state(state);
66        }
67
68        if let Some(dispatcher) = dispatcher {
69            builder = builder.dispatcher(dispatcher);
70        }
71        if let Some(greeting) = self.greeting {
72            builder = builder.greeting(greeting);
73        }
74        builder = builder.callbacks(self.callbacks);
75        for ext in self.extractors {
76            builder = builder.extractor(ext);
77        }
78
79        // Pass L1 registries
80        if !self.computed.is_empty() {
81            builder = builder.computed(self.computed);
82        }
83        if let Some(initial) = self.initial_phase {
84            let mut pm = PhaseMachine::new(&initial);
85            for phase in self.phases {
86                pm.add_phase(phase);
87            }
88            builder = builder.phase_machine(pm);
89        }
90        if !self.watchers.observed_keys().is_empty() {
91            builder = builder.watchers(self.watchers);
92        }
93        builder = builder.temporal(self.temporal);
94
95        // Pass tool execution modes
96        for (name, mode) in self.tool_execution_modes {
97            builder = builder.tool_execution_mode(name, mode);
98        }
99
100        // Pass control plane configuration
101        if let Some(timeout) = self.soft_turn_timeout {
102            builder = builder.soft_turn_timeout(timeout);
103        }
104        builder = builder.steering_mode(self.steering_mode);
105        builder = builder.context_delivery(self.context_delivery);
106        if let Some(config) = self.repair_config {
107            builder = builder.repair(config);
108        }
109        if let Some(p) = self.persistence {
110            builder = builder.persistence(p);
111        }
112        if let Some(id) = self.session_id {
113            builder = builder.session_id(id);
114        }
115        builder = builder.tool_advisory(self.tool_advisory);
116        if let Some(interval) = self.telemetry_interval {
117            builder = builder.telemetry_interval(interval);
118        }
119
120        // Spawn fire-and-forget warm-up tasks for OOB LLMs
121        // (pre-establishes TCP+TLS so first extract call is fast)
122        for llm in self.warm_up_llms {
123            tokio::spawn(async move {
124                let _ = llm.warm_up().await;
125            });
126        }
127
128        builder.connect().await
129    }
130}
131
132fn uses_audio_output(config: &SessionConfig) -> bool {
133    config
134        .generation_config
135        .response_modalities
136        .as_ref()
137        .map(|modalities| modalities.iter().any(|m| matches!(m, Modality::Audio)))
138        .unwrap_or(true)
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn uses_audio_output_defaults_to_audio() {
147        let config = SessionConfig::new("key");
148        assert!(uses_audio_output(&config));
149    }
150
151    #[test]
152    fn uses_audio_output_respects_text_only() {
153        let config = SessionConfig::new("key").text_only();
154        assert!(!uses_audio_output(&config));
155    }
156}