gemini_adk_rs/auth/
exchanger.rs1use std::collections::HashMap;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use super::credential::AuthCredential;
10use super::schemes::AuthScheme;
11
12#[derive(Debug, thiserror::Error)]
14pub enum CredentialExchangeError {
15 #[error("Exchange failed: {0}")]
17 ExchangeFailed(String),
18 #[error("No exchanger registered for scheme type: {0}")]
20 NoExchanger(String),
21 #[error("{0}")]
23 Other(String),
24}
25
26#[async_trait]
28pub trait CredentialExchanger: Send + Sync {
29 async fn exchange(
31 &self,
32 credential: &AuthCredential,
33 scheme: Option<&AuthScheme>,
34 ) -> Result<AuthCredential, CredentialExchangeError>;
35}
36
37pub struct CredentialExchangerRegistry {
39 exchangers: HashMap<String, Arc<dyn CredentialExchanger>>,
40}
41
42impl CredentialExchangerRegistry {
43 pub fn new() -> Self {
45 Self {
46 exchangers: HashMap::new(),
47 }
48 }
49
50 pub fn register(&mut self, scheme_type: &str, exchanger: Arc<dyn CredentialExchanger>) {
52 self.exchangers.insert(scheme_type.to_string(), exchanger);
53 }
54
55 pub async fn exchange(
57 &self,
58 credential: &AuthCredential,
59 scheme: &AuthScheme,
60 ) -> Result<AuthCredential, CredentialExchangeError> {
61 let scheme_type = match scheme {
62 AuthScheme::ApiKey { .. } => "apiKey",
63 AuthScheme::Http { .. } => "http",
64 AuthScheme::OAuth2 { .. } => "oauth2",
65 AuthScheme::OpenIdConnect { .. } => "openIdConnect",
66 };
67 let exchanger = self
68 .exchangers
69 .get(scheme_type)
70 .ok_or_else(|| CredentialExchangeError::NoExchanger(scheme_type.to_string()))?;
71 exchanger.exchange(credential, Some(scheme)).await
72 }
73}
74
75impl Default for CredentialExchangerRegistry {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::auth::credential::{AuthCredentialType, OAuth2Auth};
85 use crate::auth::schemes::OAuthGrantType;
86
87 struct MockExchanger;
89
90 #[async_trait]
91 impl CredentialExchanger for MockExchanger {
92 async fn exchange(
93 &self,
94 credential: &AuthCredential,
95 _scheme: Option<&AuthScheme>,
96 ) -> Result<AuthCredential, CredentialExchangeError> {
97 let mut result = credential.clone();
98 if let Some(ref mut oauth2) = result.oauth2 {
99 oauth2.access_token = Some("exchanged".into());
100 }
101 Ok(result)
102 }
103 }
104
105 fn test_credential() -> AuthCredential {
106 AuthCredential {
107 auth_type: AuthCredentialType::OAuth2,
108 resource_ref: None,
109 api_key: None,
110 http: None,
111 oauth2: Some(OAuth2Auth {
112 client_id: Some("client-123".into()),
113 client_secret: Some("secret-456".into()),
114 auth_uri: None,
115 token_uri: None,
116 redirect_uri: None,
117 auth_code: Some("auth-code-789".into()),
118 access_token: None,
119 refresh_token: None,
120 expires_at: None,
121 scopes: None,
122 auth_response_uri: None,
123 }),
124 service_account: None,
125 }
126 }
127
128 #[tokio::test]
129 async fn register_and_exchange_with_mock() {
130 let mut registry = CredentialExchangerRegistry::new();
131 registry.register("oauth2", Arc::new(MockExchanger));
132
133 let cred = test_credential();
134 let scheme = AuthScheme::OAuth2 {
135 grant_type: Some(OAuthGrantType::AuthorizationCode),
136 authorization_url: Some("https://example.com/authorize".into()),
137 token_url: Some("https://example.com/token".into()),
138 scopes: None,
139 };
140
141 let result = registry.exchange(&cred, &scheme).await.unwrap();
142 assert_eq!(
143 result.oauth2.as_ref().unwrap().access_token.as_deref(),
144 Some("exchanged")
145 );
146 assert_eq!(
148 result.oauth2.as_ref().unwrap().client_id.as_deref(),
149 Some("client-123")
150 );
151 }
152
153 #[tokio::test]
154 async fn exchange_unregistered_scheme_returns_error() {
155 let registry = CredentialExchangerRegistry::new();
156
157 let cred = test_credential();
158 let scheme = AuthScheme::ApiKey {
159 location: "header".into(),
160 name: "X-API-Key".into(),
161 };
162
163 let result = registry.exchange(&cred, &scheme).await;
164 assert!(result.is_err());
165 let err = result.unwrap_err();
166 match err {
167 CredentialExchangeError::NoExchanger(scheme_type) => {
168 assert_eq!(scheme_type, "apiKey");
169 }
170 _ => panic!("expected NoExchanger error"),
171 }
172 }
173
174 #[test]
175 fn credential_exchanger_trait_is_object_safe() {
176 fn _assert_object_safe(_: &dyn CredentialExchanger) {}
179 fn _assert_arc_object_safe(_: Arc<dyn CredentialExchanger>) {}
180 }
181
182 #[test]
183 fn default_registry_is_empty() {
184 let registry = CredentialExchangerRegistry::default();
185 assert!(registry.exchangers.is_empty());
186 }
187}