gemini_adk_rs/text/
route.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::context::AgentEvent;
7use crate::error::AgentError;
8use crate::middleware::MiddlewareChain;
9use crate::state::State;
10
11pub struct RouteRule {
13 predicate: Box<dyn Fn(&State) -> bool + Send + Sync>,
14 agent: Arc<dyn TextAgent>,
15}
16
17impl RouteRule {
18 pub fn new(
20 predicate: impl Fn(&State) -> bool + Send + Sync + 'static,
21 agent: Arc<dyn TextAgent>,
22 ) -> Self {
23 Self {
24 predicate: Box::new(predicate),
25 agent,
26 }
27 }
28}
29
30pub struct RouteTextAgent {
33 name: String,
34 rules: Vec<RouteRule>,
35 default: Arc<dyn TextAgent>,
36 middleware: MiddlewareChain,
37}
38
39impl RouteTextAgent {
40 pub fn new(
42 name: impl Into<String>,
43 rules: Vec<RouteRule>,
44 default: Arc<dyn TextAgent>,
45 ) -> Self {
46 Self {
47 name: name.into(),
48 rules,
49 default,
50 middleware: MiddlewareChain::new(),
51 }
52 }
53
54 pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
57 self.middleware = chain;
58 self
59 }
60}
61
62#[async_trait]
63impl TextAgent for RouteTextAgent {
64 fn name(&self) -> &str {
65 &self.name
66 }
67
68 async fn run(&self, state: &State) -> Result<String, AgentError> {
69 for rule in &self.rules {
70 if (rule.predicate)(state) {
71 let _ = self
72 .middleware
73 .run_on_event(&AgentEvent::RouteSelected {
74 agent_name: rule.agent.name().to_string(),
75 })
76 .await;
77 return rule.agent.run(state).await;
78 }
79 }
80 let _ = self
81 .middleware
82 .run_on_event(&AgentEvent::RouteSelected {
83 agent_name: self.default.name().to_string(),
84 })
85 .await;
86 self.default.run(state).await
87 }
88}