gemini_adk_rs/llm/
registry.rs1use std::sync::Arc;
7
8use super::BaseLlm;
9
10type LlmFactory = Box<dyn Fn(&str) -> Arc<dyn BaseLlm> + Send + Sync>;
11
12pub struct LlmRegistry {
17 factories: Vec<(String, LlmFactory)>,
18}
19
20impl LlmRegistry {
21 pub fn new() -> Self {
23 Self {
24 factories: Vec::new(),
25 }
26 }
27
28 pub fn register(
30 &mut self,
31 pattern: impl Into<String>,
32 factory: impl Fn(&str) -> Arc<dyn BaseLlm> + Send + Sync + 'static,
33 ) {
34 self.factories.push((pattern.into(), Box::new(factory)));
35 }
36
37 pub fn resolve(&self, model_name: &str) -> Option<Arc<dyn BaseLlm>> {
40 for (pattern, factory) in &self.factories {
41 if model_name.starts_with(pattern.as_str()) {
42 return Some(factory(model_name));
43 }
44 }
45 None
46 }
47
48 pub fn len(&self) -> usize {
50 self.factories.len()
51 }
52
53 pub fn is_empty(&self) -> bool {
55 self.factories.is_empty()
56 }
57}
58
59impl Default for LlmRegistry {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 use crate::llm::{LlmError, LlmRequest, LlmResponse};
69 use async_trait::async_trait;
70
71 struct MockLlm {
72 model: String,
73 }
74
75 #[async_trait]
76 impl BaseLlm for MockLlm {
77 fn model_id(&self) -> &str {
78 &self.model
79 }
80 async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
81 Err(LlmError::Other("mock".into()))
82 }
83 }
84
85 #[test]
86 fn register_and_resolve() {
87 let mut registry = LlmRegistry::new();
88 registry.register("gemini", |name: &str| {
89 Arc::new(MockLlm {
90 model: name.to_string(),
91 })
92 });
93
94 let llm = registry.resolve("gemini-2.5-flash").unwrap();
95 assert_eq!(llm.model_id(), "gemini-2.5-flash");
96 }
97
98 #[test]
99 fn resolve_unknown_returns_none() {
100 let registry = LlmRegistry::new();
101 assert!(registry.resolve("gpt-4").is_none());
102 }
103
104 #[test]
105 fn first_match_wins() {
106 let mut registry = LlmRegistry::new();
107 registry.register("gemini-2.5", |name: &str| {
108 Arc::new(MockLlm {
109 model: format!("v2.5:{name}"),
110 })
111 });
112 registry.register("gemini", |name: &str| {
113 Arc::new(MockLlm {
114 model: format!("generic:{name}"),
115 })
116 });
117
118 let llm = registry.resolve("gemini-2.5-flash").unwrap();
119 assert_eq!(llm.model_id(), "v2.5:gemini-2.5-flash");
120
121 let llm2 = registry.resolve("gemini-1.5-pro").unwrap();
122 assert_eq!(llm2.model_id(), "generic:gemini-1.5-pro");
123 }
124
125 #[test]
126 fn len_and_is_empty() {
127 let mut registry = LlmRegistry::new();
128 assert!(registry.is_empty());
129 assert_eq!(registry.len(), 0);
130
131 registry.register("test", |name: &str| {
132 Arc::new(MockLlm {
133 model: name.to_string(),
134 })
135 });
136 assert!(!registry.is_empty());
137 assert_eq!(registry.len(), 1);
138 }
139}