gemini_adk_fluent_rs/compose/
guards.rs

1//! G — Guard composition.
2//!
3//! Compose output guards with `|` for validation and safety checks.
4
5use std::sync::Arc;
6
7/// A guard that validates agent output.
8#[derive(Clone)]
9pub struct GGuard {
10    name: &'static str,
11    #[allow(clippy::type_complexity)]
12    checker: Arc<dyn Fn(&str) -> Result<(), String> + Send + Sync>,
13}
14
15impl GGuard {
16    fn new(
17        name: &'static str,
18        f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static,
19    ) -> Self {
20        Self {
21            name,
22            checker: Arc::new(f),
23        }
24    }
25
26    /// Name of this guard.
27    pub fn name(&self) -> &str {
28        self.name
29    }
30
31    /// Check the output. Returns `Ok(())` if valid, `Err(reason)` if not.
32    pub fn check(&self, output: &str) -> Result<(), String> {
33        (self.checker)(output)
34    }
35}
36
37impl std::fmt::Debug for GGuard {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("GGuard").field("name", &self.name).finish()
40    }
41}
42
43/// Compose two guards with `|`.
44impl std::ops::BitOr for GGuard {
45    type Output = GComposite;
46
47    fn bitor(self, rhs: GGuard) -> Self::Output {
48        GComposite {
49            guards: vec![self, rhs],
50        }
51    }
52}
53
54/// A composite of guards — all must pass for output to be accepted.
55#[derive(Clone)]
56pub struct GComposite {
57    /// The guards in this composite.
58    pub guards: Vec<GGuard>,
59}
60
61impl GComposite {
62    /// Check all guards against the output. Returns all violations.
63    pub fn check_all(&self, output: &str) -> Vec<String> {
64        self.guards
65            .iter()
66            .filter_map(|g| g.check(output).err())
67            .collect()
68    }
69
70    /// Number of guards.
71    pub fn len(&self) -> usize {
72        self.guards.len()
73    }
74
75    /// Whether empty.
76    pub fn is_empty(&self) -> bool {
77        self.guards.is_empty()
78    }
79}
80
81impl std::ops::BitOr<GGuard> for GComposite {
82    type Output = GComposite;
83
84    fn bitor(mut self, rhs: GGuard) -> Self::Output {
85        self.guards.push(rhs);
86        self
87    }
88}
89
90/// The `G` namespace — static factory methods for guards.
91pub struct G;
92
93impl G {
94    /// Length guard — output must be within bounds.
95    pub fn length(min: usize, max: usize) -> GGuard {
96        GGuard::new("length", move |output| {
97            let len = output.len();
98            if len < min {
99                Err(format!("Output too short: {} < {}", len, min))
100            } else if len > max {
101                Err(format!("Output too long: {} > {}", len, max))
102            } else {
103                Ok(())
104            }
105        })
106    }
107
108    /// Regex guard — output must match (or not match) a pattern.
109    pub fn regex(pattern: &str) -> GGuard {
110        let pattern = pattern.to_string();
111        GGuard::new("regex", move |output| {
112            // Simple substring check — full regex requires the `regex` crate.
113            if output.contains(&pattern) {
114                Err(format!("Output matches forbidden pattern: {}", pattern))
115            } else {
116                Ok(())
117            }
118        })
119    }
120
121    /// Budget guard — output must not exceed a token estimate.
122    pub fn budget(max_tokens: usize) -> GGuard {
123        GGuard::new("budget", move |output| {
124            // Rough estimate: 4 chars per token.
125            let estimated_tokens = output.len() / 4;
126            if estimated_tokens > max_tokens {
127                Err(format!(
128                    "Output exceeds token budget: ~{} > {}",
129                    estimated_tokens, max_tokens
130                ))
131            } else {
132                Ok(())
133            }
134        })
135    }
136
137    /// JSON guard — output must be valid JSON.
138    pub fn json() -> GGuard {
139        GGuard::new("json", |output| {
140            serde_json::from_str::<serde_json::Value>(output)
141                .map(|_| ())
142                .map_err(|e| format!("Invalid JSON: {}", e))
143        })
144    }
145
146    /// Max turns guard — placeholder for turn limit enforcement.
147    pub fn max_turns(n: u32) -> GGuard {
148        GGuard::new("max_turns", move |_output| {
149            // Turn counting happens at runtime, not at output validation.
150            let _ = n;
151            Ok(())
152        })
153    }
154
155    /// PII guard — checks for common PII patterns (email, phone).
156    pub fn pii() -> GGuard {
157        GGuard::new("pii", |output| {
158            // Simple heuristic checks for common PII patterns.
159            if output.contains('@') && output.contains('.') {
160                // Might be an email — flag it.
161                return Err("Output may contain email addresses".to_string());
162            }
163            Ok(())
164        })
165    }
166
167    /// Topic restriction guard — output must not mention denied topics.
168    pub fn topic(deny: &[&str]) -> GGuard {
169        let deny: Vec<String> = deny.iter().map(|s| s.to_lowercase()).collect();
170        GGuard::new("topic", move |output| {
171            let lower = output.to_lowercase();
172            for topic in &deny {
173                if lower.contains(topic.as_str()) {
174                    return Err(format!("Output mentions denied topic: {}", topic));
175                }
176            }
177            Ok(())
178        })
179    }
180
181    /// Custom guard from a validation function.
182    pub fn custom(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
183        GGuard::new("custom", f)
184    }
185
186    /// Output guard — validates model output content via a predicate function.
187    pub fn output(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
188        GGuard::new("output", f)
189    }
190
191    /// Input guard — validates user input content via a predicate function.
192    pub fn input(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
193        GGuard::new("input", f)
194    }
195
196    /// Rate limiting guard — enforces a maximum number of checks per minute.
197    pub fn rate_limit(max_per_minute: u32) -> GGuard {
198        GGuard::new("rate_limit", move |_output| {
199            // Rate limiting is enforced at runtime by the processor.
200            let _ = max_per_minute;
201            Ok(())
202        })
203    }
204
205    /// Toxicity detection guard — placeholder for toxicity classification.
206    pub fn toxicity() -> GGuard {
207        GGuard::new("toxicity", |_output| {
208            // Toxicity detection requires an external classifier at runtime.
209            Ok(())
210        })
211    }
212
213    /// Grounding check guard — placeholder for grounding verification.
214    pub fn grounded() -> GGuard {
215        GGuard::new("grounded", |_output| {
216            // Grounding checks require external verification at runtime.
217            Ok(())
218        })
219    }
220
221    /// Hallucination detection guard — placeholder for hallucination detection.
222    pub fn hallucination() -> GGuard {
223        GGuard::new("hallucination", |_output| {
224            // Hallucination detection requires external verification at runtime.
225            Ok(())
226        })
227    }
228
229    /// Conditional guard — only applies `inner` when `predicate` returns true.
230    pub fn when(predicate: impl Fn(&str) -> bool + Send + Sync + 'static, inner: GGuard) -> GGuard {
231        GGuard::new("when", move |output| {
232            if predicate(output) {
233                inner.check(output)
234            } else {
235                Ok(())
236            }
237        })
238    }
239
240    /// LLM-as-judge content guard — stores a prompt for later LLM evaluation.
241    pub fn llm_judge(prompt: &str) -> GGuard {
242        let prompt = prompt.to_string();
243        GGuard::new("llm_judge", move |_output| {
244            // LLM judge evaluation happens at runtime with access to the LLM.
245            let _ = &prompt;
246            Ok(())
247        })
248    }
249
250    /// Named custom judge function guard.
251    pub fn custom_judge(
252        name: &str,
253        f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static,
254    ) -> GGuard {
255        // Leak the name to get a 'static str, matching the GGuard field type.
256        let name: &'static str = Box::leak(name.to_string().into_boxed_str());
257        GGuard::new(name, f)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn length_guard_passes() {
267        assert!(G::length(1, 100).check("hello").is_ok());
268    }
269
270    #[test]
271    fn length_guard_too_short() {
272        assert!(G::length(10, 100).check("hi").is_err());
273    }
274
275    #[test]
276    fn length_guard_too_long() {
277        assert!(G::length(1, 5).check("too long text").is_err());
278    }
279
280    #[test]
281    fn json_guard_valid() {
282        assert!(G::json().check(r#"{"key": "value"}"#).is_ok());
283    }
284
285    #[test]
286    fn json_guard_invalid() {
287        assert!(G::json().check("not json").is_err());
288    }
289
290    #[test]
291    fn regex_guard_blocks() {
292        assert!(G::regex("secret").check("this is a secret").is_err());
293    }
294
295    #[test]
296    fn regex_guard_passes() {
297        assert!(G::regex("secret").check("this is public").is_ok());
298    }
299
300    #[test]
301    fn budget_guard_passes() {
302        assert!(G::budget(100).check("short").is_ok());
303    }
304
305    #[test]
306    fn topic_guard_blocks() {
307        assert!(G::topic(&["violence"]).check("There was violence").is_err());
308    }
309
310    #[test]
311    fn topic_guard_passes() {
312        assert!(G::topic(&["violence"]).check("A peaceful day").is_ok());
313    }
314
315    #[test]
316    fn compose_with_bitor() {
317        let composite = G::length(1, 1000) | G::json();
318        assert_eq!(composite.len(), 2);
319    }
320
321    #[test]
322    fn check_all_returns_violations() {
323        let composite = G::length(1, 5) | G::json();
324        let violations = composite.check_all("not json and too long text here");
325        assert!(!violations.is_empty());
326    }
327
328    #[test]
329    fn custom_guard() {
330        let g = G::custom(|output| {
331            if output.contains("bad") {
332                Err("Contains 'bad'".into())
333            } else {
334                Ok(())
335            }
336        });
337        assert!(g.check("good output").is_ok());
338        assert!(g.check("bad output").is_err());
339    }
340
341    #[test]
342    fn output_guard() {
343        let g = G::output(|output| {
344            if output.contains("forbidden") {
345                Err("Forbidden content".into())
346            } else {
347                Ok(())
348            }
349        });
350        assert!(g.check("safe content").is_ok());
351        assert!(g.check("forbidden content").is_err());
352        assert_eq!(g.name(), "output");
353    }
354
355    #[test]
356    fn input_guard() {
357        let g = G::input(|input| {
358            if input.is_empty() {
359                Err("Empty input".into())
360            } else {
361                Ok(())
362            }
363        });
364        assert!(g.check("hello").is_ok());
365        assert!(g.check("").is_err());
366        assert_eq!(g.name(), "input");
367    }
368
369    #[test]
370    fn rate_limit_guard() {
371        let g = G::rate_limit(60);
372        assert!(g.check("anything").is_ok());
373        assert_eq!(g.name(), "rate_limit");
374    }
375
376    #[test]
377    fn toxicity_guard() {
378        let g = G::toxicity();
379        assert!(g.check("anything").is_ok());
380        assert_eq!(g.name(), "toxicity");
381    }
382
383    #[test]
384    fn grounded_guard() {
385        let g = G::grounded();
386        assert!(g.check("anything").is_ok());
387        assert_eq!(g.name(), "grounded");
388    }
389
390    #[test]
391    fn hallucination_guard() {
392        let g = G::hallucination();
393        assert!(g.check("anything").is_ok());
394        assert_eq!(g.name(), "hallucination");
395    }
396
397    #[test]
398    fn when_guard_applies() {
399        let inner = G::length(1, 5);
400        let g = G::when(|output| output.starts_with("check:"), inner);
401        // Predicate true — inner guard runs and rejects long output.
402        assert!(g.check("check: this is way too long").is_err());
403        // Predicate false — inner guard skipped.
404        assert!(g.check("skip: this is way too long").is_ok());
405        assert_eq!(g.name(), "when");
406    }
407
408    #[test]
409    fn llm_judge_guard() {
410        let g = G::llm_judge("Is this response helpful?");
411        assert!(g.check("anything").is_ok());
412        assert_eq!(g.name(), "llm_judge");
413    }
414
415    #[test]
416    fn custom_judge_guard() {
417        let g = G::custom_judge("profanity_filter", |output| {
418            if output.contains("bad_word") {
419                Err("Profanity detected".into())
420            } else {
421                Ok(())
422            }
423        });
424        assert!(g.check("clean text").is_ok());
425        assert!(g.check("has bad_word here").is_err());
426        assert_eq!(g.name(), "profanity_filter");
427    }
428
429    #[test]
430    fn compose_new_guards_with_bitor() {
431        let composite = G::toxicity() | G::grounded() | G::hallucination();
432        assert_eq!(composite.len(), 3);
433        assert!(composite.check_all("test").is_empty());
434    }
435}