gemini_adk_rs/llm/
mod.rs

1//! LLM abstraction — decouples agents from specific model providers.
2//!
3//! The `BaseLlm` trait provides a unified interface for generating content
4//! from any LLM. The `GeminiLlm` implementation wraps gemini-live's `Client`
5//! for Gemini models.
6
7pub mod gemini;
8pub mod registry;
9
10pub use gemini::{GeminiLlm, GeminiLlmParams};
11pub use registry::LlmRegistry;
12
13use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15
16use gemini_genai_rs::prelude::{Content, Part, Tool};
17
18/// Provides access tokens for VertexAI authentication.
19///
20/// Implement this trait to supply dynamically refreshed tokens.
21/// The default implementation reads `GOOGLE_ACCESS_TOKEN` from the environment.
22pub trait TokenProvider: Send + Sync {
23    /// Return a valid access token. Called before each `generate()` request
24    /// when using VertexAI variant.
25    fn token(&self) -> String;
26}
27
28/// Default token provider — reads `GOOGLE_ACCESS_TOKEN` from environment.
29pub struct EnvTokenProvider;
30
31impl TokenProvider for EnvTokenProvider {
32    fn token(&self) -> String {
33        std::env::var("GOOGLE_ACCESS_TOKEN").unwrap_or_default()
34    }
35}
36
37/// Token provider that shells out to `gcloud auth print-access-token`,
38/// caching the result with a configurable TTL.
39pub struct GcloudTokenProvider {
40    cache: parking_lot::Mutex<(String, std::time::Instant)>,
41    ttl: std::time::Duration,
42}
43
44impl GcloudTokenProvider {
45    /// Create a new provider with the given cache TTL (recommended: 45 minutes).
46    pub fn new(ttl: std::time::Duration) -> Self {
47        Self {
48            cache: parking_lot::Mutex::new((String::new(), std::time::Instant::now())),
49            ttl,
50        }
51    }
52}
53
54impl TokenProvider for GcloudTokenProvider {
55    fn token(&self) -> String {
56        let mut guard = self.cache.lock();
57        let (ref mut cached_token, ref mut fetched_at) = *guard;
58        if !cached_token.is_empty() && fetched_at.elapsed() < self.ttl {
59            return cached_token.clone();
60        }
61        // Shell out to gcloud
62        match std::process::Command::new("gcloud")
63            .args(["auth", "print-access-token"])
64            .output()
65        {
66            Ok(output) if output.status.success() => {
67                let token = String::from_utf8_lossy(&output.stdout).trim().to_string();
68                *cached_token = token.clone();
69                *fetched_at = std::time::Instant::now();
70                token
71            }
72            _ => {
73                // Fall back to env var
74                std::env::var("GOOGLE_ACCESS_TOKEN").unwrap_or_default()
75            }
76        }
77    }
78}
79
80/// Configuration for an LLM generation request.
81#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82pub struct LlmRequest {
83    /// The messages/contents to send.
84    pub contents: Vec<Content>,
85    /// System instruction.
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub system_instruction: Option<String>,
88    /// Available tools.
89    #[serde(skip_serializing_if = "Vec::is_empty", default)]
90    pub tools: Vec<Tool>,
91    /// Temperature for generation.
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub temperature: Option<f32>,
94    /// Maximum output tokens.
95    #[serde(skip_serializing_if = "Option::is_none")]
96    pub max_output_tokens: Option<u32>,
97    /// MIME type for structured output (e.g., `"application/json"`).
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub response_mime_type: Option<String>,
100    /// JSON Schema for structured output. Requires `response_mime_type = "application/json"`.
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub response_json_schema: Option<serde_json::Value>,
103}
104
105impl LlmRequest {
106    /// Create a request from a single user message.
107    pub fn from_text(text: impl Into<String>) -> Self {
108        Self {
109            contents: vec![Content {
110                role: Some(gemini_genai_rs::prelude::Role::User),
111                parts: vec![Part::Text { text: text.into() }],
112            }],
113            ..Default::default()
114        }
115    }
116
117    /// Create a request from existing contents.
118    pub fn from_contents(contents: Vec<Content>) -> Self {
119        Self {
120            contents,
121            ..Default::default()
122        }
123    }
124}
125
126/// The response from an LLM generation request.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct LlmResponse {
129    /// The generated content.
130    pub content: Content,
131    /// Finish reason (if available).
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub finish_reason: Option<String>,
134    /// Token usage (if available).
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub usage: Option<TokenUsage>,
137}
138
139impl LlmResponse {
140    /// Extract text from the response, concatenating all text parts.
141    pub fn text(&self) -> String {
142        self.content
143            .parts
144            .iter()
145            .filter_map(|p| match p {
146                Part::Text { text } => Some(text.as_str()),
147                _ => None,
148            })
149            .collect::<Vec<_>>()
150            .join("")
151    }
152
153    /// Extract function calls from the response.
154    pub fn function_calls(&self) -> Vec<&gemini_genai_rs::prelude::FunctionCall> {
155        self.content
156            .parts
157            .iter()
158            .filter_map(|p| match p {
159                Part::FunctionCall { function_call } => Some(function_call),
160                _ => None,
161            })
162            .collect()
163    }
164}
165
166/// Token usage statistics.
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct TokenUsage {
169    /// Input/prompt tokens.
170    pub prompt_tokens: u32,
171    /// Output/completion tokens.
172    pub completion_tokens: u32,
173    /// Total tokens.
174    pub total_tokens: u32,
175}
176
177/// Errors from LLM operations.
178#[derive(Debug, thiserror::Error)]
179pub enum LlmError {
180    /// The HTTP request to the LLM API failed.
181    #[error("LLM request failed: {0}")]
182    RequestFailed(String),
183    /// The requested model is not available.
184    #[error("Model not available: {0}")]
185    ModelNotAvailable(String),
186    /// The request was rate-limited by the provider.
187    #[error("Rate limited")]
188    RateLimited,
189    /// The response was filtered by content safety.
190    #[error("Content filtered")]
191    ContentFiltered,
192    /// A catch-all for other LLM errors.
193    #[error("{0}")]
194    Other(String),
195}
196
197/// Trait for LLM providers — decouples agents from specific models.
198///
199/// Implementations must be `Send + Sync` for use across async tasks.
200#[async_trait]
201pub trait BaseLlm: Send + Sync {
202    /// The model identifier (e.g., "gemini-2.5-flash").
203    fn model_id(&self) -> &str;
204
205    /// Generate content from the LLM.
206    async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError>;
207
208    /// Pre-warm the HTTP connection pool to avoid cold-start latency.
209    ///
210    /// The default implementation is a no-op. `GeminiLlm` overrides this to
211    /// establish the TCP+TLS connection so the first real `generate()` call
212    /// doesn't pay the ~100-300ms handshake penalty.
213    async fn warm_up(&self) -> Result<(), LlmError> {
214        Ok(())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn llm_request_from_text() {
224        let req = LlmRequest::from_text("Hello!");
225        assert_eq!(req.contents.len(), 1);
226        assert!(req.system_instruction.is_none());
227        assert!(req.tools.is_empty());
228    }
229
230    #[test]
231    fn llm_request_from_contents() {
232        let contents = vec![Content {
233            role: Some(gemini_genai_rs::prelude::Role::User),
234            parts: vec![Part::Text {
235                text: "Hello".into(),
236            }],
237        }];
238        let req = LlmRequest::from_contents(contents);
239        assert_eq!(req.contents.len(), 1);
240    }
241
242    #[test]
243    fn llm_response_text() {
244        let resp = LlmResponse {
245            content: Content {
246                role: Some(gemini_genai_rs::prelude::Role::Model),
247                parts: vec![
248                    Part::Text {
249                        text: "Hello ".into(),
250                    },
251                    Part::Text {
252                        text: "world!".into(),
253                    },
254                ],
255            },
256            finish_reason: Some("STOP".into()),
257            usage: None,
258        };
259        assert_eq!(resp.text(), "Hello world!");
260    }
261
262    #[test]
263    fn llm_response_function_calls() {
264        let resp = LlmResponse {
265            content: Content {
266                role: Some(gemini_genai_rs::prelude::Role::Model),
267                parts: vec![Part::FunctionCall {
268                    function_call: gemini_genai_rs::prelude::FunctionCall {
269                        name: "get_weather".into(),
270                        args: serde_json::json!({"city": "London"}),
271                        id: None,
272                    },
273                }],
274            },
275            finish_reason: None,
276            usage: None,
277        };
278        let calls = resp.function_calls();
279        assert_eq!(calls.len(), 1);
280        assert_eq!(calls[0].name, "get_weather");
281    }
282
283    #[test]
284    fn base_llm_is_object_safe() {
285        fn _assert(_: &dyn BaseLlm) {}
286    }
287
288    #[test]
289    fn token_usage() {
290        let usage = TokenUsage {
291            prompt_tokens: 10,
292            completion_tokens: 20,
293            total_tokens: 30,
294        };
295        assert_eq!(usage.total_tokens, 30);
296    }
297}