gemini_adk_fluent_rs/live/
connect.rs

1//! Connection methods for `Live`.
2
3use gemini_adk_rs::live::{LiveHandle, LiveSessionBuilder, PhaseMachine};
4use gemini_adk_rs::State;
5use gemini_genai_rs::prelude::*;
6
7use super::Live;
8
9impl Live {
10    /// Connect using a Google AI API key.
11    pub async fn connect_google_ai(
12        mut self,
13        api_key: impl Into<String>,
14    ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
15        self.config.endpoint = ApiEndpoint::google_ai(api_key);
16        self.build_and_connect().await
17    }
18
19    /// Connect using Vertex AI credentials.
20    pub async fn connect_vertex(
21        mut self,
22        project: impl Into<String>,
23        location: impl Into<String>,
24        access_token: impl Into<String>,
25    ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
26        self.config.endpoint = ApiEndpoint::vertex(project, location, access_token);
27        self.build_and_connect().await
28    }
29
30    /// Connect by resolving the platform and credentials from standard
31    /// environment variables — the zero-ceremony entry point.
32    ///
33    /// Resolution (see [`ApiEndpoint::from_env`]):
34    /// - `GOOGLE_GENAI_USE_VERTEXAI=true` → Vertex AI using
35    ///   `GOOGLE_CLOUD_PROJECT`, `GOOGLE_CLOUD_LOCATION` (default
36    ///   `us-central1`), and a token from `GOOGLE_ACCESS_TOKEN`. If that
37    ///   token is unset, this falls back to running
38    ///   `gcloud auth print-access-token`.
39    /// - otherwise → Google AI using `GEMINI_API_KEY` (or
40    ///   `GOOGLE_GENAI_API_KEY` / `GOOGLE_API_KEY`).
41    ///
42    /// ```no_run
43    /// # use gemini_adk_fluent_rs::prelude::*;
44    /// # async fn run() -> Result<(), AgentError> {
45    /// let handle = Live::builder()
46    ///     .model(GeminiModel::Gemini2_0FlashLive)
47    ///     .voice(Voice::Kore)
48    ///     .connect_from_env()
49    ///     .await?;
50    /// # let _ = handle; Ok(())
51    /// # }
52    /// ```
53    pub async fn connect_from_env(
54        mut self,
55    ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
56        self.config.endpoint = resolve_endpoint_from_env()?;
57        self.build_and_connect().await
58    }
59
60    /// Connect using a pre-configured SessionConfig for auth and model.
61    ///
62    /// Merges the provided config's `endpoint` and `model` into the builder's
63    /// config, preserving system instruction, tools, voice, transcription, and
64    /// all other settings configured via the fluent API.
65    pub async fn connect(
66        mut self,
67        config: SessionConfig,
68    ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
69        // Merge auth/model from external config, keep everything else from builder.
70        self.config.endpoint = config.endpoint;
71        self.config.model = config.model;
72        self.build_and_connect().await
73    }
74
75    async fn build_and_connect(mut self) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
76        if uses_audio_output(&self.config) {
77            self.config = self.config.voice_realtime_defaults();
78        }
79
80        // Resolve a `.record_wire(path)` request into a FileWireRecorder now
81        // that we are actually connecting.
82        if let Some(path) = self.record_wire_path.take() {
83            let recorder = FileWireRecorder::create(&path).map_err(|e| {
84                gemini_adk_rs::error::AgentError::Config(format!(
85                    "failed to create wire log at {}: {e}",
86                    path.display()
87                ))
88            })?;
89            self.config = self.config.record_wire(std::sync::Arc::new(recorder));
90        }
91
92        let mut builder = LiveSessionBuilder::new(self.config);
93
94        // Resolve deferred agent tools: create shared State, register TextAgentTools
95        let mut dispatcher = self.dispatcher;
96        if !self.deferred_agent_tools.is_empty() {
97            let state = State::new();
98            let d = dispatcher.get_or_insert_with(gemini_adk_rs::tool::ToolDispatcher::new);
99            for deferred in self.deferred_agent_tools {
100                d.register(gemini_adk_rs::TextAgentTool::from_arc(
101                    deferred.name,
102                    deferred.description,
103                    deferred.agent,
104                    state.clone(),
105                ));
106            }
107            builder = builder.with_state(state);
108        }
109
110        // Resolve deferred async tools (MCP connections, etc.).
111        if !self.deferred_tools.is_empty() {
112            let d = dispatcher.get_or_insert_with(gemini_adk_rs::tool::ToolDispatcher::new);
113            for deferred in std::mem::take(&mut self.deferred_tools) {
114                resolve_deferred_tool(deferred, d).await?;
115            }
116        }
117
118        // Attach the confirmation provider so `T::confirm(..)` tools are gated.
119        if let Some(provider) = self.confirmation_provider {
120            dispatcher
121                .get_or_insert_with(gemini_adk_rs::tool::ToolDispatcher::new)
122                .set_confirmation_provider(provider);
123        }
124
125        if let Some(dispatcher) = dispatcher {
126            builder = builder.dispatcher(dispatcher);
127        }
128        if let Some(greeting) = self.greeting {
129            builder = builder.greeting(greeting);
130        }
131        builder = builder.callbacks(self.callbacks);
132        for ext in self.extractors {
133            builder = builder.extractor(ext);
134        }
135
136        // Pass L1 registries
137        if !self.computed.is_empty() {
138            builder = builder.computed(self.computed);
139        }
140        if let Some(initial) = self.initial_phase {
141            let mut pm = PhaseMachine::new(&initial);
142            for phase in self.phases {
143                pm.add_phase(phase);
144            }
145            builder = builder.phase_machine(pm);
146        }
147        if !self.watchers.observed_keys().is_empty() {
148            builder = builder.watchers(self.watchers);
149        }
150        builder = builder.temporal(self.temporal);
151
152        // Pass tool execution modes
153        for (name, mode) in self.tool_execution_modes {
154            builder = builder.tool_execution_mode(name, mode);
155        }
156
157        // Pass control plane configuration
158        if let Some(timeout) = self.soft_turn_timeout {
159            builder = builder.soft_turn_timeout(timeout);
160        }
161        builder = builder.steering_mode(self.steering_mode);
162        builder = builder.context_delivery(self.context_delivery);
163        builder = builder.delivery(self.delivery);
164        if let Some(config) = self.repair_config {
165            builder = builder.repair(config);
166        }
167        if let Some(p) = self.persistence {
168            builder = builder.persistence(p);
169        }
170        if let Some(id) = self.session_id {
171            builder = builder.session_id(id);
172        }
173        for layer in self.middleware_layers {
174            builder = builder.middleware(layer);
175        }
176        if let Some(flow) = self.flow {
177            let mut monitor = gemini_adk_rs::flow::FlowMonitor::new(flow, self.flow_mode);
178            for (step, agent, mode) in self.flow_actions {
179                monitor = monitor.on_enter(step, gemini_adk_rs::flow::run(agent, mode));
180            }
181            builder = builder.flow_monitor(monitor);
182        }
183        builder = builder.tool_advisory(self.tool_advisory);
184        if let Some(interval) = self.telemetry_interval {
185            builder = builder.telemetry_interval(interval);
186        }
187
188        // Spawn fire-and-forget warm-up tasks for OOB LLMs
189        // (pre-establishes TCP+TLS so first extract call is fast)
190        for llm in self.warm_up_llms {
191            tokio::spawn(async move {
192                let _ = llm.warm_up().await;
193            });
194        }
195
196        builder.connect().await
197    }
198}
199
200/// Resolve an [`ApiEndpoint`] from the environment, with a `gcloud` token
201/// fallback for Vertex AI when `GOOGLE_ACCESS_TOKEN` is not set.
202fn resolve_endpoint_from_env() -> Result<ApiEndpoint, gemini_adk_rs::error::AgentError> {
203    use gemini_adk_rs::error::AgentError;
204    use gemini_genai_rs::protocol::types::EndpointEnvError;
205
206    match ApiEndpoint::from_env() {
207        Ok(endpoint) => Ok(endpoint),
208        // Vertex was selected but no token was in the environment — fall back
209        // to Application Default Credentials via the gcloud CLI.
210        Err(EndpointEnvError::Missing("GOOGLE_ACCESS_TOKEN")) => {
211            let project = std::env::var("GOOGLE_CLOUD_PROJECT").map_err(|_| {
212                AgentError::Config("GOOGLE_CLOUD_PROJECT is required for Vertex AI".into())
213            })?;
214            let location = std::env::var("GOOGLE_CLOUD_LOCATION")
215                .unwrap_or_else(|_| "us-central1".to_string());
216            let token = gcloud_access_token()?;
217            Ok(ApiEndpoint::vertex(project, location, token))
218        }
219        Err(e) => Err(AgentError::Config(format!(
220            "connect_from_env: {e}. For Google AI set GEMINI_API_KEY; for Vertex AI set \
221             GOOGLE_GENAI_USE_VERTEXAI=true and GOOGLE_CLOUD_PROJECT (token via \
222             GOOGLE_ACCESS_TOKEN or the gcloud CLI)."
223        ))),
224    }
225}
226
227/// Fetch an OAuth2 access token via `gcloud auth print-access-token`.
228fn gcloud_access_token() -> Result<String, gemini_adk_rs::error::AgentError> {
229    use gemini_adk_rs::error::AgentError;
230
231    let output = std::process::Command::new("gcloud")
232        .args(["auth", "print-access-token"])
233        .output()
234        .map_err(|e| {
235            AgentError::Config(format!(
236                "Vertex AI needs an access token: set GOOGLE_ACCESS_TOKEN, or install the \
237                 gcloud CLI (failed to run `gcloud auth print-access-token`: {e})"
238            ))
239        })?;
240    if !output.status.success() {
241        return Err(AgentError::Config(format!(
242            "`gcloud auth print-access-token` failed: {}",
243            String::from_utf8_lossy(&output.stderr).trim()
244        )));
245    }
246    let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
247    if token.is_empty() {
248        return Err(AgentError::Config(
249            "`gcloud auth print-access-token` returned an empty token".into(),
250        ));
251    }
252    Ok(token)
253}
254
255/// Resolve a single [`DeferredTool`](crate::compose::tools::DeferredTool) into
256/// concrete tool registrations on the dispatcher. Runs at connect time because
257/// these tools require async I/O (a network call or a subprocess handshake).
258async fn resolve_deferred_tool(
259    tool: crate::compose::tools::DeferredTool,
260    dispatcher: &mut gemini_adk_rs::tool::ToolDispatcher,
261) -> Result<(), gemini_adk_rs::error::AgentError> {
262    use crate::compose::tools::DeferredTool;
263    use gemini_adk_rs::error::AgentError;
264    use gemini_adk_rs::tools::mcp::{McpSessionManager, McpTool};
265    use std::sync::Arc;
266
267    match tool {
268        DeferredTool::Mcp { params } => {
269            let manager = Arc::new(McpSessionManager::new(parse_mcp_params(&params)));
270            let infos = manager.list_tools().await.map_err(|e| {
271                AgentError::Config(format!("MCP tool discovery failed for {params:?}: {e}"))
272            })?;
273            for info in infos {
274                dispatcher.register_function(Arc::new(McpTool::new(
275                    info.name,
276                    info.description,
277                    Some(info.input_schema),
278                    manager.clone(),
279                )));
280            }
281            Ok(())
282        }
283        // The following are part of the ADK-parity toolset roadmap; they are
284        // surfaced as explicit connect-time errors rather than silently dropped.
285        DeferredTool::A2a { url, skill } => Err(AgentError::Config(format!(
286            "T::a2a(url={url:?}, skill={skill:?}) is not yet implemented; tracked for ADK parity"
287        ))),
288        DeferredTool::OpenApi { name, spec_url } => Err(AgentError::Config(format!(
289            "T::openapi(name={name:?}, spec_url={spec_url:?}) is not yet implemented; \
290             tracked for ADK parity"
291        ))),
292        DeferredTool::Search { name, .. } => Err(AgentError::Config(format!(
293            "T::search(name={name:?}) is not yet implemented; tracked for ADK parity"
294        ))),
295    }
296}
297
298/// Parse an MCP connection string: an `http(s)://` URL becomes an SSE/HTTP
299/// connection, anything else is treated as a stdio command line.
300fn parse_mcp_params(params: &str) -> gemini_adk_rs::tools::mcp::McpConnectionParams {
301    use gemini_adk_rs::tools::mcp::McpConnectionParams;
302
303    let trimmed = params.trim();
304    if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
305        McpConnectionParams::Sse {
306            url: trimmed.to_string(),
307            headers: None,
308        }
309    } else {
310        let mut parts = trimmed.split_whitespace();
311        let command = parts.next().unwrap_or_default().to_string();
312        let args = parts.map(str::to_string).collect();
313        McpConnectionParams::Stdio {
314            command,
315            args,
316            timeout: Some(std::time::Duration::from_secs(30)),
317        }
318    }
319}
320
321fn uses_audio_output(config: &SessionConfig) -> bool {
322    config
323        .generation_config
324        .response_modalities
325        .as_ref()
326        .map(|modalities| modalities.iter().any(|m| matches!(m, Modality::Audio)))
327        .unwrap_or(true)
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn uses_audio_output_defaults_to_audio() {
336        let config = SessionConfig::new("key");
337        assert!(uses_audio_output(&config));
338    }
339
340    #[test]
341    fn uses_audio_output_respects_text_only() {
342        let config = SessionConfig::new("key").text_only();
343        assert!(!uses_audio_output(&config));
344    }
345}