gemini_adk_rs/text/
dispatch.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use super::TextAgent;
8use crate::error::AgentError;
9use crate::state::State;
10
11/// Background task handle: agent name → join handle yielding `Ok(json)`/`Err(msg)`.
12type TaskMap = HashMap<String, tokio::task::JoinHandle<Result<String, String>>>;
13
14/// Shared registry for dispatched background tasks.
15#[derive(Clone, Default)]
16pub struct TaskRegistry {
17    pub(crate) inner: Arc<tokio::sync::Mutex<TaskMap>>,
18}
19
20impl TaskRegistry {
21    /// Create a new empty task registry.
22    pub fn new() -> Self {
23        Self::default()
24    }
25}
26
27/// Fire-and-forget background task launcher with global task budget.
28///
29/// Launches each child agent as a background `tokio::spawn` task,
30/// stores handles in a `TaskRegistry`, and returns immediately.
31pub struct DispatchTextAgent {
32    name: String,
33    children: Vec<(String, Arc<dyn TextAgent>)>,
34    registry: TaskRegistry,
35    budget: Arc<tokio::sync::Semaphore>,
36}
37
38impl DispatchTextAgent {
39    /// Create a new dispatch agent with named children and a concurrency budget.
40    pub fn new(
41        name: impl Into<String>,
42        children: Vec<(String, Arc<dyn TextAgent>)>,
43        registry: TaskRegistry,
44        budget: Arc<tokio::sync::Semaphore>,
45    ) -> Self {
46        Self {
47            name: name.into(),
48            children,
49            registry,
50            budget,
51        }
52    }
53}
54
55#[async_trait]
56impl TextAgent for DispatchTextAgent {
57    fn name(&self) -> &str {
58        &self.name
59    }
60
61    async fn run(&self, state: &State) -> Result<String, AgentError> {
62        let mut registry = self.registry.inner.lock().await;
63
64        for (task_name, agent) in &self.children {
65            let agent = agent.clone();
66            let state = state.clone();
67            let budget = self.budget.clone();
68            let task_name_owned = task_name.clone();
69
70            let handle = tokio::spawn(async move {
71                let _permit = budget
72                    .acquire()
73                    .await
74                    .map_err(|e| format!("Semaphore closed: {e}"))?;
75                agent
76                    .run(&state)
77                    .await
78                    .map_err(|e| format!("Task '{}' failed: {}", task_name_owned, e))
79            });
80
81            registry.insert(task_name.clone(), handle);
82        }
83
84        let _ = state.set(
85            "_dispatch_status",
86            self.children
87                .iter()
88                .map(|(name, _)| (name.clone(), "running".to_string()))
89                .collect::<HashMap<String, String>>(),
90        );
91
92        Ok(String::new())
93    }
94}
95
96// ── JoinTextAgent ─────────────────────────────────────────────────────────
97
98/// Waits for dispatched background tasks and collects their results.
99pub struct JoinTextAgent {
100    name: String,
101    registry: TaskRegistry,
102    target_names: Option<Vec<String>>,
103    timeout: Option<Duration>,
104}
105
106impl JoinTextAgent {
107    /// Create a new join agent that waits for dispatched tasks.
108    pub fn new(name: impl Into<String>, registry: TaskRegistry) -> Self {
109        Self {
110            name: name.into(),
111            registry,
112            target_names: None,
113            timeout: None,
114        }
115    }
116
117    /// Only wait for specific named tasks.
118    pub fn targets(mut self, names: Vec<String>) -> Self {
119        self.target_names = Some(names);
120        self
121    }
122
123    /// Set a timeout for waiting.
124    pub fn timeout(mut self, timeout: Duration) -> Self {
125        self.timeout = Some(timeout);
126        self
127    }
128}
129
130#[async_trait]
131impl TextAgent for JoinTextAgent {
132    fn name(&self) -> &str {
133        &self.name
134    }
135
136    async fn run(&self, state: &State) -> Result<String, AgentError> {
137        let mut registry = self.registry.inner.lock().await;
138
139        // Select tasks to wait for.
140        let tasks: HashMap<String, _> = if let Some(targets) = &self.target_names {
141            targets
142                .iter()
143                .filter_map(|name| registry.remove(name).map(|h| (name.clone(), h)))
144                .collect()
145        } else {
146            std::mem::take(&mut *registry)
147        };
148        drop(registry);
149
150        let mut results = Vec::new();
151
152        for (task_name, handle) in tasks {
153            let result = if let Some(timeout) = self.timeout {
154                match tokio::time::timeout(timeout, handle).await {
155                    Ok(Ok(Ok(text))) => {
156                        let _ = state.set(format!("_result_{}", task_name), &text);
157                        Ok(text)
158                    }
159                    Ok(Ok(Err(e))) => Err(AgentError::Other(e)),
160                    Ok(Err(e)) => Err(AgentError::Other(format!("Join error: {e}"))),
161                    Err(_) => Err(AgentError::Timeout),
162                }
163            } else {
164                match handle.await {
165                    Ok(Ok(text)) => {
166                        let _ = state.set(format!("_result_{}", task_name), &text);
167                        Ok(text)
168                    }
169                    Ok(Err(e)) => Err(AgentError::Other(e)),
170                    Err(e) => Err(AgentError::Other(format!("Join error: {e}"))),
171                }
172            };
173
174            results.push(result?);
175        }
176
177        let combined = results.join("\n");
178        let _ = state.set("output", &combined);
179        Ok(combined)
180    }
181}