gemini_genai_rs/transport/auth/
vertex.rs

1//! Vertex AI Bearer token authentication.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use crate::protocol::types::GeminiModel;
8use crate::session::AuthError;
9
10use super::url_builders::build_vertex_rest_url;
11use super::{AuthProvider, ServiceEndpoint};
12
13/// How a [`VertexAIAuth`] resolves its Bearer token.
14enum TokenSource {
15    /// Fixed token string — used for WebSocket connections where the token
16    /// is only needed once at connect time.
17    Fixed(parking_lot::Mutex<String>),
18    /// Dynamic token refresher — called on every `auth_headers()` invocation.
19    /// Used for HTTP REST calls (e.g., generate) where the token must remain
20    /// valid across many requests over a long session.
21    Refreshable(Arc<dyn Fn() -> String + Send + Sync>),
22}
23
24/// Vertex AI Bearer token authentication.
25///
26/// Uses a project/location pair to construct the Vertex AI WebSocket URL,
27/// and a Bearer token for the `Authorization` header.
28///
29/// Supports two token modes:
30/// - **Fixed** ([`new`](Self::new)): token is set once at construction.
31///   Best for WebSocket connections where the token is only needed at connect time.
32/// - **Refreshable** ([`with_token_refresher`](Self::with_token_refresher)):
33///   a closure is called on every `auth_headers()` invocation, ensuring fresh
34///   tokens for long-running HTTP clients (e.g., generate API calls).
35pub struct VertexAIAuth {
36    project: String,
37    location: String,
38    token_source: TokenSource,
39}
40
41impl VertexAIAuth {
42    /// Create a new Vertex AI auth provider with a fixed token.
43    ///
44    /// The token is stored and reused for all requests. Use this for
45    /// WebSocket connections where the token is only needed at connect time.
46    pub fn new(
47        project: impl Into<String>,
48        location: impl Into<String>,
49        token: impl Into<String>,
50    ) -> Self {
51        Self {
52            project: project.into(),
53            location: location.into(),
54            token_source: TokenSource::Fixed(parking_lot::Mutex::new(token.into())),
55        }
56    }
57
58    /// Create a Vertex AI auth provider with a dynamic token refresher.
59    ///
60    /// The `refresher` closure is called on every `auth_headers()` invocation,
61    /// allowing token refresh for long-running HTTP clients. The closure
62    /// should handle caching internally to avoid unnecessary overhead.
63    pub fn with_token_refresher(
64        project: impl Into<String>,
65        location: impl Into<String>,
66        refresher: impl Fn() -> String + Send + Sync + 'static,
67    ) -> Self {
68        Self {
69            project: project.into(),
70            location: location.into(),
71            token_source: TokenSource::Refreshable(Arc::new(refresher)),
72        }
73    }
74}
75
76#[async_trait]
77impl AuthProvider for VertexAIAuth {
78    fn ws_url(&self, model: &GeminiModel) -> String {
79        let host = if self.location == "global" {
80            "aiplatform.googleapis.com".to_string()
81        } else {
82            format!("{}-aiplatform.googleapis.com", self.location)
83        };
84        let model_id = model.to_string().trim_start_matches("models/").to_string();
85        format!(
86            "wss://{host}/ws/google.cloud.aiplatform.v1beta1.LlmBidiService/BidiGenerateContent\
87             ?alt=json\
88             &x-goog-project-id={project}\
89             &model={model_id}",
90            host = host,
91            project = self.project,
92            model_id = model_id,
93        )
94    }
95
96    fn rest_url(&self, endpoint: ServiceEndpoint, model: Option<&GeminiModel>) -> String {
97        let host = if self.location == "global" {
98            "aiplatform.googleapis.com".to_string()
99        } else {
100            format!("{}-aiplatform.googleapis.com", self.location)
101        };
102        build_vertex_rest_url(&host, &self.project, &self.location, endpoint, model)
103    }
104
105    async fn auth_headers(&self) -> Result<Vec<(String, String)>, AuthError> {
106        let token = match &self.token_source {
107            TokenSource::Fixed(m) => m.lock().clone(),
108            TokenSource::Refreshable(f) => f(),
109        };
110        Ok(vec![(
111            "Authorization".to_string(),
112            format!("Bearer {token}"),
113        )])
114    }
115}