gemini_adk_fluent_rs/
builder.rs

1//! AgentBuilder — copy-on-write immutable builder for fluent agent construction.
2//!
3//! Every mutation returns a new builder (original unchanged), so builders
4//! are safely shareable as templates.
5
6use std::sync::Arc;
7
8use gemini_adk_rs::llm::BaseLlm;
9use gemini_adk_rs::middleware::Middleware;
10use gemini_adk_rs::text::{LlmTextAgent, TextAgent};
11use gemini_adk_rs::tool::{ToolDispatcher, ToolFunction, ToolKind};
12use gemini_genai_rs::prelude::{GeminiModel, Modality, Tool, Voice};
13
14use crate::compose::context::ContextPolicyChain;
15use crate::compose::guards::GComposite;
16use crate::compose::middleware::MiddlewareComposite;
17use crate::compose::tools::ToolComposite;
18
19/// Inner state of an AgentBuilder — shared via Arc for copy-on-write.
20#[derive(Clone)]
21struct AgentBuilderInner {
22    name: String,
23    model: Option<GeminiModel>,
24    instruction: Option<String>,
25    voice: Option<Voice>,
26    temperature: Option<f32>,
27    top_p: Option<f32>,
28    top_k: Option<u32>,
29    max_output_tokens: Option<u32>,
30    stop_sequences: Vec<String>,
31    response_modalities: Option<Vec<Modality>>,
32    thinking_budget: Option<u32>,
33    tools: Vec<ToolEntry>,
34    built_in_tools: Vec<Tool>,
35    writes: Vec<String>,
36    reads: Vec<String>,
37    sub_agents: Vec<AgentBuilder>,
38    isolate: bool,
39    stay: bool,
40    description: Option<String>,
41    output_schema: Option<serde_json::Value>,
42    output_key: Option<String>,
43    transfer_to_agent: Option<String>,
44    /// Middleware layers to install on the compiled `LlmTextAgent`.
45    middleware_layers: Vec<Arc<dyn Middleware>>,
46}
47
48/// An entry in the builder's tool list — either a runtime ToolKind or a declaration.
49#[derive(Clone)]
50pub enum ToolEntry {
51    /// A runtime tool with a handler function.
52    Runtime(Arc<dyn ToolEntryTrait>),
53    /// A wire-level tool declaration (e.g., built-in tools like Google Search).
54    Declaration(Tool),
55}
56
57/// Trait for tool entries that can provide a name (for dedup/inspection).
58pub trait ToolEntryTrait: Send + Sync + 'static {
59    /// The tool's registered name.
60    fn name(&self) -> &str;
61    /// Convert this entry into the runtime `ToolKind` variant for dispatch.
62    fn to_tool_kind(&self) -> ToolKind;
63}
64
65/// Alias for [`AgentBuilder`] — matches upstream Python `Agent("name")` naming.
66pub type Agent = AgentBuilder;
67
68/// Copy-on-write immutable builder for agent construction.
69///
70/// Every setter returns a new `AgentBuilder`, leaving the original unchanged.
71/// This makes builders safe to share as templates.
72///
73/// # Basic Usage
74///
75/// ```rust
76/// use gemini_adk_fluent_rs::builder::AgentBuilder;
77/// use gemini_genai_rs::prelude::GeminiModel;
78///
79/// let agent = AgentBuilder::new("analyst")
80///     .model(GeminiModel::Gemini2_0FlashLive)
81///     .instruction("Analyze the given topic")
82///     .temperature(0.3);
83///
84/// assert_eq!(agent.name(), "analyst");
85/// assert_eq!(agent.get_temperature(), Some(0.3));
86/// ```
87///
88/// # Copy-on-Write Pattern
89///
90/// Cloning a builder and modifying the clone leaves the original unchanged.
91/// This is useful for creating template builders with shared defaults.
92///
93/// ```rust
94/// use gemini_adk_fluent_rs::builder::AgentBuilder;
95///
96/// let base = AgentBuilder::new("researcher")
97///     .instruction("You are a research assistant.")
98///     .temperature(0.5);
99///
100/// let creative = base.clone().temperature(0.9);
101/// let precise  = base.clone().temperature(0.1);
102///
103/// // Original unchanged
104/// assert_eq!(base.get_temperature(), Some(0.5));
105/// assert_eq!(creative.get_temperature(), Some(0.9));
106/// assert_eq!(precise.get_temperature(), Some(0.1));
107/// ```
108///
109/// # Sampling Parameters
110///
111/// ```rust
112/// use gemini_adk_fluent_rs::builder::AgentBuilder;
113///
114/// let agent = AgentBuilder::new("sampler")
115///     .temperature(0.7)
116///     .top_p(0.95)
117///     .top_k(40)
118///     .max_output_tokens(4096);
119///
120/// assert_eq!(agent.get_top_p(), Some(0.95));
121/// assert_eq!(agent.get_top_k(), Some(40));
122/// assert_eq!(agent.get_max_output_tokens(), Some(4096));
123/// ```
124///
125/// # Built-in Tools
126///
127/// ```rust
128/// use gemini_adk_fluent_rs::builder::AgentBuilder;
129///
130/// let agent = AgentBuilder::new("searcher")
131///     .google_search()
132///     .code_execution()
133///     .url_context();
134///
135/// assert_eq!(agent.tool_count(), 3);
136/// ```
137///
138/// # Thinking Budget
139///
140/// ```rust
141/// use gemini_adk_fluent_rs::builder::AgentBuilder;
142///
143/// let agent = AgentBuilder::new("thinker")
144///     .thinking(2048);
145///
146/// assert_eq!(agent.get_thinking_budget(), Some(2048));
147/// ```
148#[derive(Clone)]
149pub struct AgentBuilder {
150    inner: Arc<AgentBuilderInner>,
151}
152
153impl AgentBuilder {
154    /// Create a new builder with the given agent name.
155    pub fn new(name: impl Into<String>) -> Self {
156        Self {
157            inner: Arc::new(AgentBuilderInner {
158                name: name.into(),
159                model: None,
160                instruction: None,
161                voice: None,
162                temperature: None,
163                top_p: None,
164                top_k: None,
165                max_output_tokens: None,
166                stop_sequences: Vec::new(),
167                response_modalities: None,
168                thinking_budget: None,
169                tools: Vec::new(),
170                built_in_tools: Vec::new(),
171                writes: Vec::new(),
172                reads: Vec::new(),
173                sub_agents: Vec::new(),
174                isolate: false,
175                stay: false,
176                description: None,
177                output_schema: None,
178                output_key: None,
179                transfer_to_agent: None,
180                middleware_layers: Vec::new(),
181            }),
182        }
183    }
184
185    // ── Private helper: clone-on-write ──
186
187    fn mutate(&self) -> AgentBuilderInner {
188        (*self.inner).clone()
189    }
190
191    fn with(inner: AgentBuilderInner) -> Self {
192        Self {
193            inner: Arc::new(inner),
194        }
195    }
196
197    // ── Accessors ──
198
199    /// The agent name.
200    pub fn name(&self) -> &str {
201        &self.inner.name
202    }
203
204    /// Configured model, if any.
205    pub fn get_model(&self) -> Option<&GeminiModel> {
206        self.inner.model.as_ref()
207    }
208
209    /// Configured instruction, if any.
210    pub fn get_instruction(&self) -> Option<&str> {
211        self.inner.instruction.as_deref()
212    }
213
214    /// Configured voice, if any.
215    pub fn get_voice(&self) -> Option<&Voice> {
216        self.inner.voice.as_ref()
217    }
218
219    /// Configured temperature, if any.
220    pub fn get_temperature(&self) -> Option<f32> {
221        self.inner.temperature
222    }
223
224    /// Whether text-only mode is set.
225    pub fn is_text_only(&self) -> bool {
226        self.inner
227            .response_modalities
228            .as_ref()
229            .map(|m| m == &[Modality::Text])
230            .unwrap_or(false)
231    }
232
233    /// Configured thinking budget, if any.
234    pub fn get_thinking_budget(&self) -> Option<u32> {
235        self.inner.thinking_budget
236    }
237
238    /// State keys this agent writes.
239    pub fn get_writes(&self) -> &[String] {
240        &self.inner.writes
241    }
242
243    /// State keys this agent reads.
244    pub fn get_reads(&self) -> &[String] {
245        &self.inner.reads
246    }
247
248    /// Sub-agents registered.
249    pub fn get_sub_agents(&self) -> &[AgentBuilder] {
250        &self.inner.sub_agents
251    }
252
253    /// Whether agent runs in isolated state.
254    pub fn is_isolated(&self) -> bool {
255        self.inner.isolate
256    }
257
258    /// Whether agent stays after transfer.
259    pub fn is_stay(&self) -> bool {
260        self.inner.stay
261    }
262
263    /// Number of tool entries.
264    pub fn tool_count(&self) -> usize {
265        self.inner.tools.len() + self.inner.built_in_tools.len()
266    }
267
268    /// Configured top_p, if any.
269    pub fn get_top_p(&self) -> Option<f32> {
270        self.inner.top_p
271    }
272
273    /// Configured top_k, if any.
274    pub fn get_top_k(&self) -> Option<u32> {
275        self.inner.top_k
276    }
277
278    /// Configured max_output_tokens, if any.
279    pub fn get_max_output_tokens(&self) -> Option<u32> {
280        self.inner.max_output_tokens
281    }
282
283    /// Configured stop sequences.
284    pub fn get_stop_sequences(&self) -> &[String] {
285        &self.inner.stop_sequences
286    }
287
288    /// Configured description, if any.
289    pub fn get_description(&self) -> Option<&str> {
290        self.inner.description.as_deref()
291    }
292
293    /// Configured output schema, if any.
294    pub fn get_output_schema(&self) -> Option<&serde_json::Value> {
295        self.inner.output_schema.as_ref()
296    }
297
298    /// Get the configured output key.
299    pub fn get_output_key(&self) -> Option<&str> {
300        self.inner.output_key.as_deref()
301    }
302
303    /// Configured transfer target agent, if any.
304    pub fn get_transfer_to(&self) -> Option<&str> {
305        self.inner.transfer_to_agent.as_deref()
306    }
307
308    /// Number of registered middleware layers.
309    pub fn middleware_layer_count(&self) -> usize {
310        self.inner.middleware_layers.len()
311    }
312
313    // ── Fluent Setters (copy-on-write) ──
314
315    /// Set the Gemini model.
316    pub fn model(self, model: GeminiModel) -> Self {
317        let mut inner = self.mutate();
318        inner.model = Some(model);
319        Self::with(inner)
320    }
321
322    /// Set the system instruction.
323    pub fn instruction(self, inst: impl Into<String>) -> Self {
324        let mut inner = self.mutate();
325        inner.instruction = Some(inst.into());
326        Self::with(inner)
327    }
328
329    /// Set the output voice.
330    pub fn voice(self, voice: Voice) -> Self {
331        let mut inner = self.mutate();
332        inner.voice = Some(voice);
333        Self::with(inner)
334    }
335
336    /// Set the temperature.
337    pub fn temperature(self, t: f32) -> Self {
338        let mut inner = self.mutate();
339        inner.temperature = Some(t);
340        Self::with(inner)
341    }
342
343    /// Set text-only mode (no audio output).
344    pub fn text_only(self) -> Self {
345        let mut inner = self.mutate();
346        inner.response_modalities = Some(vec![Modality::Text]);
347        Self::with(inner)
348    }
349
350    /// Set response modalities explicitly.
351    pub fn response_modalities(self, modalities: Vec<Modality>) -> Self {
352        let mut inner = self.mutate();
353        inner.response_modalities = Some(modalities);
354        Self::with(inner)
355    }
356
357    /// Enable thinking with a token budget.
358    pub fn thinking(self, budget: u32) -> Self {
359        let mut inner = self.mutate();
360        inner.thinking_budget = Some(budget);
361        Self::with(inner)
362    }
363
364    /// Add a built-in URL context tool.
365    pub fn url_context(self) -> Self {
366        let mut inner = self.mutate();
367        inner.built_in_tools.push(Tool::url_context());
368        Self::with(inner)
369    }
370
371    /// Add a built-in Google Search tool.
372    pub fn google_search(self) -> Self {
373        let mut inner = self.mutate();
374        inner.built_in_tools.push(Tool::google_search());
375        Self::with(inner)
376    }
377
378    /// Add a built-in code execution tool.
379    pub fn code_execution(self) -> Self {
380        let mut inner = self.mutate();
381        inner.built_in_tools.push(Tool::code_execution());
382        Self::with(inner)
383    }
384
385    /// Declare a state key this agent writes.
386    pub fn writes(self, key: impl Into<String>) -> Self {
387        let mut inner = self.mutate();
388        inner.writes.push(key.into());
389        Self::with(inner)
390    }
391
392    /// Declare a state key this agent reads.
393    pub fn reads(self, key: impl Into<String>) -> Self {
394        let mut inner = self.mutate();
395        inner.reads.push(key.into());
396        Self::with(inner)
397    }
398
399    /// Add a sub-agent for transfer.
400    pub fn sub_agent(self, agent: AgentBuilder) -> Self {
401        let mut inner = self.mutate();
402        inner.sub_agents.push(agent);
403        Self::with(inner)
404    }
405
406    /// Run this agent in isolated state (no shared state).
407    pub fn isolate(self) -> Self {
408        let mut inner = self.mutate();
409        inner.isolate = true;
410        Self::with(inner)
411    }
412
413    /// Keep this agent active after transfer (don't tear down).
414    pub fn stay(self) -> Self {
415        let mut inner = self.mutate();
416        inner.stay = true;
417        Self::with(inner)
418    }
419
420    /// Set top_p (nucleus sampling).
421    pub fn top_p(self, p: f32) -> Self {
422        let mut inner = self.mutate();
423        inner.top_p = Some(p);
424        Self::with(inner)
425    }
426
427    /// Set top_k (top-k sampling).
428    pub fn top_k(self, k: u32) -> Self {
429        let mut inner = self.mutate();
430        inner.top_k = Some(k);
431        Self::with(inner)
432    }
433
434    /// Set maximum output tokens.
435    pub fn max_output_tokens(self, n: u32) -> Self {
436        let mut inner = self.mutate();
437        inner.max_output_tokens = Some(n);
438        Self::with(inner)
439    }
440
441    /// Set stop sequences.
442    pub fn stop_sequences(self, seqs: Vec<String>) -> Self {
443        let mut inner = self.mutate();
444        inner.stop_sequences = seqs;
445        Self::with(inner)
446    }
447
448    /// Set a description for this agent (used in tool/agent metadata).
449    pub fn description(self, desc: impl Into<String>) -> Self {
450        let mut inner = self.mutate();
451        inner.description = Some(desc.into());
452        Self::with(inner)
453    }
454
455    /// Set a JSON schema for structured output.
456    pub fn output_schema(self, schema: serde_json::Value) -> Self {
457        let mut inner = self.mutate();
458        inner.output_schema = Some(schema);
459        Self::with(inner)
460    }
461
462    /// Set the output key — agent's final text response is auto-saved to this state key.
463    pub fn output_key(self, key: impl Into<String>) -> Self {
464        let mut inner = self.mutate();
465        inner.output_key = Some(key.into());
466        Self::with(inner)
467    }
468
469    /// Set a default transfer target agent.
470    pub fn transfer_to(self, agent_name: impl Into<String>) -> Self {
471        let mut inner = self.mutate();
472        inner.transfer_to_agent = Some(agent_name.into());
473        Self::with(inner)
474    }
475
476    // ── Upstream naming aliases ──
477
478    /// Alias for [`instruction`](Self::instruction) — matches upstream Python `Agent.instruct()`.
479    pub fn instruct(self, inst: impl Into<String>) -> Self {
480        self.instruction(inst)
481    }
482
483    /// Alias for [`description`](Self::description) — matches upstream Python `Agent.describe()`.
484    pub fn describe(self, desc: impl Into<String>) -> Self {
485        self.description(desc)
486    }
487
488    /// Register a single tool function.
489    ///
490    /// ```ignore
491    /// Agent::new("assistant").tool(Arc::new(my_tool))
492    /// ```
493    pub fn tool(self, f: Arc<dyn ToolFunction>) -> Self {
494        let mut inner = self.mutate();
495        inner
496            .tools
497            .push(ToolEntry::Runtime(Arc::new(ToolFunctionEntry(f))));
498        Self::with(inner)
499    }
500
501    /// Register multiple tools from a [`ToolComposite`].
502    ///
503    /// ```ignore
504    /// let tools = T::simple("greet", "Greet", |_| async { Ok(json!({})) })
505    ///     | T::google_search();
506    /// Agent::new("assistant").tools(tools)
507    /// ```
508    pub fn tools(self, composite: ToolComposite) -> Self {
509        use crate::compose::tools::{DeferredTool, ToolResolution};
510        let mut inner = self.mutate();
511        for entry in composite.entries {
512            match entry.classify() {
513                ToolResolution::Runtime(f) => {
514                    inner
515                        .tools
516                        .push(ToolEntry::Runtime(Arc::new(ToolFunctionEntry(f))));
517                }
518                ToolResolution::BuiltIn(t) => {
519                    inner.built_in_tools.push(t);
520                }
521                ToolResolution::Agent {
522                    name,
523                    description,
524                    agent,
525                } => {
526                    // Expose the sub-agent as a callable tool over a fresh State.
527                    let tool = gemini_adk_rs::TextAgentTool::from_arc(
528                        name,
529                        description,
530                        agent,
531                        gemini_adk_rs::State::new(),
532                    );
533                    inner
534                        .tools
535                        .push(ToolEntry::Runtime(Arc::new(ToolFunctionEntry(Arc::new(
536                            tool,
537                        )))));
538                }
539                ToolResolution::Deferred(deferred) => {
540                    // MCP / A2A / OpenAPI / Search require an async connection,
541                    // which the synchronous text-agent `build()` cannot perform.
542                    // These belong on a `Live` session; surface that rather than
543                    // dropping the tool silently.
544                    let kind = match deferred {
545                        DeferredTool::Mcp { .. } => "T::mcp",
546                        DeferredTool::A2a { .. } => "T::a2a",
547                        DeferredTool::OpenApi { .. } => "T::openapi",
548                        DeferredTool::Search { .. } => "T::search",
549                    };
550                    tracing::warn!(
551                        tool = kind,
552                        "ignoring async-resolved tool on a text AgentBuilder: {kind} \
553                         requires a Live session (async connect); attach it via Live::with_tools"
554                    );
555                }
556            }
557        }
558        Self::with(inner)
559    }
560
561    /// Attach output guards. Each model response is validated against every
562    /// guard; if any rejects the output the agent run fails with an
563    /// [`AgentError`](gemini_adk_rs::error::AgentError) listing the violations.
564    ///
565    /// Accepts a single guard or a `|`-composed [`GComposite`]:
566    ///
567    /// ```rust,ignore
568    /// use gemini_adk_fluent_rs::compose::guards::G;
569    /// Agent::new("writer").guard(G::pii() | G::length(1, 2000))
570    /// ```
571    ///
572    /// The guards are installed as an `after_model` middleware layer, so they
573    /// accumulate with `.middleware(...)` and honor copy-on-write.
574    pub fn guard(self, guard: impl Into<GComposite>) -> Self {
575        let mut inner = self.mutate();
576        inner.middleware_layers.push(guard.into().into_middleware());
577        Self::with(inner)
578    }
579
580    /// Attach a context policy that rewrites conversation history before each
581    /// model call (e.g. windowing, role filtering, tool-result exclusion).
582    ///
583    /// Accepts a single policy or a `+`-composed [`ContextPolicyChain`]:
584    ///
585    /// ```rust,ignore
586    /// use gemini_adk_fluent_rs::compose::context::C;
587    /// Agent::new("chat").context(C::window(10) + C::user_only())
588    /// ```
589    ///
590    /// The policy is installed as a `transform_request` middleware layer.
591    pub fn context(self, policy: impl Into<ContextPolicyChain>) -> Self {
592        let mut inner = self.mutate();
593        inner
594            .middleware_layers
595            .push(policy.into().into_middleware());
596        Self::with(inner)
597    }
598
599    /// Disallow transfer to peer agents.
600    pub fn no_peers(self) -> Self {
601        self.isolate()
602    }
603
604    /// Attach a [`MiddlewareComposite`] — all layers are installed on the
605    /// compiled `LlmTextAgent` in the order they appear in the composite.
606    ///
607    /// Multiple calls to `.middleware()` accumulate: the new layers are
608    /// appended after any previously registered layers, preserving the
609    /// copy-on-write contract.
610    ///
611    /// ```rust,ignore
612    /// use gemini_adk_fluent_rs::compose::middleware::M;
613    ///
614    /// let agent = AgentBuilder::new("analyst")
615    ///     .instruction("Analyze topics")
616    ///     .middleware(M::log() | M::latency())
617    ///     .build(llm);
618    /// ```
619    pub fn middleware(self, composite: MiddlewareComposite) -> Self {
620        let mut inner = self.mutate();
621        inner.middleware_layers.extend(composite.layers);
622        Self::with(inner)
623    }
624
625    // ── Compilation ──
626
627    /// Compile this builder into an executable `TextAgent`.
628    ///
629    /// The LLM is required because `TextAgent` makes `BaseLlm::generate()` calls.
630    /// Builder configuration (instruction, temperature, tools) is transferred to
631    /// the resulting agent.
632    ///
633    /// ```rust,ignore
634    /// let agent = AgentBuilder::new("analyst")
635    ///     .instruction("Analyze the topic")
636    ///     .temperature(0.3)
637    ///     .build(llm);
638    ///
639    /// let result = agent.run(&state).await?;
640    /// ```
641    pub fn build(self, llm: Arc<dyn BaseLlm>) -> Arc<dyn TextAgent> {
642        let mut agent = LlmTextAgent::new(&self.inner.name, llm);
643
644        if let Some(inst) = &self.inner.instruction {
645            agent = agent.instruction(inst);
646        }
647        if let Some(t) = self.inner.temperature {
648            agent = agent.temperature(t);
649        }
650        if let Some(n) = self.inner.max_output_tokens {
651            agent = agent.max_output_tokens(n);
652        }
653
654        // Build ToolDispatcher from registered tools.
655        if !self.inner.tools.is_empty() {
656            let mut dispatcher = ToolDispatcher::new();
657            for entry in &self.inner.tools {
658                match entry {
659                    ToolEntry::Runtime(t) => {
660                        let kind = t.to_tool_kind();
661                        match kind {
662                            ToolKind::Function(f) => dispatcher.register_function(f),
663                            ToolKind::Streaming(s) => dispatcher.register_streaming(s),
664                            ToolKind::InputStream(i) => dispatcher.register_input_streaming(i),
665                        }
666                    }
667                    ToolEntry::Declaration(_) => {
668                        // Built-in tool declarations (google_search, etc.) are sent
669                        // as-is; they don't have runtime handlers for text dispatch.
670                    }
671                }
672            }
673            if !dispatcher.is_empty() {
674                agent = agent.tools(Arc::new(dispatcher));
675            }
676        }
677
678        // Install middleware layers from the builder.
679        for mw in &self.inner.middleware_layers {
680            agent = agent.add_middleware(mw.clone());
681        }
682
683        Arc::new(agent)
684    }
685}
686
687/// Adapter that wraps an `Arc<dyn ToolFunction>` as a `ToolEntryTrait`.
688#[derive(Clone)]
689struct ToolFunctionEntry(Arc<dyn ToolFunction>);
690
691impl ToolEntryTrait for ToolFunctionEntry {
692    fn name(&self) -> &str {
693        self.0.name()
694    }
695
696    fn to_tool_kind(&self) -> ToolKind {
697        ToolKind::Function(self.0.clone())
698    }
699}
700
701impl std::fmt::Debug for AgentBuilder {
702    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
703        f.debug_struct("AgentBuilder")
704            .field("name", &self.inner.name)
705            .field("model", &self.inner.model)
706            .field("instruction", &self.inner.instruction)
707            .field("temperature", &self.inner.temperature)
708            .field("text_only", &self.is_text_only())
709            .field("tool_count", &self.tool_count())
710            .field("sub_agents", &self.inner.sub_agents.len())
711            .finish()
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use super::*;
718    use async_trait::async_trait;
719    use gemini_adk_rs::llm::{LlmError, LlmRequest, LlmResponse};
720    use gemini_genai_rs::prelude::{Content, Part, Role};
721
722    /// A mock LLM for build() tests.
723    struct MockLlm(String);
724
725    #[async_trait]
726    impl BaseLlm for MockLlm {
727        fn model_id(&self) -> &str {
728            "mock"
729        }
730        async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
731            Ok(LlmResponse {
732                content: Content {
733                    role: Some(Role::Model),
734                    parts: vec![Part::Text {
735                        text: self.0.clone(),
736                    }],
737                },
738                finish_reason: Some("STOP".into()),
739                usage: None,
740            })
741        }
742    }
743
744    #[test]
745    fn builder_creates_with_name() {
746        let b = AgentBuilder::new("test-agent");
747        assert_eq!(b.name(), "test-agent");
748    }
749
750    #[test]
751    fn fluent_chaining_works() {
752        let b = AgentBuilder::new("agent")
753            .instruction("Be helpful")
754            .temperature(0.7)
755            .model(GeminiModel::Gemini2_0FlashLive);
756
757        assert_eq!(b.get_instruction(), Some("Be helpful"));
758        assert_eq!(b.get_temperature(), Some(0.7));
759        assert_eq!(b.get_model(), Some(&GeminiModel::Gemini2_0FlashLive));
760    }
761
762    #[test]
763    fn copy_on_write_clone_independence() {
764        let base = AgentBuilder::new("base").temperature(0.5);
765        let variant = base.clone().temperature(0.9);
766
767        // Original unchanged
768        assert_eq!(base.get_temperature(), Some(0.5));
769        // Variant has new value
770        assert_eq!(variant.get_temperature(), Some(0.9));
771    }
772
773    #[test]
774    fn text_only_sets_modalities() {
775        let b = AgentBuilder::new("text").text_only();
776        assert!(b.is_text_only());
777    }
778
779    #[test]
780    fn url_context_adds_tool() {
781        let b = AgentBuilder::new("search").url_context();
782        assert_eq!(b.tool_count(), 1);
783    }
784
785    #[test]
786    fn google_search_adds_tool() {
787        let b = AgentBuilder::new("search").google_search();
788        assert_eq!(b.tool_count(), 1);
789    }
790
791    #[test]
792    fn code_execution_adds_tool() {
793        let b = AgentBuilder::new("code").code_execution();
794        assert_eq!(b.tool_count(), 1);
795    }
796
797    #[test]
798    fn thinking_sets_budget() {
799        let b = AgentBuilder::new("thinker").thinking(2048);
800        assert_eq!(b.get_thinking_budget(), Some(2048));
801    }
802
803    #[test]
804    fn writes_and_reads_keys() {
805        let b = AgentBuilder::new("data").writes("output").reads("input");
806        assert_eq!(b.get_writes(), &["output"]);
807        assert_eq!(b.get_reads(), &["input"]);
808    }
809
810    #[test]
811    fn sub_agent_registration() {
812        let child = AgentBuilder::new("child");
813        let parent = AgentBuilder::new("parent").sub_agent(child);
814        assert_eq!(parent.get_sub_agents().len(), 1);
815        assert_eq!(parent.get_sub_agents()[0].name(), "child");
816    }
817
818    #[test]
819    fn isolate_and_stay() {
820        let b = AgentBuilder::new("agent").isolate().stay();
821        assert!(b.is_isolated());
822        assert!(b.is_stay());
823    }
824
825    #[test]
826    fn debug_display() {
827        let b = AgentBuilder::new("debug-test");
828        let debug = format!("{:?}", b);
829        assert!(debug.contains("debug-test"));
830    }
831
832    #[test]
833    fn top_p_sets_value() {
834        let b = AgentBuilder::new("agent").top_p(0.95);
835        assert_eq!(b.get_top_p(), Some(0.95));
836    }
837
838    #[test]
839    fn top_k_sets_value() {
840        let b = AgentBuilder::new("agent").top_k(40);
841        assert_eq!(b.get_top_k(), Some(40));
842    }
843
844    #[test]
845    fn max_output_tokens_sets_value() {
846        let b = AgentBuilder::new("agent").max_output_tokens(4096);
847        assert_eq!(b.get_max_output_tokens(), Some(4096));
848    }
849
850    #[test]
851    fn stop_sequences_sets_value() {
852        let b =
853            AgentBuilder::new("agent").stop_sequences(vec!["END".to_string(), "STOP".to_string()]);
854        assert_eq!(b.get_stop_sequences().len(), 2);
855    }
856
857    #[test]
858    fn description_sets_value() {
859        let b = AgentBuilder::new("agent").description("A helpful agent");
860        assert_eq!(b.get_description(), Some("A helpful agent"));
861    }
862
863    #[test]
864    fn output_schema_sets_value() {
865        let schema = serde_json::json!({"type": "object"});
866        let b = AgentBuilder::new("agent").output_schema(schema.clone());
867        assert_eq!(b.get_output_schema(), Some(&schema));
868    }
869
870    #[test]
871    fn transfer_to_sets_value() {
872        let b = AgentBuilder::new("agent").transfer_to("target-agent");
873        assert_eq!(b.get_transfer_to(), Some("target-agent"));
874    }
875
876    #[test]
877    fn full_fluent_chain() {
878        let b = AgentBuilder::new("full-agent")
879            .model(GeminiModel::Gemini2_0FlashLive)
880            .instruction("Be helpful")
881            .temperature(0.7)
882            .top_p(0.95)
883            .top_k(40)
884            .max_output_tokens(4096)
885            .thinking(2048)
886            .description("A fully configured agent")
887            .google_search()
888            .writes("output")
889            .reads("input");
890
891        assert_eq!(b.name(), "full-agent");
892        assert_eq!(b.get_temperature(), Some(0.7));
893        assert_eq!(b.get_top_p(), Some(0.95));
894        assert_eq!(b.get_top_k(), Some(40));
895        assert_eq!(b.get_max_output_tokens(), Some(4096));
896        assert_eq!(b.get_thinking_budget(), Some(2048));
897        assert_eq!(b.get_description(), Some("A fully configured agent"));
898        assert_eq!(b.tool_count(), 1);
899    }
900
901    // ── build() tests ──
902
903    #[tokio::test]
904    async fn build_produces_executable_agent() {
905        let llm: Arc<dyn BaseLlm> = Arc::new(MockLlm("built agent output".into()));
906        let agent = AgentBuilder::new("test")
907            .instruction("Be helpful")
908            .temperature(0.5)
909            .build(llm);
910
911        assert_eq!(agent.name(), "test");
912        let state = gemini_adk_rs::State::new();
913        let result = agent.run(&state).await.unwrap();
914        assert_eq!(result, "built agent output");
915    }
916
917    #[tokio::test]
918    async fn build_stores_output_in_state() {
919        let llm: Arc<dyn BaseLlm> = Arc::new(MockLlm("state output".into()));
920        let agent = AgentBuilder::new("test").build(llm);
921        let state = gemini_adk_rs::State::new();
922        agent.run(&state).await.unwrap();
923        assert_eq!(state.get::<String>("output"), Some("state output".into()));
924    }
925
926    #[tokio::test]
927    async fn build_reads_input_from_state() {
928        use gemini_adk_rs::llm::LlmRequest;
929
930        // An LLM that echoes whatever it receives.
931        struct EchoLlm;
932        #[async_trait]
933        impl BaseLlm for EchoLlm {
934            fn model_id(&self) -> &str {
935                "echo"
936            }
937            async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
938                let text: String = req
939                    .contents
940                    .iter()
941                    .flat_map(|c| &c.parts)
942                    .filter_map(|p| match p {
943                        Part::Text { text } => Some(text.as_str()),
944                        _ => None,
945                    })
946                    .collect::<Vec<_>>()
947                    .join("");
948                Ok(LlmResponse {
949                    content: Content {
950                        role: Some(Role::Model),
951                        parts: vec![Part::Text { text }],
952                    },
953                    finish_reason: Some("STOP".into()),
954                    usage: None,
955                })
956            }
957        }
958
959        let agent = AgentBuilder::new("echo").build(Arc::new(EchoLlm));
960        let state = gemini_adk_rs::State::new();
961        let _ = state.set("input", "hello from state");
962        let result = agent.run(&state).await.unwrap();
963        assert!(result.contains("hello from state"));
964    }
965
966    // ── Middleware end-to-end tests ──
967
968    /// A mock LLM that issues one tool call and then returns text.
969    struct ToolCallingMockLlm {
970        tool_name: &'static str,
971        final_text: &'static str,
972    }
973
974    #[async_trait]
975    impl BaseLlm for ToolCallingMockLlm {
976        fn model_id(&self) -> &str {
977            "tool-mock"
978        }
979
980        async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
981            use gemini_genai_rs::prelude::FunctionCall;
982
983            // If any part is a FunctionResponse, we already dispatched — return text.
984            let already_responded = req
985                .contents
986                .iter()
987                .flat_map(|c| &c.parts)
988                .any(|p| matches!(p, Part::FunctionResponse { .. }));
989
990            if already_responded {
991                Ok(LlmResponse {
992                    content: Content {
993                        role: Some(Role::Model),
994                        parts: vec![Part::Text {
995                            text: self.final_text.to_string(),
996                        }],
997                    },
998                    finish_reason: Some("STOP".into()),
999                    usage: None,
1000                })
1001            } else {
1002                Ok(LlmResponse {
1003                    content: Content {
1004                        role: Some(Role::Model),
1005                        parts: vec![Part::FunctionCall {
1006                            function_call: FunctionCall {
1007                                name: self.tool_name.to_string(),
1008                                args: serde_json::json!({"x": 1}),
1009                                id: Some("call-1".into()),
1010                            },
1011                        }],
1012                    },
1013                    finish_reason: None,
1014                    usage: None,
1015                })
1016            }
1017        }
1018    }
1019
1020    /// Verify that `M::before_model` and `M::after_tool` hooks fire when the agent runs.
1021    #[tokio::test]
1022    async fn middleware_hooks_fire_end_to_end() {
1023        use crate::compose::middleware::M;
1024        use gemini_adk_rs::tool::SimpleTool;
1025        use std::sync::atomic::{AtomicUsize, Ordering};
1026
1027        let before_model_count = Arc::new(AtomicUsize::new(0));
1028        let after_tool_count = Arc::new(AtomicUsize::new(0));
1029
1030        let bm = before_model_count.clone();
1031        let at = after_tool_count.clone();
1032
1033        let mw = M::before_model(move |_req| {
1034            bm.fetch_add(1, Ordering::SeqCst);
1035            Ok(())
1036        }) | M::after_tool(move |_call, _result| {
1037            at.fetch_add(1, Ordering::SeqCst);
1038            Ok(())
1039        });
1040
1041        let llm: Arc<dyn BaseLlm> = Arc::new(ToolCallingMockLlm {
1042            tool_name: "echo_tool",
1043            final_text: "done",
1044        });
1045
1046        let agent = AgentBuilder::new("mw-test")
1047            .middleware(mw)
1048            .tool(Arc::new(SimpleTool::new(
1049                "echo_tool",
1050                "Echo tool",
1051                None,
1052                |_args| async move { Ok(serde_json::json!({"echo": true})) },
1053            )))
1054            .build(llm);
1055
1056        let state = gemini_adk_rs::State::new();
1057        let result = agent.run(&state).await.unwrap();
1058        assert_eq!(result, "done");
1059
1060        // before_model fires once per LLM call: first call (tool call) + second call (final text).
1061        assert_eq!(
1062            before_model_count.load(Ordering::SeqCst),
1063            2,
1064            "before_model should fire for each generate() call"
1065        );
1066        // after_tool fires once per successful tool dispatch.
1067        assert_eq!(
1068            after_tool_count.load(Ordering::SeqCst),
1069            1,
1070            "after_tool should fire once for the tool dispatch"
1071        );
1072    }
1073
1074    /// Verify copy-on-write: adding middleware to a clone does not affect the original.
1075    #[test]
1076    fn middleware_copy_on_write() {
1077        use crate::compose::middleware::M;
1078
1079        let base = AgentBuilder::new("base").instruction("base");
1080        let with_mw = base.clone().middleware(M::log() | M::latency());
1081
1082        // Original should have no middleware layers.
1083        assert_eq!(base.middleware_layer_count(), 0);
1084        // Clone with middleware should have 2 layers.
1085        assert_eq!(with_mw.middleware_layer_count(), 2);
1086    }
1087
1088    /// Verify `on_error` hook fires when the agent errors.
1089    #[tokio::test]
1090    async fn middleware_on_error_fires_on_failure() {
1091        use crate::compose::middleware::M;
1092        use gemini_adk_rs::llm::LlmError;
1093        use std::sync::atomic::{AtomicUsize, Ordering};
1094
1095        let error_count = Arc::new(AtomicUsize::new(0));
1096        let ec = error_count.clone();
1097
1098        let mw = M::on_error(move |_err| {
1099            ec.fetch_add(1, Ordering::SeqCst);
1100            Ok(())
1101        });
1102
1103        struct FailLlm;
1104        #[async_trait]
1105        impl BaseLlm for FailLlm {
1106            fn model_id(&self) -> &str {
1107                "fail"
1108            }
1109            async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
1110                Err(LlmError::RequestFailed("boom".into()))
1111            }
1112        }
1113
1114        let agent = AgentBuilder::new("error-test")
1115            .middleware(mw)
1116            .build(Arc::new(FailLlm));
1117
1118        let state = gemini_adk_rs::State::new();
1119        let result = agent.run(&state).await;
1120        assert!(result.is_err(), "agent should fail");
1121        assert_eq!(
1122            error_count.load(Ordering::SeqCst),
1123            1,
1124            "on_error should fire exactly once"
1125        );
1126    }
1127
1128    // ── Guard / context wiring tests ──
1129
1130    /// A mock LLM that echoes a fixed response and records the number of
1131    /// `contents` it was asked to generate from (to observe context rewriting).
1132    struct RecordingLlm {
1133        text: &'static str,
1134        seen_len: Arc<std::sync::atomic::AtomicUsize>,
1135    }
1136
1137    #[async_trait]
1138    impl BaseLlm for RecordingLlm {
1139        fn model_id(&self) -> &str {
1140            "recording-mock"
1141        }
1142
1143        async fn generate(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
1144            self.seen_len
1145                .store(req.contents.len(), std::sync::atomic::Ordering::SeqCst);
1146            Ok(LlmResponse {
1147                content: Content {
1148                    role: Some(Role::Model),
1149                    parts: vec![Part::Text {
1150                        text: self.text.to_string(),
1151                    }],
1152                },
1153                finish_reason: Some("STOP".into()),
1154                usage: None,
1155            })
1156        }
1157    }
1158
1159    #[tokio::test]
1160    async fn guard_blocks_violating_output() {
1161        use crate::compose::guards::G;
1162
1163        let llm: Arc<dyn BaseLlm> = Arc::new(RecordingLlm {
1164            text: "you can reach me at agent@example.com",
1165            seen_len: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1166        });
1167
1168        let agent = AgentBuilder::new("guarded").guard(G::pii()).build(llm);
1169
1170        let state = gemini_adk_rs::State::new();
1171        let err = agent.run(&state).await.unwrap_err();
1172        assert!(
1173            err.to_string().contains("guard violation"),
1174            "PII guard should veto the response, got: {err}"
1175        );
1176    }
1177
1178    #[tokio::test]
1179    async fn guard_allows_clean_output() {
1180        use crate::compose::guards::G;
1181
1182        let llm: Arc<dyn BaseLlm> = Arc::new(RecordingLlm {
1183            text: "all clean here",
1184            seen_len: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
1185        });
1186
1187        let agent = AgentBuilder::new("guarded")
1188            .guard(G::pii() | G::length(1, 1000))
1189            .build(llm);
1190
1191        let state = gemini_adk_rs::State::new();
1192        let result = agent.run(&state).await.unwrap();
1193        assert_eq!(result, "all clean here");
1194    }
1195
1196    #[tokio::test]
1197    async fn context_policy_rewrites_request_history() {
1198        use crate::compose::context::C;
1199
1200        // The agent seeds one user turn; a prepend policy injects a second turn,
1201        // so the LLM should see 2 contents — proving transform_request ran.
1202        let seen = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1203        let llm: Arc<dyn BaseLlm> = Arc::new(RecordingLlm {
1204            text: "ok",
1205            seen_len: seen.clone(),
1206        });
1207
1208        let agent = AgentBuilder::new("ctx")
1209            .context(C::prepend(Content::user("system preamble")))
1210            .build(llm);
1211
1212        let state = gemini_adk_rs::State::new();
1213        let _ = state.set("input", "hello");
1214        let _ = agent.run(&state).await.unwrap();
1215        assert_eq!(
1216            seen.load(std::sync::atomic::Ordering::SeqCst),
1217            2,
1218            "context policy should have prepended a turn before the model call"
1219        );
1220    }
1221
1222    #[tokio::test]
1223    async fn context_window_trims_history() {
1224        use crate::compose::context::C;
1225
1226        // window(1) keeps only the last turn. We seed a single input turn and
1227        // prepend two extra turns, then window down to 1 — the model sees 1.
1228        let seen = Arc::new(std::sync::atomic::AtomicUsize::new(0));
1229        let llm: Arc<dyn BaseLlm> = Arc::new(RecordingLlm {
1230            text: "ok",
1231            seen_len: seen.clone(),
1232        });
1233
1234        let agent = AgentBuilder::new("ctx")
1235            .context(C::prepend(Content::user("a")) + C::prepend(Content::user("b")) + C::window(1))
1236            .build(llm);
1237
1238        let state = gemini_adk_rs::State::new();
1239        let _ = state.set("input", "hello");
1240        let _ = agent.run(&state).await.unwrap();
1241        assert_eq!(
1242            seen.load(std::sync::atomic::Ordering::SeqCst),
1243            1,
1244            "window(1) should trim history to the last turn"
1245        );
1246    }
1247}