gemini_adk_rs/orchestration/
mod.rs1use std::future::Future;
20use std::pin::Pin;
21use std::sync::Arc;
22
23use serde_json::Value;
24
25use crate::error::AgentError;
26use crate::llm::{BaseLlm, LlmRequest};
27use crate::state::State;
28use crate::text::TextAgent;
29
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum Mode {
33 Call,
36 Dispatch,
38 Background,
41}
42
43pub fn result_key(name: &str) -> String {
45 format!("{name}:result")
46}
47
48pub fn error_key(name: &str) -> String {
50 format!("{name}:error")
51}
52
53pub fn provenance(state: &State, key: &str) -> Option<String> {
56 state
57 .get::<serde_json::Value>(&format!("state_meta:{key}"))
58 .and_then(|m| m.get("source").and_then(|s| s.as_str().map(String::from)))
59}
60
61pub async fn call(
68 name: &str,
69 agent: Arc<dyn TextAgent>,
70 state: &State,
71) -> Result<String, AgentError> {
72 let result = agent.run(state).await;
73 match &result {
74 Ok(r) => {
75 let key = result_key(name);
76 let _ = state.set(
77 format!("state_meta:{key}"),
78 serde_json::json!({ "source": "agent", "resolver": name }),
79 );
80 let _ = state.set(key, r);
81 }
82 Err(e) => {
83 let _ = state.set(error_key(name), e.to_string());
84 }
85 }
86 result
87}
88
89type FetchFn =
91 Arc<dyn Fn(State) -> Pin<Box<dyn Future<Output = Result<Value, String>> + Send>> + Send + Sync>;
92
93enum Source {
94 Agent(Arc<dyn TextAgent>),
96 Fetch(FetchFn),
99 Llm {
101 llm: Arc<dyn BaseLlm>,
103 prompt: String,
105 },
106}
107
108fn interpolate(template: &str, state: &State) -> String {
110 let mut out = String::with_capacity(template.len());
111 let mut rest = template;
112 while let Some(open) = rest.find('{') {
113 out.push_str(&rest[..open]);
114 let after = &rest[open + 1..];
115 let Some(close) = after.find('}') else {
116 out.push_str(&rest[open..]);
117 return out;
118 };
119 let key = after[..close].trim();
120 match state.get::<serde_json::Value>(key) {
121 Some(serde_json::Value::String(s)) => out.push_str(&s),
122 Some(v) => out.push_str(&v.to_string()),
123 None => {}
124 }
125 rest = &after[close + 1..];
126 }
127 out.push_str(rest);
128 out
129}
130
131pub struct Resolver {
143 name: String,
144 source: Source,
145}
146
147impl Resolver {
148 pub fn agent(name: impl Into<String>, agent: Arc<dyn TextAgent>) -> Self {
150 Self {
151 name: name.into(),
152 source: Source::Agent(agent),
153 }
154 }
155
156 pub fn fetch<F, Fut>(name: impl Into<String>, f: F) -> Self
160 where
161 F: Fn(State) -> Fut + Send + Sync + 'static,
162 Fut: Future<Output = Result<Value, String>> + Send + 'static,
163 {
164 let f = Arc::new(f);
165 Self {
166 name: name.into(),
167 source: Source::Fetch(Arc::new(move |state| {
168 let f = f.clone();
169 Box::pin(async move { f(state).await })
170 })),
171 }
172 }
173
174 pub fn llm(name: impl Into<String>, llm: Arc<dyn BaseLlm>, prompt: impl Into<String>) -> Self {
177 Self {
178 name: name.into(),
179 source: Source::Llm {
180 llm,
181 prompt: prompt.into(),
182 },
183 }
184 }
185
186 pub fn name(&self) -> &str {
188 &self.name
189 }
190
191 fn source_kind(&self) -> &'static str {
193 match &self.source {
194 Source::Agent(_) => "agent",
195 Source::Fetch(_) => "fetch",
196 Source::Llm { .. } => "llm",
197 }
198 }
199
200 pub async fn resolve(&self, state: &State) -> Result<Value, String> {
204 let outcome = match &self.source {
205 Source::Agent(a) => a
206 .run(state)
207 .await
208 .map(Value::from)
209 .map_err(|e| e.to_string()),
210 Source::Fetch(f) => f(state.clone()).await,
211 Source::Llm { llm, prompt } => {
212 let rendered = interpolate(prompt, state);
213 llm.generate(LlmRequest::from_text(rendered))
214 .await
215 .map(|r| Value::from(r.text()))
216 .map_err(|e| e.to_string())
217 }
218 };
219 match &outcome {
220 Ok(v) => {
221 let key = result_key(&self.name);
222 let _ = state.set(
223 format!("state_meta:{key}"),
224 serde_json::json!({ "source": self.source_kind(), "resolver": self.name }),
225 );
226 let _ = state.set(key, v.clone());
227 }
228 Err(e) => {
229 let _ = state.set(error_key(&self.name), e);
230 }
231 }
232 outcome
233 }
234
235 pub fn dispatch(self, state: State) {
239 tokio::spawn(async move {
240 let _ = self.resolve(&state).await;
241 });
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use async_trait::async_trait;
249 use serde_json::json;
250
251 struct Echo(&'static str);
252 #[async_trait]
253 impl TextAgent for Echo {
254 fn name(&self) -> &str {
255 "echo"
256 }
257 async fn run(&self, _state: &State) -> Result<String, AgentError> {
258 Ok(self.0.to_string())
259 }
260 }
261
262 struct Boom;
263 #[async_trait]
264 impl TextAgent for Boom {
265 fn name(&self) -> &str {
266 "boom"
267 }
268 async fn run(&self, _state: &State) -> Result<String, AgentError> {
269 Err(AgentError::Other("kaboom".into()))
270 }
271 }
272
273 #[tokio::test]
274 async fn call_writes_result_to_state() {
275 let state = State::new();
276 let out = call("verify", Arc::new(Echo("ok-123")), &state)
277 .await
278 .unwrap();
279 assert_eq!(out, "ok-123");
280 assert_eq!(
281 state.get::<String>("verify:result").as_deref(),
282 Some("ok-123")
283 );
284 }
285
286 #[tokio::test]
287 async fn call_writes_error_to_state() {
288 let state = State::new();
289 let r = call("verify", Arc::new(Boom), &state).await;
290 assert!(r.is_err());
291 assert!(state.contains("verify:error"));
292 assert!(!state.contains("verify:result"));
293 }
294
295 #[tokio::test]
296 async fn resolver_fetch_binds_state_and_writes_result() {
297 let state = State::new();
298 let _ = state.set("slot", "afternoon");
299 let r = Resolver::fetch("availability", |s: State| async move {
300 let slot = s.get::<String>("slot").unwrap_or_default();
302 Ok(json!({ "open": slot == "afternoon" }))
303 });
304 let out = r.resolve(&state).await.unwrap();
305 assert_eq!(out, json!({ "open": true }));
306 assert_eq!(
307 state.get::<Value>("availability:result"),
308 Some(json!({ "open": true }))
309 );
310 assert_eq!(
312 provenance(&state, "availability:result").as_deref(),
313 Some("fetch")
314 );
315 }
316
317 #[tokio::test]
318 async fn resolver_agent_uses_result_convention() {
319 let state = State::new();
320 Resolver::agent("verify", Arc::new(Echo("ok-9")))
322 .resolve(&state)
323 .await
324 .unwrap();
325 assert_eq!(
326 state.get::<String>("verify:result").as_deref(),
327 Some("ok-9")
328 );
329 }
330
331 #[tokio::test]
332 async fn resolver_fetch_records_error() {
333 let state = State::new();
334 let r = Resolver::fetch("lookup", |_s: State| async move {
335 Err::<Value, String>("upstream 503".into())
336 });
337 assert!(r.resolve(&state).await.is_err());
338 assert_eq!(
339 state.get::<String>("lookup:error").as_deref(),
340 Some("upstream 503")
341 );
342 assert!(!state.contains("lookup:result"));
343 }
344
345 #[tokio::test]
346 async fn resolver_llm_interpolates_prompt_and_stores_text() {
347 use crate::llm::{LlmError, LlmResponse};
348 use gemini_genai_rs::prelude::Content;
349
350 struct EchoLlm;
351 #[async_trait]
352 impl BaseLlm for EchoLlm {
353 fn model_id(&self) -> &str {
354 "echo"
355 }
356 async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
357 let prompt = request.contents[0].parts.iter().find_map(|p| match p {
359 gemini_genai_rs::prelude::Part::Text { text } => Some(text.clone()),
360 _ => None,
361 });
362 Ok(LlmResponse {
363 content: Content::model(prompt.unwrap_or_default()),
364 finish_reason: None,
365 usage: None,
366 })
367 }
368 }
369
370 let state = State::new();
371 let _ = state.set("topic", "billing");
372 let out = Resolver::llm("summary", Arc::new(EchoLlm), "Summarize the {topic} issue")
373 .resolve(&state)
374 .await
375 .unwrap();
376 assert_eq!(out, json!("Summarize the billing issue"));
377 assert_eq!(
378 state.get::<String>("summary:result").as_deref(),
379 Some("Summarize the billing issue")
380 );
381 }
382
383 #[tokio::test]
384 async fn resolver_dispatch_runs_detached() {
385 let state = State::new();
386 Resolver::fetch("ping", |_s: State| async move { Ok(json!("pong")) })
387 .dispatch(state.clone());
388 for _ in 0..100 {
390 if state.contains("ping:result") {
391 break;
392 }
393 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
394 }
395 assert_eq!(state.get::<String>("ping:result").as_deref(), Some("pong"));
396 }
397}