gemini_genai_rs/transport/auth/
vertex.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::protocol::types::GeminiModel;
8use crate::session::AuthError;
9
10use super::url_builders::build_vertex_rest_url;
11use super::{AuthProvider, ServiceEndpoint};
12
13enum TokenSource {
15 Fixed(parking_lot::Mutex<String>),
18 Refreshable(Arc<dyn Fn() -> String + Send + Sync>),
22}
23
24pub struct VertexAIAuth {
36 project: String,
37 location: String,
38 token_source: TokenSource,
39}
40
41impl VertexAIAuth {
42 pub fn new(
47 project: impl Into<String>,
48 location: impl Into<String>,
49 token: impl Into<String>,
50 ) -> Self {
51 Self {
52 project: project.into(),
53 location: location.into(),
54 token_source: TokenSource::Fixed(parking_lot::Mutex::new(token.into())),
55 }
56 }
57
58 pub fn with_token_refresher(
64 project: impl Into<String>,
65 location: impl Into<String>,
66 refresher: impl Fn() -> String + Send + Sync + 'static,
67 ) -> Self {
68 Self {
69 project: project.into(),
70 location: location.into(),
71 token_source: TokenSource::Refreshable(Arc::new(refresher)),
72 }
73 }
74}
75
76#[async_trait]
77impl AuthProvider for VertexAIAuth {
78 fn ws_url(&self, model: &GeminiModel) -> String {
79 let host = if self.location == "global" {
80 "aiplatform.googleapis.com".to_string()
81 } else {
82 format!("{}-aiplatform.googleapis.com", self.location)
83 };
84 let model_id = model.to_string().trim_start_matches("models/").to_string();
85 format!(
86 "wss://{host}/ws/google.cloud.aiplatform.v1beta1.LlmBidiService/BidiGenerateContent\
87 ?alt=json\
88 &x-goog-project-id={project}\
89 &model={model_id}",
90 host = host,
91 project = self.project,
92 model_id = model_id,
93 )
94 }
95
96 fn rest_url(&self, endpoint: ServiceEndpoint, model: Option<&GeminiModel>) -> String {
97 let host = if self.location == "global" {
98 "aiplatform.googleapis.com".to_string()
99 } else {
100 format!("{}-aiplatform.googleapis.com", self.location)
101 };
102 build_vertex_rest_url(&host, &self.project, &self.location, endpoint, model)
103 }
104
105 async fn auth_headers(&self) -> Result<Vec<(String, String)>, AuthError> {
106 let token = match &self.token_source {
107 TokenSource::Fixed(m) => m.lock().clone(),
108 TokenSource::Refreshable(f) => f(),
109 };
110 Ok(vec![(
111 "Authorization".to_string(),
112 format!("Bearer {token}"),
113 )])
114 }
115}