gemini_adk_fluent_rs/compose/
judge.rs1use std::sync::Arc;
9
10use gemini_adk_rs::llm::{BaseLlm, LlmRequest};
11use gemini_genai_rs::prelude::{Content, Part, Role};
12
13#[derive(Clone)]
15pub struct LlmJudge {
16 llm: Arc<dyn BaseLlm>,
17 rubric: String,
19 context_label: &'static str,
21 use_context: bool,
23}
24
25pub struct Verdict {
27 pub flagged: bool,
29 pub reason: String,
31}
32
33impl LlmJudge {
34 pub fn new(llm: Arc<dyn BaseLlm>, rubric: impl Into<String>) -> Self {
36 Self {
37 llm,
38 rubric: rubric.into(),
39 context_label: "CONTEXT",
40 use_context: false,
41 }
42 }
43
44 pub fn with_context(mut self, label: &'static str) -> Self {
47 self.use_context = true;
48 self.context_label = label;
49 self
50 }
51
52 pub fn rubric(&self) -> &str {
54 &self.rubric
55 }
56
57 fn prompt(&self, output: &str, context: Option<&str>) -> String {
58 let mut p = String::with_capacity(output.len() + self.rubric.len() + 256);
59 p.push_str(
60 "You are a strict evaluator. Decide whether the RESPONSE below violates the \
61 RULE. Judge only against the RULE; do not invent additional requirements.\n\n",
62 );
63 p.push_str("RULE — a violation is when the following is TRUE:\n");
64 p.push_str(&self.rubric);
65 p.push_str("\n\n");
66 if self.use_context {
67 if let Some(ctx) = context {
68 p.push_str(self.context_label);
69 p.push_str(":\n");
70 p.push_str(ctx);
71 p.push_str("\n\n");
72 }
73 }
74 p.push_str("RESPONSE:\n");
75 p.push_str(output);
76 p.push_str(
77 "\n\nReply with ONLY a single-line JSON object and nothing else:\n\
78 {\"violation\": true|false, \"reason\": \"<at most 20 words>\"}",
79 );
80 p
81 }
82
83 pub async fn judge(&self, output: &str, context: Option<&str>) -> Verdict {
89 let req = LlmRequest::from_contents(vec![Content::user(self.prompt(output, context))]);
90 match self.llm.generate(req).await {
91 Ok(resp) => parse_verdict(&resp.text()),
92 Err(e) => Verdict {
93 flagged: false,
94 reason: format!("judge unavailable: {e}"),
95 },
96 }
97 }
98}
99
100pub fn parse_verdict(text: &str) -> Verdict {
104 let lower = text.to_ascii_lowercase();
105 let flagged = match lower.find("violation") {
106 Some(idx) => {
107 let tail = &lower[idx..];
108 match (tail.find("true"), tail.find("false")) {
109 (Some(t), Some(f)) => t < f,
110 (Some(_), None) => true,
111 (None, Some(_)) => false,
112 (None, None) => lower.contains("invalid") || lower.contains("unsafe"),
113 }
114 }
115 None => lower.contains("invalid") || lower.contains("unsafe"),
118 };
119 let reason = extract_reason(text).unwrap_or_else(|| {
120 if flagged {
121 "flagged by judge".to_string()
122 } else {
123 "ok".to_string()
124 }
125 });
126 Verdict { flagged, reason }
127}
128
129fn extract_reason(text: &str) -> Option<String> {
130 let key = "\"reason\"";
131 let i = text.find(key)?;
132 let after = &text[i + key.len()..];
133 let colon = after.find(':')?;
134 let rest = after[colon + 1..].trim_start();
135 let rest = rest.strip_prefix('"')?;
136 let end = rest.find('"')?;
137 Some(rest[..end].to_string())
138}
139
140pub fn render_contents(contents: &[Content]) -> String {
143 let mut out = String::new();
144 for content in contents {
145 let role = match content.role {
146 Some(Role::User) => "user",
147 Some(Role::Model) => "model",
148 _ => "system",
149 };
150 for part in &content.parts {
151 if let Part::Text { text } = part {
152 if !text.is_empty() {
153 out.push_str(role);
154 out.push_str(": ");
155 out.push_str(text);
156 out.push('\n');
157 }
158 }
159 }
160 }
161 out
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn parses_violation_true() {
170 let v = parse_verdict(r#"{"violation": true, "reason": "contains slur"}"#);
171 assert!(v.flagged);
172 assert_eq!(v.reason, "contains slur");
173 }
174
175 #[test]
176 fn parses_violation_false() {
177 let v = parse_verdict("Sure! {\"violation\": false, \"reason\": \"all good\"}");
178 assert!(!v.flagged);
179 assert_eq!(v.reason, "all good");
180 }
181
182 #[test]
183 fn falls_back_to_labels() {
184 assert!(parse_verdict("Verdict: INVALID").flagged);
185 assert!(!parse_verdict("looks valid to me").flagged);
186 }
187}