gemini_adk_fluent_rs/
operators.rs

1//! Operator algebra for agent composition.
2//!
3//! All types implementing `Composable` participate in the algebra:
4//!
5//! | Operator | Meaning            | Example                    |
6//! |----------|--------------------|----------------------------|
7//! | `>>`     | Sequential pipeline| `agent_a >> agent_b`       |
8//! | `\|`     | Parallel fan-out   | `agent_a \| agent_b`       |
9//! | `*`      | Loop (fixed)       | `agent * 3`                |
10//! | `//`     | Fallback chain     | `agent_a // agent_b`       |
11
12use std::sync::Arc;
13
14use gemini_adk_rs::llm::BaseLlm;
15use gemini_adk_rs::middleware::{Middleware, MiddlewareChain};
16use gemini_adk_rs::text::{
17    FallbackTextAgent, LoopTextAgent, ParallelTextAgent, SequentialTextAgent, TextAgent,
18};
19
20use crate::builder::AgentBuilder;
21use crate::compose::middleware::MiddlewareComposite;
22
23/// A composable workflow node — can be sequenced, fan-out, looped, etc.
24#[derive(Clone, Debug)]
25pub enum Composable {
26    /// A single agent node.
27    Agent(AgentBuilder),
28    /// A sequential pipeline of steps.
29    Pipeline(Pipeline),
30    /// A parallel fan-out of branches.
31    FanOut(FanOut),
32    /// A loop with optional termination predicate.
33    Loop(Loop),
34    /// A fallback chain (try each until one succeeds).
35    Fallback(Fallback),
36}
37
38/// Sequential pipeline: execute steps in order, passing state between them.
39#[derive(Clone, Debug)]
40pub struct Pipeline {
41    /// Ordered steps to execute sequentially.
42    pub steps: Vec<Composable>,
43}
44
45/// Parallel fan-out: execute branches concurrently, merge results.
46#[derive(Clone, Debug)]
47pub struct FanOut {
48    /// Branches to execute concurrently.
49    pub branches: Vec<Composable>,
50}
51
52/// Loop: repeat an agent or pipeline up to `max` times, or until a predicate.
53#[derive(Clone)]
54pub struct Loop {
55    /// The composable to repeat.
56    pub body: Box<Composable>,
57    /// Maximum number of iterations.
58    pub max: u32,
59    /// Optional early-exit predicate evaluated after each iteration.
60    pub until: Option<LoopPredicate>,
61    /// Middleware attached to the loop agent (e.g. `M::on_loop` observers).
62    /// Set via [`Loop::middleware`] / [`Composable::middleware`]; construct as
63    /// `Vec::new()` in literals.
64    #[doc(hidden)]
65    pub middleware: Vec<Arc<dyn Middleware>>,
66}
67
68/// Predicate for conditional loop termination.
69#[derive(Clone)]
70pub struct LoopPredicate {
71    predicate: std::sync::Arc<dyn Fn(&serde_json::Value) -> bool + Send + Sync>,
72}
73
74impl LoopPredicate {
75    /// Create a new predicate from a closure that checks loop state.
76    pub fn new(f: impl Fn(&serde_json::Value) -> bool + Send + Sync + 'static) -> Self {
77        Self {
78            predicate: std::sync::Arc::new(f),
79        }
80    }
81
82    /// Evaluate the predicate against the current state.
83    pub fn check(&self, state: &serde_json::Value) -> bool {
84        (self.predicate)(state)
85    }
86}
87
88impl std::fmt::Debug for LoopPredicate {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.write_str("LoopPredicate(<fn>)")
91    }
92}
93
94impl std::fmt::Debug for Loop {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_struct("Loop")
97            .field("body", &self.body)
98            .field("max", &self.max)
99            .field("until", &self.until)
100            .finish()
101    }
102}
103
104/// Fallback chain: try each agent in sequence until one succeeds.
105#[derive(Clone)]
106pub struct Fallback {
107    /// Candidate composables tried in order until one succeeds.
108    pub candidates: Vec<Composable>,
109    /// Middleware attached to the fallback agent (e.g. `M::on_fallback`).
110    middleware: Vec<Arc<dyn Middleware>>,
111}
112
113impl std::fmt::Debug for Fallback {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("Fallback")
116            .field("candidates", &self.candidates)
117            .finish()
118    }
119}
120
121/// Create a conditional loop predicate.
122pub fn until(
123    predicate: impl Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
124) -> LoopPredicate {
125    LoopPredicate::new(predicate)
126}
127
128// ── Conversions ──
129
130impl From<AgentBuilder> for Composable {
131    fn from(b: AgentBuilder) -> Self {
132        Composable::Agent(b)
133    }
134}
135
136impl From<Pipeline> for Composable {
137    fn from(p: Pipeline) -> Self {
138        Composable::Pipeline(p)
139    }
140}
141
142impl From<FanOut> for Composable {
143    fn from(f: FanOut) -> Self {
144        Composable::FanOut(f)
145    }
146}
147
148impl From<Loop> for Composable {
149    fn from(l: Loop) -> Self {
150        Composable::Loop(l)
151    }
152}
153
154impl From<Fallback> for Composable {
155    fn from(f: Fallback) -> Self {
156        Composable::Fallback(f)
157    }
158}
159
160// ── Compilation: Composable → TextAgent ──
161
162impl Composable {
163    /// Compile this composable tree into an executable `TextAgent`.
164    ///
165    /// Recursively compiles the tree: pipelines become `SequentialTextAgent`,
166    /// fan-outs become `ParallelTextAgent`, loops become `LoopTextAgent`,
167    /// fallbacks become `FallbackTextAgent`, and agents compile via
168    /// `AgentBuilder::build()`.
169    ///
170    /// ```rust,ignore
171    /// let pipeline = AgentBuilder::new("writer").instruction("Write a draft")
172    ///     >> AgentBuilder::new("reviewer").instruction("Review and improve");
173    ///
174    /// let agent = pipeline.compile(llm);
175    /// let result = agent.run(&state).await?;
176    /// ```
177    pub fn compile(self, llm: Arc<dyn BaseLlm>) -> Arc<dyn TextAgent> {
178        match self {
179            Composable::Agent(builder) => builder.build(llm),
180
181            Composable::Pipeline(pipeline) => {
182                let children: Vec<Arc<dyn TextAgent>> = pipeline
183                    .steps
184                    .into_iter()
185                    .map(|step| step.compile(llm.clone()))
186                    .collect();
187                Arc::new(SequentialTextAgent::new("pipeline", children))
188            }
189
190            Composable::FanOut(fan_out) => {
191                let branches: Vec<Arc<dyn TextAgent>> = fan_out
192                    .branches
193                    .into_iter()
194                    .map(|branch| branch.compile(llm.clone()))
195                    .collect();
196                Arc::new(ParallelTextAgent::new("fan_out", branches))
197            }
198
199            Composable::Loop(loop_node) => {
200                let middleware = loop_node.middleware;
201                let body = loop_node.body.compile(llm);
202                let mut loop_agent = LoopTextAgent::new("loop", body, loop_node.max);
203
204                if let Some(predicate) = loop_node.until {
205                    loop_agent = loop_agent.until(move |state: &gemini_adk_rs::State| {
206                        // Convert State to serde_json::Value for LoopPredicate compatibility.
207                        let keys = state.keys();
208                        let mut map = serde_json::Map::new();
209                        for key in keys {
210                            if let Some(val) = state.get_raw(&key) {
211                                map.insert(key, val);
212                            }
213                        }
214                        predicate.check(&serde_json::Value::Object(map))
215                    });
216                }
217
218                if !middleware.is_empty() {
219                    loop_agent = loop_agent.with_middleware_chain(chain_from(middleware));
220                }
221
222                Arc::new(loop_agent)
223            }
224
225            Composable::Fallback(fallback) => {
226                let middleware = fallback.middleware;
227                let candidates: Vec<Arc<dyn TextAgent>> = fallback
228                    .candidates
229                    .into_iter()
230                    .map(|c| c.compile(llm.clone()))
231                    .collect();
232                let mut agent = FallbackTextAgent::new("fallback", candidates);
233                if !middleware.is_empty() {
234                    agent = agent.with_middleware_chain(chain_from(middleware));
235                }
236                Arc::new(agent)
237            }
238        }
239    }
240}
241
242/// Build a [`MiddlewareChain`] from an ordered list of middleware layers.
243fn chain_from(layers: Vec<Arc<dyn Middleware>>) -> MiddlewareChain {
244    let mut chain = MiddlewareChain::new();
245    for layer in layers {
246        chain.add(layer);
247    }
248    chain
249}
250
251impl Composable {
252    /// Attach middleware to a `Loop` or `Fallback` node — the place where
253    /// combinator-level observers (`M::on_loop`, `M::on_fallback`) live.
254    ///
255    /// For other node kinds (single agent, pipeline, fan-out) this is a no-op:
256    /// attach `M::` middleware to the agent itself via
257    /// [`AgentBuilder::middleware`](crate::builder::AgentBuilder::middleware) instead.
258    pub fn middleware(self, composite: MiddlewareComposite) -> Self {
259        match self {
260            Composable::Loop(l) => Composable::Loop(l.middleware(composite)),
261            Composable::Fallback(f) => Composable::Fallback(f.middleware(composite)),
262            other => other,
263        }
264    }
265}
266
267// ── Safe variant accessors ──
268//
269// These inspect a `Composable` for a specific shape and return `None` when the
270// variant does not match, rather than panicking. Callers that want the
271// underlying structure should pattern-match directly; these are convenience
272// accessors for introspection (tests, tooling, debugging).
273
274impl Composable {
275    /// The first step of a [`Pipeline`], or `None` if this is not a pipeline
276    /// (or the pipeline is empty).
277    pub fn first_step(&self) -> Option<&Composable> {
278        match self {
279            Composable::Pipeline(p) => p.steps.first(),
280            _ => None,
281        }
282    }
283
284    /// The last step of a [`Pipeline`], or `None` if this is not a pipeline
285    /// (or the pipeline is empty).
286    pub fn last_step(&self) -> Option<&Composable> {
287        match self {
288            Composable::Pipeline(p) => p.steps.last(),
289            _ => None,
290        }
291    }
292
293    /// The `n`th step of a [`Pipeline`], or `None` if this is not a pipeline
294    /// or the index is out of bounds.
295    pub fn nth_step(&self, n: usize) -> Option<&Composable> {
296        match self {
297            Composable::Pipeline(p) => p.steps.get(n),
298            _ => None,
299        }
300    }
301
302    /// All steps of a [`Pipeline`], or `None` if this is not a pipeline.
303    pub fn pipeline_steps(&self) -> Option<&[Composable]> {
304        match self {
305            Composable::Pipeline(p) => Some(&p.steps),
306            _ => None,
307        }
308    }
309
310    /// The branches of a [`FanOut`], or `None` if this is not a fan-out.
311    pub fn fan_out_branches(&self) -> Option<&[Composable]> {
312        match self {
313            Composable::FanOut(f) => Some(&f.branches),
314            _ => None,
315        }
316    }
317
318    /// The termination predicate of a [`Loop`], or `None` if this is not a loop
319    /// (or the loop has no predicate).
320    pub fn loop_predicate(&self) -> Option<&LoopPredicate> {
321        match self {
322            Composable::Loop(l) => l.until.as_ref(),
323            _ => None,
324        }
325    }
326
327    /// The body of a [`Loop`], or `None` if this is not a loop.
328    pub fn loop_body(&self) -> Option<&Composable> {
329        match self {
330            Composable::Loop(l) => Some(&l.body),
331            _ => None,
332        }
333    }
334
335    /// The candidates of a [`Fallback`] chain, or `None` if this is not a fallback.
336    pub fn fallback_candidates(&self) -> Option<&[Composable]> {
337        match self {
338            Composable::Fallback(f) => Some(&f.candidates),
339            _ => None,
340        }
341    }
342}
343
344// ── Pipeline construction helpers ──
345
346impl Pipeline {
347    /// Create a pipeline from the given steps.
348    pub fn new(steps: Vec<Composable>) -> Self {
349        Self { steps }
350    }
351
352    /// Create an empty named pipeline (fluent builder entry point).
353    ///
354    /// ```ignore
355    /// Pipeline::builder("etl")
356    ///     .step(extract_agent)
357    ///     .step(transform_agent)
358    ///     .step(load_agent)
359    /// ```
360    pub fn builder(_name: &str) -> Self {
361        Self { steps: Vec::new() }
362    }
363
364    /// Add a sequential step to this pipeline (fluent builder).
365    pub fn step(mut self, agent: impl Into<Composable>) -> Self {
366        self.steps.push(agent.into());
367        self
368    }
369
370    /// Add a sub-agent step (alias for `step` — matches upstream naming).
371    pub fn sub_agent(self, agent: AgentBuilder) -> Self {
372        self.step(agent)
373    }
374
375    /// Set a description (metadata, not used at runtime).
376    pub fn describe(self, _desc: &str) -> Self {
377        self
378    }
379
380    /// Flatten: if a step is itself a Pipeline, inline its steps.
381    fn push_flat(&mut self, step: Composable) {
382        match step {
383            Composable::Pipeline(p) => self.steps.extend(p.steps),
384            other => self.steps.push(other),
385        }
386    }
387}
388
389impl FanOut {
390    /// Create a fan-out from the given branches.
391    pub fn new(branches: Vec<Composable>) -> Self {
392        Self { branches }
393    }
394
395    /// Create an empty named fan-out (fluent builder entry point).
396    ///
397    /// ```ignore
398    /// FanOut::builder("research")
399    ///     .branch(web_agent)
400    ///     .branch(db_agent)
401    /// ```
402    pub fn builder(_name: &str) -> Self {
403        Self {
404            branches: Vec::new(),
405        }
406    }
407
408    /// Add a parallel branch (fluent builder).
409    pub fn branch(mut self, agent: impl Into<Composable>) -> Self {
410        self.branches.push(agent.into());
411        self
412    }
413
414    /// Add a sub-agent branch (alias for `branch` — matches upstream naming).
415    pub fn sub_agent(self, agent: AgentBuilder) -> Self {
416        self.branch(agent)
417    }
418
419    /// Set a description (metadata, not used at runtime).
420    pub fn describe(self, _desc: &str) -> Self {
421        self
422    }
423
424    fn push_flat(&mut self, branch: Composable) {
425        match branch {
426            Composable::FanOut(f) => self.branches.extend(f.branches),
427            other => self.branches.push(other),
428        }
429    }
430}
431
432impl Fallback {
433    /// Create a fallback chain from the given candidates.
434    pub fn new(candidates: Vec<Composable>) -> Self {
435        Self {
436            candidates,
437            middleware: Vec::new(),
438        }
439    }
440
441    /// Attach middleware to the fallback agent (e.g. `M::on_fallback(|name| …)`),
442    /// observed when a fallback branch activates.
443    pub fn middleware(mut self, composite: MiddlewareComposite) -> Self {
444        self.middleware.extend(composite.layers);
445        self
446    }
447
448    fn push_flat(&mut self, candidate: Composable) {
449        match candidate {
450            Composable::Fallback(f) => self.candidates.extend(f.candidates),
451            other => self.candidates.push(other),
452        }
453    }
454}
455
456// ── Operator: >> (Shr) = Sequential Pipeline ──
457
458/// AgentBuilder >> AgentBuilder → Pipeline
459impl std::ops::Shr for AgentBuilder {
460    type Output = Composable;
461
462    fn shr(self, rhs: AgentBuilder) -> Self::Output {
463        Composable::Pipeline(Pipeline::new(vec![
464            Composable::Agent(self),
465            Composable::Agent(rhs),
466        ]))
467    }
468}
469
470/// Composable >> AgentBuilder → Pipeline (flattening)
471impl std::ops::Shr<AgentBuilder> for Composable {
472    type Output = Composable;
473
474    fn shr(self, rhs: AgentBuilder) -> Self::Output {
475        let mut pipeline = match self {
476            Composable::Pipeline(p) => p,
477            other => Pipeline::new(vec![other]),
478        };
479        pipeline.push_flat(Composable::Agent(rhs));
480        Composable::Pipeline(pipeline)
481    }
482}
483
484/// AgentBuilder >> Composable → Pipeline (flattening)
485impl std::ops::Shr<Composable> for AgentBuilder {
486    type Output = Composable;
487
488    fn shr(self, rhs: Composable) -> Self::Output {
489        let mut pipeline = Pipeline::new(vec![Composable::Agent(self)]);
490        pipeline.push_flat(rhs);
491        Composable::Pipeline(pipeline)
492    }
493}
494
495/// Composable >> Composable → Pipeline (flattening)
496impl std::ops::Shr for Composable {
497    type Output = Composable;
498
499    fn shr(self, rhs: Composable) -> Self::Output {
500        let mut pipeline = match self {
501            Composable::Pipeline(p) => p,
502            other => Pipeline::new(vec![other]),
503        };
504        pipeline.push_flat(rhs);
505        Composable::Pipeline(pipeline)
506    }
507}
508
509// ── Operator: | (BitOr) = Parallel Fan-Out ──
510
511/// AgentBuilder | AgentBuilder → FanOut
512impl std::ops::BitOr for AgentBuilder {
513    type Output = Composable;
514
515    fn bitor(self, rhs: AgentBuilder) -> Self::Output {
516        Composable::FanOut(FanOut::new(vec![
517            Composable::Agent(self),
518            Composable::Agent(rhs),
519        ]))
520    }
521}
522
523/// Composable | AgentBuilder → FanOut (flattening)
524impl std::ops::BitOr<AgentBuilder> for Composable {
525    type Output = Composable;
526
527    fn bitor(self, rhs: AgentBuilder) -> Self::Output {
528        let mut fan_out = match self {
529            Composable::FanOut(f) => f,
530            other => FanOut::new(vec![other]),
531        };
532        fan_out.push_flat(Composable::Agent(rhs));
533        Composable::FanOut(fan_out)
534    }
535}
536
537/// Composable | Composable → FanOut (flattening)
538impl std::ops::BitOr for Composable {
539    type Output = Composable;
540
541    fn bitor(self, rhs: Composable) -> Self::Output {
542        let mut fan_out = match self {
543            Composable::FanOut(f) => f,
544            other => FanOut::new(vec![other]),
545        };
546        fan_out.push_flat(rhs);
547        Composable::FanOut(fan_out)
548    }
549}
550
551// ── Operator: * (Mul<u32>) = Fixed Loop ──
552
553/// AgentBuilder * 3 → Loop(max=3)
554impl std::ops::Mul<u32> for AgentBuilder {
555    type Output = Composable;
556
557    fn mul(self, rhs: u32) -> Self::Output {
558        Composable::Loop(Loop {
559            body: Box::new(Composable::Agent(self)),
560            max: rhs,
561            until: None,
562            middleware: Vec::new(),
563        })
564    }
565}
566
567/// Composable * 3 → Loop(max=3)
568impl std::ops::Mul<u32> for Composable {
569    type Output = Composable;
570
571    fn mul(self, rhs: u32) -> Self::Output {
572        Composable::Loop(Loop {
573            body: Box::new(self),
574            max: rhs,
575            until: None,
576            middleware: Vec::new(),
577        })
578    }
579}
580
581/// AgentBuilder * until(pred) → conditional Loop
582impl std::ops::Mul<LoopPredicate> for AgentBuilder {
583    type Output = Composable;
584
585    fn mul(self, rhs: LoopPredicate) -> Self::Output {
586        Composable::Loop(Loop {
587            body: Box::new(Composable::Agent(self)),
588            max: u32::MAX,
589            until: Some(rhs),
590            middleware: Vec::new(),
591        })
592    }
593}
594
595/// Composable * until(pred) → conditional Loop
596impl std::ops::Mul<LoopPredicate> for Composable {
597    type Output = Composable;
598
599    fn mul(self, rhs: LoopPredicate) -> Self::Output {
600        Composable::Loop(Loop {
601            body: Box::new(self),
602            max: u32::MAX,
603            until: Some(rhs),
604            middleware: Vec::new(),
605        })
606    }
607}
608
609// ── Operator: / (Div) = Fallback Chain ──
610// Note: Rust doesn't have a `//` operator. We use `/` (Div) instead.
611
612/// AgentBuilder / AgentBuilder → Fallback
613impl std::ops::Div for AgentBuilder {
614    type Output = Composable;
615
616    fn div(self, rhs: AgentBuilder) -> Self::Output {
617        Composable::Fallback(Fallback::new(vec![
618            Composable::Agent(self),
619            Composable::Agent(rhs),
620        ]))
621    }
622}
623
624/// Composable / AgentBuilder → Fallback (flattening)
625impl std::ops::Div<AgentBuilder> for Composable {
626    type Output = Composable;
627
628    fn div(self, rhs: AgentBuilder) -> Self::Output {
629        let mut fallback = match self {
630            Composable::Fallback(f) => f,
631            other => Fallback::new(vec![other]),
632        };
633        fallback.push_flat(Composable::Agent(rhs));
634        Composable::Fallback(fallback)
635    }
636}
637
638/// Composable / Composable → Fallback (flattening)
639impl std::ops::Div for Composable {
640    type Output = Composable;
641
642    fn div(self, rhs: Composable) -> Self::Output {
643        let mut fallback = match self {
644            Composable::Fallback(f) => f,
645            other => Fallback::new(vec![other]),
646        };
647        fallback.push_flat(rhs);
648        Composable::Fallback(fallback)
649    }
650}
651
652// ── Loop builder method (for chaining max on until-loops) ──
653
654impl Loop {
655    /// Create a loop builder with a body agent and default max iterations.
656    ///
657    /// ```ignore
658    /// Loop::builder("refine")
659    ///     .step(refine_agent)
660    ///     .max_iterations(5)
661    /// ```
662    pub fn builder(_name: &str) -> Self {
663        Self {
664            body: Box::new(Composable::Pipeline(Pipeline::new(Vec::new()))),
665            max: 10,
666            until: None,
667            middleware: Vec::new(),
668        }
669    }
670
671    /// Attach middleware to the loop agent (e.g. `M::on_loop(|i| …)`), observed
672    /// on every iteration.
673    pub fn middleware(mut self, composite: MiddlewareComposite) -> Self {
674        self.middleware.extend(composite.layers);
675        self
676    }
677
678    /// Set the body composable to loop over.
679    pub fn step(mut self, agent: impl Into<Composable>) -> Self {
680        self.body = Box::new(agent.into());
681        self
682    }
683
684    /// Set a maximum number of iterations.
685    pub fn max_iterations(mut self, n: u32) -> Self {
686        self.max = n;
687        self
688    }
689
690    /// Set a maximum number of iterations for a conditional loop.
691    pub fn max(mut self, max: u32) -> Self {
692        self.max = max;
693        self
694    }
695
696    /// Set a description (metadata, not used at runtime).
697    pub fn describe(self, _desc: &str) -> Self {
698        self
699    }
700}
701
702#[cfg(test)]
703mod tests {
704    use super::*;
705
706    fn agent(name: &str) -> AgentBuilder {
707        AgentBuilder::new(name)
708    }
709
710    #[test]
711    fn pipeline_from_shr() {
712        let result = agent("a") >> agent("b");
713        match result {
714            Composable::Pipeline(p) => assert_eq!(p.steps.len(), 2),
715            _ => panic!("expected Pipeline"),
716        }
717    }
718
719    #[test]
720    fn pipeline_flattens() {
721        let result = agent("a") >> agent("b") >> agent("c");
722        match result {
723            Composable::Pipeline(p) => assert_eq!(p.steps.len(), 3),
724            _ => panic!("expected Pipeline"),
725        }
726    }
727
728    #[test]
729    fn fan_out_from_bitor() {
730        let result = agent("a") | agent("b");
731        match result {
732            Composable::FanOut(f) => assert_eq!(f.branches.len(), 2),
733            _ => panic!("expected FanOut"),
734        }
735    }
736
737    #[test]
738    fn fan_out_flattens() {
739        let result = (agent("a") | agent("b")) | agent("c");
740        match result {
741            Composable::FanOut(f) => assert_eq!(f.branches.len(), 3),
742            _ => panic!("expected FanOut"),
743        }
744    }
745
746    #[test]
747    fn fixed_loop_from_mul() {
748        let result = agent("a") * 3;
749        match result {
750            Composable::Loop(l) => {
751                assert_eq!(l.max, 3);
752                assert!(l.until.is_none());
753            }
754            _ => panic!("expected Loop"),
755        }
756    }
757
758    #[test]
759    fn conditional_loop_from_mul_until() {
760        let pred = until(|_v| true);
761        let result = agent("a") * pred;
762        match result {
763            Composable::Loop(l) => {
764                assert_eq!(l.max, u32::MAX);
765                assert!(l.until.is_some());
766            }
767            _ => panic!("expected Loop"),
768        }
769    }
770
771    #[test]
772    fn fallback_from_div() {
773        let result = agent("a") / agent("b");
774        match result {
775            Composable::Fallback(f) => assert_eq!(f.candidates.len(), 2),
776            _ => panic!("expected Fallback"),
777        }
778    }
779
780    #[test]
781    fn fallback_flattens() {
782        let result = (agent("a") / agent("b")) / agent("c");
783        match result {
784            Composable::Fallback(f) => assert_eq!(f.candidates.len(), 3),
785            _ => panic!("expected Fallback"),
786        }
787    }
788
789    #[test]
790    fn mixed_pipeline_with_fan_out() {
791        let result = agent("a") >> (agent("b") | agent("c"));
792        match &result {
793            Composable::Pipeline(p) => {
794                assert_eq!(p.steps.len(), 2);
795                assert!(matches!(&p.steps[1], Composable::FanOut(_)));
796            }
797            _ => panic!("expected Pipeline"),
798        }
799    }
800
801    #[test]
802    fn pipeline_then_loop() {
803        let result = agent("a") >> (agent("b") * 5);
804        match &result {
805            Composable::Pipeline(p) => {
806                assert_eq!(p.steps.len(), 2);
807                assert!(matches!(&p.steps[1], Composable::Loop(_)));
808            }
809            _ => panic!("expected Pipeline"),
810        }
811    }
812
813    #[test]
814    fn safe_accessors_return_some_on_match() {
815        let pipeline = agent("a").instruction("x") >> agent("b").instruction("y");
816        assert!(pipeline.first_step().is_some());
817        assert!(pipeline.last_step().is_some());
818        assert!(pipeline.nth_step(1).is_some());
819        assert!(pipeline.nth_step(99).is_none());
820        assert_eq!(pipeline.pipeline_steps().map(|s| s.len()), Some(2));
821
822        let fan_out = Composable::Agent(agent("a")) | Composable::Agent(agent("b"));
823        assert_eq!(fan_out.fan_out_branches().map(|b| b.len()), Some(2));
824
825        let looped = agent("a") * until(|_| true);
826        assert!(looped.loop_predicate().is_some());
827        assert!(looped.loop_body().is_some());
828
829        let fallback = agent("a") / agent("b");
830        assert_eq!(fallback.fallback_candidates().map(|c| c.len()), Some(2));
831    }
832
833    #[test]
834    fn safe_accessors_return_none_on_mismatch() {
835        // Calling a pipeline accessor on a non-Pipeline returns None, not panic.
836        let solo = Composable::Agent(agent("solo"));
837        assert!(solo.first_step().is_none());
838        assert!(solo.last_step().is_none());
839        assert!(solo.nth_step(0).is_none());
840        assert!(solo.pipeline_steps().is_none());
841        assert!(solo.fan_out_branches().is_none());
842        assert!(solo.loop_predicate().is_none());
843        assert!(solo.loop_body().is_none());
844        assert!(solo.fallback_candidates().is_none());
845
846        // A fixed loop (no predicate) returns None for loop_predicate but
847        // Some for loop_body.
848        let fixed = agent("a") * 3;
849        assert!(fixed.loop_predicate().is_none());
850        assert!(fixed.loop_body().is_some());
851        // And a pipeline accessor on a loop is None.
852        assert!(fixed.first_step().is_none());
853    }
854
855    #[test]
856    fn loop_predicate_check() {
857        let pred = until(|v| v.get("done").and_then(|v| v.as_bool()).unwrap_or(false));
858        assert!(!pred.check(&serde_json::json!({"done": false})));
859        assert!(pred.check(&serde_json::json!({"done": true})));
860    }
861
862    // ── compile() tests ──
863
864    mod compile_tests {
865        use super::*;
866        use async_trait::async_trait;
867        use gemini_adk_rs::llm::{BaseLlm, LlmError, LlmRequest, LlmResponse};
868        use gemini_genai_rs::prelude::{Content, Part, Role};
869
870        /// A mock LLM that returns its agent's name from the system instruction.
871        struct NameEchoLlm;
872
873        #[async_trait]
874        impl BaseLlm for NameEchoLlm {
875            fn model_id(&self) -> &str {
876                "name-echo"
877            }
878            async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
879                let text = req
880                    .system_instruction
881                    .unwrap_or_else(|| "no-instruction".into());
882                Ok(LlmResponse {
883                    content: Content {
884                        role: Some(Role::Model),
885                        parts: vec![Part::Text { text }],
886                    },
887                    finish_reason: Some("STOP".into()),
888                    usage: None,
889                })
890            }
891        }
892
893        fn llm() -> Arc<dyn BaseLlm> {
894            Arc::new(NameEchoLlm)
895        }
896
897        #[tokio::test]
898        async fn compile_single_agent() {
899            let composable = Composable::Agent(AgentBuilder::new("solo").instruction("hello"));
900            let agent = composable.compile(llm());
901            let state = gemini_adk_rs::State::new();
902            let result = agent.run(&state).await.unwrap();
903            assert_eq!(result, "hello");
904        }
905
906        #[tokio::test]
907        async fn compile_pipeline() {
908            let pipeline = agent("a").instruction("step-a") >> agent("b").instruction("step-b");
909            let compiled = pipeline.compile(llm());
910            let state = gemini_adk_rs::State::new();
911            let result = compiled.run(&state).await.unwrap();
912            // Sequential: last agent's output wins. step-b echoes its instruction.
913            assert_eq!(result, "step-b");
914        }
915
916        #[tokio::test]
917        async fn compile_fan_out() {
918            let fan_out = Composable::Agent(agent("a").instruction("branch-a"))
919                | Composable::Agent(agent("b").instruction("branch-b"));
920            let compiled = fan_out.compile(llm());
921            let state = gemini_adk_rs::State::new();
922            let result = compiled.run(&state).await.unwrap();
923            assert!(result.contains("branch-a"));
924            assert!(result.contains("branch-b"));
925        }
926
927        #[tokio::test]
928        async fn compile_loop() {
929            let looped = agent("counter").instruction("tick") * 3;
930            let compiled = looped.compile(llm());
931            let state = gemini_adk_rs::State::new();
932            let result = compiled.run(&state).await.unwrap();
933            assert_eq!(result, "tick");
934        }
935
936        #[tokio::test]
937        async fn compile_fallback() {
938            let fallback = agent("a").instruction("first") / agent("b").instruction("second");
939            let compiled = fallback.compile(llm());
940            let state = gemini_adk_rs::State::new();
941            let result = compiled.run(&state).await.unwrap();
942            // First agent succeeds, so its result is returned.
943            assert_eq!(result, "first");
944        }
945
946        #[tokio::test]
947        async fn on_loop_fires_through_operator() {
948            use crate::compose::M;
949            use std::sync::atomic::{AtomicU32, Ordering};
950
951            let count = Arc::new(AtomicU32::new(0));
952            let c2 = count.clone();
953            // Attach the combinator-level observer to the loop node.
954            let looped =
955                (agent("counter").instruction("tick") * 3).middleware(M::on_loop(move |_i| {
956                    c2.fetch_add(1, Ordering::SeqCst);
957                }));
958            let compiled = looped.compile(llm());
959            let state = gemini_adk_rs::State::new();
960            compiled.run(&state).await.unwrap();
961            // Three iterations → three LoopIteration events observed.
962            assert_eq!(count.load(Ordering::SeqCst), 3);
963        }
964
965        #[tokio::test]
966        async fn compile_loop_with_predicate() {
967            // Use a mock LLM that increments state on each call.
968            struct IncrementLlm;
969
970            #[async_trait]
971            impl BaseLlm for IncrementLlm {
972                fn model_id(&self) -> &str {
973                    "incr"
974                }
975                async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
976                    Ok(LlmResponse {
977                        content: Content {
978                            role: Some(Role::Model),
979                            parts: vec![Part::Text {
980                                text: "done".into(),
981                            }],
982                        },
983                        finish_reason: Some("STOP".into()),
984                        usage: None,
985                    })
986                }
987            }
988
989            // Build a FnTextAgent-driven loop instead to test predicate.
990            // We'll test via the operators directly.
991            let pred = until(|v| v.get("n").and_then(|v| v.as_i64()).unwrap_or(0) >= 3);
992            let body = agent("incr").instruction("increment");
993            let looped = body * pred;
994
995            // Compile it. The predicate checks state for "n" >= 3, but
996            // the mock LLM doesn't set "n". Loop will run max iterations.
997            // This tests that the predicate is wired through.
998            let compiled = looped.compile(Arc::new(IncrementLlm));
999            let state = gemini_adk_rs::State::new();
1000            let _ = state.set("n", 5); // Pre-set to pass predicate immediately.
1001            let result = compiled.run(&state).await.unwrap();
1002            assert_eq!(result, "done"); // Ran once, predicate passed.
1003        }
1004
1005        #[tokio::test]
1006        async fn compile_mixed_pipeline_with_fan_out() {
1007            let mixed = agent("a").instruction("start")
1008                >> (Composable::Agent(agent("b").instruction("left"))
1009                    | Composable::Agent(agent("c").instruction("right")));
1010            let compiled = mixed.compile(llm());
1011            let state = gemini_adk_rs::State::new();
1012            let result = compiled.run(&state).await.unwrap();
1013            assert!(result.contains("left"));
1014            assert!(result.contains("right"));
1015        }
1016    }
1017}