gemini_adk_rs/llm/
registry.rs

1//! LLM registry — pattern-based resolution of LLM providers.
2//!
3//! Allows registering factory functions keyed by model name patterns.
4//! When resolving, the first matching pattern wins.
5
6use std::sync::Arc;
7
8use super::BaseLlm;
9
10type LlmFactory = Box<dyn Fn(&str) -> Arc<dyn BaseLlm> + Send + Sync>;
11
12/// Registry that maps model name patterns to LLM factory functions.
13///
14/// Patterns are matched by prefix: a pattern `"gemini"` matches model names
15/// `"gemini-2.5-flash"`, `"gemini-2.0-pro"`, etc.
16pub struct LlmRegistry {
17    factories: Vec<(String, LlmFactory)>,
18}
19
20impl LlmRegistry {
21    /// Create a new empty LLM registry.
22    pub fn new() -> Self {
23        Self {
24            factories: Vec::new(),
25        }
26    }
27
28    /// Register a factory for model names matching the given pattern (prefix match).
29    pub fn register(
30        &mut self,
31        pattern: impl Into<String>,
32        factory: impl Fn(&str) -> Arc<dyn BaseLlm> + Send + Sync + 'static,
33    ) {
34        self.factories.push((pattern.into(), Box::new(factory)));
35    }
36
37    /// Resolve a model name to an LLM instance.
38    /// Returns the first factory whose pattern is a prefix of `model_name`.
39    pub fn resolve(&self, model_name: &str) -> Option<Arc<dyn BaseLlm>> {
40        for (pattern, factory) in &self.factories {
41            if model_name.starts_with(pattern.as_str()) {
42                return Some(factory(model_name));
43            }
44        }
45        None
46    }
47
48    /// Number of registered factories.
49    pub fn len(&self) -> usize {
50        self.factories.len()
51    }
52
53    /// Whether no factories are registered.
54    pub fn is_empty(&self) -> bool {
55        self.factories.is_empty()
56    }
57}
58
59impl Default for LlmRegistry {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    use crate::llm::{LlmError, LlmRequest, LlmResponse};
69    use async_trait::async_trait;
70
71    struct MockLlm {
72        model: String,
73    }
74
75    #[async_trait]
76    impl BaseLlm for MockLlm {
77        fn model_id(&self) -> &str {
78            &self.model
79        }
80        async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
81            Err(LlmError::Other("mock".into()))
82        }
83    }
84
85    #[test]
86    fn register_and_resolve() {
87        let mut registry = LlmRegistry::new();
88        registry.register("gemini", |name: &str| {
89            Arc::new(MockLlm {
90                model: name.to_string(),
91            })
92        });
93
94        let llm = registry.resolve("gemini-2.5-flash").unwrap();
95        assert_eq!(llm.model_id(), "gemini-2.5-flash");
96    }
97
98    #[test]
99    fn resolve_unknown_returns_none() {
100        let registry = LlmRegistry::new();
101        assert!(registry.resolve("gpt-4").is_none());
102    }
103
104    #[test]
105    fn first_match_wins() {
106        let mut registry = LlmRegistry::new();
107        registry.register("gemini-2.5", |name: &str| {
108            Arc::new(MockLlm {
109                model: format!("v2.5:{name}"),
110            })
111        });
112        registry.register("gemini", |name: &str| {
113            Arc::new(MockLlm {
114                model: format!("generic:{name}"),
115            })
116        });
117
118        let llm = registry.resolve("gemini-2.5-flash").unwrap();
119        assert_eq!(llm.model_id(), "v2.5:gemini-2.5-flash");
120
121        let llm2 = registry.resolve("gemini-1.5-pro").unwrap();
122        assert_eq!(llm2.model_id(), "generic:gemini-1.5-pro");
123    }
124
125    #[test]
126    fn len_and_is_empty() {
127        let mut registry = LlmRegistry::new();
128        assert!(registry.is_empty());
129        assert_eq!(registry.len(), 0);
130
131        registry.register("test", |name: &str| {
132            Arc::new(MockLlm {
133                model: name.to_string(),
134            })
135        });
136        assert!(!registry.is_empty());
137        assert_eq!(registry.len(), 1);
138    }
139}