1pub mod gemini;
8pub mod registry;
9
10pub use gemini::{GeminiLlm, GeminiLlmParams};
11pub use registry::LlmRegistry;
12
13use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15
16use gemini_genai_rs::prelude::{Content, Part, Tool};
17
18pub trait TokenProvider: Send + Sync {
23 fn token(&self) -> String;
26}
27
28pub struct EnvTokenProvider;
30
31impl TokenProvider for EnvTokenProvider {
32 fn token(&self) -> String {
33 std::env::var("GOOGLE_ACCESS_TOKEN").unwrap_or_default()
34 }
35}
36
37pub struct GcloudTokenProvider {
40 cache: parking_lot::Mutex<(String, std::time::Instant)>,
41 ttl: std::time::Duration,
42}
43
44impl GcloudTokenProvider {
45 pub fn new(ttl: std::time::Duration) -> Self {
47 Self {
48 cache: parking_lot::Mutex::new((String::new(), std::time::Instant::now())),
49 ttl,
50 }
51 }
52}
53
54impl TokenProvider for GcloudTokenProvider {
55 fn token(&self) -> String {
56 let mut guard = self.cache.lock();
57 let (ref mut cached_token, ref mut fetched_at) = *guard;
58 if !cached_token.is_empty() && fetched_at.elapsed() < self.ttl {
59 return cached_token.clone();
60 }
61 match std::process::Command::new("gcloud")
63 .args(["auth", "print-access-token"])
64 .output()
65 {
66 Ok(output) if output.status.success() => {
67 let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
68 *cached_token = token.clone();
69 *fetched_at = std::time::Instant::now();
70 token
71 }
72 _ => {
73 std::env::var("GOOGLE_ACCESS_TOKEN").unwrap_or_default()
75 }
76 }
77 }
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82pub struct LlmRequest {
83 pub contents: Vec<Content>,
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub system_instruction: Option<String>,
88 #[serde(skip_serializing_if = "Vec::is_empty", default)]
90 pub tools: Vec<Tool>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub temperature: Option<f32>,
94 #[serde(skip_serializing_if = "Option::is_none")]
96 pub max_output_tokens: Option<u32>,
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub response_mime_type: Option<String>,
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub response_json_schema: Option<serde_json::Value>,
103}
104
105impl LlmRequest {
106 pub fn from_text(text: impl Into<String>) -> Self {
108 Self {
109 contents: vec![Content {
110 role: Some(gemini_genai_rs::prelude::Role::User),
111 parts: vec![Part::Text { text: text.into() }],
112 }],
113 ..Default::default()
114 }
115 }
116
117 pub fn from_contents(contents: Vec<Content>) -> Self {
119 Self {
120 contents,
121 ..Default::default()
122 }
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct LlmResponse {
129 pub content: Content,
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub finish_reason: Option<String>,
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub usage: Option<TokenUsage>,
137}
138
139impl LlmResponse {
140 pub fn text(&self) -> String {
142 self.content
143 .parts
144 .iter()
145 .filter_map(|p| match p {
146 Part::Text { text } => Some(text.as_str()),
147 _ => None,
148 })
149 .collect::<Vec<_>>()
150 .join("")
151 }
152
153 pub fn function_calls(&self) -> Vec<&gemini_genai_rs::prelude::FunctionCall> {
155 self.content
156 .parts
157 .iter()
158 .filter_map(|p| match p {
159 Part::FunctionCall { function_call } => Some(function_call),
160 _ => None,
161 })
162 .collect()
163 }
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct TokenUsage {
169 pub prompt_tokens: u32,
171 pub completion_tokens: u32,
173 pub total_tokens: u32,
175}
176
177#[derive(Debug, thiserror::Error)]
179pub enum LlmError {
180 #[error("LLM request failed: {0}")]
182 RequestFailed(String),
183 #[error("Model not available: {0}")]
185 ModelNotAvailable(String),
186 #[error("Rate limited")]
188 RateLimited,
189 #[error("Content filtered")]
191 ContentFiltered,
192 #[error("{0}")]
194 Other(String),
195}
196
197#[async_trait]
201pub trait BaseLlm: Send + Sync {
202 fn model_id(&self) -> &str;
204
205 async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError>;
207
208 async fn warm_up(&self) -> Result<(), LlmError> {
214 Ok(())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[test]
223 fn llm_request_from_text() {
224 let req = LlmRequest::from_text("Hello!");
225 assert_eq!(req.contents.len(), 1);
226 assert!(req.system_instruction.is_none());
227 assert!(req.tools.is_empty());
228 }
229
230 #[test]
231 fn llm_request_from_contents() {
232 let contents = vec![Content {
233 role: Some(gemini_genai_rs::prelude::Role::User),
234 parts: vec![Part::Text {
235 text: "Hello".into(),
236 }],
237 }];
238 let req = LlmRequest::from_contents(contents);
239 assert_eq!(req.contents.len(), 1);
240 }
241
242 #[test]
243 fn llm_response_text() {
244 let resp = LlmResponse {
245 content: Content {
246 role: Some(gemini_genai_rs::prelude::Role::Model),
247 parts: vec![
248 Part::Text {
249 text: "Hello ".into(),
250 },
251 Part::Text {
252 text: "world!".into(),
253 },
254 ],
255 },
256 finish_reason: Some("STOP".into()),
257 usage: None,
258 };
259 assert_eq!(resp.text(), "Hello world!");
260 }
261
262 #[test]
263 fn llm_response_function_calls() {
264 let resp = LlmResponse {
265 content: Content {
266 role: Some(gemini_genai_rs::prelude::Role::Model),
267 parts: vec![Part::FunctionCall {
268 function_call: gemini_genai_rs::prelude::FunctionCall {
269 name: "get_weather".into(),
270 args: serde_json::json!({"city": "London"}),
271 id: None,
272 },
273 }],
274 },
275 finish_reason: None,
276 usage: None,
277 };
278 let calls = resp.function_calls();
279 assert_eq!(calls.len(), 1);
280 assert_eq!(calls[0].name, "get_weather");
281 }
282
283 #[test]
284 fn base_llm_is_object_safe() {
285 fn _assert(_: &dyn BaseLlm) {}
286 }
287
288 #[test]
289 fn token_usage() {
290 let usage = TokenUsage {
291 prompt_tokens: 10,
292 completion_tokens: 20,
293 total_tokens: 30,
294 };
295 assert_eq!(usage.total_tokens, 30);
296 }
297}