gemini_adk_fluent_rs/compose/
guards.rs1use std::sync::Arc;
6
7#[derive(Clone)]
9pub struct GGuard {
10 name: &'static str,
11 #[allow(clippy::type_complexity)]
12 checker: Arc<dyn Fn(&str) -> Result<(), String> + Send + Sync>,
13}
14
15impl GGuard {
16 fn new(
17 name: &'static str,
18 f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static,
19 ) -> Self {
20 Self {
21 name,
22 checker: Arc::new(f),
23 }
24 }
25
26 pub fn name(&self) -> &str {
28 self.name
29 }
30
31 pub fn check(&self, output: &str) -> Result<(), String> {
33 (self.checker)(output)
34 }
35}
36
37impl std::fmt::Debug for GGuard {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("GGuard").field("name", &self.name).finish()
40 }
41}
42
43impl std::ops::BitOr for GGuard {
45 type Output = GComposite;
46
47 fn bitor(self, rhs: GGuard) -> Self::Output {
48 GComposite {
49 guards: vec![self, rhs],
50 }
51 }
52}
53
54#[derive(Clone)]
56pub struct GComposite {
57 pub guards: Vec<GGuard>,
59}
60
61impl GComposite {
62 pub fn check_all(&self, output: &str) -> Vec<String> {
64 self.guards
65 .iter()
66 .filter_map(|g| g.check(output).err())
67 .collect()
68 }
69
70 pub fn len(&self) -> usize {
72 self.guards.len()
73 }
74
75 pub fn is_empty(&self) -> bool {
77 self.guards.is_empty()
78 }
79}
80
81impl std::ops::BitOr<GGuard> for GComposite {
82 type Output = GComposite;
83
84 fn bitor(mut self, rhs: GGuard) -> Self::Output {
85 self.guards.push(rhs);
86 self
87 }
88}
89
90pub struct G;
92
93impl G {
94 pub fn length(min: usize, max: usize) -> GGuard {
96 GGuard::new("length", move |output| {
97 let len = output.len();
98 if len < min {
99 Err(format!("Output too short: {} < {}", len, min))
100 } else if len > max {
101 Err(format!("Output too long: {} > {}", len, max))
102 } else {
103 Ok(())
104 }
105 })
106 }
107
108 pub fn regex(pattern: &str) -> GGuard {
110 let pattern = pattern.to_string();
111 GGuard::new("regex", move |output| {
112 if output.contains(&pattern) {
114 Err(format!("Output matches forbidden pattern: {}", pattern))
115 } else {
116 Ok(())
117 }
118 })
119 }
120
121 pub fn budget(max_tokens: usize) -> GGuard {
123 GGuard::new("budget", move |output| {
124 let estimated_tokens = output.len() / 4;
126 if estimated_tokens > max_tokens {
127 Err(format!(
128 "Output exceeds token budget: ~{} > {}",
129 estimated_tokens, max_tokens
130 ))
131 } else {
132 Ok(())
133 }
134 })
135 }
136
137 pub fn json() -> GGuard {
139 GGuard::new("json", |output| {
140 serde_json::from_str::<serde_json::Value>(output)
141 .map(|_| ())
142 .map_err(|e| format!("Invalid JSON: {}", e))
143 })
144 }
145
146 pub fn max_turns(n: u32) -> GGuard {
148 GGuard::new("max_turns", move |_output| {
149 let _ = n;
151 Ok(())
152 })
153 }
154
155 pub fn pii() -> GGuard {
157 GGuard::new("pii", |output| {
158 if output.contains('@') && output.contains('.') {
160 return Err("Output may contain email addresses".to_string());
162 }
163 Ok(())
164 })
165 }
166
167 pub fn topic(deny: &[&str]) -> GGuard {
169 let deny: Vec<String> = deny.iter().map(|s| s.to_lowercase()).collect();
170 GGuard::new("topic", move |output| {
171 let lower = output.to_lowercase();
172 for topic in &deny {
173 if lower.contains(topic.as_str()) {
174 return Err(format!("Output mentions denied topic: {}", topic));
175 }
176 }
177 Ok(())
178 })
179 }
180
181 pub fn custom(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
183 GGuard::new("custom", f)
184 }
185
186 pub fn output(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
188 GGuard::new("output", f)
189 }
190
191 pub fn input(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
193 GGuard::new("input", f)
194 }
195
196 pub fn rate_limit(max_per_minute: u32) -> GGuard {
198 GGuard::new("rate_limit", move |_output| {
199 let _ = max_per_minute;
201 Ok(())
202 })
203 }
204
205 pub fn toxicity() -> GGuard {
207 GGuard::new("toxicity", |_output| {
208 Ok(())
210 })
211 }
212
213 pub fn grounded() -> GGuard {
215 GGuard::new("grounded", |_output| {
216 Ok(())
218 })
219 }
220
221 pub fn hallucination() -> GGuard {
223 GGuard::new("hallucination", |_output| {
224 Ok(())
226 })
227 }
228
229 pub fn when(predicate: impl Fn(&str) -> bool + Send + Sync + 'static, inner: GGuard) -> GGuard {
231 GGuard::new("when", move |output| {
232 if predicate(output) {
233 inner.check(output)
234 } else {
235 Ok(())
236 }
237 })
238 }
239
240 pub fn llm_judge(prompt: &str) -> GGuard {
242 let prompt = prompt.to_string();
243 GGuard::new("llm_judge", move |_output| {
244 let _ = &prompt;
246 Ok(())
247 })
248 }
249
250 pub fn custom_judge(
252 name: &str,
253 f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static,
254 ) -> GGuard {
255 let name: &'static str = Box::leak(name.to_string().into_boxed_str());
257 GGuard::new(name, f)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn length_guard_passes() {
267 assert!(G::length(1, 100).check("hello").is_ok());
268 }
269
270 #[test]
271 fn length_guard_too_short() {
272 assert!(G::length(10, 100).check("hi").is_err());
273 }
274
275 #[test]
276 fn length_guard_too_long() {
277 assert!(G::length(1, 5).check("too long text").is_err());
278 }
279
280 #[test]
281 fn json_guard_valid() {
282 assert!(G::json().check(r#"{"key": "value"}"#).is_ok());
283 }
284
285 #[test]
286 fn json_guard_invalid() {
287 assert!(G::json().check("not json").is_err());
288 }
289
290 #[test]
291 fn regex_guard_blocks() {
292 assert!(G::regex("secret").check("this is a secret").is_err());
293 }
294
295 #[test]
296 fn regex_guard_passes() {
297 assert!(G::regex("secret").check("this is public").is_ok());
298 }
299
300 #[test]
301 fn budget_guard_passes() {
302 assert!(G::budget(100).check("short").is_ok());
303 }
304
305 #[test]
306 fn topic_guard_blocks() {
307 assert!(G::topic(&["violence"]).check("There was violence").is_err());
308 }
309
310 #[test]
311 fn topic_guard_passes() {
312 assert!(G::topic(&["violence"]).check("A peaceful day").is_ok());
313 }
314
315 #[test]
316 fn compose_with_bitor() {
317 let composite = G::length(1, 1000) | G::json();
318 assert_eq!(composite.len(), 2);
319 }
320
321 #[test]
322 fn check_all_returns_violations() {
323 let composite = G::length(1, 5) | G::json();
324 let violations = composite.check_all("not json and too long text here");
325 assert!(!violations.is_empty());
326 }
327
328 #[test]
329 fn custom_guard() {
330 let g = G::custom(|output| {
331 if output.contains("bad") {
332 Err("Contains 'bad'".into())
333 } else {
334 Ok(())
335 }
336 });
337 assert!(g.check("good output").is_ok());
338 assert!(g.check("bad output").is_err());
339 }
340
341 #[test]
342 fn output_guard() {
343 let g = G::output(|output| {
344 if output.contains("forbidden") {
345 Err("Forbidden content".into())
346 } else {
347 Ok(())
348 }
349 });
350 assert!(g.check("safe content").is_ok());
351 assert!(g.check("forbidden content").is_err());
352 assert_eq!(g.name(), "output");
353 }
354
355 #[test]
356 fn input_guard() {
357 let g = G::input(|input| {
358 if input.is_empty() {
359 Err("Empty input".into())
360 } else {
361 Ok(())
362 }
363 });
364 assert!(g.check("hello").is_ok());
365 assert!(g.check("").is_err());
366 assert_eq!(g.name(), "input");
367 }
368
369 #[test]
370 fn rate_limit_guard() {
371 let g = G::rate_limit(60);
372 assert!(g.check("anything").is_ok());
373 assert_eq!(g.name(), "rate_limit");
374 }
375
376 #[test]
377 fn toxicity_guard() {
378 let g = G::toxicity();
379 assert!(g.check("anything").is_ok());
380 assert_eq!(g.name(), "toxicity");
381 }
382
383 #[test]
384 fn grounded_guard() {
385 let g = G::grounded();
386 assert!(g.check("anything").is_ok());
387 assert_eq!(g.name(), "grounded");
388 }
389
390 #[test]
391 fn hallucination_guard() {
392 let g = G::hallucination();
393 assert!(g.check("anything").is_ok());
394 assert_eq!(g.name(), "hallucination");
395 }
396
397 #[test]
398 fn when_guard_applies() {
399 let inner = G::length(1, 5);
400 let g = G::when(|output| output.starts_with("check:"), inner);
401 assert!(g.check("check: this is way too long").is_err());
403 assert!(g.check("skip: this is way too long").is_ok());
405 assert_eq!(g.name(), "when");
406 }
407
408 #[test]
409 fn llm_judge_guard() {
410 let g = G::llm_judge("Is this response helpful?");
411 assert!(g.check("anything").is_ok());
412 assert_eq!(g.name(), "llm_judge");
413 }
414
415 #[test]
416 fn custom_judge_guard() {
417 let g = G::custom_judge("profanity_filter", |output| {
418 if output.contains("bad_word") {
419 Err("Profanity detected".into())
420 } else {
421 Ok(())
422 }
423 });
424 assert!(g.check("clean text").is_ok());
425 assert!(g.check("has bad_word here").is_err());
426 assert_eq!(g.name(), "profanity_filter");
427 }
428
429 #[test]
430 fn compose_new_guards_with_bitor() {
431 let composite = G::toxicity() | G::grounded() | G::hallucination();
432 assert_eq!(composite.len(), 3);
433 assert!(composite.check_all("test").is_empty());
434 }
435}