gemini_genai_rs/transport/auth/
mod.rs1pub mod google_ai;
10pub(crate) mod url_builders;
11pub mod vertex;
12
13pub use google_ai::*;
14pub use vertex::*;
15
16use async_trait::async_trait;
17
18use crate::protocol::types::GeminiModel;
19use crate::session::AuthError;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum ServiceEndpoint {
26 LiveWs,
28 GenerateContent,
30 StreamGenerateContent,
32 EmbedContent,
34 CountTokens,
36 ComputeTokens,
38 ListModels,
40 GetModel,
42 Files,
44 CachedContents,
46 TuningJobs,
48 BatchJobs,
50}
51
52impl ServiceEndpoint {
53 pub fn model_method(&self) -> Option<&'static str> {
56 match self {
57 Self::GenerateContent => Some("generateContent"),
58 Self::StreamGenerateContent => Some("streamGenerateContent"),
59 Self::EmbedContent => Some("embedContent"),
60 Self::CountTokens => Some("countTokens"),
61 Self::ComputeTokens => Some("computeTokens"),
62 _ => None,
63 }
64 }
65
66 pub fn requires_model(&self) -> bool {
68 matches!(
69 self,
70 Self::GenerateContent
71 | Self::StreamGenerateContent
72 | Self::EmbedContent
73 | Self::CountTokens
74 | Self::ComputeTokens
75 | Self::GetModel
76 )
77 }
78}
79
80#[async_trait]
82pub trait AuthProvider: Send + Sync + 'static {
83 fn ws_url(&self, model: &GeminiModel) -> String;
85
86 fn rest_url(&self, endpoint: ServiceEndpoint, model: Option<&GeminiModel>) -> String {
90 let _ = (endpoint, model);
91 unimplemented!("REST URLs require a concrete auth provider (GoogleAIAuth or VertexAIAuth)")
92 }
93
94 async fn auth_headers(&self) -> Result<Vec<(String, String)>, AuthError>;
96
97 fn query_params(&self) -> Vec<(String, String)> {
99 vec![]
100 }
101
102 async fn refresh(&self) -> Result<(), AuthError> {
104 Ok(())
105 }
106}
107
108#[cfg(test)]
113mod tests {
114 use super::*;
115 use crate::protocol::types::GeminiModel;
116
117 #[test]
118 fn google_ai_auth_url() {
119 let auth = GoogleAIAuth::new("test-key-123");
120 let url = auth.ws_url(&GeminiModel::default());
121 assert!(url.contains("generativelanguage.googleapis.com"));
122 assert!(url.contains("v1beta"));
123 assert!(url.contains("key=test-key-123"));
124 }
125
126 #[test]
127 fn google_ai_auth_query_params() {
128 let auth = GoogleAIAuth::new("my-api-key");
129 let params = auth.query_params();
130 assert_eq!(params.len(), 1);
131 assert_eq!(params[0].0, "key");
132 assert_eq!(params[0].1, "my-api-key");
133 }
134
135 #[tokio::test]
136 async fn google_ai_auth_headers_empty() {
137 let auth = GoogleAIAuth::new("test-key");
138 let headers = auth.auth_headers().await.unwrap();
139 assert!(headers.is_empty());
140 }
141
142 #[test]
143 fn google_ai_token_auth_url() {
144 let auth = GoogleAITokenAuth::new("oauth2-token-abc");
145 let url = auth.ws_url(&GeminiModel::default());
146 assert!(url.contains("generativelanguage.googleapis.com"));
147 assert!(url.contains("access_token=oauth2-token-abc"));
148 assert!(url.contains("v1alpha"));
149 }
150
151 #[test]
152 fn vertex_ai_auth_url_regional() {
153 let auth = VertexAIAuth::new("my-project", "us-central1", "token");
154 let url = auth.ws_url(&GeminiModel::default());
155 assert!(url.contains("us-central1-aiplatform.googleapis.com"));
156 assert!(url.contains("v1beta1"));
157 assert!(url.contains("x-goog-project-id=my-project"));
158 }
159
160 #[test]
161 fn vertex_ai_auth_url_global() {
162 let auth = VertexAIAuth::new("my-project", "global", "token");
163 let url = auth.ws_url(&GeminiModel::default());
164 assert!(url.starts_with("wss://aiplatform.googleapis.com/"));
166 assert!(!url.contains("global-aiplatform"));
167 }
168
169 #[tokio::test]
170 async fn vertex_ai_auth_headers() {
171 let auth = VertexAIAuth::new("proj", "us-central1", "my-bearer-token");
172 let headers = auth.auth_headers().await.unwrap();
173 assert_eq!(headers.len(), 1);
174 assert_eq!(headers[0].0, "Authorization");
175 assert_eq!(headers[0].1, "Bearer my-bearer-token");
176 }
177
178 #[test]
179 fn vertex_ai_auth_url_contains_model() {
180 let auth = VertexAIAuth::new("proj", "us-central1", "tok");
181 let url = auth.ws_url(&GeminiModel::Gemini2_0FlashLive);
182 assert!(url.contains("model=gemini-2.0-flash-live-001"));
183 }
184
185 #[test]
186 fn auth_provider_is_object_safe() {
187 fn _assert(_: &dyn AuthProvider) {}
188 }
189
190 #[tokio::test]
191 async fn default_refresh_is_noop() {
192 let auth = GoogleAIAuth::new("key");
193 auth.refresh().await.unwrap();
195 }
196
197 #[tokio::test]
198 async fn default_query_params_empty_for_vertex() {
199 let auth = VertexAIAuth::new("proj", "loc", "tok");
200 let params = auth.query_params();
201 assert!(params.is_empty());
202 }
203
204 #[test]
209 fn google_ai_rest_url_generate_content() {
210 let auth = GoogleAIAuth::new("test-key");
211 let model = GeminiModel::Gemini2_0FlashLive;
212 let url = auth.rest_url(ServiceEndpoint::GenerateContent, Some(&model));
213 assert!(url.starts_with("https://generativelanguage.googleapis.com/v1beta/"));
214 assert!(url.contains(":generateContent"));
215 assert!(url.contains("key=test-key"));
216 }
217
218 #[test]
219 fn google_ai_rest_url_list_models() {
220 let auth = GoogleAIAuth::new("key123");
221 let url = auth.rest_url(ServiceEndpoint::ListModels, None);
222 assert!(url.contains("/models?key=key123"));
223 }
224
225 #[test]
226 fn google_ai_rest_url_files() {
227 let auth = GoogleAIAuth::new("key");
228 let url = auth.rest_url(ServiceEndpoint::Files, None);
229 assert!(url.contains("/files?key=key"));
230 }
231
232 #[test]
233 fn google_ai_token_rest_url_no_key_in_url() {
234 let auth = GoogleAITokenAuth::new("oauth-token");
235 let url = auth.rest_url(ServiceEndpoint::CountTokens, Some(&GeminiModel::default()));
236 assert!(url.contains(":countTokens"));
237 assert!(!url.contains("key="));
238 assert!(!url.contains("access_token="));
239 }
240
241 #[test]
242 fn vertex_rest_url_generate_content() {
243 let auth = VertexAIAuth::new("my-project", "us-central1", "token");
244 let model = GeminiModel::Gemini2_0FlashLive;
245 let url = auth.rest_url(ServiceEndpoint::GenerateContent, Some(&model));
246 assert!(url.starts_with("https://us-central1-aiplatform.googleapis.com/v1beta1/"));
247 assert!(url.contains("projects/my-project/locations/us-central1"));
248 assert!(url.contains(":generateContent"));
249 }
250
251 #[test]
252 fn vertex_rest_url_list_models() {
253 let auth = VertexAIAuth::new("proj", "us-east1", "tok");
254 let url = auth.rest_url(ServiceEndpoint::ListModels, None);
255 assert!(url.contains("publishers/google/models"));
256 }
257
258 #[test]
259 fn vertex_rest_url_global() {
260 let auth = VertexAIAuth::new("proj", "global", "tok");
261 let model = GeminiModel::default();
262 let url = auth.rest_url(ServiceEndpoint::EmbedContent, Some(&model));
263 assert!(url.starts_with("https://aiplatform.googleapis.com/"));
264 assert!(!url.contains("global-aiplatform"));
265 assert!(url.contains(":embedContent"));
266 }
267
268 #[test]
269 fn service_endpoint_model_method() {
270 assert_eq!(
271 ServiceEndpoint::GenerateContent.model_method(),
272 Some("generateContent")
273 );
274 assert_eq!(
275 ServiceEndpoint::StreamGenerateContent.model_method(),
276 Some("streamGenerateContent")
277 );
278 assert_eq!(ServiceEndpoint::ListModels.model_method(), None);
279 assert_eq!(ServiceEndpoint::Files.model_method(), None);
280 }
281
282 #[tokio::test]
283 async fn vertex_ai_refreshable_token() {
284 use std::sync::atomic::{AtomicU32, Ordering};
285 let counter = std::sync::Arc::new(AtomicU32::new(0));
286 let c = counter.clone();
287 let auth = VertexAIAuth::with_token_refresher("proj", "us-central1", move || {
288 c.fetch_add(1, Ordering::SeqCst);
289 format!("token-{}", c.load(Ordering::SeqCst))
290 });
291 let h1 = auth.auth_headers().await.unwrap();
292 assert!(h1[0].1.starts_with("Bearer token-"));
293 let h2 = auth.auth_headers().await.unwrap();
294 assert!(h2[0].1.starts_with("Bearer token-"));
295 assert_eq!(counter.load(Ordering::SeqCst), 2);
297 }
298
299 #[test]
300 fn service_endpoint_requires_model() {
301 assert!(ServiceEndpoint::GenerateContent.requires_model());
302 assert!(ServiceEndpoint::CountTokens.requires_model());
303 assert!(!ServiceEndpoint::ListModels.requires_model());
304 assert!(!ServiceEndpoint::Files.requires_model());
305 }
306}