gemini_adk_fluent_rs/compose/
guards.rs

1//! G — Guard composition.
2//!
3//! Compose output guards with `|` for validation and safety checks.
4//!
5//! ## Wiring
6//!
7//! A [`GComposite`] attached via `AgentBuilder::guard` is installed on the
8//! compiled `LlmTextAgent` as an `after_model` middleware layer (see
9//! [`GComposite::into_middleware`]). Every model response is checked against
10//! all guards; if any guard rejects the output the agent run fails with an
11//! [`AgentError`] enumerating the violations, vetoing the response.
12
13use std::sync::Arc;
14
15use async_trait::async_trait;
16use gemini_adk_rs::error::AgentError;
17use gemini_adk_rs::llm::{BaseLlm, LlmRequest, LlmResponse};
18use gemini_adk_rs::middleware::Middleware;
19
20use crate::compose::judge::{render_contents, LlmJudge};
21
22/// A guard that validates agent output.
23#[derive(Clone)]
24pub struct GGuard {
25    name: &'static str,
26    kind: GuardKind,
27}
28
29/// How a guard decides pass/fail.
30#[derive(Clone)]
31enum GuardKind {
32    /// Synchronous predicate over the output text.
33    Sync(#[allow(clippy::type_complexity)] Arc<dyn Fn(&str) -> Result<(), String> + Send + Sync>),
34    /// LLM-as-judge over the output (and, for grounding, the input context).
35    Judge(LlmJudge),
36}
37
38impl GGuard {
39    fn new(
40        name: &'static str,
41        f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static,
42    ) -> Self {
43        Self {
44            name,
45            kind: GuardKind::Sync(Arc::new(f)),
46        }
47    }
48
49    fn judge(name: &'static str, judge: LlmJudge) -> Self {
50        Self {
51            name,
52            kind: GuardKind::Judge(judge),
53        }
54    }
55
56    /// Name of this guard.
57    pub fn name(&self) -> &str {
58        self.name
59    }
60
61    /// Synchronously check the output. LLM-judge guards cannot run on the sync
62    /// path and always return `Ok(())` here — use [`GGuard::check_async`] (the
63    /// guard middleware uses the async path).
64    pub fn check(&self, output: &str) -> Result<(), String> {
65        match &self.kind {
66            GuardKind::Sync(f) => f(output),
67            GuardKind::Judge(_) => Ok(()),
68        }
69    }
70
71    /// Check the output, running an LLM judge if this is a judge guard.
72    /// `context` is the model's input history (for grounding/hallucination).
73    pub async fn check_async(&self, output: &str, context: Option<&str>) -> Result<(), String> {
74        match &self.kind {
75            GuardKind::Sync(f) => f(output),
76            GuardKind::Judge(judge) => {
77                let verdict = judge.judge(output, context).await;
78                if verdict.flagged {
79                    Err(verdict.reason)
80                } else {
81                    Ok(())
82                }
83            }
84        }
85    }
86}
87
88impl std::fmt::Debug for GGuard {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("GGuard").field("name", &self.name).finish()
91    }
92}
93
94/// Compose two guards with `|`.
95impl std::ops::BitOr for GGuard {
96    type Output = GComposite;
97
98    fn bitor(self, rhs: GGuard) -> Self::Output {
99        GComposite {
100            guards: vec![self, rhs],
101        }
102    }
103}
104
105/// A composite of guards — all must pass for output to be accepted.
106#[derive(Clone)]
107pub struct GComposite {
108    /// The guards in this composite.
109    pub guards: Vec<GGuard>,
110}
111
112impl GComposite {
113    /// Check all guards against the output (sync path; LLM-judge guards are
114    /// skipped — see [`GComposite::check_all_async`]). Returns all violations.
115    pub fn check_all(&self, output: &str) -> Vec<String> {
116        self.guards
117            .iter()
118            .filter_map(|g| g.check(output).err())
119            .collect()
120    }
121
122    /// Check all guards, running LLM-judge guards against `output` and the
123    /// optional input `context`. Returns all violations as `name: reason`.
124    pub async fn check_all_async(&self, output: &str, context: Option<&str>) -> Vec<String> {
125        let mut violations = Vec::new();
126        for g in &self.guards {
127            if let Err(reason) = g.check_async(output, context).await {
128                violations.push(format!("{}: {}", g.name(), reason));
129            }
130        }
131        violations
132    }
133
134    /// Number of guards.
135    pub fn len(&self) -> usize {
136        self.guards.len()
137    }
138
139    /// Whether empty.
140    pub fn is_empty(&self) -> bool {
141        self.guards.is_empty()
142    }
143}
144
145impl std::ops::BitOr<GGuard> for GComposite {
146    type Output = GComposite;
147
148    fn bitor(mut self, rhs: GGuard) -> Self::Output {
149        self.guards.push(rhs);
150        self
151    }
152}
153
154/// A single guard is a one-element composite, so `.guard(G::pii())` works
155/// without an explicit `| `.
156impl From<GGuard> for GComposite {
157    fn from(guard: GGuard) -> Self {
158        GComposite {
159            guards: vec![guard],
160        }
161    }
162}
163
164impl GComposite {
165    /// Adapt this guard composite into an `after_model` middleware layer that
166    /// vetoes any model response failing one or more guards.
167    pub fn into_middleware(self) -> Arc<dyn Middleware> {
168        Arc::new(GuardMiddleware { guards: self })
169    }
170}
171
172/// Middleware adapter that enforces a [`GComposite`] on every model response.
173struct GuardMiddleware {
174    guards: GComposite,
175}
176
177#[async_trait]
178impl Middleware for GuardMiddleware {
179    fn name(&self) -> &str {
180        "guard"
181    }
182
183    async fn after_model(
184        &self,
185        request: &LlmRequest,
186        response: &LlmResponse,
187    ) -> Result<Option<LlmResponse>, AgentError> {
188        // Render the input history so grounding/hallucination judges can see
189        // what the response is supposed to be consistent with.
190        let context = render_contents(&request.contents);
191        let violations = self
192            .guards
193            .check_all_async(&response.text(), Some(&context))
194            .await;
195        if violations.is_empty() {
196            Ok(None)
197        } else {
198            Err(AgentError::Other(format!(
199                "guard violation: {}",
200                violations.join("; ")
201            )))
202        }
203    }
204}
205
206/// The `G` namespace — static factory methods for guards.
207pub struct G;
208
209impl G {
210    /// Length guard — output must be within bounds.
211    pub fn length(min: usize, max: usize) -> GGuard {
212        GGuard::new("length", move |output| {
213            let len = output.len();
214            if len < min {
215                Err(format!("Output too short: {} < {}", len, min))
216            } else if len > max {
217                Err(format!("Output too long: {} > {}", len, max))
218            } else {
219                Ok(())
220            }
221        })
222    }
223
224    /// Regex guard — output must match (or not match) a pattern.
225    pub fn regex(pattern: &str) -> GGuard {
226        let pattern = pattern.to_string();
227        GGuard::new("regex", move |output| {
228            // Simple substring check — full regex requires the `regex` crate.
229            if output.contains(&pattern) {
230                Err(format!("Output matches forbidden pattern: {}", pattern))
231            } else {
232                Ok(())
233            }
234        })
235    }
236
237    /// Budget guard — output must not exceed a token estimate.
238    pub fn budget(max_tokens: usize) -> GGuard {
239        GGuard::new("budget", move |output| {
240            // Rough estimate: 4 chars per token.
241            let estimated_tokens = output.len() / 4;
242            if estimated_tokens > max_tokens {
243                Err(format!(
244                    "Output exceeds token budget: ~{} > {}",
245                    estimated_tokens, max_tokens
246                ))
247            } else {
248                Ok(())
249            }
250        })
251    }
252
253    /// JSON guard — output must be valid JSON.
254    pub fn json() -> GGuard {
255        GGuard::new("json", |output| {
256            serde_json::from_str::<serde_json::Value>(output)
257                .map(|_| ())
258                .map_err(|e| format!("Invalid JSON: {}", e))
259        })
260    }
261
262    /// Max turns guard — placeholder for turn limit enforcement.
263    pub fn max_turns(n: u32) -> GGuard {
264        GGuard::new("max_turns", move |_output| {
265            // Turn counting happens at runtime, not at output validation.
266            let _ = n;
267            Ok(())
268        })
269    }
270
271    /// PII guard — checks for common PII patterns (email, phone).
272    pub fn pii() -> GGuard {
273        GGuard::new("pii", |output| {
274            // Simple heuristic checks for common PII patterns.
275            if output.contains('@') && output.contains('.') {
276                // Might be an email — flag it.
277                return Err("Output may contain email addresses".to_string());
278            }
279            Ok(())
280        })
281    }
282
283    /// Topic restriction guard — output must not mention denied topics.
284    pub fn topic(deny: &[&str]) -> GGuard {
285        let deny: Vec<String> = deny.iter().map(|s| s.to_lowercase()).collect();
286        GGuard::new("topic", move |output| {
287            let lower = output.to_lowercase();
288            for topic in &deny {
289                if lower.contains(topic.as_str()) {
290                    return Err(format!("Output mentions denied topic: {}", topic));
291                }
292            }
293            Ok(())
294        })
295    }
296
297    /// Custom guard from a validation function.
298    pub fn custom(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
299        GGuard::new("custom", f)
300    }
301
302    /// Output guard — validates model output content via a predicate function.
303    pub fn output(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
304        GGuard::new("output", f)
305    }
306
307    /// Input guard — validates user input content via a predicate function.
308    pub fn input(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
309        GGuard::new("input", f)
310    }
311
312    /// Rate limiting guard — enforces a maximum number of checks per minute.
313    pub fn rate_limit(max_per_minute: u32) -> GGuard {
314        GGuard::new("rate_limit", move |_output| {
315            // Rate limiting is enforced at runtime by the processor.
316            let _ = max_per_minute;
317            Ok(())
318        })
319    }
320
321    /// Toxicity guard — flags toxic/abusive output using an LLM judge.
322    ///
323    /// Vetoes the response if the judge model decides it contains toxic, hateful,
324    /// harassing, sexual, or abusive content (mirrors ADK's safety evaluation,
325    /// but runs locally against the provided judge LLM).
326    pub fn toxicity(judge: Arc<dyn BaseLlm>) -> GGuard {
327        GGuard::judge(
328            "toxicity",
329            LlmJudge::new(
330                judge,
331                "the RESPONSE contains toxic, hateful, harassing, sexual, or otherwise \
332                 abusive or unsafe content",
333            ),
334        )
335    }
336
337    /// Grounding guard — flags output not supported by the conversation context.
338    ///
339    /// The judge sees the model's input history as CONTEXT and vetoes the
340    /// response if it makes factual claims not supported by that context.
341    pub fn grounded(judge: Arc<dyn BaseLlm>) -> GGuard {
342        GGuard::judge(
343            "grounded",
344            LlmJudge::new(
345                judge,
346                "the RESPONSE asserts facts that are NOT supported by, or that \
347                 contradict, the provided CONTEXT",
348            )
349            .with_context("CONTEXT"),
350        )
351    }
352
353    /// Hallucination guard — flags fabricated/unverifiable claims via an LLM judge.
354    pub fn hallucination(judge: Arc<dyn BaseLlm>) -> GGuard {
355        GGuard::judge(
356            "hallucination",
357            LlmJudge::new(
358                judge,
359                "the RESPONSE contains fabricated, invented, or unverifiable facts \
360                 that are not supported by the CONTEXT",
361            )
362            .with_context("CONTEXT"),
363        )
364    }
365
366    /// Conditional guard — only applies `inner` when `predicate` returns true.
367    pub fn when(predicate: impl Fn(&str) -> bool + Send + Sync + 'static, inner: GGuard) -> GGuard {
368        GGuard::new("when", move |output| {
369            if predicate(output) {
370                inner.check(output)
371            } else {
372                Ok(())
373            }
374        })
375    }
376
377    /// LLM-as-judge content guard.
378    ///
379    /// `rubric` describes the condition that constitutes a *violation*; the judge
380    /// model vetoes the response when that condition holds. Example:
381    /// `G::llm_judge(llm, "the response gives medical advice without a disclaimer")`.
382    pub fn llm_judge(judge: Arc<dyn BaseLlm>, rubric: impl Into<String>) -> GGuard {
383        GGuard::judge("llm_judge", LlmJudge::new(judge, rubric))
384    }
385
386    /// Named custom judge function guard.
387    pub fn custom_judge(
388        name: &str,
389        f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static,
390    ) -> GGuard {
391        // Leak the name to get a 'static str, matching the GGuard field type.
392        let name: &'static str = Box::leak(name.to_string().into_boxed_str());
393        GGuard::new(name, f)
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn length_guard_passes() {
403        assert!(G::length(1, 100).check("hello").is_ok());
404    }
405
406    #[test]
407    fn length_guard_too_short() {
408        assert!(G::length(10, 100).check("hi").is_err());
409    }
410
411    #[test]
412    fn length_guard_too_long() {
413        assert!(G::length(1, 5).check("too long text").is_err());
414    }
415
416    #[test]
417    fn json_guard_valid() {
418        assert!(G::json().check(r#"{"key": "value"}"#).is_ok());
419    }
420
421    #[test]
422    fn json_guard_invalid() {
423        assert!(G::json().check("not json").is_err());
424    }
425
426    #[test]
427    fn regex_guard_blocks() {
428        assert!(G::regex("secret").check("this is a secret").is_err());
429    }
430
431    #[test]
432    fn regex_guard_passes() {
433        assert!(G::regex("secret").check("this is public").is_ok());
434    }
435
436    #[test]
437    fn budget_guard_passes() {
438        assert!(G::budget(100).check("short").is_ok());
439    }
440
441    #[test]
442    fn topic_guard_blocks() {
443        assert!(G::topic(&["violence"]).check("There was violence").is_err());
444    }
445
446    #[test]
447    fn topic_guard_passes() {
448        assert!(G::topic(&["violence"]).check("A peaceful day").is_ok());
449    }
450
451    #[test]
452    fn compose_with_bitor() {
453        let composite = G::length(1, 1000) | G::json();
454        assert_eq!(composite.len(), 2);
455    }
456
457    #[test]
458    fn check_all_returns_violations() {
459        let composite = G::length(1, 5) | G::json();
460        let violations = composite.check_all("not json and too long text here");
461        assert!(!violations.is_empty());
462    }
463
464    #[test]
465    fn custom_guard() {
466        let g = G::custom(|output| {
467            if output.contains("bad") {
468                Err("Contains 'bad'".into())
469            } else {
470                Ok(())
471            }
472        });
473        assert!(g.check("good output").is_ok());
474        assert!(g.check("bad output").is_err());
475    }
476
477    #[test]
478    fn output_guard() {
479        let g = G::output(|output| {
480            if output.contains("forbidden") {
481                Err("Forbidden content".into())
482            } else {
483                Ok(())
484            }
485        });
486        assert!(g.check("safe content").is_ok());
487        assert!(g.check("forbidden content").is_err());
488        assert_eq!(g.name(), "output");
489    }
490
491    #[test]
492    fn input_guard() {
493        let g = G::input(|input| {
494            if input.is_empty() {
495                Err("Empty input".into())
496            } else {
497                Ok(())
498            }
499        });
500        assert!(g.check("hello").is_ok());
501        assert!(g.check("").is_err());
502        assert_eq!(g.name(), "input");
503    }
504
505    #[test]
506    fn rate_limit_guard() {
507        let g = G::rate_limit(60);
508        assert!(g.check("anything").is_ok());
509        assert_eq!(g.name(), "rate_limit");
510    }
511
512    // A no-op judge LLM for constructing LLM-backed guards in unit tests
513    // (these tests exercise composition/naming, not the judge call itself).
514    fn judge_llm() -> Arc<dyn BaseLlm> {
515        use gemini_adk_rs::llm::{LlmError, LlmResponse};
516        use gemini_genai_rs::prelude::{Content, Part, Role};
517
518        struct NoopJudge;
519        #[async_trait]
520        impl BaseLlm for NoopJudge {
521            fn model_id(&self) -> &str {
522                "noop-judge"
523            }
524            async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
525                Ok(LlmResponse {
526                    content: Content {
527                        role: Some(Role::Model),
528                        parts: vec![Part::Text {
529                            text: r#"{"violation": false, "reason": "ok"}"#.to_string(),
530                        }],
531                    },
532                    finish_reason: Some("STOP".into()),
533                    usage: None,
534                })
535            }
536        }
537        Arc::new(NoopJudge)
538    }
539
540    #[test]
541    fn toxicity_guard() {
542        let g = G::toxicity(judge_llm());
543        // Sync path is a no-op for judge guards.
544        assert!(g.check("anything").is_ok());
545        assert_eq!(g.name(), "toxicity");
546    }
547
548    #[test]
549    fn grounded_guard() {
550        let g = G::grounded(judge_llm());
551        assert!(g.check("anything").is_ok());
552        assert_eq!(g.name(), "grounded");
553    }
554
555    #[test]
556    fn hallucination_guard() {
557        let g = G::hallucination(judge_llm());
558        assert!(g.check("anything").is_ok());
559        assert_eq!(g.name(), "hallucination");
560    }
561
562    #[tokio::test]
563    async fn judge_guard_runs_async() {
564        // A judge that flags everything should produce a violation via check_async.
565        use gemini_adk_rs::llm::{LlmError, LlmResponse};
566        use gemini_genai_rs::prelude::{Content, Part, Role};
567        struct FlagAll;
568        #[async_trait]
569        impl BaseLlm for FlagAll {
570            fn model_id(&self) -> &str {
571                "flag-all"
572            }
573            async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
574                Ok(LlmResponse {
575                    content: Content {
576                        role: Some(Role::Model),
577                        parts: vec![Part::Text {
578                            text: r#"{"violation": true, "reason": "bad"}"#.to_string(),
579                        }],
580                    },
581                    finish_reason: Some("STOP".into()),
582                    usage: None,
583                })
584            }
585        }
586        let g = G::toxicity(Arc::new(FlagAll));
587        assert!(g.check_async("hello", None).await.is_err());
588    }
589
590    #[test]
591    fn when_guard_applies() {
592        let inner = G::length(1, 5);
593        let g = G::when(|output| output.starts_with("check:"), inner);
594        // Predicate true — inner guard runs and rejects long output.
595        assert!(g.check("check: this is way too long").is_err());
596        // Predicate false — inner guard skipped.
597        assert!(g.check("skip: this is way too long").is_ok());
598        assert_eq!(g.name(), "when");
599    }
600
601    #[test]
602    fn llm_judge_guard() {
603        let g = G::llm_judge(judge_llm(), "the response is unhelpful");
604        assert!(g.check("anything").is_ok());
605        assert_eq!(g.name(), "llm_judge");
606    }
607
608    #[test]
609    fn custom_judge_guard() {
610        let g = G::custom_judge("profanity_filter", |output| {
611            if output.contains("bad_word") {
612                Err("Profanity detected".into())
613            } else {
614                Ok(())
615            }
616        });
617        assert!(g.check("clean text").is_ok());
618        assert!(g.check("has bad_word here").is_err());
619        assert_eq!(g.name(), "profanity_filter");
620    }
621
622    #[test]
623    fn compose_new_guards_with_bitor() {
624        let composite =
625            G::toxicity(judge_llm()) | G::grounded(judge_llm()) | G::hallucination(judge_llm());
626        assert_eq!(composite.len(), 3);
627        // Sync path skips judge guards, so no violations surface synchronously.
628        assert!(composite.check_all("test").is_empty());
629    }
630}