1use std::sync::Arc;
13
14use gemini_adk_rs::llm::BaseLlm;
15use gemini_adk_rs::text::{
16 FallbackTextAgent, LoopTextAgent, ParallelTextAgent, SequentialTextAgent, TextAgent,
17};
18
19use crate::builder::AgentBuilder;
20
21#[derive(Clone, Debug)]
23pub enum Composable {
24 Agent(AgentBuilder),
26 Pipeline(Pipeline),
28 FanOut(FanOut),
30 Loop(Loop),
32 Fallback(Fallback),
34}
35
36#[derive(Clone, Debug)]
38pub struct Pipeline {
39 pub steps: Vec<Composable>,
41}
42
43#[derive(Clone, Debug)]
45pub struct FanOut {
46 pub branches: Vec<Composable>,
48}
49
50#[derive(Clone)]
52pub struct Loop {
53 pub body: Box<Composable>,
55 pub max: u32,
57 pub until: Option<LoopPredicate>,
59}
60
61#[derive(Clone)]
63pub struct LoopPredicate {
64 predicate: std::sync::Arc<dyn Fn(&serde_json::Value) -> bool + Send + Sync>,
65}
66
67impl LoopPredicate {
68 pub fn new(f: impl Fn(&serde_json::Value) -> bool + Send + Sync + 'static) -> Self {
70 Self {
71 predicate: std::sync::Arc::new(f),
72 }
73 }
74
75 pub fn check(&self, state: &serde_json::Value) -> bool {
77 (self.predicate)(state)
78 }
79}
80
81impl std::fmt::Debug for LoopPredicate {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.write_str("LoopPredicate(<fn>)")
84 }
85}
86
87impl std::fmt::Debug for Loop {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("Loop")
90 .field("body", &self.body)
91 .field("max", &self.max)
92 .field("until", &self.until)
93 .finish()
94 }
95}
96
97#[derive(Clone, Debug)]
99pub struct Fallback {
100 pub candidates: Vec<Composable>,
102}
103
104pub fn until(
106 predicate: impl Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
107) -> LoopPredicate {
108 LoopPredicate::new(predicate)
109}
110
111impl From<AgentBuilder> for Composable {
114 fn from(b: AgentBuilder) -> Self {
115 Composable::Agent(b)
116 }
117}
118
119impl From<Pipeline> for Composable {
120 fn from(p: Pipeline) -> Self {
121 Composable::Pipeline(p)
122 }
123}
124
125impl From<FanOut> for Composable {
126 fn from(f: FanOut) -> Self {
127 Composable::FanOut(f)
128 }
129}
130
131impl From<Loop> for Composable {
132 fn from(l: Loop) -> Self {
133 Composable::Loop(l)
134 }
135}
136
137impl From<Fallback> for Composable {
138 fn from(f: Fallback) -> Self {
139 Composable::Fallback(f)
140 }
141}
142
143impl Composable {
146 pub fn compile(self, llm: Arc<dyn BaseLlm>) -> Arc<dyn TextAgent> {
161 match self {
162 Composable::Agent(builder) => builder.build(llm),
163
164 Composable::Pipeline(pipeline) => {
165 let children: Vec<Arc<dyn TextAgent>> = pipeline
166 .steps
167 .into_iter()
168 .map(|step| step.compile(llm.clone()))
169 .collect();
170 Arc::new(SequentialTextAgent::new("pipeline", children))
171 }
172
173 Composable::FanOut(fan_out) => {
174 let branches: Vec<Arc<dyn TextAgent>> = fan_out
175 .branches
176 .into_iter()
177 .map(|branch| branch.compile(llm.clone()))
178 .collect();
179 Arc::new(ParallelTextAgent::new("fan_out", branches))
180 }
181
182 Composable::Loop(loop_node) => {
183 let body = loop_node.body.compile(llm);
184 let mut loop_agent = LoopTextAgent::new("loop", body, loop_node.max);
185
186 if let Some(predicate) = loop_node.until {
187 loop_agent = loop_agent.until(move |state: &gemini_adk_rs::State| {
188 let keys = state.keys();
190 let mut map = serde_json::Map::new();
191 for key in keys {
192 if let Some(val) = state.get_raw(&key) {
193 map.insert(key, val);
194 }
195 }
196 predicate.check(&serde_json::Value::Object(map))
197 });
198 }
199
200 Arc::new(loop_agent)
201 }
202
203 Composable::Fallback(fallback) => {
204 let candidates: Vec<Arc<dyn TextAgent>> = fallback
205 .candidates
206 .into_iter()
207 .map(|c| c.compile(llm.clone()))
208 .collect();
209 Arc::new(FallbackTextAgent::new("fallback", candidates))
210 }
211 }
212 }
213}
214
215impl Pipeline {
218 pub fn new(steps: Vec<Composable>) -> Self {
220 Self { steps }
221 }
222
223 pub fn builder(_name: &str) -> Self {
232 Self { steps: Vec::new() }
233 }
234
235 pub fn step(mut self, agent: impl Into<Composable>) -> Self {
237 self.steps.push(agent.into());
238 self
239 }
240
241 pub fn sub_agent(self, agent: AgentBuilder) -> Self {
243 self.step(agent)
244 }
245
246 pub fn describe(self, _desc: &str) -> Self {
248 self
249 }
250
251 fn push_flat(&mut self, step: Composable) {
253 match step {
254 Composable::Pipeline(p) => self.steps.extend(p.steps),
255 other => self.steps.push(other),
256 }
257 }
258}
259
260impl FanOut {
261 pub fn new(branches: Vec<Composable>) -> Self {
263 Self { branches }
264 }
265
266 pub fn builder(_name: &str) -> Self {
274 Self {
275 branches: Vec::new(),
276 }
277 }
278
279 pub fn branch(mut self, agent: impl Into<Composable>) -> Self {
281 self.branches.push(agent.into());
282 self
283 }
284
285 pub fn sub_agent(self, agent: AgentBuilder) -> Self {
287 self.branch(agent)
288 }
289
290 pub fn describe(self, _desc: &str) -> Self {
292 self
293 }
294
295 fn push_flat(&mut self, branch: Composable) {
296 match branch {
297 Composable::FanOut(f) => self.branches.extend(f.branches),
298 other => self.branches.push(other),
299 }
300 }
301}
302
303impl Fallback {
304 pub fn new(candidates: Vec<Composable>) -> Self {
306 Self { candidates }
307 }
308
309 fn push_flat(&mut self, candidate: Composable) {
310 match candidate {
311 Composable::Fallback(f) => self.candidates.extend(f.candidates),
312 other => self.candidates.push(other),
313 }
314 }
315}
316
317impl std::ops::Shr for AgentBuilder {
321 type Output = Composable;
322
323 fn shr(self, rhs: AgentBuilder) -> Self::Output {
324 Composable::Pipeline(Pipeline::new(vec![
325 Composable::Agent(self),
326 Composable::Agent(rhs),
327 ]))
328 }
329}
330
331impl std::ops::Shr<AgentBuilder> for Composable {
333 type Output = Composable;
334
335 fn shr(self, rhs: AgentBuilder) -> Self::Output {
336 let mut pipeline = match self {
337 Composable::Pipeline(p) => p,
338 other => Pipeline::new(vec![other]),
339 };
340 pipeline.push_flat(Composable::Agent(rhs));
341 Composable::Pipeline(pipeline)
342 }
343}
344
345impl std::ops::Shr<Composable> for AgentBuilder {
347 type Output = Composable;
348
349 fn shr(self, rhs: Composable) -> Self::Output {
350 let mut pipeline = Pipeline::new(vec![Composable::Agent(self)]);
351 pipeline.push_flat(rhs);
352 Composable::Pipeline(pipeline)
353 }
354}
355
356impl std::ops::Shr for Composable {
358 type Output = Composable;
359
360 fn shr(self, rhs: Composable) -> Self::Output {
361 let mut pipeline = match self {
362 Composable::Pipeline(p) => p,
363 other => Pipeline::new(vec![other]),
364 };
365 pipeline.push_flat(rhs);
366 Composable::Pipeline(pipeline)
367 }
368}
369
370impl std::ops::BitOr for AgentBuilder {
374 type Output = Composable;
375
376 fn bitor(self, rhs: AgentBuilder) -> Self::Output {
377 Composable::FanOut(FanOut::new(vec![
378 Composable::Agent(self),
379 Composable::Agent(rhs),
380 ]))
381 }
382}
383
384impl std::ops::BitOr<AgentBuilder> for Composable {
386 type Output = Composable;
387
388 fn bitor(self, rhs: AgentBuilder) -> Self::Output {
389 let mut fan_out = match self {
390 Composable::FanOut(f) => f,
391 other => FanOut::new(vec![other]),
392 };
393 fan_out.push_flat(Composable::Agent(rhs));
394 Composable::FanOut(fan_out)
395 }
396}
397
398impl std::ops::BitOr for Composable {
400 type Output = Composable;
401
402 fn bitor(self, rhs: Composable) -> Self::Output {
403 let mut fan_out = match self {
404 Composable::FanOut(f) => f,
405 other => FanOut::new(vec![other]),
406 };
407 fan_out.push_flat(rhs);
408 Composable::FanOut(fan_out)
409 }
410}
411
412impl std::ops::Mul<u32> for AgentBuilder {
416 type Output = Composable;
417
418 fn mul(self, rhs: u32) -> Self::Output {
419 Composable::Loop(Loop {
420 body: Box::new(Composable::Agent(self)),
421 max: rhs,
422 until: None,
423 })
424 }
425}
426
427impl std::ops::Mul<u32> for Composable {
429 type Output = Composable;
430
431 fn mul(self, rhs: u32) -> Self::Output {
432 Composable::Loop(Loop {
433 body: Box::new(self),
434 max: rhs,
435 until: None,
436 })
437 }
438}
439
440impl std::ops::Mul<LoopPredicate> for AgentBuilder {
442 type Output = Composable;
443
444 fn mul(self, rhs: LoopPredicate) -> Self::Output {
445 Composable::Loop(Loop {
446 body: Box::new(Composable::Agent(self)),
447 max: u32::MAX,
448 until: Some(rhs),
449 })
450 }
451}
452
453impl std::ops::Mul<LoopPredicate> for Composable {
455 type Output = Composable;
456
457 fn mul(self, rhs: LoopPredicate) -> Self::Output {
458 Composable::Loop(Loop {
459 body: Box::new(self),
460 max: u32::MAX,
461 until: Some(rhs),
462 })
463 }
464}
465
466impl std::ops::Div for AgentBuilder {
471 type Output = Composable;
472
473 fn div(self, rhs: AgentBuilder) -> Self::Output {
474 Composable::Fallback(Fallback::new(vec![
475 Composable::Agent(self),
476 Composable::Agent(rhs),
477 ]))
478 }
479}
480
481impl std::ops::Div<AgentBuilder> for Composable {
483 type Output = Composable;
484
485 fn div(self, rhs: AgentBuilder) -> Self::Output {
486 let mut fallback = match self {
487 Composable::Fallback(f) => f,
488 other => Fallback::new(vec![other]),
489 };
490 fallback.push_flat(Composable::Agent(rhs));
491 Composable::Fallback(fallback)
492 }
493}
494
495impl std::ops::Div for Composable {
497 type Output = Composable;
498
499 fn div(self, rhs: Composable) -> Self::Output {
500 let mut fallback = match self {
501 Composable::Fallback(f) => f,
502 other => Fallback::new(vec![other]),
503 };
504 fallback.push_flat(rhs);
505 Composable::Fallback(fallback)
506 }
507}
508
509impl Loop {
512 pub fn builder(_name: &str) -> Self {
520 Self {
521 body: Box::new(Composable::Pipeline(Pipeline::new(Vec::new()))),
522 max: 10,
523 until: None,
524 }
525 }
526
527 pub fn step(mut self, agent: impl Into<Composable>) -> Self {
529 self.body = Box::new(agent.into());
530 self
531 }
532
533 pub fn max_iterations(mut self, n: u32) -> Self {
535 self.max = n;
536 self
537 }
538
539 pub fn max(mut self, max: u32) -> Self {
541 self.max = max;
542 self
543 }
544
545 pub fn describe(self, _desc: &str) -> Self {
547 self
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554
555 fn agent(name: &str) -> AgentBuilder {
556 AgentBuilder::new(name)
557 }
558
559 #[test]
560 fn pipeline_from_shr() {
561 let result = agent("a") >> agent("b");
562 match result {
563 Composable::Pipeline(p) => assert_eq!(p.steps.len(), 2),
564 _ => panic!("expected Pipeline"),
565 }
566 }
567
568 #[test]
569 fn pipeline_flattens() {
570 let result = agent("a") >> agent("b") >> agent("c");
571 match result {
572 Composable::Pipeline(p) => assert_eq!(p.steps.len(), 3),
573 _ => panic!("expected Pipeline"),
574 }
575 }
576
577 #[test]
578 fn fan_out_from_bitor() {
579 let result = agent("a") | agent("b");
580 match result {
581 Composable::FanOut(f) => assert_eq!(f.branches.len(), 2),
582 _ => panic!("expected FanOut"),
583 }
584 }
585
586 #[test]
587 fn fan_out_flattens() {
588 let result = (agent("a") | agent("b")) | agent("c");
589 match result {
590 Composable::FanOut(f) => assert_eq!(f.branches.len(), 3),
591 _ => panic!("expected FanOut"),
592 }
593 }
594
595 #[test]
596 fn fixed_loop_from_mul() {
597 let result = agent("a") * 3;
598 match result {
599 Composable::Loop(l) => {
600 assert_eq!(l.max, 3);
601 assert!(l.until.is_none());
602 }
603 _ => panic!("expected Loop"),
604 }
605 }
606
607 #[test]
608 fn conditional_loop_from_mul_until() {
609 let pred = until(|_v| true);
610 let result = agent("a") * pred;
611 match result {
612 Composable::Loop(l) => {
613 assert_eq!(l.max, u32::MAX);
614 assert!(l.until.is_some());
615 }
616 _ => panic!("expected Loop"),
617 }
618 }
619
620 #[test]
621 fn fallback_from_div() {
622 let result = agent("a") / agent("b");
623 match result {
624 Composable::Fallback(f) => assert_eq!(f.candidates.len(), 2),
625 _ => panic!("expected Fallback"),
626 }
627 }
628
629 #[test]
630 fn fallback_flattens() {
631 let result = (agent("a") / agent("b")) / agent("c");
632 match result {
633 Composable::Fallback(f) => assert_eq!(f.candidates.len(), 3),
634 _ => panic!("expected Fallback"),
635 }
636 }
637
638 #[test]
639 fn mixed_pipeline_with_fan_out() {
640 let result = agent("a") >> (agent("b") | agent("c"));
641 match &result {
642 Composable::Pipeline(p) => {
643 assert_eq!(p.steps.len(), 2);
644 assert!(matches!(&p.steps[1], Composable::FanOut(_)));
645 }
646 _ => panic!("expected Pipeline"),
647 }
648 }
649
650 #[test]
651 fn pipeline_then_loop() {
652 let result = agent("a") >> (agent("b") * 5);
653 match &result {
654 Composable::Pipeline(p) => {
655 assert_eq!(p.steps.len(), 2);
656 assert!(matches!(&p.steps[1], Composable::Loop(_)));
657 }
658 _ => panic!("expected Pipeline"),
659 }
660 }
661
662 #[test]
663 fn loop_predicate_check() {
664 let pred = until(|v| v.get("done").and_then(|v| v.as_bool()).unwrap_or(false));
665 assert!(!pred.check(&serde_json::json!({"done": false})));
666 assert!(pred.check(&serde_json::json!({"done": true})));
667 }
668
669 mod compile_tests {
672 use super::*;
673 use async_trait::async_trait;
674 use gemini_adk_rs::llm::{BaseLlm, LlmError, LlmRequest, LlmResponse};
675 use gemini_genai_rs::prelude::{Content, Part, Role};
676
677 struct NameEchoLlm;
679
680 #[async_trait]
681 impl BaseLlm for NameEchoLlm {
682 fn model_id(&self) -> &str {
683 "name-echo"
684 }
685 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
686 let text = req
687 .system_instruction
688 .unwrap_or_else(|| "no-instruction".into());
689 Ok(LlmResponse {
690 content: Content {
691 role: Some(Role::Model),
692 parts: vec![Part::Text { text }],
693 },
694 finish_reason: Some("STOP".into()),
695 usage: None,
696 })
697 }
698 }
699
700 fn llm() -> Arc<dyn BaseLlm> {
701 Arc::new(NameEchoLlm)
702 }
703
704 #[tokio::test]
705 async fn compile_single_agent() {
706 let composable = Composable::Agent(AgentBuilder::new("solo").instruction("hello"));
707 let agent = composable.compile(llm());
708 let state = gemini_adk_rs::State::new();
709 let result = agent.run(&state).await.unwrap();
710 assert_eq!(result, "hello");
711 }
712
713 #[tokio::test]
714 async fn compile_pipeline() {
715 let pipeline = agent("a").instruction("step-a") >> agent("b").instruction("step-b");
716 let compiled = pipeline.compile(llm());
717 let state = gemini_adk_rs::State::new();
718 let result = compiled.run(&state).await.unwrap();
719 assert_eq!(result, "step-b");
721 }
722
723 #[tokio::test]
724 async fn compile_fan_out() {
725 let fan_out = Composable::Agent(agent("a").instruction("branch-a"))
726 | Composable::Agent(agent("b").instruction("branch-b"));
727 let compiled = fan_out.compile(llm());
728 let state = gemini_adk_rs::State::new();
729 let result = compiled.run(&state).await.unwrap();
730 assert!(result.contains("branch-a"));
731 assert!(result.contains("branch-b"));
732 }
733
734 #[tokio::test]
735 async fn compile_loop() {
736 let looped = agent("counter").instruction("tick") * 3;
737 let compiled = looped.compile(llm());
738 let state = gemini_adk_rs::State::new();
739 let result = compiled.run(&state).await.unwrap();
740 assert_eq!(result, "tick");
741 }
742
743 #[tokio::test]
744 async fn compile_fallback() {
745 let fallback = agent("a").instruction("first") / agent("b").instruction("second");
746 let compiled = fallback.compile(llm());
747 let state = gemini_adk_rs::State::new();
748 let result = compiled.run(&state).await.unwrap();
749 assert_eq!(result, "first");
751 }
752
753 #[tokio::test]
754 async fn compile_loop_with_predicate() {
755 struct IncrementLlm;
757
758 #[async_trait]
759 impl BaseLlm for IncrementLlm {
760 fn model_id(&self) -> &str {
761 "incr"
762 }
763 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
764 Ok(LlmResponse {
765 content: Content {
766 role: Some(Role::Model),
767 parts: vec![Part::Text {
768 text: "done".into(),
769 }],
770 },
771 finish_reason: Some("STOP".into()),
772 usage: None,
773 })
774 }
775 }
776
777 let pred = until(|v| v.get("n").and_then(|v| v.as_i64()).unwrap_or(0) >= 3);
780 let body = agent("incr").instruction("increment");
781 let looped = body * pred;
782
783 let compiled = looped.compile(Arc::new(IncrementLlm));
787 let state = gemini_adk_rs::State::new();
788 state.set("n", 5); let result = compiled.run(&state).await.unwrap();
790 assert_eq!(result, "done"); }
792
793 #[tokio::test]
794 async fn compile_mixed_pipeline_with_fan_out() {
795 let mixed = agent("a").instruction("start")
796 >> (Composable::Agent(agent("b").instruction("left"))
797 | Composable::Agent(agent("c").instruction("right")));
798 let compiled = mixed.compile(llm());
799 let state = gemini_adk_rs::State::new();
800 let result = compiled.run(&state).await.unwrap();
801 assert!(result.contains("left"));
802 assert!(result.contains("right"));
803 }
804 }
805}