gemini_adk_rs/live/
background_agent_dispatch.rs

1//! BackgroundAgentDispatcher — fire-and-forget text agent dispatch from live callbacks.
2//!
3//! Spawns [`TextAgent`] pipelines on background tokio tasks. Results are written
4//! to [`State`] under `{task_name}:result` / `{task_name}:error` keys, where
5//! watchers can react to them.
6//!
7//! A semaphore budget prevents unbounded task explosion.
8
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use tokio::sync::{Mutex, Semaphore};
13
14use crate::state::State;
15use crate::text::TextAgent;
16
17/// Dispatcher for running [`TextAgent`] pipelines as background tasks.
18///
19/// # Example
20///
21/// ```ignore
22/// let dispatcher = BackgroundAgentDispatcher::new(5); // max 5 concurrent
23///
24/// // From an on_turn_complete callback:
25/// dispatcher.dispatch("compliance_check", compliance_agent.clone(), state.clone());
26///
27/// // Results appear in state:
28/// //   "compliance_check:result" = "No violations detected"
29/// // OR
30/// //   "compliance_check:error" = "Agent failed: ..."
31/// ```
32pub struct BackgroundAgentDispatcher {
33    budget: Arc<Semaphore>,
34    tasks: Arc<Mutex<HashMap<String, tokio::task::JoinHandle<()>>>>,
35    max_concurrent: usize,
36}
37
38impl BackgroundAgentDispatcher {
39    /// Create a new dispatcher with the given concurrency budget.
40    pub fn new(max_concurrent: usize) -> Self {
41        Self {
42            budget: Arc::new(Semaphore::new(max_concurrent)),
43            tasks: Arc::new(Mutex::new(HashMap::new())),
44            max_concurrent,
45        }
46    }
47
48    /// Maximum concurrent background agents.
49    pub fn max_concurrent(&self) -> usize {
50        self.max_concurrent
51    }
52
53    /// Number of currently available permits.
54    pub fn available_permits(&self) -> usize {
55        self.budget.available_permits()
56    }
57
58    /// Dispatch a text agent to run in the background.
59    ///
60    /// Results are written to state under `{task_name}:result`.
61    /// Errors are written to `{task_name}:error`.
62    ///
63    /// If the budget is exhausted, the task will wait for a permit.
64    pub fn dispatch(&self, task_name: impl Into<String>, agent: Arc<dyn TextAgent>, state: State) {
65        let name = task_name.into();
66        let budget = self.budget.clone();
67        let tasks = self.tasks.clone();
68        let result_key = format!("{name}:result");
69        let error_key = format!("{name}:error");
70        let name_for_cleanup = name.clone();
71
72        let handle = tokio::spawn(async move {
73            // Acquire permit (waits if budget exhausted)
74            let _permit = match budget.acquire().await {
75                Ok(p) => p,
76                Err(_) => return, // Semaphore closed
77            };
78
79            match agent.run(&state).await {
80                Ok(result) => {
81                    state.set(&result_key, &result);
82                }
83                Err(e) => {
84                    state.set(&error_key, format!("{e}"));
85                }
86            }
87
88            // Clean up task handle
89            tasks.lock().await.remove(&name_for_cleanup);
90        });
91
92        // Store handle for cancellation. Use blocking try_lock to avoid
93        // making dispatch async — callers are typically in sync contexts.
94        // Fall back to fire-and-forget if lock is contended.
95        if let Ok(mut guard) = self.tasks.try_lock() {
96            guard.insert(name, handle);
97        }
98    }
99
100    /// Check if a named task is still running.
101    pub async fn is_running(&self, name: &str) -> bool {
102        let guard = self.tasks.lock().await;
103        guard.get(name).map(|h| !h.is_finished()).unwrap_or(false)
104    }
105
106    /// Cancel all running background agents.
107    pub async fn cancel_all(&self) {
108        let mut guard = self.tasks.lock().await;
109        for (_, handle) in guard.drain() {
110            handle.abort();
111        }
112    }
113
114    /// Cancel a specific named task.
115    pub async fn cancel(&self, name: &str) {
116        let mut guard = self.tasks.lock().await;
117        if let Some(handle) = guard.remove(name) {
118            handle.abort();
119        }
120    }
121
122    /// Number of tasks currently tracked (running or recently completed).
123    pub async fn active_count(&self) -> usize {
124        let guard = self.tasks.lock().await;
125        guard.values().filter(|h| !h.is_finished()).count()
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::error::AgentError;
133    use async_trait::async_trait;
134
135    struct QuickAgent {
136        output: String,
137    }
138
139    #[async_trait]
140    impl TextAgent for QuickAgent {
141        fn name(&self) -> &str {
142            "quick"
143        }
144        async fn run(&self, _state: &State) -> Result<String, AgentError> {
145            Ok(self.output.clone())
146        }
147    }
148
149    struct SlowAgent;
150
151    #[async_trait]
152    impl TextAgent for SlowAgent {
153        fn name(&self) -> &str {
154            "slow"
155        }
156        async fn run(&self, _state: &State) -> Result<String, AgentError> {
157            tokio::time::sleep(std::time::Duration::from_millis(200)).await;
158            Ok("done".into())
159        }
160    }
161
162    struct FailAgent;
163
164    #[async_trait]
165    impl TextAgent for FailAgent {
166        fn name(&self) -> &str {
167            "fail"
168        }
169        async fn run(&self, _state: &State) -> Result<String, AgentError> {
170            Err(AgentError::Other("background failure".into()))
171        }
172    }
173
174    struct StateWriterAgent;
175
176    #[async_trait]
177    impl TextAgent for StateWriterAgent {
178        fn name(&self) -> &str {
179            "writer"
180        }
181        async fn run(&self, state: &State) -> Result<String, AgentError> {
182            state.set("bg_wrote", true);
183            Ok("wrote state".into())
184        }
185    }
186
187    #[tokio::test]
188    async fn dispatch_writes_result_to_state() {
189        let dispatcher = BackgroundAgentDispatcher::new(5);
190        let state = State::new();
191        let agent: Arc<dyn TextAgent> = Arc::new(QuickAgent {
192            output: "analysis complete".into(),
193        });
194
195        dispatcher.dispatch("analysis", agent, state.clone());
196
197        // Wait for completion
198        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
199
200        assert_eq!(
201            state.get::<String>("analysis:result"),
202            Some("analysis complete".into())
203        );
204    }
205
206    #[tokio::test]
207    async fn dispatch_writes_error_to_state() {
208        let dispatcher = BackgroundAgentDispatcher::new(5);
209        let state = State::new();
210        let agent: Arc<dyn TextAgent> = Arc::new(FailAgent);
211
212        dispatcher.dispatch("check", agent, state.clone());
213
214        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
215
216        let error = state.get::<String>("check:error");
217        assert!(error.is_some());
218        assert!(error.unwrap().contains("background failure"));
219    }
220
221    #[tokio::test]
222    async fn budget_limits_concurrency() {
223        let dispatcher = BackgroundAgentDispatcher::new(2);
224        let state = State::new();
225        let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
226
227        assert_eq!(dispatcher.available_permits(), 2);
228
229        dispatcher.dispatch("task1", agent.clone(), state.clone());
230        dispatcher.dispatch("task2", agent.clone(), state.clone());
231
232        // Let tasks start
233        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
234
235        // Both permits should be taken
236        assert_eq!(dispatcher.available_permits(), 0);
237
238        // Wait for completion
239        tokio::time::sleep(std::time::Duration::from_millis(300)).await;
240
241        assert_eq!(dispatcher.available_permits(), 2);
242    }
243
244    #[tokio::test]
245    async fn cancel_all_aborts_tasks() {
246        let dispatcher = BackgroundAgentDispatcher::new(5);
247        let state = State::new();
248        let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
249
250        dispatcher.dispatch("long", agent, state.clone());
251
252        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
253        assert!(dispatcher.is_running("long").await);
254
255        dispatcher.cancel_all().await;
256
257        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
258
259        // Task was aborted, no result written
260        assert!(state.get::<String>("long:result").is_none());
261    }
262
263    #[tokio::test]
264    async fn state_mutations_visible_to_parent() {
265        let dispatcher = BackgroundAgentDispatcher::new(5);
266        let state = State::new();
267        let agent: Arc<dyn TextAgent> = Arc::new(StateWriterAgent);
268
269        dispatcher.dispatch("writer", agent, state.clone());
270
271        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
272
273        assert_eq!(state.get::<bool>("bg_wrote"), Some(true));
274        assert_eq!(
275            state.get::<String>("writer:result"),
276            Some("wrote state".into())
277        );
278    }
279
280    #[tokio::test]
281    async fn cancel_specific_task() {
282        let dispatcher = BackgroundAgentDispatcher::new(5);
283        let state = State::new();
284        let agent: Arc<dyn TextAgent> = Arc::new(SlowAgent);
285
286        dispatcher.dispatch("keep", agent.clone(), state.clone());
287        dispatcher.dispatch("abort", agent, state.clone());
288
289        tokio::time::sleep(std::time::Duration::from_millis(20)).await;
290
291        dispatcher.cancel("abort").await;
292
293        tokio::time::sleep(std::time::Duration::from_millis(300)).await;
294
295        // "keep" should complete
296        assert_eq!(state.get::<String>("keep:result"), Some("done".into()));
297        // "abort" should not have result
298        assert!(state.get::<String>("abort:result").is_none());
299    }
300}