gemini_adk_rs/auth/
exchanger.rs

1//! Credential exchanger — trait and registry for exchanging/transforming credentials
2//! (e.g. auth code to access token).
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use super::credential::AuthCredential;
10use super::schemes::AuthScheme;
11
12/// Error from credential exchange.
13#[derive(Debug, thiserror::Error)]
14pub enum CredentialExchangeError {
15    /// The credential exchange operation failed.
16    #[error("Exchange failed: {0}")]
17    ExchangeFailed(String),
18    /// No exchanger is registered for the given scheme type.
19    #[error("No exchanger registered for scheme type: {0}")]
20    NoExchanger(String),
21    /// A catch-all for other exchange errors.
22    #[error("{0}")]
23    Other(String),
24}
25
26/// Trait for exchanging/transforming credentials (e.g. auth code -> access token).
27#[async_trait]
28pub trait CredentialExchanger: Send + Sync {
29    /// Exchange or transform a credential (e.g., auth code to access token).
30    async fn exchange(
31        &self,
32        credential: &AuthCredential,
33        scheme: Option<&AuthScheme>,
34    ) -> Result<AuthCredential, CredentialExchangeError>;
35}
36
37/// Registry of credential exchangers, keyed by scheme type name.
38pub struct CredentialExchangerRegistry {
39    exchangers: HashMap<String, Arc<dyn CredentialExchanger>>,
40}
41
42impl CredentialExchangerRegistry {
43    /// Create a new empty exchanger registry.
44    pub fn new() -> Self {
45        Self {
46            exchangers: HashMap::new(),
47        }
48    }
49
50    /// Register an exchanger for a given scheme type (e.g., "oauth2").
51    pub fn register(&mut self, scheme_type: &str, exchanger: Arc<dyn CredentialExchanger>) {
52        self.exchangers.insert(scheme_type.to_string(), exchanger);
53    }
54
55    /// Exchange a credential using the appropriate registered exchanger.
56    pub async fn exchange(
57        &self,
58        credential: &AuthCredential,
59        scheme: &AuthScheme,
60    ) -> Result<AuthCredential, CredentialExchangeError> {
61        let scheme_type = match scheme {
62            AuthScheme::ApiKey { .. } => "apiKey",
63            AuthScheme::Http { .. } => "http",
64            AuthScheme::OAuth2 { .. } => "oauth2",
65            AuthScheme::OpenIdConnect { .. } => "openIdConnect",
66        };
67        let exchanger = self
68            .exchangers
69            .get(scheme_type)
70            .ok_or_else(|| CredentialExchangeError::NoExchanger(scheme_type.to_string()))?;
71        exchanger.exchange(credential, Some(scheme)).await
72    }
73}
74
75impl Default for CredentialExchangerRegistry {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::auth::credential::{AuthCredentialType, OAuth2Auth};
85    use crate::auth::schemes::OAuthGrantType;
86
87    /// Mock exchanger that just sets the access_token to "exchanged".
88    struct MockExchanger;
89
90    #[async_trait]
91    impl CredentialExchanger for MockExchanger {
92        async fn exchange(
93            &self,
94            credential: &AuthCredential,
95            _scheme: Option<&AuthScheme>,
96        ) -> Result<AuthCredential, CredentialExchangeError> {
97            let mut result = credential.clone();
98            if let Some(ref mut oauth2) = result.oauth2 {
99                oauth2.access_token = Some("exchanged".into());
100            }
101            Ok(result)
102        }
103    }
104
105    fn test_credential() -> AuthCredential {
106        AuthCredential {
107            auth_type: AuthCredentialType::OAuth2,
108            resource_ref: None,
109            api_key: None,
110            http: None,
111            oauth2: Some(OAuth2Auth {
112                client_id: Some("client-123".into()),
113                client_secret: Some("secret-456".into()),
114                auth_uri: None,
115                token_uri: None,
116                redirect_uri: None,
117                auth_code: Some("auth-code-789".into()),
118                access_token: None,
119                refresh_token: None,
120                expires_at: None,
121                scopes: None,
122                auth_response_uri: None,
123            }),
124            service_account: None,
125        }
126    }
127
128    #[tokio::test]
129    async fn register_and_exchange_with_mock() {
130        let mut registry = CredentialExchangerRegistry::new();
131        registry.register("oauth2", Arc::new(MockExchanger));
132
133        let cred = test_credential();
134        let scheme = AuthScheme::OAuth2 {
135            grant_type: Some(OAuthGrantType::AuthorizationCode),
136            authorization_url: Some("https://example.com/authorize".into()),
137            token_url: Some("https://example.com/token".into()),
138            scopes: None,
139        };
140
141        let result = registry.exchange(&cred, &scheme).await.unwrap();
142        assert_eq!(
143            result.oauth2.as_ref().unwrap().access_token.as_deref(),
144            Some("exchanged")
145        );
146        // Original fields preserved
147        assert_eq!(
148            result.oauth2.as_ref().unwrap().client_id.as_deref(),
149            Some("client-123")
150        );
151    }
152
153    #[tokio::test]
154    async fn exchange_unregistered_scheme_returns_error() {
155        let registry = CredentialExchangerRegistry::new();
156
157        let cred = test_credential();
158        let scheme = AuthScheme::ApiKey {
159            location: "header".into(),
160            name: "X-API-Key".into(),
161        };
162
163        let result = registry.exchange(&cred, &scheme).await;
164        assert!(result.is_err());
165        let err = result.unwrap_err();
166        match err {
167            CredentialExchangeError::NoExchanger(scheme_type) => {
168                assert_eq!(scheme_type, "apiKey");
169            }
170            _ => panic!("expected NoExchanger error"),
171        }
172    }
173
174    #[test]
175    fn credential_exchanger_trait_is_object_safe() {
176        // This test verifies that the trait can be used as a trait object.
177        // If CredentialExchanger is not object-safe, this will fail to compile.
178        fn _assert_object_safe(_: &dyn CredentialExchanger) {}
179        fn _assert_arc_object_safe(_: Arc<dyn CredentialExchanger>) {}
180    }
181
182    #[test]
183    fn default_registry_is_empty() {
184        let registry = CredentialExchangerRegistry::default();
185        assert!(registry.exchangers.is_empty());
186    }
187}