gemini_genai_rs/transport/auth/
mod.rs1pub mod google_ai;
10pub(crate) mod url_builders;
11pub mod vertex;
12
13pub use google_ai::*;
14pub use vertex::*;
15
16use async_trait::async_trait;
17
18use crate::protocol::types::GeminiModel;
19use crate::session::AuthError;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub enum ServiceEndpoint {
26 LiveWs,
28 GenerateContent,
30 StreamGenerateContent,
32 EmbedContent,
34 CountTokens,
36 ComputeTokens,
38 ListModels,
40 GetModel,
42 Files,
44 CachedContents,
46 TuningJobs,
48 BatchJobs,
50}
51
52impl ServiceEndpoint {
53 pub fn model_method(&self) -> Option<&'static str> {
56 match self {
57 Self::GenerateContent => Some("generateContent"),
58 Self::StreamGenerateContent => Some("streamGenerateContent"),
59 Self::EmbedContent => Some("embedContent"),
60 Self::CountTokens => Some("countTokens"),
61 Self::ComputeTokens => Some("computeTokens"),
62 _ => None,
63 }
64 }
65
66 pub fn requires_model(&self) -> bool {
68 matches!(
69 self,
70 Self::GenerateContent
71 | Self::StreamGenerateContent
72 | Self::EmbedContent
73 | Self::CountTokens
74 | Self::ComputeTokens
75 | Self::GetModel
76 )
77 }
78}
79
80#[async_trait]
82pub trait AuthProvider: Send + Sync + 'static {
83 fn ws_url(&self, model: &GeminiModel) -> String;
85
86 async fn auth_headers(&self) -> Result<Vec<(String, String)>, AuthError>;
88
89 fn query_params(&self) -> Vec<(String, String)> {
91 vec![]
92 }
93
94 async fn refresh(&self) -> Result<(), AuthError> {
96 Ok(())
97 }
98}
99
100pub trait RestAuth: AuthProvider {
107 fn rest_url(&self, endpoint: ServiceEndpoint, model: Option<&GeminiModel>) -> String;
109}
110
111#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::protocol::types::GeminiModel;
119
120 #[test]
121 fn google_ai_auth_url() {
122 let auth = GoogleAIAuth::new("test-key-123");
123 let url = auth.ws_url(&GeminiModel::default());
124 assert!(url.contains("generativelanguage.googleapis.com"));
125 assert!(url.contains("v1beta"));
126 assert!(url.contains("key=test-key-123"));
127 }
128
129 #[test]
130 fn google_ai_auth_query_params() {
131 let auth = GoogleAIAuth::new("my-api-key");
132 let params = auth.query_params();
133 assert_eq!(params.len(), 1);
134 assert_eq!(params[0].0, "key");
135 assert_eq!(params[0].1, "my-api-key");
136 }
137
138 #[tokio::test]
139 async fn google_ai_auth_headers_empty() {
140 let auth = GoogleAIAuth::new("test-key");
141 let headers = auth.auth_headers().await.unwrap();
142 assert!(headers.is_empty());
143 }
144
145 #[test]
146 fn google_ai_token_auth_url() {
147 let auth = GoogleAITokenAuth::new("oauth2-token-abc");
148 let url = auth.ws_url(&GeminiModel::default());
149 assert!(url.contains("generativelanguage.googleapis.com"));
150 assert!(url.contains("access_token=oauth2-token-abc"));
151 assert!(url.contains("v1alpha"));
152 }
153
154 #[test]
155 fn vertex_ai_auth_url_regional() {
156 let auth = VertexAIAuth::new("my-project", "us-central1", "token");
157 let url = auth.ws_url(&GeminiModel::default());
158 assert!(url.contains("us-central1-aiplatform.googleapis.com"));
159 assert!(url.contains("v1beta1"));
160 assert!(url.contains("x-goog-project-id=my-project"));
161 }
162
163 #[test]
164 fn vertex_ai_auth_url_global() {
165 let auth = VertexAIAuth::new("my-project", "global", "token");
166 let url = auth.ws_url(&GeminiModel::default());
167 assert!(url.starts_with("wss://aiplatform.googleapis.com/"));
169 assert!(!url.contains("global-aiplatform"));
170 }
171
172 #[tokio::test]
173 async fn vertex_ai_auth_headers() {
174 let auth = VertexAIAuth::new("proj", "us-central1", "my-bearer-token");
175 let headers = auth.auth_headers().await.unwrap();
176 assert_eq!(headers.len(), 1);
177 assert_eq!(headers[0].0, "Authorization");
178 assert_eq!(headers[0].1, "Bearer my-bearer-token");
179 }
180
181 #[test]
182 fn vertex_ai_auth_url_contains_model() {
183 let auth = VertexAIAuth::new("proj", "us-central1", "tok");
184 let url = auth.ws_url(&GeminiModel::Gemini2_0FlashLive);
185 assert!(url.contains("model=gemini-2.0-flash-live-001"));
186 }
187
188 #[test]
189 fn auth_provider_is_object_safe() {
190 fn _assert(_: &dyn AuthProvider) {}
191 }
192
193 #[tokio::test]
194 async fn default_refresh_is_noop() {
195 let auth = GoogleAIAuth::new("key");
196 auth.refresh().await.unwrap();
198 }
199
200 #[tokio::test]
201 async fn default_query_params_empty_for_vertex() {
202 let auth = VertexAIAuth::new("proj", "loc", "tok");
203 let params = auth.query_params();
204 assert!(params.is_empty());
205 }
206
207 #[test]
212 fn google_ai_rest_url_generate_content() {
213 let auth = GoogleAIAuth::new("test-key");
214 let model = GeminiModel::Gemini2_0FlashLive;
215 let url = auth.rest_url(ServiceEndpoint::GenerateContent, Some(&model));
216 assert!(url.starts_with("https://generativelanguage.googleapis.com/v1beta/"));
217 assert!(url.contains(":generateContent"));
218 assert!(url.contains("key=test-key"));
219 }
220
221 #[test]
222 fn google_ai_rest_url_list_models() {
223 let auth = GoogleAIAuth::new("key123");
224 let url = auth.rest_url(ServiceEndpoint::ListModels, None);
225 assert!(url.contains("/models?key=key123"));
226 }
227
228 #[test]
229 fn google_ai_rest_url_files() {
230 let auth = GoogleAIAuth::new("key");
231 let url = auth.rest_url(ServiceEndpoint::Files, None);
232 assert!(url.contains("/files?key=key"));
233 }
234
235 #[test]
236 fn google_ai_token_rest_url_no_key_in_url() {
237 let auth = GoogleAITokenAuth::new("oauth-token");
238 let url = auth.rest_url(ServiceEndpoint::CountTokens, Some(&GeminiModel::default()));
239 assert!(url.contains(":countTokens"));
240 assert!(!url.contains("key="));
241 assert!(!url.contains("access_token="));
242 }
243
244 #[test]
245 fn vertex_rest_url_generate_content() {
246 let auth = VertexAIAuth::new("my-project", "us-central1", "token");
247 let model = GeminiModel::Gemini2_0FlashLive;
248 let url = auth.rest_url(ServiceEndpoint::GenerateContent, Some(&model));
249 assert!(url.starts_with("https://us-central1-aiplatform.googleapis.com/v1beta1/"));
250 assert!(url.contains("projects/my-project/locations/us-central1"));
251 assert!(url.contains(":generateContent"));
252 }
253
254 #[test]
255 fn vertex_rest_url_list_models() {
256 let auth = VertexAIAuth::new("proj", "us-east1", "tok");
257 let url = auth.rest_url(ServiceEndpoint::ListModels, None);
258 assert!(url.contains("publishers/google/models"));
259 }
260
261 #[test]
262 fn vertex_rest_url_global() {
263 let auth = VertexAIAuth::new("proj", "global", "tok");
264 let model = GeminiModel::default();
265 let url = auth.rest_url(ServiceEndpoint::EmbedContent, Some(&model));
266 assert!(url.starts_with("https://aiplatform.googleapis.com/"));
267 assert!(!url.contains("global-aiplatform"));
268 assert!(url.contains(":embedContent"));
269 }
270
271 #[test]
272 fn service_endpoint_model_method() {
273 assert_eq!(
274 ServiceEndpoint::GenerateContent.model_method(),
275 Some("generateContent")
276 );
277 assert_eq!(
278 ServiceEndpoint::StreamGenerateContent.model_method(),
279 Some("streamGenerateContent")
280 );
281 assert_eq!(ServiceEndpoint::ListModels.model_method(), None);
282 assert_eq!(ServiceEndpoint::Files.model_method(), None);
283 }
284
285 #[tokio::test]
286 async fn vertex_ai_refreshable_token() {
287 use std::sync::atomic::{AtomicU32, Ordering};
288 let counter = std::sync::Arc::new(AtomicU32::new(0));
289 let c = counter.clone();
290 let auth = VertexAIAuth::with_token_refresher("proj", "us-central1", move || {
291 c.fetch_add(1, Ordering::SeqCst);
292 format!("token-{}", c.load(Ordering::SeqCst))
293 });
294 let h1 = auth.auth_headers().await.unwrap();
295 assert!(h1[0].1.starts_with("Bearer token-"));
296 let h2 = auth.auth_headers().await.unwrap();
297 assert!(h2[0].1.starts_with("Bearer token-"));
298 assert_eq!(counter.load(Ordering::SeqCst), 2);
300 }
301
302 #[test]
303 fn service_endpoint_requires_model() {
304 assert!(ServiceEndpoint::GenerateContent.requires_model());
305 assert!(ServiceEndpoint::CountTokens.requires_model());
306 assert!(!ServiceEndpoint::ListModels.requires_model());
307 assert!(!ServiceEndpoint::Files.requires_model());
308 }
309}