gemini_adk_rs/text/
dispatch.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use super::TextAgent;
8use crate::error::AgentError;
9use crate::state::State;
10
11type TaskMap = HashMap<String, tokio::task::JoinHandle<Result<String, String>>>;
13
14#[derive(Clone, Default)]
16pub struct TaskRegistry {
17 pub(crate) inner: Arc<tokio::sync::Mutex<TaskMap>>,
18}
19
20impl TaskRegistry {
21 pub fn new() -> Self {
23 Self::default()
24 }
25}
26
27pub struct DispatchTextAgent {
32 name: String,
33 children: Vec<(String, Arc<dyn TextAgent>)>,
34 registry: TaskRegistry,
35 budget: Arc<tokio::sync::Semaphore>,
36}
37
38impl DispatchTextAgent {
39 pub fn new(
41 name: impl Into<String>,
42 children: Vec<(String, Arc<dyn TextAgent>)>,
43 registry: TaskRegistry,
44 budget: Arc<tokio::sync::Semaphore>,
45 ) -> Self {
46 Self {
47 name: name.into(),
48 children,
49 registry,
50 budget,
51 }
52 }
53}
54
55#[async_trait]
56impl TextAgent for DispatchTextAgent {
57 fn name(&self) -> &str {
58 &self.name
59 }
60
61 async fn run(&self, state: &State) -> Result<String, AgentError> {
62 let mut registry = self.registry.inner.lock().await;
63
64 for (task_name, agent) in &self.children {
65 let agent = agent.clone();
66 let state = state.clone();
67 let budget = self.budget.clone();
68 let task_name_owned = task_name.clone();
69
70 let handle = tokio::spawn(async move {
71 let _permit = budget
72 .acquire()
73 .await
74 .map_err(|e| format!("Semaphore closed: {e}"))?;
75 agent
76 .run(&state)
77 .await
78 .map_err(|e| format!("Task '{}' failed: {}", task_name_owned, e))
79 });
80
81 registry.insert(task_name.clone(), handle);
82 }
83
84 let _ = state.set(
85 "_dispatch_status",
86 self.children
87 .iter()
88 .map(|(name, _)| (name.clone(), "running".to_string()))
89 .collect::<HashMap<String, String>>(),
90 );
91
92 Ok(String::new())
93 }
94}
95
96pub struct JoinTextAgent {
100 name: String,
101 registry: TaskRegistry,
102 target_names: Option<Vec<String>>,
103 timeout: Option<Duration>,
104}
105
106impl JoinTextAgent {
107 pub fn new(name: impl Into<String>, registry: TaskRegistry) -> Self {
109 Self {
110 name: name.into(),
111 registry,
112 target_names: None,
113 timeout: None,
114 }
115 }
116
117 pub fn targets(mut self, names: Vec<String>) -> Self {
119 self.target_names = Some(names);
120 self
121 }
122
123 pub fn timeout(mut self, timeout: Duration) -> Self {
125 self.timeout = Some(timeout);
126 self
127 }
128}
129
130#[async_trait]
131impl TextAgent for JoinTextAgent {
132 fn name(&self) -> &str {
133 &self.name
134 }
135
136 async fn run(&self, state: &State) -> Result<String, AgentError> {
137 let mut registry = self.registry.inner.lock().await;
138
139 let tasks: HashMap<String, _> = if let Some(targets) = &self.target_names {
141 targets
142 .iter()
143 .filter_map(|name| registry.remove(name).map(|h| (name.clone(), h)))
144 .collect()
145 } else {
146 std::mem::take(&mut *registry)
147 };
148 drop(registry);
149
150 let mut results = Vec::new();
151
152 for (task_name, handle) in tasks {
153 let result = if let Some(timeout) = self.timeout {
154 match tokio::time::timeout(timeout, handle).await {
155 Ok(Ok(Ok(text))) => {
156 let _ = state.set(format!("_result_{}", task_name), &text);
157 Ok(text)
158 }
159 Ok(Ok(Err(e))) => Err(AgentError::Other(e)),
160 Ok(Err(e)) => Err(AgentError::Other(format!("Join error: {e}"))),
161 Err(_) => Err(AgentError::Timeout),
162 }
163 } else {
164 match handle.await {
165 Ok(Ok(text)) => {
166 let _ = state.set(format!("_result_{}", task_name), &text);
167 Ok(text)
168 }
169 Ok(Err(e)) => Err(AgentError::Other(e)),
170 Err(e) => Err(AgentError::Other(format!("Join error: {e}"))),
171 }
172 };
173
174 results.push(result?);
175 }
176
177 let combined = results.join("\n");
178 let _ = state.set("output", &combined);
179 Ok(combined)
180 }
181}