gemini_adk_rs/llm/
gemini.rs

1//! Concrete Gemini LLM implementation using gemini-live `Client`.
2//!
3//! The [`GeminiLlm`] struct is always available for type references and registry
4//! wiring. Actual HTTP generation requires the `gemini-llm` feature flag, which
5//! pulls in `gemini-live/http` and `gemini-live/generate`.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use once_cell::sync::Lazy;
12use regex::Regex;
13
14#[cfg(feature = "gemini-llm")]
15use crate::llm::TokenUsage;
16use crate::llm::{
17    BaseLlm, EnvTokenProvider, GcloudTokenProvider, LlmError, LlmRequest, LlmResponse,
18    TokenProvider,
19};
20use crate::utils::variant::{get_google_llm_variant, GoogleLlmVariant};
21
22/// Parameters for constructing a [`GeminiLlm`].
23#[derive(Default)]
24pub struct GeminiLlmParams {
25    /// Model name (defaults to "gemini-2.5-flash").
26    pub model: Option<String>,
27    /// API key for Gemini API (non-Vertex).
28    pub api_key: Option<String>,
29    /// Whether to use Vertex AI backend.
30    pub vertexai: Option<bool>,
31    /// Google Cloud project ID (Vertex AI only).
32    pub project: Option<String>,
33    /// Google Cloud region (Vertex AI only, defaults to "us-central1").
34    pub location: Option<String>,
35    /// Custom HTTP headers for requests.
36    pub headers: Option<HashMap<String, String>>,
37    /// Custom token provider for VertexAI. Defaults to reading `GOOGLE_ACCESS_TOKEN` env var.
38    pub token_provider: Option<Arc<dyn TokenProvider>>,
39}
40
41/// Concrete Gemini LLM implementation using gemini-live `Client`.
42///
43/// The gemini-live `Client` is created once at construction time and reused for
44/// all `generate()` calls, matching the JS GenAI SDK pattern where a single
45/// `GoogleGenAI` instance is shared across requests.
46pub struct GeminiLlm {
47    model: String,
48    variant: GoogleLlmVariant,
49    /// Stored for constructing the gemini-live `Client` when `gemini-llm` is enabled.
50    #[allow(dead_code)]
51    params: GeminiLlmParams,
52    /// Token provider for VertexAI token refresh.
53    #[allow(dead_code)]
54    token_provider: Arc<dyn TokenProvider>,
55    /// Cached gemini-live Client, created once at construction time.
56    #[cfg(feature = "gemini-llm")]
57    client: gemini_genai_rs::prelude::Client,
58}
59
60static SUPPORTED_PATTERNS: Lazy<Vec<Regex>> = Lazy::new(|| {
61    vec![
62        Regex::new(r"^gemini-.*$").unwrap(),
63        Regex::new(r"^projects/.*/endpoints/.*$").unwrap(),
64        Regex::new(r"^projects/.*/models/gemini.*$").unwrap(),
65    ]
66});
67
68impl GeminiLlm {
69    /// Create a new `GeminiLlm` from parameters.
70    ///
71    /// Resolves defaults for model, variant, API key, project, and location
72    /// from parameters first, then falls back to environment variables.
73    /// The gemini-live `Client` is created once here and reused for all calls.
74    pub fn new(mut params: GeminiLlmParams) -> Self {
75        // Resolve model (default to "gemini-2.5-flash")
76        let model = params
77            .model
78            .clone()
79            .unwrap_or_else(|| "gemini-2.5-flash".to_string());
80
81        // Resolve variant from params or env
82        let variant = if let Some(true) = params.vertexai {
83            GoogleLlmVariant::VertexAi
84        } else if let Some(false) = params.vertexai {
85            GoogleLlmVariant::GeminiApi
86        } else {
87            get_google_llm_variant()
88        };
89
90        // Resolve API key from params or env
91        if params.api_key.is_none() && variant == GoogleLlmVariant::GeminiApi {
92            params.api_key = std::env::var("GOOGLE_GENAI_API_KEY")
93                .or_else(|_| std::env::var("GEMINI_API_KEY"))
94                .ok();
95        }
96
97        // Resolve project/location from env for Vertex AI
98        if variant == GoogleLlmVariant::VertexAi {
99            if params.project.is_none() {
100                params.project = std::env::var("GOOGLE_CLOUD_PROJECT").ok();
101            }
102            if params.location.is_none() {
103                params.location = std::env::var("GOOGLE_CLOUD_LOCATION").ok();
104            }
105        }
106
107        // Resolve token provider for VertexAI.
108        // Default to GcloudTokenProvider (env var -> gcloud CLI fallback) for VertexAI,
109        // matching the auth resolution in build_session_config(). For GeminiApi, use
110        // EnvTokenProvider since API key auth doesn't need token refresh.
111        let token_provider: Arc<dyn TokenProvider> =
112            params.token_provider.take().unwrap_or_else(|| {
113                if variant == GoogleLlmVariant::VertexAi {
114                    Arc::new(GcloudTokenProvider::new(std::time::Duration::from_secs(
115                        45 * 60,
116                    )))
117                } else {
118                    Arc::new(EnvTokenProvider)
119                }
120            });
121
122        // Create the gemini-live Client once, reuse across generate() calls.
123        // For VertexAI, use from_vertex_refreshable() so the token is dynamically
124        // refreshed on every REST API call (via auth_headers()), preventing 401
125        // errors from stale tokens during long-running sessions.
126        #[cfg(feature = "gemini-llm")]
127        let client = {
128            use gemini_genai_rs::prelude::*;
129            match variant {
130                GoogleLlmVariant::GeminiApi => {
131                    let api_key = params.api_key.as_deref().unwrap_or("");
132                    Client::from_api_key(api_key).model(GeminiModel::Custom(model.clone()))
133                }
134                GoogleLlmVariant::VertexAi => {
135                    let project = params.project.as_deref().unwrap_or("").to_string();
136                    let location = params
137                        .location
138                        .as_deref()
139                        .unwrap_or("us-central1")
140                        .to_string();
141                    let tp = token_provider.clone();
142                    Client::from_vertex_refreshable(project, location, move || tp.token())
143                        .model(GeminiModel::Custom(model.clone()))
144                }
145            }
146        };
147
148        Self {
149            model,
150            variant,
151            params,
152            token_provider,
153            #[cfg(feature = "gemini-llm")]
154            client,
155        }
156    }
157
158    /// Check if a model name is supported by `GeminiLlm`.
159    pub fn is_supported(model: &str) -> bool {
160        SUPPORTED_PATTERNS.iter().any(|re| re.is_match(model))
161    }
162
163    /// Get the variant (VertexAI vs GeminiApi).
164    pub fn variant(&self) -> GoogleLlmVariant {
165        self.variant
166    }
167
168    /// Preprocess request: remove labels and displayName for non-Vertex (Gemini API).
169    fn preprocess_request(&self, _request: &mut LlmRequest) {
170        // For Gemini API backend: remove labels and displayName from tools.
171        // This is a no-op for now since LlmRequest doesn't have those fields yet.
172        // In a full implementation, this would strip Vertex-only fields.
173    }
174}
175
176#[async_trait]
177impl BaseLlm for GeminiLlm {
178    fn model_id(&self) -> &str {
179        &self.model
180    }
181
182    async fn generate(&self, mut request: LlmRequest) -> Result<LlmResponse, LlmError> {
183        self.preprocess_request(&mut request);
184
185        // Feature-gate the actual HTTP call behind gemini-live's generate + http features.
186        #[cfg(feature = "gemini-llm")]
187        {
188            use gemini_genai_rs::generate::GenerateContentConfig;
189            use gemini_genai_rs::prelude::*;
190
191            // Build GenerateContentConfig from LlmRequest — move, don't clone.
192            let mut config = if request.contents.is_empty() {
193                GenerateContentConfig::from_text("")
194            } else {
195                GenerateContentConfig::from_contents(std::mem::take(&mut request.contents))
196            };
197
198            if let Some(sys) = request.system_instruction.take() {
199                config = config.system_instruction(&sys);
200            }
201            if !request.tools.is_empty() {
202                config.tools = std::mem::take(&mut request.tools);
203            }
204            if let Some(temp) = request.temperature {
205                config = config.temperature(temp);
206            }
207            if let Some(max) = request.max_output_tokens {
208                config = config.max_output_tokens(max);
209            }
210            if request.response_mime_type.is_some() || request.response_json_schema.is_some() {
211                let gc = config
212                    .generation_config
213                    .get_or_insert_with(gemini_genai_rs::prelude::GenerationConfig::default);
214                if let Some(mime) = request.response_mime_type.take() {
215                    gc.response_mime_type = Some(mime);
216                }
217                if let Some(schema) = request.response_json_schema.take() {
218                    gc.response_json_schema = Some(schema);
219                }
220            }
221
222            let response = self
223                .client
224                .generate_content_with(config, None)
225                .await
226                .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
227
228            let content = response
229                .candidates
230                .first()
231                .and_then(|c| c.content.clone())
232                .unwrap_or_else(|| Content {
233                    role: Some(Role::Model),
234                    parts: vec![],
235                });
236
237            let finish_reason = response
238                .candidates
239                .first()
240                .and_then(|c| c.finish_reason)
241                .map(|r| format!("{:?}", r));
242
243            let usage = response.usage_metadata.map(|u| TokenUsage {
244                prompt_tokens: u.prompt_token_count.unwrap_or(0),
245                completion_tokens: u.response_token_count.unwrap_or(0),
246                total_tokens: u.total_token_count.unwrap_or(0),
247            });
248
249            Ok(LlmResponse {
250                content,
251                finish_reason,
252                usage,
253            })
254        }
255
256        #[cfg(not(feature = "gemini-llm"))]
257        {
258            // Suppress unused-variable warnings when the feature is disabled.
259            let _ = request;
260            Err(LlmError::RequestFailed(
261                "GeminiLlm requires the 'gemini-llm' feature flag \
262                 (depends on gemini-live HTTP client)"
263                    .into(),
264            ))
265        }
266    }
267
268    /// Pre-warm the HTTP connection pool by making a lightweight request.
269    ///
270    /// Establishes the TCP+TLS connection so the first real `generate()`
271    /// call doesn't pay the ~100-300ms handshake penalty. reqwest's
272    /// connection pool keeps it alive for subsequent calls.
273    async fn warm_up(&self) -> Result<(), LlmError> {
274        #[cfg(feature = "gemini-llm")]
275        {
276            use gemini_genai_rs::generate::GenerateContentConfig;
277            let config = GenerateContentConfig::from_text(".").max_output_tokens(1);
278            let _ = self.client.generate_content_with(config, None).await;
279        }
280        Ok(())
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn default_model_is_gemini_2_5_flash() {
290        let llm = GeminiLlm::new(GeminiLlmParams::default());
291        assert_eq!(llm.model_id(), "gemini-2.5-flash");
292    }
293
294    #[test]
295    fn explicit_model() {
296        let llm = GeminiLlm::new(GeminiLlmParams {
297            model: Some("gemini-2.0-pro".into()),
298            ..Default::default()
299        });
300        assert_eq!(llm.model_id(), "gemini-2.0-pro");
301    }
302
303    #[test]
304    fn variant_from_params_vertex() {
305        let llm = GeminiLlm::new(GeminiLlmParams {
306            vertexai: Some(true),
307            ..Default::default()
308        });
309        assert_eq!(llm.variant(), GoogleLlmVariant::VertexAi);
310    }
311
312    #[test]
313    fn variant_from_params_gemini_api() {
314        let llm = GeminiLlm::new(GeminiLlmParams {
315            vertexai: Some(false),
316            ..Default::default()
317        });
318        assert_eq!(llm.variant(), GoogleLlmVariant::GeminiApi);
319    }
320
321    #[test]
322    fn is_supported_gemini_models() {
323        assert!(GeminiLlm::is_supported("gemini-2.5-flash"));
324        assert!(GeminiLlm::is_supported("gemini-2.0-pro"));
325        assert!(GeminiLlm::is_supported("gemini-1.5-pro-001"));
326    }
327
328    #[test]
329    fn is_supported_non_gemini_models() {
330        assert!(!GeminiLlm::is_supported("gpt-4"));
331        assert!(!GeminiLlm::is_supported("claude-3-opus"));
332        assert!(!GeminiLlm::is_supported("llama-3"));
333    }
334
335    #[test]
336    fn is_supported_vertex_ai_resource_paths() {
337        assert!(GeminiLlm::is_supported(
338            "projects/my-project/endpoints/12345"
339        ));
340        assert!(GeminiLlm::is_supported(
341            "projects/my-project/models/gemini-2.5-flash"
342        ));
343    }
344
345    #[test]
346    fn model_id_returns_correct_string() {
347        let llm = GeminiLlm::new(GeminiLlmParams {
348            model: Some("gemini-2.5-flash-preview-04-17".into()),
349            ..Default::default()
350        });
351        assert_eq!(llm.model_id(), "gemini-2.5-flash-preview-04-17");
352    }
353
354    #[test]
355    fn base_llm_is_object_safe() {
356        fn _assert_object_safe(_: &dyn BaseLlm) {}
357    }
358
359    #[test]
360    fn gemini_llm_is_send_sync() {
361        fn _assert_send_sync<T: Send + Sync>() {}
362        _assert_send_sync::<GeminiLlm>();
363    }
364}