gemini_adk_fluent_rs/live/
connect.rs1use gemini_adk_rs::live::{LiveHandle, LiveSessionBuilder, PhaseMachine};
4use gemini_adk_rs::State;
5use gemini_genai_rs::prelude::*;
6
7use super::Live;
8
9impl Live {
10 pub async fn connect_google_ai(
12 mut self,
13 api_key: impl Into<String>,
14 ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
15 self.config.endpoint = ApiEndpoint::google_ai(api_key);
16 self.build_and_connect().await
17 }
18
19 pub async fn connect_vertex(
21 mut self,
22 project: impl Into<String>,
23 location: impl Into<String>,
24 access_token: impl Into<String>,
25 ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
26 self.config.endpoint = ApiEndpoint::vertex(project, location, access_token);
27 self.build_and_connect().await
28 }
29
30 pub async fn connect_from_env(
54 mut self,
55 ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
56 self.config.endpoint = resolve_endpoint_from_env()?;
57 self.build_and_connect().await
58 }
59
60 pub async fn connect(
66 mut self,
67 config: SessionConfig,
68 ) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
69 self.config.endpoint = config.endpoint;
71 self.config.model = config.model;
72 self.build_and_connect().await
73 }
74
75 async fn build_and_connect(mut self) -> Result<LiveHandle, gemini_adk_rs::error::AgentError> {
76 if uses_audio_output(&self.config) {
77 self.config = self.config.voice_realtime_defaults();
78 }
79
80 if let Some(path) = self.record_wire_path.take() {
83 let recorder = FileWireRecorder::create(&path).map_err(|e| {
84 gemini_adk_rs::error::AgentError::Config(format!(
85 "failed to create wire log at {}: {e}",
86 path.display()
87 ))
88 })?;
89 self.config = self.config.record_wire(std::sync::Arc::new(recorder));
90 }
91
92 let mut builder = LiveSessionBuilder::new(self.config);
93
94 let mut dispatcher = self.dispatcher;
96 if !self.deferred_agent_tools.is_empty() {
97 let state = State::new();
98 let d = dispatcher.get_or_insert_with(gemini_adk_rs::tool::ToolDispatcher::new);
99 for deferred in self.deferred_agent_tools {
100 d.register(gemini_adk_rs::TextAgentTool::from_arc(
101 deferred.name,
102 deferred.description,
103 deferred.agent,
104 state.clone(),
105 ));
106 }
107 builder = builder.with_state(state);
108 }
109
110 if !self.deferred_tools.is_empty() {
112 let d = dispatcher.get_or_insert_with(gemini_adk_rs::tool::ToolDispatcher::new);
113 for deferred in std::mem::take(&mut self.deferred_tools) {
114 resolve_deferred_tool(deferred, d).await?;
115 }
116 }
117
118 if let Some(provider) = self.confirmation_provider {
120 dispatcher
121 .get_or_insert_with(gemini_adk_rs::tool::ToolDispatcher::new)
122 .set_confirmation_provider(provider);
123 }
124
125 if let Some(dispatcher) = dispatcher {
126 builder = builder.dispatcher(dispatcher);
127 }
128 if let Some(greeting) = self.greeting {
129 builder = builder.greeting(greeting);
130 }
131 builder = builder.callbacks(self.callbacks);
132 for ext in self.extractors {
133 builder = builder.extractor(ext);
134 }
135
136 if !self.computed.is_empty() {
138 builder = builder.computed(self.computed);
139 }
140 if let Some(initial) = self.initial_phase {
141 let mut pm = PhaseMachine::new(&initial);
142 for phase in self.phases {
143 pm.add_phase(phase);
144 }
145 builder = builder.phase_machine(pm);
146 }
147 if !self.watchers.observed_keys().is_empty() {
148 builder = builder.watchers(self.watchers);
149 }
150 builder = builder.temporal(self.temporal);
151
152 for (name, mode) in self.tool_execution_modes {
154 builder = builder.tool_execution_mode(name, mode);
155 }
156
157 if let Some(timeout) = self.soft_turn_timeout {
159 builder = builder.soft_turn_timeout(timeout);
160 }
161 builder = builder.steering_mode(self.steering_mode);
162 builder = builder.context_delivery(self.context_delivery);
163 builder = builder.delivery(self.delivery);
164 if let Some(config) = self.repair_config {
165 builder = builder.repair(config);
166 }
167 if let Some(p) = self.persistence {
168 builder = builder.persistence(p);
169 }
170 if let Some(id) = self.session_id {
171 builder = builder.session_id(id);
172 }
173 for layer in self.middleware_layers {
174 builder = builder.middleware(layer);
175 }
176 if let Some(flow) = self.flow {
177 let mut monitor = gemini_adk_rs::flow::FlowMonitor::new(flow, self.flow_mode);
178 for (step, agent, mode) in self.flow_actions {
179 monitor = monitor.on_enter(step, gemini_adk_rs::flow::run(agent, mode));
180 }
181 builder = builder.flow_monitor(monitor);
182 }
183 builder = builder.tool_advisory(self.tool_advisory);
184 if let Some(interval) = self.telemetry_interval {
185 builder = builder.telemetry_interval(interval);
186 }
187
188 for llm in self.warm_up_llms {
191 tokio::spawn(async move {
192 let _ = llm.warm_up().await;
193 });
194 }
195
196 builder.connect().await
197 }
198}
199
200fn resolve_endpoint_from_env() -> Result<ApiEndpoint, gemini_adk_rs::error::AgentError> {
203 use gemini_adk_rs::error::AgentError;
204 use gemini_genai_rs::protocol::types::EndpointEnvError;
205
206 match ApiEndpoint::from_env() {
207 Ok(endpoint) => Ok(endpoint),
208 Err(EndpointEnvError::Missing("GOOGLE_ACCESS_TOKEN")) => {
211 let project = std::env::var("GOOGLE_CLOUD_PROJECT").map_err(|_| {
212 AgentError::Config("GOOGLE_CLOUD_PROJECT is required for Vertex AI".into())
213 })?;
214 let location = std::env::var("GOOGLE_CLOUD_LOCATION")
215 .unwrap_or_else(|_| "us-central1".to_string());
216 let token = gcloud_access_token()?;
217 Ok(ApiEndpoint::vertex(project, location, token))
218 }
219 Err(e) => Err(AgentError::Config(format!(
220 "connect_from_env: {e}. For Google AI set GEMINI_API_KEY; for Vertex AI set \
221 GOOGLE_GENAI_USE_VERTEXAI=true and GOOGLE_CLOUD_PROJECT (token via \
222 GOOGLE_ACCESS_TOKEN or the gcloud CLI)."
223 ))),
224 }
225}
226
227fn gcloud_access_token() -> Result<String, gemini_adk_rs::error::AgentError> {
229 use gemini_adk_rs::error::AgentError;
230
231 let output = std::process::Command::new("gcloud")
232 .args(["auth", "print-access-token"])
233 .output()
234 .map_err(|e| {
235 AgentError::Config(format!(
236 "Vertex AI needs an access token: set GOOGLE_ACCESS_TOKEN, or install the \
237 gcloud CLI (failed to run `gcloud auth print-access-token`: {e})"
238 ))
239 })?;
240 if !output.status.success() {
241 return Err(AgentError::Config(format!(
242 "`gcloud auth print-access-token` failed: {}",
243 String::from_utf8_lossy(&output.stderr).trim()
244 )));
245 }
246 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
247 if token.is_empty() {
248 return Err(AgentError::Config(
249 "`gcloud auth print-access-token` returned an empty token".into(),
250 ));
251 }
252 Ok(token)
253}
254
255async fn resolve_deferred_tool(
259 tool: crate::compose::tools::DeferredTool,
260 dispatcher: &mut gemini_adk_rs::tool::ToolDispatcher,
261) -> Result<(), gemini_adk_rs::error::AgentError> {
262 use crate::compose::tools::DeferredTool;
263 use gemini_adk_rs::error::AgentError;
264 use gemini_adk_rs::tools::mcp::{McpSessionManager, McpTool};
265 use std::sync::Arc;
266
267 match tool {
268 DeferredTool::Mcp { params } => {
269 let manager = Arc::new(McpSessionManager::new(parse_mcp_params(¶ms)));
270 let infos = manager.list_tools().await.map_err(|e| {
271 AgentError::Config(format!("MCP tool discovery failed for {params:?}: {e}"))
272 })?;
273 for info in infos {
274 dispatcher.register_function(Arc::new(McpTool::new(
275 info.name,
276 info.description,
277 Some(info.input_schema),
278 manager.clone(),
279 )));
280 }
281 Ok(())
282 }
283 DeferredTool::A2a { url, skill } => Err(AgentError::Config(format!(
286 "T::a2a(url={url:?}, skill={skill:?}) is not yet implemented; tracked for ADK parity"
287 ))),
288 DeferredTool::OpenApi { name, spec_url } => Err(AgentError::Config(format!(
289 "T::openapi(name={name:?}, spec_url={spec_url:?}) is not yet implemented; \
290 tracked for ADK parity"
291 ))),
292 DeferredTool::Search { name, .. } => Err(AgentError::Config(format!(
293 "T::search(name={name:?}) is not yet implemented; tracked for ADK parity"
294 ))),
295 }
296}
297
298fn parse_mcp_params(params: &str) -> gemini_adk_rs::tools::mcp::McpConnectionParams {
301 use gemini_adk_rs::tools::mcp::McpConnectionParams;
302
303 let trimmed = params.trim();
304 if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
305 McpConnectionParams::Sse {
306 url: trimmed.to_string(),
307 headers: None,
308 }
309 } else {
310 let mut parts = trimmed.split_whitespace();
311 let command = parts.next().unwrap_or_default().to_string();
312 let args = parts.map(str::to_string).collect();
313 McpConnectionParams::Stdio {
314 command,
315 args,
316 timeout: Some(std::time::Duration::from_secs(30)),
317 }
318 }
319}
320
321fn uses_audio_output(config: &SessionConfig) -> bool {
322 config
323 .generation_config
324 .response_modalities
325 .as_ref()
326 .map(|modalities| modalities.iter().any(|m| matches!(m, Modality::Audio)))
327 .unwrap_or(true)
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn uses_audio_output_defaults_to_audio() {
336 let config = SessionConfig::new("key");
337 assert!(uses_audio_output(&config));
338 }
339
340 #[test]
341 fn uses_audio_output_respects_text_only() {
342 let config = SessionConfig::new("key").text_only();
343 assert!(!uses_audio_output(&config));
344 }
345}