gemini_genai_rs/transport/auth/
mod.rs

1//! Authentication providers for Gemini API connections.
2//!
3//! This module defines the [`AuthProvider`] trait and built-in implementations
4//! for Google AI (API key and OAuth2 token) and Vertex AI (Bearer token).
5//!
6//! The [`ServiceEndpoint`] enum allows constructing URLs for both WebSocket (Live)
7//! and REST API endpoints from the same auth provider.
8
9pub mod google_ai;
10pub(crate) mod url_builders;
11pub mod vertex;
12
13pub use google_ai::*;
14pub use vertex::*;
15
16use async_trait::async_trait;
17
18use crate::protocol::types::GeminiModel;
19use crate::session::AuthError;
20
21/// Identifies which Gemini API service to connect to.
22///
23/// Used by [`AuthProvider::rest_url`] to construct the correct REST endpoint URL.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum ServiceEndpoint {
26    /// WebSocket Live/Bidi streaming endpoint.
27    LiveWs,
28    /// POST /models/{model}:generateContent
29    GenerateContent,
30    /// POST /models/{model}:streamGenerateContent
31    StreamGenerateContent,
32    /// POST /models/{model}:embedContent
33    EmbedContent,
34    /// POST /models/{model}:countTokens
35    CountTokens,
36    /// POST /models/{model}:computeTokens
37    ComputeTokens,
38    /// GET /models
39    ListModels,
40    /// GET /models/{model}
41    GetModel,
42    /// Files CRUD (upload, get, list, delete)
43    Files,
44    /// Cached content CRUD
45    CachedContents,
46    /// Tuning jobs CRUD
47    TuningJobs,
48    /// Batch jobs CRUD
49    BatchJobs,
50}
51
52impl ServiceEndpoint {
53    /// REST method suffix appended to the model path (e.g., `:generateContent`).
54    /// Returns `None` for endpoints that don't use a model suffix.
55    pub fn model_method(&self) -> Option<&'static str> {
56        match self {
57            Self::GenerateContent => Some("generateContent"),
58            Self::StreamGenerateContent => Some("streamGenerateContent"),
59            Self::EmbedContent => Some("embedContent"),
60            Self::CountTokens => Some("countTokens"),
61            Self::ComputeTokens => Some("computeTokens"),
62            _ => None,
63        }
64    }
65
66    /// Whether this endpoint requires a model ID in the path.
67    pub fn requires_model(&self) -> bool {
68        matches!(
69            self,
70            Self::GenerateContent
71                | Self::StreamGenerateContent
72                | Self::EmbedContent
73                | Self::CountTokens
74                | Self::ComputeTokens
75                | Self::GetModel
76        )
77    }
78}
79
80/// Provides authentication credentials and URL construction for Gemini API connections.
81#[async_trait]
82pub trait AuthProvider: Send + Sync + 'static {
83    /// Build the WebSocket URL for the given model.
84    fn ws_url(&self, model: &GeminiModel) -> String;
85
86    /// Build a REST API URL for the given service endpoint and model.
87    ///
88    /// Default implementation panics — override when using HTTP client features.
89    fn rest_url(&self, endpoint: ServiceEndpoint, model: Option<&GeminiModel>) -> String {
90        let _ = (endpoint, model);
91        unimplemented!("REST URLs require a concrete auth provider (GoogleAIAuth or VertexAIAuth)")
92    }
93
94    /// HTTP headers for the WebSocket upgrade request (e.g., Bearer token).
95    async fn auth_headers(&self) -> Result<Vec<(String, String)>, AuthError>;
96
97    /// Query parameters to append to the URL (e.g., API key).
98    fn query_params(&self) -> Vec<(String, String)> {
99        vec![]
100    }
101
102    /// Called on auth failure to allow token refresh. Default: no-op.
103    async fn refresh(&self) -> Result<(), AuthError> {
104        Ok(())
105    }
106}
107
108// ---------------------------------------------------------------------------
109// Tests
110// ---------------------------------------------------------------------------
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use crate::protocol::types::GeminiModel;
116
117    #[test]
118    fn google_ai_auth_url() {
119        let auth = GoogleAIAuth::new("test-key-123");
120        let url = auth.ws_url(&GeminiModel::default());
121        assert!(url.contains("generativelanguage.googleapis.com"));
122        assert!(url.contains("v1beta"));
123        assert!(url.contains("key=test-key-123"));
124    }
125
126    #[test]
127    fn google_ai_auth_query_params() {
128        let auth = GoogleAIAuth::new("my-api-key");
129        let params = auth.query_params();
130        assert_eq!(params.len(), 1);
131        assert_eq!(params[0].0, "key");
132        assert_eq!(params[0].1, "my-api-key");
133    }
134
135    #[tokio::test]
136    async fn google_ai_auth_headers_empty() {
137        let auth = GoogleAIAuth::new("test-key");
138        let headers = auth.auth_headers().await.unwrap();
139        assert!(headers.is_empty());
140    }
141
142    #[test]
143    fn google_ai_token_auth_url() {
144        let auth = GoogleAITokenAuth::new("oauth2-token-abc");
145        let url = auth.ws_url(&GeminiModel::default());
146        assert!(url.contains("generativelanguage.googleapis.com"));
147        assert!(url.contains("access_token=oauth2-token-abc"));
148        assert!(url.contains("v1alpha"));
149    }
150
151    #[test]
152    fn vertex_ai_auth_url_regional() {
153        let auth = VertexAIAuth::new("my-project", "us-central1", "token");
154        let url = auth.ws_url(&GeminiModel::default());
155        assert!(url.contains("us-central1-aiplatform.googleapis.com"));
156        assert!(url.contains("v1beta1"));
157        assert!(url.contains("x-goog-project-id=my-project"));
158    }
159
160    #[test]
161    fn vertex_ai_auth_url_global() {
162        let auth = VertexAIAuth::new("my-project", "global", "token");
163        let url = auth.ws_url(&GeminiModel::default());
164        // Global uses aiplatform.googleapis.com without location prefix.
165        assert!(url.starts_with("wss://aiplatform.googleapis.com/"));
166        assert!(!url.contains("global-aiplatform"));
167    }
168
169    #[tokio::test]
170    async fn vertex_ai_auth_headers() {
171        let auth = VertexAIAuth::new("proj", "us-central1", "my-bearer-token");
172        let headers = auth.auth_headers().await.unwrap();
173        assert_eq!(headers.len(), 1);
174        assert_eq!(headers[0].0, "Authorization");
175        assert_eq!(headers[0].1, "Bearer my-bearer-token");
176    }
177
178    #[test]
179    fn vertex_ai_auth_url_contains_model() {
180        let auth = VertexAIAuth::new("proj", "us-central1", "tok");
181        let url = auth.ws_url(&GeminiModel::Gemini2_0FlashLive);
182        assert!(url.contains("model=gemini-2.0-flash-live-001"));
183    }
184
185    #[test]
186    fn auth_provider_is_object_safe() {
187        fn _assert(_: &dyn AuthProvider) {}
188    }
189
190    #[tokio::test]
191    async fn default_refresh_is_noop() {
192        let auth = GoogleAIAuth::new("key");
193        // Should succeed without error.
194        auth.refresh().await.unwrap();
195    }
196
197    #[tokio::test]
198    async fn default_query_params_empty_for_vertex() {
199        let auth = VertexAIAuth::new("proj", "loc", "tok");
200        let params = auth.query_params();
201        assert!(params.is_empty());
202    }
203
204    // -----------------------------------------------------------------------
205    // REST URL tests
206    // -----------------------------------------------------------------------
207
208    #[test]
209    fn google_ai_rest_url_generate_content() {
210        let auth = GoogleAIAuth::new("test-key");
211        let model = GeminiModel::Gemini2_0FlashLive;
212        let url = auth.rest_url(ServiceEndpoint::GenerateContent, Some(&model));
213        assert!(url.starts_with("https://generativelanguage.googleapis.com/v1beta/"));
214        assert!(url.contains(":generateContent"));
215        assert!(url.contains("key=test-key"));
216    }
217
218    #[test]
219    fn google_ai_rest_url_list_models() {
220        let auth = GoogleAIAuth::new("key123");
221        let url = auth.rest_url(ServiceEndpoint::ListModels, None);
222        assert!(url.contains("/models?key=key123"));
223    }
224
225    #[test]
226    fn google_ai_rest_url_files() {
227        let auth = GoogleAIAuth::new("key");
228        let url = auth.rest_url(ServiceEndpoint::Files, None);
229        assert!(url.contains("/files?key=key"));
230    }
231
232    #[test]
233    fn google_ai_token_rest_url_no_key_in_url() {
234        let auth = GoogleAITokenAuth::new("oauth-token");
235        let url = auth.rest_url(ServiceEndpoint::CountTokens, Some(&GeminiModel::default()));
236        assert!(url.contains(":countTokens"));
237        assert!(!url.contains("key="));
238        assert!(!url.contains("access_token="));
239    }
240
241    #[test]
242    fn vertex_rest_url_generate_content() {
243        let auth = VertexAIAuth::new("my-project", "us-central1", "token");
244        let model = GeminiModel::Gemini2_0FlashLive;
245        let url = auth.rest_url(ServiceEndpoint::GenerateContent, Some(&model));
246        assert!(url.starts_with("https://us-central1-aiplatform.googleapis.com/v1beta1/"));
247        assert!(url.contains("projects/my-project/locations/us-central1"));
248        assert!(url.contains(":generateContent"));
249    }
250
251    #[test]
252    fn vertex_rest_url_list_models() {
253        let auth = VertexAIAuth::new("proj", "us-east1", "tok");
254        let url = auth.rest_url(ServiceEndpoint::ListModels, None);
255        assert!(url.contains("publishers/google/models"));
256    }
257
258    #[test]
259    fn vertex_rest_url_global() {
260        let auth = VertexAIAuth::new("proj", "global", "tok");
261        let model = GeminiModel::default();
262        let url = auth.rest_url(ServiceEndpoint::EmbedContent, Some(&model));
263        assert!(url.starts_with("https://aiplatform.googleapis.com/"));
264        assert!(!url.contains("global-aiplatform"));
265        assert!(url.contains(":embedContent"));
266    }
267
268    #[test]
269    fn service_endpoint_model_method() {
270        assert_eq!(
271            ServiceEndpoint::GenerateContent.model_method(),
272            Some("generateContent")
273        );
274        assert_eq!(
275            ServiceEndpoint::StreamGenerateContent.model_method(),
276            Some("streamGenerateContent")
277        );
278        assert_eq!(ServiceEndpoint::ListModels.model_method(), None);
279        assert_eq!(ServiceEndpoint::Files.model_method(), None);
280    }
281
282    #[tokio::test]
283    async fn vertex_ai_refreshable_token() {
284        use std::sync::atomic::{AtomicU32, Ordering};
285        let counter = std::sync::Arc::new(AtomicU32::new(0));
286        let c = counter.clone();
287        let auth = VertexAIAuth::with_token_refresher("proj", "us-central1", move || {
288            c.fetch_add(1, Ordering::SeqCst);
289            format!("token-{}", c.load(Ordering::SeqCst))
290        });
291        let h1 = auth.auth_headers().await.unwrap();
292        assert!(h1[0].1.starts_with("Bearer token-"));
293        let h2 = auth.auth_headers().await.unwrap();
294        assert!(h2[0].1.starts_with("Bearer token-"));
295        // Refresher called twice (once per auth_headers)
296        assert_eq!(counter.load(Ordering::SeqCst), 2);
297    }
298
299    #[test]
300    fn service_endpoint_requires_model() {
301        assert!(ServiceEndpoint::GenerateContent.requires_model());
302        assert!(ServiceEndpoint::CountTokens.requires_model());
303        assert!(!ServiceEndpoint::ListModels.requires_model());
304        assert!(!ServiceEndpoint::Files.requires_model());
305    }
306}