gemini_adk_fluent_rs/compose/
guards.rs1use std::sync::Arc;
14
15use async_trait::async_trait;
16use gemini_adk_rs::error::AgentError;
17use gemini_adk_rs::llm::{BaseLlm, LlmRequest, LlmResponse};
18use gemini_adk_rs::middleware::Middleware;
19
20use crate::compose::judge::{render_contents, LlmJudge};
21
22#[derive(Clone)]
24pub struct GGuard {
25 name: &'static str,
26 kind: GuardKind,
27}
28
29#[derive(Clone)]
31enum GuardKind {
32 Sync(#[allow(clippy::type_complexity)] Arc<dyn Fn(&str) -> Result<(), String> + Send + Sync>),
34 Judge(LlmJudge),
36}
37
38impl GGuard {
39 fn new(
40 name: &'static str,
41 f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static,
42 ) -> Self {
43 Self {
44 name,
45 kind: GuardKind::Sync(Arc::new(f)),
46 }
47 }
48
49 fn judge(name: &'static str, judge: LlmJudge) -> Self {
50 Self {
51 name,
52 kind: GuardKind::Judge(judge),
53 }
54 }
55
56 pub fn name(&self) -> &str {
58 self.name
59 }
60
61 pub fn check(&self, output: &str) -> Result<(), String> {
65 match &self.kind {
66 GuardKind::Sync(f) => f(output),
67 GuardKind::Judge(_) => Ok(()),
68 }
69 }
70
71 pub async fn check_async(&self, output: &str, context: Option<&str>) -> Result<(), String> {
74 match &self.kind {
75 GuardKind::Sync(f) => f(output),
76 GuardKind::Judge(judge) => {
77 let verdict = judge.judge(output, context).await;
78 if verdict.flagged {
79 Err(verdict.reason)
80 } else {
81 Ok(())
82 }
83 }
84 }
85 }
86}
87
88impl std::fmt::Debug for GGuard {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("GGuard").field("name", &self.name).finish()
91 }
92}
93
94impl std::ops::BitOr for GGuard {
96 type Output = GComposite;
97
98 fn bitor(self, rhs: GGuard) -> Self::Output {
99 GComposite {
100 guards: vec![self, rhs],
101 }
102 }
103}
104
105#[derive(Clone)]
107pub struct GComposite {
108 pub guards: Vec<GGuard>,
110}
111
112impl GComposite {
113 pub fn check_all(&self, output: &str) -> Vec<String> {
116 self.guards
117 .iter()
118 .filter_map(|g| g.check(output).err())
119 .collect()
120 }
121
122 pub async fn check_all_async(&self, output: &str, context: Option<&str>) -> Vec<String> {
125 let mut violations = Vec::new();
126 for g in &self.guards {
127 if let Err(reason) = g.check_async(output, context).await {
128 violations.push(format!("{}: {}", g.name(), reason));
129 }
130 }
131 violations
132 }
133
134 pub fn len(&self) -> usize {
136 self.guards.len()
137 }
138
139 pub fn is_empty(&self) -> bool {
141 self.guards.is_empty()
142 }
143}
144
145impl std::ops::BitOr<GGuard> for GComposite {
146 type Output = GComposite;
147
148 fn bitor(mut self, rhs: GGuard) -> Self::Output {
149 self.guards.push(rhs);
150 self
151 }
152}
153
154impl From<GGuard> for GComposite {
157 fn from(guard: GGuard) -> Self {
158 GComposite {
159 guards: vec![guard],
160 }
161 }
162}
163
164impl GComposite {
165 pub fn into_middleware(self) -> Arc<dyn Middleware> {
168 Arc::new(GuardMiddleware { guards: self })
169 }
170}
171
172struct GuardMiddleware {
174 guards: GComposite,
175}
176
177#[async_trait]
178impl Middleware for GuardMiddleware {
179 fn name(&self) -> &str {
180 "guard"
181 }
182
183 async fn after_model(
184 &self,
185 request: &LlmRequest,
186 response: &LlmResponse,
187 ) -> Result<Option<LlmResponse>, AgentError> {
188 let context = render_contents(&request.contents);
191 let violations = self
192 .guards
193 .check_all_async(&response.text(), Some(&context))
194 .await;
195 if violations.is_empty() {
196 Ok(None)
197 } else {
198 Err(AgentError::Other(format!(
199 "guard violation: {}",
200 violations.join("; ")
201 )))
202 }
203 }
204}
205
206pub struct G;
208
209impl G {
210 pub fn length(min: usize, max: usize) -> GGuard {
212 GGuard::new("length", move |output| {
213 let len = output.len();
214 if len < min {
215 Err(format!("Output too short: {} < {}", len, min))
216 } else if len > max {
217 Err(format!("Output too long: {} > {}", len, max))
218 } else {
219 Ok(())
220 }
221 })
222 }
223
224 pub fn regex(pattern: &str) -> GGuard {
226 let pattern = pattern.to_string();
227 GGuard::new("regex", move |output| {
228 if output.contains(&pattern) {
230 Err(format!("Output matches forbidden pattern: {}", pattern))
231 } else {
232 Ok(())
233 }
234 })
235 }
236
237 pub fn budget(max_tokens: usize) -> GGuard {
239 GGuard::new("budget", move |output| {
240 let estimated_tokens = output.len() / 4;
242 if estimated_tokens > max_tokens {
243 Err(format!(
244 "Output exceeds token budget: ~{} > {}",
245 estimated_tokens, max_tokens
246 ))
247 } else {
248 Ok(())
249 }
250 })
251 }
252
253 pub fn json() -> GGuard {
255 GGuard::new("json", |output| {
256 serde_json::from_str::<serde_json::Value>(output)
257 .map(|_| ())
258 .map_err(|e| format!("Invalid JSON: {}", e))
259 })
260 }
261
262 pub fn max_turns(n: u32) -> GGuard {
264 GGuard::new("max_turns", move |_output| {
265 let _ = n;
267 Ok(())
268 })
269 }
270
271 pub fn pii() -> GGuard {
273 GGuard::new("pii", |output| {
274 if output.contains('@') && output.contains('.') {
276 return Err("Output may contain email addresses".to_string());
278 }
279 Ok(())
280 })
281 }
282
283 pub fn topic(deny: &[&str]) -> GGuard {
285 let deny: Vec<String> = deny.iter().map(|s| s.to_lowercase()).collect();
286 GGuard::new("topic", move |output| {
287 let lower = output.to_lowercase();
288 for topic in &deny {
289 if lower.contains(topic.as_str()) {
290 return Err(format!("Output mentions denied topic: {}", topic));
291 }
292 }
293 Ok(())
294 })
295 }
296
297 pub fn custom(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
299 GGuard::new("custom", f)
300 }
301
302 pub fn output(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
304 GGuard::new("output", f)
305 }
306
307 pub fn input(f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static) -> GGuard {
309 GGuard::new("input", f)
310 }
311
312 pub fn rate_limit(max_per_minute: u32) -> GGuard {
314 GGuard::new("rate_limit", move |_output| {
315 let _ = max_per_minute;
317 Ok(())
318 })
319 }
320
321 pub fn toxicity(judge: Arc<dyn BaseLlm>) -> GGuard {
327 GGuard::judge(
328 "toxicity",
329 LlmJudge::new(
330 judge,
331 "the RESPONSE contains toxic, hateful, harassing, sexual, or otherwise \
332 abusive or unsafe content",
333 ),
334 )
335 }
336
337 pub fn grounded(judge: Arc<dyn BaseLlm>) -> GGuard {
342 GGuard::judge(
343 "grounded",
344 LlmJudge::new(
345 judge,
346 "the RESPONSE asserts facts that are NOT supported by, or that \
347 contradict, the provided CONTEXT",
348 )
349 .with_context("CONTEXT"),
350 )
351 }
352
353 pub fn hallucination(judge: Arc<dyn BaseLlm>) -> GGuard {
355 GGuard::judge(
356 "hallucination",
357 LlmJudge::new(
358 judge,
359 "the RESPONSE contains fabricated, invented, or unverifiable facts \
360 that are not supported by the CONTEXT",
361 )
362 .with_context("CONTEXT"),
363 )
364 }
365
366 pub fn when(predicate: impl Fn(&str) -> bool + Send + Sync + 'static, inner: GGuard) -> GGuard {
368 GGuard::new("when", move |output| {
369 if predicate(output) {
370 inner.check(output)
371 } else {
372 Ok(())
373 }
374 })
375 }
376
377 pub fn llm_judge(judge: Arc<dyn BaseLlm>, rubric: impl Into<String>) -> GGuard {
383 GGuard::judge("llm_judge", LlmJudge::new(judge, rubric))
384 }
385
386 pub fn custom_judge(
388 name: &str,
389 f: impl Fn(&str) -> Result<(), String> + Send + Sync + 'static,
390 ) -> GGuard {
391 let name: &'static str = Box::leak(name.to_string().into_boxed_str());
393 GGuard::new(name, f)
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[test]
402 fn length_guard_passes() {
403 assert!(G::length(1, 100).check("hello").is_ok());
404 }
405
406 #[test]
407 fn length_guard_too_short() {
408 assert!(G::length(10, 100).check("hi").is_err());
409 }
410
411 #[test]
412 fn length_guard_too_long() {
413 assert!(G::length(1, 5).check("too long text").is_err());
414 }
415
416 #[test]
417 fn json_guard_valid() {
418 assert!(G::json().check(r#"{"key": "value"}"#).is_ok());
419 }
420
421 #[test]
422 fn json_guard_invalid() {
423 assert!(G::json().check("not json").is_err());
424 }
425
426 #[test]
427 fn regex_guard_blocks() {
428 assert!(G::regex("secret").check("this is a secret").is_err());
429 }
430
431 #[test]
432 fn regex_guard_passes() {
433 assert!(G::regex("secret").check("this is public").is_ok());
434 }
435
436 #[test]
437 fn budget_guard_passes() {
438 assert!(G::budget(100).check("short").is_ok());
439 }
440
441 #[test]
442 fn topic_guard_blocks() {
443 assert!(G::topic(&["violence"]).check("There was violence").is_err());
444 }
445
446 #[test]
447 fn topic_guard_passes() {
448 assert!(G::topic(&["violence"]).check("A peaceful day").is_ok());
449 }
450
451 #[test]
452 fn compose_with_bitor() {
453 let composite = G::length(1, 1000) | G::json();
454 assert_eq!(composite.len(), 2);
455 }
456
457 #[test]
458 fn check_all_returns_violations() {
459 let composite = G::length(1, 5) | G::json();
460 let violations = composite.check_all("not json and too long text here");
461 assert!(!violations.is_empty());
462 }
463
464 #[test]
465 fn custom_guard() {
466 let g = G::custom(|output| {
467 if output.contains("bad") {
468 Err("Contains 'bad'".into())
469 } else {
470 Ok(())
471 }
472 });
473 assert!(g.check("good output").is_ok());
474 assert!(g.check("bad output").is_err());
475 }
476
477 #[test]
478 fn output_guard() {
479 let g = G::output(|output| {
480 if output.contains("forbidden") {
481 Err("Forbidden content".into())
482 } else {
483 Ok(())
484 }
485 });
486 assert!(g.check("safe content").is_ok());
487 assert!(g.check("forbidden content").is_err());
488 assert_eq!(g.name(), "output");
489 }
490
491 #[test]
492 fn input_guard() {
493 let g = G::input(|input| {
494 if input.is_empty() {
495 Err("Empty input".into())
496 } else {
497 Ok(())
498 }
499 });
500 assert!(g.check("hello").is_ok());
501 assert!(g.check("").is_err());
502 assert_eq!(g.name(), "input");
503 }
504
505 #[test]
506 fn rate_limit_guard() {
507 let g = G::rate_limit(60);
508 assert!(g.check("anything").is_ok());
509 assert_eq!(g.name(), "rate_limit");
510 }
511
512 fn judge_llm() -> Arc<dyn BaseLlm> {
515 use gemini_adk_rs::llm::{LlmError, LlmResponse};
516 use gemini_genai_rs::prelude::{Content, Part, Role};
517
518 struct NoopJudge;
519 #[async_trait]
520 impl BaseLlm for NoopJudge {
521 fn model_id(&self) -> &str {
522 "noop-judge"
523 }
524 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
525 Ok(LlmResponse {
526 content: Content {
527 role: Some(Role::Model),
528 parts: vec![Part::Text {
529 text: r#"{"violation": false, "reason": "ok"}"#.to_string(),
530 }],
531 },
532 finish_reason: Some("STOP".into()),
533 usage: None,
534 })
535 }
536 }
537 Arc::new(NoopJudge)
538 }
539
540 #[test]
541 fn toxicity_guard() {
542 let g = G::toxicity(judge_llm());
543 assert!(g.check("anything").is_ok());
545 assert_eq!(g.name(), "toxicity");
546 }
547
548 #[test]
549 fn grounded_guard() {
550 let g = G::grounded(judge_llm());
551 assert!(g.check("anything").is_ok());
552 assert_eq!(g.name(), "grounded");
553 }
554
555 #[test]
556 fn hallucination_guard() {
557 let g = G::hallucination(judge_llm());
558 assert!(g.check("anything").is_ok());
559 assert_eq!(g.name(), "hallucination");
560 }
561
562 #[tokio::test]
563 async fn judge_guard_runs_async() {
564 use gemini_adk_rs::llm::{LlmError, LlmResponse};
566 use gemini_genai_rs::prelude::{Content, Part, Role};
567 struct FlagAll;
568 #[async_trait]
569 impl BaseLlm for FlagAll {
570 fn model_id(&self) -> &str {
571 "flag-all"
572 }
573 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
574 Ok(LlmResponse {
575 content: Content {
576 role: Some(Role::Model),
577 parts: vec![Part::Text {
578 text: r#"{"violation": true, "reason": "bad"}"#.to_string(),
579 }],
580 },
581 finish_reason: Some("STOP".into()),
582 usage: None,
583 })
584 }
585 }
586 let g = G::toxicity(Arc::new(FlagAll));
587 assert!(g.check_async("hello", None).await.is_err());
588 }
589
590 #[test]
591 fn when_guard_applies() {
592 let inner = G::length(1, 5);
593 let g = G::when(|output| output.starts_with("check:"), inner);
594 assert!(g.check("check: this is way too long").is_err());
596 assert!(g.check("skip: this is way too long").is_ok());
598 assert_eq!(g.name(), "when");
599 }
600
601 #[test]
602 fn llm_judge_guard() {
603 let g = G::llm_judge(judge_llm(), "the response is unhelpful");
604 assert!(g.check("anything").is_ok());
605 assert_eq!(g.name(), "llm_judge");
606 }
607
608 #[test]
609 fn custom_judge_guard() {
610 let g = G::custom_judge("profanity_filter", |output| {
611 if output.contains("bad_word") {
612 Err("Profanity detected".into())
613 } else {
614 Ok(())
615 }
616 });
617 assert!(g.check("clean text").is_ok());
618 assert!(g.check("has bad_word here").is_err());
619 assert_eq!(g.name(), "profanity_filter");
620 }
621
622 #[test]
623 fn compose_new_guards_with_bitor() {
624 let composite =
625 G::toxicity(judge_llm()) | G::grounded(judge_llm()) | G::hallucination(judge_llm());
626 assert_eq!(composite.len(), 3);
627 assert!(composite.check_all("test").is_empty());
629 }
630}