gemini_adk_rs/llm/
gemini.rs1use std::collections::HashMap;
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use once_cell::sync::Lazy;
12use regex::Regex;
13
14#[cfg(feature = "gemini-llm")]
15use crate::llm::TokenUsage;
16use crate::llm::{
17 BaseLlm, EnvTokenProvider, GcloudTokenProvider, LlmError, LlmRequest, LlmResponse,
18 TokenProvider,
19};
20use crate::utils::variant::{get_google_llm_variant, GoogleLlmVariant};
21
22#[derive(Default)]
24pub struct GeminiLlmParams {
25 pub model: Option<String>,
27 pub api_key: Option<String>,
29 pub vertexai: Option<bool>,
31 pub project: Option<String>,
33 pub location: Option<String>,
35 pub headers: Option<HashMap<String, String>>,
37 pub token_provider: Option<Arc<dyn TokenProvider>>,
39}
40
41pub struct GeminiLlm {
47 model: String,
48 variant: GoogleLlmVariant,
49 #[allow(dead_code)]
51 params: GeminiLlmParams,
52 #[allow(dead_code)]
54 token_provider: Arc<dyn TokenProvider>,
55 #[cfg(feature = "gemini-llm")]
57 client: gemini_genai_rs::prelude::Client,
58}
59
60static SUPPORTED_PATTERNS: Lazy<Vec<Regex>> = Lazy::new(|| {
61 vec![
62 Regex::new(r"^gemini-.*$").unwrap(),
63 Regex::new(r"^projects/.*/endpoints/.*$").unwrap(),
64 Regex::new(r"^projects/.*/models/gemini.*$").unwrap(),
65 ]
66});
67
68impl GeminiLlm {
69 pub fn new(mut params: GeminiLlmParams) -> Self {
75 let model = params
77 .model
78 .clone()
79 .unwrap_or_else(|| "gemini-2.5-flash".to_string());
80
81 let variant = if let Some(true) = params.vertexai {
83 GoogleLlmVariant::VertexAi
84 } else if let Some(false) = params.vertexai {
85 GoogleLlmVariant::GeminiApi
86 } else {
87 get_google_llm_variant()
88 };
89
90 if params.api_key.is_none() && variant == GoogleLlmVariant::GeminiApi {
92 params.api_key = std::env::var("GOOGLE_GENAI_API_KEY")
93 .or_else(|_| std::env::var("GEMINI_API_KEY"))
94 .ok();
95 }
96
97 if variant == GoogleLlmVariant::VertexAi {
99 if params.project.is_none() {
100 params.project = std::env::var("GOOGLE_CLOUD_PROJECT").ok();
101 }
102 if params.location.is_none() {
103 params.location = std::env::var("GOOGLE_CLOUD_LOCATION").ok();
104 }
105 }
106
107 let token_provider: Arc<dyn TokenProvider> =
112 params.token_provider.take().unwrap_or_else(|| {
113 if variant == GoogleLlmVariant::VertexAi {
114 Arc::new(GcloudTokenProvider::new(std::time::Duration::from_secs(
115 45 * 60,
116 )))
117 } else {
118 Arc::new(EnvTokenProvider)
119 }
120 });
121
122 #[cfg(feature = "gemini-llm")]
127 let client = {
128 use gemini_genai_rs::prelude::*;
129 match variant {
130 GoogleLlmVariant::GeminiApi => {
131 let api_key = params.api_key.as_deref().unwrap_or("");
132 Client::from_api_key(api_key).model(GeminiModel::Custom(model.clone()))
133 }
134 GoogleLlmVariant::VertexAi => {
135 let project = params.project.as_deref().unwrap_or("").to_string();
136 let location = params
137 .location
138 .as_deref()
139 .unwrap_or("us-central1")
140 .to_string();
141 let tp = token_provider.clone();
142 Client::from_vertex_refreshable(project, location, move || tp.token())
143 .model(GeminiModel::Custom(model.clone()))
144 }
145 }
146 };
147
148 Self {
149 model,
150 variant,
151 params,
152 token_provider,
153 #[cfg(feature = "gemini-llm")]
154 client,
155 }
156 }
157
158 pub fn is_supported(model: &str) -> bool {
160 SUPPORTED_PATTERNS.iter().any(|re| re.is_match(model))
161 }
162
163 pub fn variant(&self) -> GoogleLlmVariant {
165 self.variant
166 }
167
168 fn preprocess_request(&self, _request: &mut LlmRequest) {
170 }
174}
175
176#[async_trait]
177impl BaseLlm for GeminiLlm {
178 fn model_id(&self) -> &str {
179 &self.model
180 }
181
182 async fn generate(&self, mut request: LlmRequest) -> Result<LlmResponse, LlmError> {
183 self.preprocess_request(&mut request);
184
185 #[cfg(feature = "gemini-llm")]
187 {
188 use gemini_genai_rs::generate::GenerateContentConfig;
189 use gemini_genai_rs::prelude::*;
190
191 let mut config = if request.contents.is_empty() {
193 GenerateContentConfig::from_text("")
194 } else {
195 GenerateContentConfig::from_contents(std::mem::take(&mut request.contents))
196 };
197
198 if let Some(sys) = request.system_instruction.take() {
199 config = config.system_instruction(&sys);
200 }
201 if !request.tools.is_empty() {
202 config.tools = std::mem::take(&mut request.tools);
203 }
204 if let Some(temp) = request.temperature {
205 config = config.temperature(temp);
206 }
207 if let Some(max) = request.max_output_tokens {
208 config = config.max_output_tokens(max);
209 }
210 if request.response_mime_type.is_some() || request.response_json_schema.is_some() {
211 let gc = config
212 .generation_config
213 .get_or_insert_with(gemini_genai_rs::prelude::GenerationConfig::default);
214 if let Some(mime) = request.response_mime_type.take() {
215 gc.response_mime_type = Some(mime);
216 }
217 if let Some(schema) = request.response_json_schema.take() {
218 gc.response_json_schema = Some(schema);
219 }
220 }
221
222 let response = self
223 .client
224 .generate_content_with(config, None)
225 .await
226 .map_err(|e| LlmError::RequestFailed(e.to_string()))?;
227
228 let content = response
229 .candidates
230 .first()
231 .and_then(|c| c.content.clone())
232 .unwrap_or_else(|| Content {
233 role: Some(Role::Model),
234 parts: vec![],
235 });
236
237 let finish_reason = response
238 .candidates
239 .first()
240 .and_then(|c| c.finish_reason)
241 .map(|r| format!("{:?}", r));
242
243 let usage = response.usage_metadata.map(|u| TokenUsage {
244 prompt_tokens: u.prompt_token_count.unwrap_or(0),
245 completion_tokens: u.response_token_count.unwrap_or(0),
246 total_tokens: u.total_token_count.unwrap_or(0),
247 });
248
249 Ok(LlmResponse {
250 content,
251 finish_reason,
252 usage,
253 })
254 }
255
256 #[cfg(not(feature = "gemini-llm"))]
257 {
258 let _ = request;
260 Err(LlmError::RequestFailed(
261 "GeminiLlm requires the 'gemini-llm' feature flag \
262 (depends on gemini-live HTTP client)"
263 .into(),
264 ))
265 }
266 }
267
268 async fn warm_up(&self) -> Result<(), LlmError> {
274 #[cfg(feature = "gemini-llm")]
275 {
276 use gemini_genai_rs::generate::GenerateContentConfig;
277 let config = GenerateContentConfig::from_text(".").max_output_tokens(1);
278 let _ = self.client.generate_content_with(config, None).await;
279 }
280 Ok(())
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn default_model_is_gemini_2_5_flash() {
290 let llm = GeminiLlm::new(GeminiLlmParams::default());
291 assert_eq!(llm.model_id(), "gemini-2.5-flash");
292 }
293
294 #[test]
295 fn explicit_model() {
296 let llm = GeminiLlm::new(GeminiLlmParams {
297 model: Some("gemini-2.0-pro".into()),
298 ..Default::default()
299 });
300 assert_eq!(llm.model_id(), "gemini-2.0-pro");
301 }
302
303 #[test]
304 fn variant_from_params_vertex() {
305 let llm = GeminiLlm::new(GeminiLlmParams {
306 vertexai: Some(true),
307 ..Default::default()
308 });
309 assert_eq!(llm.variant(), GoogleLlmVariant::VertexAi);
310 }
311
312 #[test]
313 fn variant_from_params_gemini_api() {
314 let llm = GeminiLlm::new(GeminiLlmParams {
315 vertexai: Some(false),
316 ..Default::default()
317 });
318 assert_eq!(llm.variant(), GoogleLlmVariant::GeminiApi);
319 }
320
321 #[test]
322 fn is_supported_gemini_models() {
323 assert!(GeminiLlm::is_supported("gemini-2.5-flash"));
324 assert!(GeminiLlm::is_supported("gemini-2.0-pro"));
325 assert!(GeminiLlm::is_supported("gemini-1.5-pro-001"));
326 }
327
328 #[test]
329 fn is_supported_non_gemini_models() {
330 assert!(!GeminiLlm::is_supported("gpt-4"));
331 assert!(!GeminiLlm::is_supported("claude-3-opus"));
332 assert!(!GeminiLlm::is_supported("llama-3"));
333 }
334
335 #[test]
336 fn is_supported_vertex_ai_resource_paths() {
337 assert!(GeminiLlm::is_supported(
338 "projects/my-project/endpoints/12345"
339 ));
340 assert!(GeminiLlm::is_supported(
341 "projects/my-project/models/gemini-2.5-flash"
342 ));
343 }
344
345 #[test]
346 fn model_id_returns_correct_string() {
347 let llm = GeminiLlm::new(GeminiLlmParams {
348 model: Some("gemini-2.5-flash-preview-04-17".into()),
349 ..Default::default()
350 });
351 assert_eq!(llm.model_id(), "gemini-2.5-flash-preview-04-17");
352 }
353
354 #[test]
355 fn base_llm_is_object_safe() {
356 fn _assert_object_safe(_: &dyn BaseLlm) {}
357 }
358
359 #[test]
360 fn gemini_llm_is_send_sync() {
361 fn _assert_send_sync<T: Send + Sync>() {}
362 _assert_send_sync::<GeminiLlm>();
363 }
364}