gemini_adk_rs/orchestration/
mod.rs

1//! Agent orchestration — invoke an agent in a [`Mode`].
2//!
3//! An agent is a value ([`TextAgent`] — a local agent, a composed pipeline, or a
4//! remote A2A agent). Orchestration is the single question of *how* you invoke
5//! it; the result always lands in governed `State` under `{name}:result` (or
6//! `{name}:error`), so coordination is reactive and uniform regardless of the
7//! invoker (the model, a `Flow`, an `Extract`, or a watcher).
8//!
9//! | Mode | Sync? | Lowers to |
10//! |------|-------|-----------|
11//! | [`Mode::Call`] | sync — caller awaits | [`call`] (agent-as-tool, awaited inline) |
12//! | [`Mode::Dispatch`] | async, fire-and-forget | [`BackgroundAgentDispatcher::dispatch`](crate::live::BackgroundAgentDispatcher) |
13//! | [`Mode::Background`] | async, model-aware | an agent-tool marked [`ToolExecutionMode::Background`](crate::live::ToolExecutionMode) |
14//!
15//! All three write `{name}:result`, so a `Flow` step can complete on a resolved
16//! result via [`Guard::resolved`](crate::flow::Guard::resolved), and any
17//! consumer reads the value the same way.
18
19use std::future::Future;
20use std::pin::Pin;
21use std::sync::Arc;
22
23use serde_json::Value;
24
25use crate::error::AgentError;
26use crate::llm::{BaseLlm, LlmRequest};
27use crate::state::State;
28use crate::text::TextAgent;
29
30/// How an agent is invoked.
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum Mode {
33    /// Synchronous — the caller awaits the result. Use only for *fast*
34    /// dependencies (a voice session should not block on slow work).
35    Call,
36    /// Asynchronous, fire-and-forget — the conversation does not wait.
37    Dispatch,
38    /// Asynchronous, model-aware — runs detached; the result is delivered back
39    /// to the model via `FunctionResponseScheduling`.
40    Background,
41}
42
43/// State key an agent's successful result is written to.
44pub fn result_key(name: &str) -> String {
45    format!("{name}:result")
46}
47
48/// State key an agent's error is written to.
49pub fn error_key(name: &str) -> String {
50    format!("{name}:error")
51}
52
53/// The provenance source of a value at `key` (e.g. `"agent"`, `"fetch"`,
54/// `"llm"`, or `"extraction"`), if one was recorded under `state_meta:{key}`.
55pub fn provenance(state: &State, key: &str) -> Option<String> {
56    state
57        .get::<serde_json::Value>(&format!("state_meta:{key}"))
58        .and_then(|m| m.get("source").and_then(|s| s.as_str().map(String::from)))
59}
60
61/// Invoke `agent` **synchronously**: run it to completion, write its result to
62/// `{name}:result` (or its error to `{name}:error`), and return the result.
63///
64/// This is the [`Mode::Call`] lowering. It uses the same `{name}:result`
65/// convention as [`BackgroundAgentDispatcher::dispatch`](crate::live::BackgroundAgentDispatcher),
66/// so sync and async invocations are observed identically.
67pub async fn call(
68    name: &str,
69    agent: Arc<dyn TextAgent>,
70    state: &State,
71) -> Result<String, AgentError> {
72    let result = agent.run(state).await;
73    match &result {
74        Ok(r) => {
75            let key = result_key(name);
76            let _ = state.set(
77                format!("state_meta:{key}"),
78                serde_json::json!({ "source": "agent", "resolver": name }),
79            );
80            let _ = state.set(key, r);
81        }
82        Err(e) => {
83            let _ = state.set(error_key(name), e.to_string());
84        }
85    }
86    result
87}
88
89/// The async source of a value, bound from `State`.
90type FetchFn =
91    Arc<dyn Fn(State) -> Pin<Box<dyn Future<Output = Result<Value, String>> + Send>> + Send + Sync>;
92
93enum Source {
94    /// Run a [`TextAgent`] (which reads its inputs from `State`).
95    Agent(Arc<dyn TextAgent>),
96    /// Run an async closure that reads `State` and returns a value — the seam
97    /// for a tool call, an HTTP fetch, or an MCP request.
98    Fetch(FetchFn),
99    /// One-shot OOB LLM completion over a `State`-interpolated prompt.
100    Llm {
101        /// The out-of-band LLM.
102        llm: Arc<dyn BaseLlm>,
103        /// Prompt template; `{key}` interpolates the `State` value at `key`.
104        prompt: String,
105    },
106}
107
108/// Interpolate `{key}` placeholders in `template` with `State` string values.
109fn interpolate(template: &str, state: &State) -> String {
110    let mut out = String::with_capacity(template.len());
111    let mut rest = template;
112    while let Some(open) = rest.find('{') {
113        out.push_str(&rest[..open]);
114        let after = &rest[open + 1..];
115        let Some(close) = after.find('}') else {
116            out.push_str(&rest[open..]);
117            return out;
118        };
119        let key = after[..close].trim();
120        match state.get::<serde_json::Value>(key) {
121            Some(serde_json::Value::String(s)) => out.push_str(&s),
122            Some(v) => out.push_str(&v.to_string()),
123            None => {}
124        }
125        rest = &after[close + 1..];
126    }
127    out.push_str(rest);
128    out
129}
130
131/// A named async value source whose inputs come from `State` and whose result
132/// lands back in `State` under `{name}:result` (or `{name}:error`).
133///
134/// `Resolver` is the async sibling of the deterministic
135/// [`Recognizer`](crate::extract::Recognizer): both are *inputs from State →
136/// value*. A `Resolver`
137/// generalizes [`call`] from "a sub-agent" to **any** async source — a sub-agent
138/// ([`Resolver::agent`]) or a system fetch / tool call / MCP request
139/// ([`Resolver::fetch`]) — under one result convention, so a `Flow` step can
140/// complete on it via [`Guard::resolved`](crate::flow::Guard::resolved)
141/// regardless of where the value came from.
142pub struct Resolver {
143    name: String,
144    source: Source,
145}
146
147impl Resolver {
148    /// Resolve by running a sub-agent. Its `String` output becomes the result.
149    pub fn agent(name: impl Into<String>, agent: Arc<dyn TextAgent>) -> Self {
150        Self {
151            name: name.into(),
152            source: Source::Agent(agent),
153        }
154    }
155
156    /// Resolve by running an async closure over a clone of `State` — the seam
157    /// for an HTTP fetch, a tool call, or an MCP request. The closure returns
158    /// `Ok(value)` on success or `Err(message)` to record an error.
159    pub fn fetch<F, Fut>(name: impl Into<String>, f: F) -> Self
160    where
161        F: Fn(State) -> Fut + Send + Sync + 'static,
162        Fut: Future<Output = Result<Value, String>> + Send + 'static,
163    {
164        let f = Arc::new(f);
165        Self {
166            name: name.into(),
167            source: Source::Fetch(Arc::new(move |state| {
168                let f = f.clone();
169                Box::pin(async move { f(state).await })
170            })),
171        }
172    }
173
174    /// Resolve by running a one-shot OOB LLM over a `State`-interpolated prompt
175    /// (`{key}` placeholders). The completion text becomes the result.
176    pub fn llm(name: impl Into<String>, llm: Arc<dyn BaseLlm>, prompt: impl Into<String>) -> Self {
177        Self {
178            name: name.into(),
179            source: Source::Llm {
180                llm,
181                prompt: prompt.into(),
182            },
183        }
184    }
185
186    /// The resolver's name (the `{name}:result` prefix it writes).
187    pub fn name(&self) -> &str {
188        &self.name
189    }
190
191    /// The provenance kind of this resolver's source (`agent`/`fetch`/`llm`).
192    fn source_kind(&self) -> &'static str {
193        match &self.source {
194            Source::Agent(_) => "agent",
195            Source::Fetch(_) => "fetch",
196            Source::Llm { .. } => "llm",
197        }
198    }
199
200    /// Resolve **synchronously** ([`Mode::Call`]): await the source, write its
201    /// value to `{name}:result` (or its error to `{name}:error`), record its
202    /// provenance under `state_meta:{name}:result`, and return it.
203    pub async fn resolve(&self, state: &State) -> Result<Value, String> {
204        let outcome = match &self.source {
205            Source::Agent(a) => a
206                .run(state)
207                .await
208                .map(Value::from)
209                .map_err(|e| e.to_string()),
210            Source::Fetch(f) => f(state.clone()).await,
211            Source::Llm { llm, prompt } => {
212                let rendered = interpolate(prompt, state);
213                llm.generate(LlmRequest::from_text(rendered))
214                    .await
215                    .map(|r| Value::from(r.text()))
216                    .map_err(|e| e.to_string())
217            }
218        };
219        match &outcome {
220            Ok(v) => {
221                let key = result_key(&self.name);
222                let _ = state.set(
223                    format!("state_meta:{key}"),
224                    serde_json::json!({ "source": self.source_kind(), "resolver": self.name }),
225                );
226                let _ = state.set(key, v.clone());
227            }
228            Err(e) => {
229                let _ = state.set(error_key(&self.name), e);
230            }
231        }
232        outcome
233    }
234
235    /// Resolve **detached** ([`Mode::Dispatch`]): spawn the resolution on the
236    /// runtime and return immediately. The conversation does not wait; consumers
237    /// observe completion reactively via `{name}:result`.
238    pub fn dispatch(self, state: State) {
239        tokio::spawn(async move {
240            let _ = self.resolve(&state).await;
241        });
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use async_trait::async_trait;
249    use serde_json::json;
250
251    struct Echo(&'static str);
252    #[async_trait]
253    impl TextAgent for Echo {
254        fn name(&self) -> &str {
255            "echo"
256        }
257        async fn run(&self, _state: &State) -> Result<String, AgentError> {
258            Ok(self.0.to_string())
259        }
260    }
261
262    struct Boom;
263    #[async_trait]
264    impl TextAgent for Boom {
265        fn name(&self) -> &str {
266            "boom"
267        }
268        async fn run(&self, _state: &State) -> Result<String, AgentError> {
269            Err(AgentError::Other("kaboom".into()))
270        }
271    }
272
273    #[tokio::test]
274    async fn call_writes_result_to_state() {
275        let state = State::new();
276        let out = call("verify", Arc::new(Echo("ok-123")), &state)
277            .await
278            .unwrap();
279        assert_eq!(out, "ok-123");
280        assert_eq!(
281            state.get::<String>("verify:result").as_deref(),
282            Some("ok-123")
283        );
284    }
285
286    #[tokio::test]
287    async fn call_writes_error_to_state() {
288        let state = State::new();
289        let r = call("verify", Arc::new(Boom), &state).await;
290        assert!(r.is_err());
291        assert!(state.contains("verify:error"));
292        assert!(!state.contains("verify:result"));
293    }
294
295    #[tokio::test]
296    async fn resolver_fetch_binds_state_and_writes_result() {
297        let state = State::new();
298        let _ = state.set("slot", "afternoon");
299        let r = Resolver::fetch("availability", |s: State| async move {
300            // Inputs come from State; the value is arbitrary JSON.
301            let slot = s.get::<String>("slot").unwrap_or_default();
302            Ok(json!({ "open": slot == "afternoon" }))
303        });
304        let out = r.resolve(&state).await.unwrap();
305        assert_eq!(out, json!({ "open": true }));
306        assert_eq!(
307            state.get::<Value>("availability:result"),
308            Some(json!({ "open": true }))
309        );
310        // Provenance is recorded for the resolved value.
311        assert_eq!(
312            provenance(&state, "availability:result").as_deref(),
313            Some("fetch")
314        );
315    }
316
317    #[tokio::test]
318    async fn resolver_agent_uses_result_convention() {
319        let state = State::new();
320        // An agent resolver shares the `{name}:result` convention with `call`.
321        Resolver::agent("verify", Arc::new(Echo("ok-9")))
322            .resolve(&state)
323            .await
324            .unwrap();
325        assert_eq!(
326            state.get::<String>("verify:result").as_deref(),
327            Some("ok-9")
328        );
329    }
330
331    #[tokio::test]
332    async fn resolver_fetch_records_error() {
333        let state = State::new();
334        let r = Resolver::fetch("lookup", |_s: State| async move {
335            Err::<Value, String>("upstream 503".into())
336        });
337        assert!(r.resolve(&state).await.is_err());
338        assert_eq!(
339            state.get::<String>("lookup:error").as_deref(),
340            Some("upstream 503")
341        );
342        assert!(!state.contains("lookup:result"));
343    }
344
345    #[tokio::test]
346    async fn resolver_llm_interpolates_prompt_and_stores_text() {
347        use crate::llm::{LlmError, LlmResponse};
348        use gemini_genai_rs::prelude::Content;
349
350        struct EchoLlm;
351        #[async_trait]
352        impl BaseLlm for EchoLlm {
353            fn model_id(&self) -> &str {
354                "echo"
355            }
356            async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
357                // Echo the (interpolated) prompt back as the completion.
358                let prompt = request.contents[0].parts.iter().find_map(|p| match p {
359                    gemini_genai_rs::prelude::Part::Text { text } => Some(text.clone()),
360                    _ => None,
361                });
362                Ok(LlmResponse {
363                    content: Content::model(prompt.unwrap_or_default()),
364                    finish_reason: None,
365                    usage: None,
366                })
367            }
368        }
369
370        let state = State::new();
371        let _ = state.set("topic", "billing");
372        let out = Resolver::llm("summary", Arc::new(EchoLlm), "Summarize the {topic} issue")
373            .resolve(&state)
374            .await
375            .unwrap();
376        assert_eq!(out, json!("Summarize the billing issue"));
377        assert_eq!(
378            state.get::<String>("summary:result").as_deref(),
379            Some("Summarize the billing issue")
380        );
381    }
382
383    #[tokio::test]
384    async fn resolver_dispatch_runs_detached() {
385        let state = State::new();
386        Resolver::fetch("ping", |_s: State| async move { Ok(json!("pong")) })
387            .dispatch(state.clone());
388        // The spawned task writes the result; await it becoming visible.
389        for _ in 0..100 {
390            if state.contains("ping:result") {
391                break;
392            }
393            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
394        }
395        assert_eq!(state.get::<String>("ping:result").as_deref(), Some("pong"));
396    }
397}