1use std::sync::Arc;
7
8use gemini_adk_rs::llm::BaseLlm;
9use gemini_adk_rs::middleware::Middleware;
10use gemini_adk_rs::text::{LlmTextAgent, TextAgent};
11use gemini_adk_rs::tool::{ToolDispatcher, ToolFunction, ToolKind};
12use gemini_genai_rs::prelude::{GeminiModel, Modality, Tool, Voice};
13
14use crate::compose::context::ContextPolicyChain;
15use crate::compose::guards::GComposite;
16use crate::compose::middleware::MiddlewareComposite;
17use crate::compose::tools::ToolComposite;
18
19#[derive(Clone)]
21struct AgentBuilderInner {
22 name: String,
23 model: Option<GeminiModel>,
24 instruction: Option<String>,
25 voice: Option<Voice>,
26 temperature: Option<f32>,
27 top_p: Option<f32>,
28 top_k: Option<u32>,
29 max_output_tokens: Option<u32>,
30 stop_sequences: Vec<String>,
31 response_modalities: Option<Vec<Modality>>,
32 thinking_budget: Option<u32>,
33 tools: Vec<ToolEntry>,
34 built_in_tools: Vec<Tool>,
35 writes: Vec<String>,
36 reads: Vec<String>,
37 sub_agents: Vec<AgentBuilder>,
38 isolate: bool,
39 stay: bool,
40 description: Option<String>,
41 output_schema: Option<serde_json::Value>,
42 output_key: Option<String>,
43 transfer_to_agent: Option<String>,
44 middleware_layers: Vec<Arc<dyn Middleware>>,
46}
47
48#[derive(Clone)]
50pub enum ToolEntry {
51 Runtime(Arc<dyn ToolEntryTrait>),
53 Declaration(Tool),
55}
56
57pub trait ToolEntryTrait: Send + Sync + 'static {
59 fn name(&self) -> &str;
61 fn to_tool_kind(&self) -> ToolKind;
63}
64
65pub type Agent = AgentBuilder;
67
68#[derive(Clone)]
149pub struct AgentBuilder {
150 inner: Arc<AgentBuilderInner>,
151}
152
153impl AgentBuilder {
154 pub fn new(name: impl Into<String>) -> Self {
156 Self {
157 inner: Arc::new(AgentBuilderInner {
158 name: name.into(),
159 model: None,
160 instruction: None,
161 voice: None,
162 temperature: None,
163 top_p: None,
164 top_k: None,
165 max_output_tokens: None,
166 stop_sequences: Vec::new(),
167 response_modalities: None,
168 thinking_budget: None,
169 tools: Vec::new(),
170 built_in_tools: Vec::new(),
171 writes: Vec::new(),
172 reads: Vec::new(),
173 sub_agents: Vec::new(),
174 isolate: false,
175 stay: false,
176 description: None,
177 output_schema: None,
178 output_key: None,
179 transfer_to_agent: None,
180 middleware_layers: Vec::new(),
181 }),
182 }
183 }
184
185 fn mutate(&self) -> AgentBuilderInner {
188 (*self.inner).clone()
189 }
190
191 fn with(inner: AgentBuilderInner) -> Self {
192 Self {
193 inner: Arc::new(inner),
194 }
195 }
196
197 pub fn name(&self) -> &str {
201 &self.inner.name
202 }
203
204 pub fn get_model(&self) -> Option<&GeminiModel> {
206 self.inner.model.as_ref()
207 }
208
209 pub fn get_instruction(&self) -> Option<&str> {
211 self.inner.instruction.as_deref()
212 }
213
214 pub fn get_voice(&self) -> Option<&Voice> {
216 self.inner.voice.as_ref()
217 }
218
219 pub fn get_temperature(&self) -> Option<f32> {
221 self.inner.temperature
222 }
223
224 pub fn is_text_only(&self) -> bool {
226 self.inner
227 .response_modalities
228 .as_ref()
229 .map(|m| m == &[Modality::Text])
230 .unwrap_or(false)
231 }
232
233 pub fn get_thinking_budget(&self) -> Option<u32> {
235 self.inner.thinking_budget
236 }
237
238 pub fn get_writes(&self) -> &[String] {
240 &self.inner.writes
241 }
242
243 pub fn get_reads(&self) -> &[String] {
245 &self.inner.reads
246 }
247
248 pub fn get_sub_agents(&self) -> &[AgentBuilder] {
250 &self.inner.sub_agents
251 }
252
253 pub fn is_isolated(&self) -> bool {
255 self.inner.isolate
256 }
257
258 pub fn is_stay(&self) -> bool {
260 self.inner.stay
261 }
262
263 pub fn tool_count(&self) -> usize {
265 self.inner.tools.len() + self.inner.built_in_tools.len()
266 }
267
268 pub fn get_top_p(&self) -> Option<f32> {
270 self.inner.top_p
271 }
272
273 pub fn get_top_k(&self) -> Option<u32> {
275 self.inner.top_k
276 }
277
278 pub fn get_max_output_tokens(&self) -> Option<u32> {
280 self.inner.max_output_tokens
281 }
282
283 pub fn get_stop_sequences(&self) -> &[String] {
285 &self.inner.stop_sequences
286 }
287
288 pub fn get_description(&self) -> Option<&str> {
290 self.inner.description.as_deref()
291 }
292
293 pub fn get_output_schema(&self) -> Option<&serde_json::Value> {
295 self.inner.output_schema.as_ref()
296 }
297
298 pub fn get_output_key(&self) -> Option<&str> {
300 self.inner.output_key.as_deref()
301 }
302
303 pub fn get_transfer_to(&self) -> Option<&str> {
305 self.inner.transfer_to_agent.as_deref()
306 }
307
308 pub fn middleware_layer_count(&self) -> usize {
310 self.inner.middleware_layers.len()
311 }
312
313 pub fn model(self, model: GeminiModel) -> Self {
317 let mut inner = self.mutate();
318 inner.model = Some(model);
319 Self::with(inner)
320 }
321
322 pub fn instruction(self, inst: impl Into<String>) -> Self {
324 let mut inner = self.mutate();
325 inner.instruction = Some(inst.into());
326 Self::with(inner)
327 }
328
329 pub fn voice(self, voice: Voice) -> Self {
331 let mut inner = self.mutate();
332 inner.voice = Some(voice);
333 Self::with(inner)
334 }
335
336 pub fn temperature(self, t: f32) -> Self {
338 let mut inner = self.mutate();
339 inner.temperature = Some(t);
340 Self::with(inner)
341 }
342
343 pub fn text_only(self) -> Self {
345 let mut inner = self.mutate();
346 inner.response_modalities = Some(vec![Modality::Text]);
347 Self::with(inner)
348 }
349
350 pub fn response_modalities(self, modalities: Vec<Modality>) -> Self {
352 let mut inner = self.mutate();
353 inner.response_modalities = Some(modalities);
354 Self::with(inner)
355 }
356
357 pub fn thinking(self, budget: u32) -> Self {
359 let mut inner = self.mutate();
360 inner.thinking_budget = Some(budget);
361 Self::with(inner)
362 }
363
364 pub fn url_context(self) -> Self {
366 let mut inner = self.mutate();
367 inner.built_in_tools.push(Tool::url_context());
368 Self::with(inner)
369 }
370
371 pub fn google_search(self) -> Self {
373 let mut inner = self.mutate();
374 inner.built_in_tools.push(Tool::google_search());
375 Self::with(inner)
376 }
377
378 pub fn code_execution(self) -> Self {
380 let mut inner = self.mutate();
381 inner.built_in_tools.push(Tool::code_execution());
382 Self::with(inner)
383 }
384
385 pub fn writes(self, key: impl Into<String>) -> Self {
387 let mut inner = self.mutate();
388 inner.writes.push(key.into());
389 Self::with(inner)
390 }
391
392 pub fn reads(self, key: impl Into<String>) -> Self {
394 let mut inner = self.mutate();
395 inner.reads.push(key.into());
396 Self::with(inner)
397 }
398
399 pub fn sub_agent(self, agent: AgentBuilder) -> Self {
401 let mut inner = self.mutate();
402 inner.sub_agents.push(agent);
403 Self::with(inner)
404 }
405
406 pub fn isolate(self) -> Self {
408 let mut inner = self.mutate();
409 inner.isolate = true;
410 Self::with(inner)
411 }
412
413 pub fn stay(self) -> Self {
415 let mut inner = self.mutate();
416 inner.stay = true;
417 Self::with(inner)
418 }
419
420 pub fn top_p(self, p: f32) -> Self {
422 let mut inner = self.mutate();
423 inner.top_p = Some(p);
424 Self::with(inner)
425 }
426
427 pub fn top_k(self, k: u32) -> Self {
429 let mut inner = self.mutate();
430 inner.top_k = Some(k);
431 Self::with(inner)
432 }
433
434 pub fn max_output_tokens(self, n: u32) -> Self {
436 let mut inner = self.mutate();
437 inner.max_output_tokens = Some(n);
438 Self::with(inner)
439 }
440
441 pub fn stop_sequences(self, seqs: Vec<String>) -> Self {
443 let mut inner = self.mutate();
444 inner.stop_sequences = seqs;
445 Self::with(inner)
446 }
447
448 pub fn description(self, desc: impl Into<String>) -> Self {
450 let mut inner = self.mutate();
451 inner.description = Some(desc.into());
452 Self::with(inner)
453 }
454
455 pub fn output_schema(self, schema: serde_json::Value) -> Self {
457 let mut inner = self.mutate();
458 inner.output_schema = Some(schema);
459 Self::with(inner)
460 }
461
462 pub fn output_key(self, key: impl Into<String>) -> Self {
464 let mut inner = self.mutate();
465 inner.output_key = Some(key.into());
466 Self::with(inner)
467 }
468
469 pub fn transfer_to(self, agent_name: impl Into<String>) -> Self {
471 let mut inner = self.mutate();
472 inner.transfer_to_agent = Some(agent_name.into());
473 Self::with(inner)
474 }
475
476 pub fn instruct(self, inst: impl Into<String>) -> Self {
480 self.instruction(inst)
481 }
482
483 pub fn describe(self, desc: impl Into<String>) -> Self {
485 self.description(desc)
486 }
487
488 pub fn tool(self, f: Arc<dyn ToolFunction>) -> Self {
494 let mut inner = self.mutate();
495 inner
496 .tools
497 .push(ToolEntry::Runtime(Arc::new(ToolFunctionEntry(f))));
498 Self::with(inner)
499 }
500
501 pub fn tools(self, composite: ToolComposite) -> Self {
509 use crate::compose::tools::{DeferredTool, ToolResolution};
510 let mut inner = self.mutate();
511 for entry in composite.entries {
512 match entry.classify() {
513 ToolResolution::Runtime(f) => {
514 inner
515 .tools
516 .push(ToolEntry::Runtime(Arc::new(ToolFunctionEntry(f))));
517 }
518 ToolResolution::BuiltIn(t) => {
519 inner.built_in_tools.push(t);
520 }
521 ToolResolution::Agent {
522 name,
523 description,
524 agent,
525 } => {
526 let tool = gemini_adk_rs::TextAgentTool::from_arc(
528 name,
529 description,
530 agent,
531 gemini_adk_rs::State::new(),
532 );
533 inner
534 .tools
535 .push(ToolEntry::Runtime(Arc::new(ToolFunctionEntry(Arc::new(
536 tool,
537 )))));
538 }
539 ToolResolution::Deferred(deferred) => {
540 let kind = match deferred {
545 DeferredTool::Mcp { .. } => "T::mcp",
546 DeferredTool::A2a { .. } => "T::a2a",
547 DeferredTool::OpenApi { .. } => "T::openapi",
548 DeferredTool::Search { .. } => "T::search",
549 };
550 tracing::warn!(
551 tool = kind,
552 "ignoring async-resolved tool on a text AgentBuilder: {kind} \
553 requires a Live session (async connect); attach it via Live::with_tools"
554 );
555 }
556 }
557 }
558 Self::with(inner)
559 }
560
561 pub fn guard(self, guard: impl Into<GComposite>) -> Self {
575 let mut inner = self.mutate();
576 inner.middleware_layers.push(guard.into().into_middleware());
577 Self::with(inner)
578 }
579
580 pub fn context(self, policy: impl Into<ContextPolicyChain>) -> Self {
592 let mut inner = self.mutate();
593 inner
594 .middleware_layers
595 .push(policy.into().into_middleware());
596 Self::with(inner)
597 }
598
599 pub fn no_peers(self) -> Self {
601 self.isolate()
602 }
603
604 pub fn middleware(self, composite: MiddlewareComposite) -> Self {
620 let mut inner = self.mutate();
621 inner.middleware_layers.extend(composite.layers);
622 Self::with(inner)
623 }
624
625 pub fn build(self, llm: Arc<dyn BaseLlm>) -> Arc<dyn TextAgent> {
642 let mut agent = LlmTextAgent::new(&self.inner.name, llm);
643
644 if let Some(inst) = &self.inner.instruction {
645 agent = agent.instruction(inst);
646 }
647 if let Some(t) = self.inner.temperature {
648 agent = agent.temperature(t);
649 }
650 if let Some(n) = self.inner.max_output_tokens {
651 agent = agent.max_output_tokens(n);
652 }
653
654 if !self.inner.tools.is_empty() {
656 let mut dispatcher = ToolDispatcher::new();
657 for entry in &self.inner.tools {
658 match entry {
659 ToolEntry::Runtime(t) => {
660 let kind = t.to_tool_kind();
661 match kind {
662 ToolKind::Function(f) => dispatcher.register_function(f),
663 ToolKind::Streaming(s) => dispatcher.register_streaming(s),
664 ToolKind::InputStream(i) => dispatcher.register_input_streaming(i),
665 }
666 }
667 ToolEntry::Declaration(_) => {
668 }
671 }
672 }
673 if !dispatcher.is_empty() {
674 agent = agent.tools(Arc::new(dispatcher));
675 }
676 }
677
678 for mw in &self.inner.middleware_layers {
680 agent = agent.add_middleware(mw.clone());
681 }
682
683 Arc::new(agent)
684 }
685}
686
687#[derive(Clone)]
689struct ToolFunctionEntry(Arc<dyn ToolFunction>);
690
691impl ToolEntryTrait for ToolFunctionEntry {
692 fn name(&self) -> &str {
693 self.0.name()
694 }
695
696 fn to_tool_kind(&self) -> ToolKind {
697 ToolKind::Function(self.0.clone())
698 }
699}
700
701impl std::fmt::Debug for AgentBuilder {
702 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
703 f.debug_struct("AgentBuilder")
704 .field("name", &self.inner.name)
705 .field("model", &self.inner.model)
706 .field("instruction", &self.inner.instruction)
707 .field("temperature", &self.inner.temperature)
708 .field("text_only", &self.is_text_only())
709 .field("tool_count", &self.tool_count())
710 .field("sub_agents", &self.inner.sub_agents.len())
711 .finish()
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718 use async_trait::async_trait;
719 use gemini_adk_rs::llm::{LlmError, LlmRequest, LlmResponse};
720 use gemini_genai_rs::prelude::{Content, Part, Role};
721
722 struct MockLlm(String);
724
725 #[async_trait]
726 impl BaseLlm for MockLlm {
727 fn model_id(&self) -> &str {
728 "mock"
729 }
730 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
731 Ok(LlmResponse {
732 content: Content {
733 role: Some(Role::Model),
734 parts: vec![Part::Text {
735 text: self.0.clone(),
736 }],
737 },
738 finish_reason: Some("STOP".into()),
739 usage: None,
740 })
741 }
742 }
743
744 #[test]
745 fn builder_creates_with_name() {
746 let b = AgentBuilder::new("test-agent");
747 assert_eq!(b.name(), "test-agent");
748 }
749
750 #[test]
751 fn fluent_chaining_works() {
752 let b = AgentBuilder::new("agent")
753 .instruction("Be helpful")
754 .temperature(0.7)
755 .model(GeminiModel::Gemini2_0FlashLive);
756
757 assert_eq!(b.get_instruction(), Some("Be helpful"));
758 assert_eq!(b.get_temperature(), Some(0.7));
759 assert_eq!(b.get_model(), Some(&GeminiModel::Gemini2_0FlashLive));
760 }
761
762 #[test]
763 fn copy_on_write_clone_independence() {
764 let base = AgentBuilder::new("base").temperature(0.5);
765 let variant = base.clone().temperature(0.9);
766
767 assert_eq!(base.get_temperature(), Some(0.5));
769 assert_eq!(variant.get_temperature(), Some(0.9));
771 }
772
773 #[test]
774 fn text_only_sets_modalities() {
775 let b = AgentBuilder::new("text").text_only();
776 assert!(b.is_text_only());
777 }
778
779 #[test]
780 fn url_context_adds_tool() {
781 let b = AgentBuilder::new("search").url_context();
782 assert_eq!(b.tool_count(), 1);
783 }
784
785 #[test]
786 fn google_search_adds_tool() {
787 let b = AgentBuilder::new("search").google_search();
788 assert_eq!(b.tool_count(), 1);
789 }
790
791 #[test]
792 fn code_execution_adds_tool() {
793 let b = AgentBuilder::new("code").code_execution();
794 assert_eq!(b.tool_count(), 1);
795 }
796
797 #[test]
798 fn thinking_sets_budget() {
799 let b = AgentBuilder::new("thinker").thinking(2048);
800 assert_eq!(b.get_thinking_budget(), Some(2048));
801 }
802
803 #[test]
804 fn writes_and_reads_keys() {
805 let b = AgentBuilder::new("data").writes("output").reads("input");
806 assert_eq!(b.get_writes(), &["output"]);
807 assert_eq!(b.get_reads(), &["input"]);
808 }
809
810 #[test]
811 fn sub_agent_registration() {
812 let child = AgentBuilder::new("child");
813 let parent = AgentBuilder::new("parent").sub_agent(child);
814 assert_eq!(parent.get_sub_agents().len(), 1);
815 assert_eq!(parent.get_sub_agents()[0].name(), "child");
816 }
817
818 #[test]
819 fn isolate_and_stay() {
820 let b = AgentBuilder::new("agent").isolate().stay();
821 assert!(b.is_isolated());
822 assert!(b.is_stay());
823 }
824
825 #[test]
826 fn debug_display() {
827 let b = AgentBuilder::new("debug-test");
828 let debug = format!("{:?}", b);
829 assert!(debug.contains("debug-test"));
830 }
831
832 #[test]
833 fn top_p_sets_value() {
834 let b = AgentBuilder::new("agent").top_p(0.95);
835 assert_eq!(b.get_top_p(), Some(0.95));
836 }
837
838 #[test]
839 fn top_k_sets_value() {
840 let b = AgentBuilder::new("agent").top_k(40);
841 assert_eq!(b.get_top_k(), Some(40));
842 }
843
844 #[test]
845 fn max_output_tokens_sets_value() {
846 let b = AgentBuilder::new("agent").max_output_tokens(4096);
847 assert_eq!(b.get_max_output_tokens(), Some(4096));
848 }
849
850 #[test]
851 fn stop_sequences_sets_value() {
852 let b =
853 AgentBuilder::new("agent").stop_sequences(vec!["END".to_string(), "STOP".to_string()]);
854 assert_eq!(b.get_stop_sequences().len(), 2);
855 }
856
857 #[test]
858 fn description_sets_value() {
859 let b = AgentBuilder::new("agent").description("A helpful agent");
860 assert_eq!(b.get_description(), Some("A helpful agent"));
861 }
862
863 #[test]
864 fn output_schema_sets_value() {
865 let schema = serde_json::json!({"type": "object"});
866 let b = AgentBuilder::new("agent").output_schema(schema.clone());
867 assert_eq!(b.get_output_schema(), Some(&schema));
868 }
869
870 #[test]
871 fn transfer_to_sets_value() {
872 let b = AgentBuilder::new("agent").transfer_to("target-agent");
873 assert_eq!(b.get_transfer_to(), Some("target-agent"));
874 }
875
876 #[test]
877 fn full_fluent_chain() {
878 let b = AgentBuilder::new("full-agent")
879 .model(GeminiModel::Gemini2_0FlashLive)
880 .instruction("Be helpful")
881 .temperature(0.7)
882 .top_p(0.95)
883 .top_k(40)
884 .max_output_tokens(4096)
885 .thinking(2048)
886 .description("A fully configured agent")
887 .google_search()
888 .writes("output")
889 .reads("input");
890
891 assert_eq!(b.name(), "full-agent");
892 assert_eq!(b.get_temperature(), Some(0.7));
893 assert_eq!(b.get_top_p(), Some(0.95));
894 assert_eq!(b.get_top_k(), Some(40));
895 assert_eq!(b.get_max_output_tokens(), Some(4096));
896 assert_eq!(b.get_thinking_budget(), Some(2048));
897 assert_eq!(b.get_description(), Some("A fully configured agent"));
898 assert_eq!(b.tool_count(), 1);
899 }
900
901 #[tokio::test]
904 async fn build_produces_executable_agent() {
905 let llm: Arc<dyn BaseLlm> = Arc::new(MockLlm("built agent output".into()));
906 let agent = AgentBuilder::new("test")
907 .instruction("Be helpful")
908 .temperature(0.5)
909 .build(llm);
910
911 assert_eq!(agent.name(), "test");
912 let state = gemini_adk_rs::State::new();
913 let result = agent.run(&state).await.unwrap();
914 assert_eq!(result, "built agent output");
915 }
916
917 #[tokio::test]
918 async fn build_stores_output_in_state() {
919 let llm: Arc<dyn BaseLlm> = Arc::new(MockLlm("state output".into()));
920 let agent = AgentBuilder::new("test").build(llm);
921 let state = gemini_adk_rs::State::new();
922 agent.run(&state).await.unwrap();
923 assert_eq!(state.get::<String>("output"), Some("state output".into()));
924 }
925
926 #[tokio::test]
927 async fn build_reads_input_from_state() {
928 use gemini_adk_rs::llm::LlmRequest;
929
930 struct EchoLlm;
932 #[async_trait]
933 impl BaseLlm for EchoLlm {
934 fn model_id(&self) -> &str {
935 "echo"
936 }
937 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
938 let text: String = req
939 .contents
940 .iter()
941 .flat_map(|c| &c.parts)
942 .filter_map(|p| match p {
943 Part::Text { text } => Some(text.as_str()),
944 _ => None,
945 })
946 .collect::<Vec<_>>()
947 .join("");
948 Ok(LlmResponse {
949 content: Content {
950 role: Some(Role::Model),
951 parts: vec![Part::Text { text }],
952 },
953 finish_reason: Some("STOP".into()),
954 usage: None,
955 })
956 }
957 }
958
959 let agent = AgentBuilder::new("echo").build(Arc::new(EchoLlm));
960 let state = gemini_adk_rs::State::new();
961 let _ = state.set("input", "hello from state");
962 let result = agent.run(&state).await.unwrap();
963 assert!(result.contains("hello from state"));
964 }
965
966 struct ToolCallingMockLlm {
970 tool_name: &'static str,
971 final_text: &'static str,
972 }
973
974 #[async_trait]
975 impl BaseLlm for ToolCallingMockLlm {
976 fn model_id(&self) -> &str {
977 "tool-mock"
978 }
979
980 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
981 use gemini_genai_rs::prelude::FunctionCall;
982
983 let already_responded = req
985 .contents
986 .iter()
987 .flat_map(|c| &c.parts)
988 .any(|p| matches!(p, Part::FunctionResponse { .. }));
989
990 if already_responded {
991 Ok(LlmResponse {
992 content: Content {
993 role: Some(Role::Model),
994 parts: vec![Part::Text {
995 text: self.final_text.to_string(),
996 }],
997 },
998 finish_reason: Some("STOP".into()),
999 usage: None,
1000 })
1001 } else {
1002 Ok(LlmResponse {
1003 content: Content {
1004 role: Some(Role::Model),
1005 parts: vec![Part::FunctionCall {
1006 function_call: FunctionCall {
1007 name: self.tool_name.to_string(),
1008 args: serde_json::json!({"x": 1}),
1009 id: Some("call-1".into()),
1010 },
1011 }],
1012 },
1013 finish_reason: None,
1014 usage: None,
1015 })
1016 }
1017 }
1018 }
1019
1020 #[tokio::test]
1022 async fn middleware_hooks_fire_end_to_end() {
1023 use crate::compose::middleware::M;
1024 use gemini_adk_rs::tool::SimpleTool;
1025 use std::sync::atomic::{AtomicUsize, Ordering};
1026
1027 let before_model_count = Arc::new(AtomicUsize::new(0));
1028 let after_tool_count = Arc::new(AtomicUsize::new(0));
1029
1030 let bm = before_model_count.clone();
1031 let at = after_tool_count.clone();
1032
1033 let mw = M::before_model(move |_req| {
1034 bm.fetch_add(1, Ordering::SeqCst);
1035 Ok(())
1036 }) | M::after_tool(move |_call, _result| {
1037 at.fetch_add(1, Ordering::SeqCst);
1038 Ok(())
1039 });
1040
1041 let llm: Arc<dyn BaseLlm> = Arc::new(ToolCallingMockLlm {
1042 tool_name: "echo_tool",
1043 final_text: "done",
1044 });
1045
1046 let agent = AgentBuilder::new("mw-test")
1047 .middleware(mw)
1048 .tool(Arc::new(SimpleTool::new(
1049 "echo_tool",
1050 "Echo tool",
1051 None,
1052 |_args| async move { Ok(serde_json::json!({"echo": true})) },
1053 )))
1054 .build(llm);
1055
1056 let state = gemini_adk_rs::State::new();
1057 let result = agent.run(&state).await.unwrap();
1058 assert_eq!(result, "done");
1059
1060 assert_eq!(
1062 before_model_count.load(Ordering::SeqCst),
1063 2,
1064 "before_model should fire for each generate() call"
1065 );
1066 assert_eq!(
1068 after_tool_count.load(Ordering::SeqCst),
1069 1,
1070 "after_tool should fire once for the tool dispatch"
1071 );
1072 }
1073
1074 #[test]
1076 fn middleware_copy_on_write() {
1077 use crate::compose::middleware::M;
1078
1079 let base = AgentBuilder::new("base").instruction("base");
1080 let with_mw = base.clone().middleware(M::log() | M::latency());
1081
1082 assert_eq!(base.middleware_layer_count(), 0);
1084 assert_eq!(with_mw.middleware_layer_count(), 2);
1086 }
1087
1088 #[tokio::test]
1090 async fn middleware_on_error_fires_on_failure() {
1091 use crate::compose::middleware::M;
1092 use gemini_adk_rs::llm::LlmError;
1093 use std::sync::atomic::{AtomicUsize, Ordering};
1094
1095 let error_count = Arc::new(AtomicUsize::new(0));
1096 let ec = error_count.clone();
1097
1098 let mw = M::on_error(move |_err| {
1099 ec.fetch_add(1, Ordering::SeqCst);
1100 Ok(())
1101 });
1102
1103 struct FailLlm;
1104 #[async_trait]
1105 impl BaseLlm for FailLlm {
1106 fn model_id(&self) -> &str {
1107 "fail"
1108 }
1109 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
1110 Err(LlmError::RequestFailed("boom".into()))
1111 }
1112 }
1113
1114 let agent = AgentBuilder::new("error-test")
1115 .middleware(mw)
1116 .build(Arc::new(FailLlm));
1117
1118 let state = gemini_adk_rs::State::new();
1119 let result = agent.run(&state).await;
1120 assert!(result.is_err(), "agent should fail");
1121 assert_eq!(
1122 error_count.load(Ordering::SeqCst),
1123 1,
1124 "on_error should fire exactly once"
1125 );
1126 }
1127
1128 struct RecordingLlm {
1133 text: &'static str,
1134 seen_len: Arc<std::sync::atomic::AtomicUsize>,
1135 }
1136
1137 #[async_trait]
1138 impl BaseLlm for RecordingLlm {
1139 fn model_id(&self) -> &str {
1140 "recording-mock"
1141 }
1142
1143 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
1144 self.seen_len
1145 .store(req.contents.len(), std::sync::atomic::Ordering::SeqCst);
1146 Ok(LlmResponse {
1147 content: Content {
1148 role: Some(Role::Model),
1149 parts: vec![Part::Text {
1150 text: self.text.to_string(),
1151 }],
1152 },
1153 finish_reason: Some("STOP".into()),
1154 usage: None,
1155 })
1156 }
1157 }
1158
1159 #[tokio::test]
1160 async fn guard_blocks_violating_output() {
1161 use crate::compose::guards::G;
1162
1163 let llm: Arc<dyn BaseLlm> = Arc::new(RecordingLlm {
1164 text: "you can reach me at agent@example.com",
1165 seen_len: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1166 });
1167
1168 let agent = AgentBuilder::new("guarded").guard(G::pii()).build(llm);
1169
1170 let state = gemini_adk_rs::State::new();
1171 let err = agent.run(&state).await.unwrap_err();
1172 assert!(
1173 err.to_string().contains("guard violation"),
1174 "PII guard should veto the response, got: {err}"
1175 );
1176 }
1177
1178 #[tokio::test]
1179 async fn guard_allows_clean_output() {
1180 use crate::compose::guards::G;
1181
1182 let llm: Arc<dyn BaseLlm> = Arc::new(RecordingLlm {
1183 text: "all clean here",
1184 seen_len: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1185 });
1186
1187 let agent = AgentBuilder::new("guarded")
1188 .guard(G::pii() | G::length(1, 1000))
1189 .build(llm);
1190
1191 let state = gemini_adk_rs::State::new();
1192 let result = agent.run(&state).await.unwrap();
1193 assert_eq!(result, "all clean here");
1194 }
1195
1196 #[tokio::test]
1197 async fn context_policy_rewrites_request_history() {
1198 use crate::compose::context::C;
1199
1200 let seen = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1203 let llm: Arc<dyn BaseLlm> = Arc::new(RecordingLlm {
1204 text: "ok",
1205 seen_len: seen.clone(),
1206 });
1207
1208 let agent = AgentBuilder::new("ctx")
1209 .context(C::prepend(Content::user("system preamble")))
1210 .build(llm);
1211
1212 let state = gemini_adk_rs::State::new();
1213 let _ = state.set("input", "hello");
1214 let _ = agent.run(&state).await.unwrap();
1215 assert_eq!(
1216 seen.load(std::sync::atomic::Ordering::SeqCst),
1217 2,
1218 "context policy should have prepended a turn before the model call"
1219 );
1220 }
1221
1222 #[tokio::test]
1223 async fn context_window_trims_history() {
1224 use crate::compose::context::C;
1225
1226 let seen = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1229 let llm: Arc<dyn BaseLlm> = Arc::new(RecordingLlm {
1230 text: "ok",
1231 seen_len: seen.clone(),
1232 });
1233
1234 let agent = AgentBuilder::new("ctx")
1235 .context(C::prepend(Content::user("a")) + C::prepend(Content::user("b")) + C::window(1))
1236 .build(llm);
1237
1238 let state = gemini_adk_rs::State::new();
1239 let _ = state.set("input", "hello");
1240 let _ = agent.run(&state).await.unwrap();
1241 assert_eq!(
1242 seen.load(std::sync::atomic::Ordering::SeqCst),
1243 1,
1244 "window(1) should trim history to the last turn"
1245 );
1246 }
1247}