gemini_genai_rs/transport/auth/
mod.rs

1//! Authentication providers for Gemini API connections.
2//!
3//! This module defines the [`AuthProvider`] trait and built-in implementations
4//! for Google AI (API key and OAuth2 token) and Vertex AI (Bearer token).
5//!
6//! The [`ServiceEndpoint`] enum allows constructing URLs for both WebSocket (Live)
7//! and REST API endpoints from the same auth provider.
8
9pub mod google_ai;
10pub(crate) mod url_builders;
11pub mod vertex;
12
13pub use google_ai::*;
14pub use vertex::*;
15
16use async_trait::async_trait;
17
18use crate::protocol::types::GeminiModel;
19use crate::session::AuthError;
20
21/// Identifies which Gemini API service to connect to.
22///
23/// Used by `AuthProvider::rest_url` to construct the correct REST endpoint URL.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum ServiceEndpoint {
26    /// WebSocket Live/Bidi streaming endpoint.
27    LiveWs,
28    /// POST /models/{model}:generateContent
29    GenerateContent,
30    /// POST /models/{model}:streamGenerateContent
31    StreamGenerateContent,
32    /// POST /models/{model}:embedContent
33    EmbedContent,
34    /// POST /models/{model}:countTokens
35    CountTokens,
36    /// POST /models/{model}:computeTokens
37    ComputeTokens,
38    /// GET /models
39    ListModels,
40    /// GET /models/{model}
41    GetModel,
42    /// Files CRUD (upload, get, list, delete)
43    Files,
44    /// Cached content CRUD
45    CachedContents,
46    /// Tuning jobs CRUD
47    TuningJobs,
48    /// Batch jobs CRUD
49    BatchJobs,
50}
51
52impl ServiceEndpoint {
53    /// REST method suffix appended to the model path (e.g., `:generateContent`).
54    /// Returns `None` for endpoints that don't use a model suffix.
55    pub fn model_method(&self) -> Option<&'static str> {
56        match self {
57            Self::GenerateContent => Some("generateContent"),
58            Self::StreamGenerateContent => Some("streamGenerateContent"),
59            Self::EmbedContent => Some("embedContent"),
60            Self::CountTokens => Some("countTokens"),
61            Self::ComputeTokens => Some("computeTokens"),
62            _ => None,
63        }
64    }
65
66    /// Whether this endpoint requires a model ID in the path.
67    pub fn requires_model(&self) -> bool {
68        matches!(
69            self,
70            Self::GenerateContent
71                | Self::StreamGenerateContent
72                | Self::EmbedContent
73                | Self::CountTokens
74                | Self::ComputeTokens
75                | Self::GetModel
76        )
77    }
78}
79
80/// Provides authentication credentials and URL construction for Gemini API connections.
81#[async_trait]
82pub trait AuthProvider: Send + Sync + 'static {
83    /// Build the WebSocket URL for the given model.
84    fn ws_url(&self, model: &GeminiModel) -> String;
85
86    /// HTTP headers for the WebSocket upgrade request (e.g., Bearer token).
87    async fn auth_headers(&self) -> Result<Vec<(String, String)>, AuthError>;
88
89    /// Query parameters to append to the URL (e.g., API key).
90    fn query_params(&self) -> Vec<(String, String)> {
91        vec![]
92    }
93
94    /// Called on auth failure to allow token refresh. Default: no-op.
95    async fn refresh(&self) -> Result<(), AuthError> {
96        Ok(())
97    }
98}
99
100/// Auth providers that additionally support REST endpoint URL construction.
101///
102/// Split from [`AuthProvider`] so that Live-only providers are not forced to
103/// implement REST URL building, and so the REST [`Client`](crate::client::Client)
104/// can require it at the type level — replacing the previous runtime
105/// `unimplemented!()` default with a compile-time guarantee.
106pub trait RestAuth: AuthProvider {
107    /// Build a REST API URL for the given service endpoint and model.
108    fn rest_url(&self, endpoint: ServiceEndpoint, model: Option<&GeminiModel>) -> String;
109}
110
111// ---------------------------------------------------------------------------
112// Tests
113// ---------------------------------------------------------------------------
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::protocol::types::GeminiModel;
119
120    #[test]
121    fn google_ai_auth_url() {
122        let auth = GoogleAIAuth::new("test-key-123");
123        let url = auth.ws_url(&GeminiModel::default());
124        assert!(url.contains("generativelanguage.googleapis.com"));
125        assert!(url.contains("v1beta"));
126        assert!(url.contains("key=test-key-123"));
127    }
128
129    #[test]
130    fn google_ai_auth_query_params() {
131        let auth = GoogleAIAuth::new("my-api-key");
132        let params = auth.query_params();
133        assert_eq!(params.len(), 1);
134        assert_eq!(params[0].0, "key");
135        assert_eq!(params[0].1, "my-api-key");
136    }
137
138    #[tokio::test]
139    async fn google_ai_auth_headers_empty() {
140        let auth = GoogleAIAuth::new("test-key");
141        let headers = auth.auth_headers().await.unwrap();
142        assert!(headers.is_empty());
143    }
144
145    #[test]
146    fn google_ai_token_auth_url() {
147        let auth = GoogleAITokenAuth::new("oauth2-token-abc");
148        let url = auth.ws_url(&GeminiModel::default());
149        assert!(url.contains("generativelanguage.googleapis.com"));
150        assert!(url.contains("access_token=oauth2-token-abc"));
151        assert!(url.contains("v1alpha"));
152    }
153
154    #[test]
155    fn vertex_ai_auth_url_regional() {
156        let auth = VertexAIAuth::new("my-project", "us-central1", "token");
157        let url = auth.ws_url(&GeminiModel::default());
158        assert!(url.contains("us-central1-aiplatform.googleapis.com"));
159        assert!(url.contains("v1beta1"));
160        assert!(url.contains("x-goog-project-id=my-project"));
161    }
162
163    #[test]
164    fn vertex_ai_auth_url_global() {
165        let auth = VertexAIAuth::new("my-project", "global", "token");
166        let url = auth.ws_url(&GeminiModel::default());
167        // Global uses aiplatform.googleapis.com without location prefix.
168        assert!(url.starts_with("wss://aiplatform.googleapis.com/"));
169        assert!(!url.contains("global-aiplatform"));
170    }
171
172    #[tokio::test]
173    async fn vertex_ai_auth_headers() {
174        let auth = VertexAIAuth::new("proj", "us-central1", "my-bearer-token");
175        let headers = auth.auth_headers().await.unwrap();
176        assert_eq!(headers.len(), 1);
177        assert_eq!(headers[0].0, "Authorization");
178        assert_eq!(headers[0].1, "Bearer my-bearer-token");
179    }
180
181    #[test]
182    fn vertex_ai_auth_url_contains_model() {
183        let auth = VertexAIAuth::new("proj", "us-central1", "tok");
184        let url = auth.ws_url(&GeminiModel::Gemini2_0FlashLive);
185        assert!(url.contains("model=gemini-2.0-flash-live-001"));
186    }
187
188    #[test]
189    fn auth_provider_is_object_safe() {
190        fn _assert(_: &dyn AuthProvider) {}
191    }
192
193    #[tokio::test]
194    async fn default_refresh_is_noop() {
195        let auth = GoogleAIAuth::new("key");
196        // Should succeed without error.
197        auth.refresh().await.unwrap();
198    }
199
200    #[tokio::test]
201    async fn default_query_params_empty_for_vertex() {
202        let auth = VertexAIAuth::new("proj", "loc", "tok");
203        let params = auth.query_params();
204        assert!(params.is_empty());
205    }
206
207    // -----------------------------------------------------------------------
208    // REST URL tests
209    // -----------------------------------------------------------------------
210
211    #[test]
212    fn google_ai_rest_url_generate_content() {
213        let auth = GoogleAIAuth::new("test-key");
214        let model = GeminiModel::Gemini2_0FlashLive;
215        let url = auth.rest_url(ServiceEndpoint::GenerateContent, Some(&model));
216        assert!(url.starts_with("https://generativelanguage.googleapis.com/v1beta/"));
217        assert!(url.contains(":generateContent"));
218        assert!(url.contains("key=test-key"));
219    }
220
221    #[test]
222    fn google_ai_rest_url_list_models() {
223        let auth = GoogleAIAuth::new("key123");
224        let url = auth.rest_url(ServiceEndpoint::ListModels, None);
225        assert!(url.contains("/models?key=key123"));
226    }
227
228    #[test]
229    fn google_ai_rest_url_files() {
230        let auth = GoogleAIAuth::new("key");
231        let url = auth.rest_url(ServiceEndpoint::Files, None);
232        assert!(url.contains("/files?key=key"));
233    }
234
235    #[test]
236    fn google_ai_token_rest_url_no_key_in_url() {
237        let auth = GoogleAITokenAuth::new("oauth-token");
238        let url = auth.rest_url(ServiceEndpoint::CountTokens, Some(&GeminiModel::default()));
239        assert!(url.contains(":countTokens"));
240        assert!(!url.contains("key="));
241        assert!(!url.contains("access_token="));
242    }
243
244    #[test]
245    fn vertex_rest_url_generate_content() {
246        let auth = VertexAIAuth::new("my-project", "us-central1", "token");
247        let model = GeminiModel::Gemini2_0FlashLive;
248        let url = auth.rest_url(ServiceEndpoint::GenerateContent, Some(&model));
249        assert!(url.starts_with("https://us-central1-aiplatform.googleapis.com/v1beta1/"));
250        assert!(url.contains("projects/my-project/locations/us-central1"));
251        assert!(url.contains(":generateContent"));
252    }
253
254    #[test]
255    fn vertex_rest_url_list_models() {
256        let auth = VertexAIAuth::new("proj", "us-east1", "tok");
257        let url = auth.rest_url(ServiceEndpoint::ListModels, None);
258        assert!(url.contains("publishers/google/models"));
259    }
260
261    #[test]
262    fn vertex_rest_url_global() {
263        let auth = VertexAIAuth::new("proj", "global", "tok");
264        let model = GeminiModel::default();
265        let url = auth.rest_url(ServiceEndpoint::EmbedContent, Some(&model));
266        assert!(url.starts_with("https://aiplatform.googleapis.com/"));
267        assert!(!url.contains("global-aiplatform"));
268        assert!(url.contains(":embedContent"));
269    }
270
271    #[test]
272    fn service_endpoint_model_method() {
273        assert_eq!(
274            ServiceEndpoint::GenerateContent.model_method(),
275            Some("generateContent")
276        );
277        assert_eq!(
278            ServiceEndpoint::StreamGenerateContent.model_method(),
279            Some("streamGenerateContent")
280        );
281        assert_eq!(ServiceEndpoint::ListModels.model_method(), None);
282        assert_eq!(ServiceEndpoint::Files.model_method(), None);
283    }
284
285    #[tokio::test]
286    async fn vertex_ai_refreshable_token() {
287        use std::sync::atomic::{AtomicU32, Ordering};
288        let counter = std::sync::Arc::new(AtomicU32::new(0));
289        let c = counter.clone();
290        let auth = VertexAIAuth::with_token_refresher("proj", "us-central1", move || {
291            c.fetch_add(1, Ordering::SeqCst);
292            format!("token-{}", c.load(Ordering::SeqCst))
293        });
294        let h1 = auth.auth_headers().await.unwrap();
295        assert!(h1[0].1.starts_with("Bearer token-"));
296        let h2 = auth.auth_headers().await.unwrap();
297        assert!(h2[0].1.starts_with("Bearer token-"));
298        // Refresher called twice (once per auth_headers)
299        assert_eq!(counter.load(Ordering::SeqCst), 2);
300    }
301
302    #[test]
303    fn service_endpoint_requires_model() {
304        assert!(ServiceEndpoint::GenerateContent.requires_model());
305        assert!(ServiceEndpoint::CountTokens.requires_model());
306        assert!(!ServiceEndpoint::ListModels.requires_model());
307        assert!(!ServiceEndpoint::Files.requires_model());
308    }
309}