gemini_adk_rs/live/
background_agent_dispatch.rs1use std::collections::HashMap;
10use std::sync::Arc;
11
12use tokio::sync::{Mutex, Semaphore};
13
14use crate::state::State;
15use crate::text::TextAgent;
16
17pub struct BackgroundAgentDispatcher {
33 budget: Arc<Semaphore>,
34 tasks: Arc<Mutex<HashMap<String, tokio::task::JoinHandle<()>>>>,
35 max_concurrent: usize,
36}
37
38impl BackgroundAgentDispatcher {
39 pub fn new(max_concurrent: usize) -> Self {
41 Self {
42 budget: Arc::new(Semaphore::new(max_concurrent)),
43 tasks: Arc::new(Mutex::new(HashMap::new())),
44 max_concurrent,
45 }
46 }
47
48 pub fn max_concurrent(&self) -> usize {
50 self.max_concurrent
51 }
52
53 pub fn available_permits(&self) -> usize {
55 self.budget.available_permits()
56 }
57
58 pub fn dispatch(&self, task_name: impl Into<String>, agent: Arc<dyn TextAgent>, state: State) {
65 let name = task_name.into();
66 let budget = self.budget.clone();
67 let tasks = self.tasks.clone();
68 let result_key = format!("{name}:result");
69 let error_key = format!("{name}:error");
70 let name_for_cleanup = name.clone();
71
72 let handle = tokio::spawn(async move {
73 let _permit = match budget.acquire().await {
75 Ok(p) => p,
76 Err(_) => return, };
78
79 match agent.run(&state).await {
80 Ok(result) => {
81 state.set(&result_key, &result);
82 }
83 Err(e) => {
84 state.set(&error_key, format!("{e}"));
85 }
86 }
87
88 tasks.lock().await.remove(&name_for_cleanup);
90 });
91
92 if let Ok(mut guard) = self.tasks.try_lock() {
96 guard.insert(name, handle);
97 }
98 }
99
100 pub async fn is_running(&self, name: &str) -> bool {
102 let guard = self.tasks.lock().await;
103 guard.get(name).map(|h| !h.is_finished()).unwrap_or(false)
104 }
105
106 pub async fn cancel_all(&self) {
108 let mut guard = self.tasks.lock().await;
109 for (_, handle) in guard.drain() {
110 handle.abort();
111 }
112 }
113
114 pub async fn cancel(&self, name: &str) {
116 let mut guard = self.tasks.lock().await;
117 if let Some(handle) = guard.remove(name) {
118 handle.abort();
119 }
120 }
121
122 pub async fn active_count(&self) -> usize {
124 let guard = self.tasks.lock().await;
125 guard.values().filter(|h| !h.is_finished()).count()
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::error::AgentError;
133 use async_trait::async_trait;
134
135 struct QuickAgent {
136 output: String,
137 }
138
139 #[async_trait]
140 impl TextAgent for QuickAgent {
141 fn name(&self) -> &str {
142 "quick"
143 }
144 async fn run(&self, _state: &State) -> Result<String, AgentError> {
145 Ok(self.output.clone())
146 }
147 }
148
149 struct SlowAgent;
150
151 #[async_trait]
152 impl TextAgent for SlowAgent {
153 fn name(&self) -> &str {
154 "slow"
155 }
156 async fn run(&self, _state: &State) -> Result<String, AgentError> {
157 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
158 Ok("done".into())
159 }
160 }
161
162 struct FailAgent;
163
164 #[async_trait]
165 impl TextAgent for FailAgent {
166 fn name(&self) -> &str {
167 "fail"
168 }
169 async fn run(&self, _state: &State) -> Result<String, AgentError> {
170 Err(AgentError::Other("background failure".into()))
171 }
172 }
173
174 struct StateWriterAgent;
175
176 #[async_trait]
177 impl TextAgent for StateWriterAgent {
178 fn name(&self) -> &str {
179 "writer"
180 }
181 async fn run(&self, state: &State) -> Result<String, AgentError> {
182 state.set("bg_wrote", true);
183 Ok("wrote state".into())
184 }
185 }
186
187 #[tokio::test]
188 async fn dispatch_writes_result_to_state() {
189 let dispatcher = BackgroundAgentDispatcher::new(5);
190 let state = State::new();
191 let agent: Arc<dyn TextAgent> = Arc::new(QuickAgent {
192 output: "analysis complete".into(),
193 });
194
195 dispatcher.dispatch("analysis", agent, state.clone());
196
197 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
199
200 assert_eq!(
201 state.get::<String>("analysis:result"),
202 Some("analysis complete".into())
203 );
204 }
205
206 #[tokio::test]
207 async fn dispatch_writes_error_to_state() {
208 let dispatcher = BackgroundAgentDispatcher::new(5);
209 let state = State::new();
210 let agent: Arc<dyn TextAgent> = Arc::new(FailAgent);
211
212 dispatcher.dispatch("check", agent, state.clone());
213
214 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
215
216 let error = state.get::<String>("check:error");
217 assert!(error.is_some());
218 assert!(error.unwrap().contains("background failure"));
219 }
220
221 #[tokio::test]
222 async fn budget_limits_concurrency() {
223 let dispatcher = BackgroundAgentDispatcher::new(2);
224 let state = State::new();
225 let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
226
227 assert_eq!(dispatcher.available_permits(), 2);
228
229 dispatcher.dispatch("task1", agent.clone(), state.clone());
230 dispatcher.dispatch("task2", agent.clone(), state.clone());
231
232 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
234
235 assert_eq!(dispatcher.available_permits(), 0);
237
238 tokio::time::sleep(std::time::Duration::from_millis(300)).await;
240
241 assert_eq!(dispatcher.available_permits(), 2);
242 }
243
244 #[tokio::test]
245 async fn cancel_all_aborts_tasks() {
246 let dispatcher = BackgroundAgentDispatcher::new(5);
247 let state = State::new();
248 let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
249
250 dispatcher.dispatch("long", agent, state.clone());
251
252 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
253 assert!(dispatcher.is_running("long").await);
254
255 dispatcher.cancel_all().await;
256
257 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
258
259 assert!(state.get::<String>("long:result").is_none());
261 }
262
263 #[tokio::test]
264 async fn state_mutations_visible_to_parent() {
265 let dispatcher = BackgroundAgentDispatcher::new(5);
266 let state = State::new();
267 let agent: Arc<dyn TextAgent> = Arc::new(StateWriterAgent);
268
269 dispatcher.dispatch("writer", agent, state.clone());
270
271 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
272
273 assert_eq!(state.get::<bool>("bg_wrote"), Some(true));
274 assert_eq!(
275 state.get::<String>("writer:result"),
276 Some("wrote state".into())
277 );
278 }
279
280 #[tokio::test]
281 async fn cancel_specific_task() {
282 let dispatcher = BackgroundAgentDispatcher::new(5);
283 let state = State::new();
284 let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
285
286 dispatcher.dispatch("keep", agent.clone(), state.clone());
287 dispatcher.dispatch("abort", agent, state.clone());
288
289 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
290
291 dispatcher.cancel("abort").await;
292
293 tokio::time::sleep(std::time::Duration::from_millis(300)).await;
294
295 assert_eq!(state.get::<String>("keep:result"), Some("done".into()));
297 assert!(state.get::<String>("abort:result").is_none());
299 }
300}