gemini_adk_fluent_rs/compose/
judge.rs

1//! LLM-as-judge — the shared async evaluation primitive behind the LLM-backed
2//! `G::` guards and `E::` criteria.
3//!
4//! Mirrors ADK Python's `final_response_match_v2` / safety-evaluator approach: a
5//! judge model is prompted to render a structured verdict, and the verdict label
6//! is parsed back out of the model's reply (robust to surrounding prose).
7
8use std::sync::Arc;
9
10use gemini_adk_rs::llm::{BaseLlm, LlmRequest};
11use gemini_genai_rs::prelude::{Content, Part, Role};
12
13/// A configured LLM-as-judge. Cheap to clone (`Arc` inside).
14#[derive(Clone)]
15pub struct LlmJudge {
16    llm: Arc<dyn BaseLlm>,
17    /// Describes the condition that constitutes a *violation* (a `true` verdict).
18    rubric: String,
19    /// Label used for the supplied context/reference block in the prompt.
20    context_label: &'static str,
21    /// Whether to include the context/reference block at all.
22    use_context: bool,
23}
24
25/// The judge's verdict.
26pub struct Verdict {
27    /// Whether the judge flagged a violation of the rubric.
28    pub flagged: bool,
29    /// Short reason / explanation extracted from the judge reply.
30    pub reason: String,
31}
32
33impl LlmJudge {
34    /// Create a judge for the given rubric (the condition that is a *violation*).
35    pub fn new(llm: Arc<dyn BaseLlm>, rubric: impl Into<String>) -> Self {
36        Self {
37            llm,
38            rubric: rubric.into(),
39            context_label: "CONTEXT",
40            use_context: false,
41        }
42    }
43
44    /// Include a context/reference block in the judge prompt, under `label`
45    /// (e.g. "CONTEXT" for grounding, "REFERENCE ANSWER" for semantic match).
46    pub fn with_context(mut self, label: &'static str) -> Self {
47        self.use_context = true;
48        self.context_label = label;
49        self
50    }
51
52    /// The rubric describing what counts as a violation.
53    pub fn rubric(&self) -> &str {
54        &self.rubric
55    }
56
57    fn prompt(&self, output: &str, context: Option<&str>) -> String {
58        let mut p = String::with_capacity(output.len() + self.rubric.len() + 256);
59        p.push_str(
60            "You are a strict evaluator. Decide whether the RESPONSE below violates the \
61             RULE. Judge only against the RULE; do not invent additional requirements.\n\n",
62        );
63        p.push_str("RULE — a violation is when the following is TRUE:\n");
64        p.push_str(&self.rubric);
65        p.push_str("\n\n");
66        if self.use_context {
67            if let Some(ctx) = context {
68                p.push_str(self.context_label);
69                p.push_str(":\n");
70                p.push_str(ctx);
71                p.push_str("\n\n");
72            }
73        }
74        p.push_str("RESPONSE:\n");
75        p.push_str(output);
76        p.push_str(
77            "\n\nReply with ONLY a single-line JSON object and nothing else:\n\
78             {\"violation\": true|false, \"reason\": \"<at most 20 words>\"}",
79        );
80        p
81    }
82
83    /// Run the judge over an output (and optional context/reference).
84    ///
85    /// Fails open: if the judge LLM errors, the verdict is *not* flagged (so a
86    /// transient judge outage never vetoes a turn) and the error is recorded in
87    /// `reason`.
88    pub async fn judge(&self, output: &str, context: Option<&str>) -> Verdict {
89        let req = LlmRequest::from_contents(vec![Content::user(self.prompt(output, context))]);
90        match self.llm.generate(req).await {
91            Ok(resp) => parse_verdict(&resp.text()),
92            Err(e) => Verdict {
93                flagged: false,
94                reason: format!("judge unavailable: {e}"),
95            },
96        }
97    }
98}
99
100/// Parse a verdict from the judge model's reply. Tolerant of extra prose around
101/// the JSON: it scans for the `violation` field's boolean and the `reason`
102/// string, falling back to common labels (`invalid`, `unsafe`).
103pub fn parse_verdict(text: &str) -> Verdict {
104    let lower = text.to_ascii_lowercase();
105    let flagged = match lower.find("violation") {
106        Some(idx) => {
107            let tail = &lower[idx..];
108            match (tail.find("true"), tail.find("false")) {
109                (Some(t), Some(f)) => t < f,
110                (Some(_), None) => true,
111                (None, Some(_)) => false,
112                (None, None) => lower.contains("invalid") || lower.contains("unsafe"),
113            }
114        }
115        // No explicit `violation` field — fall back to common labels. Note
116        // `contains("invalid")` is false for "valid" but true for "INVALID".
117        None => lower.contains("invalid") || lower.contains("unsafe"),
118    };
119    let reason = extract_reason(text).unwrap_or_else(|| {
120        if flagged {
121            "flagged by judge".to_string()
122        } else {
123            "ok".to_string()
124        }
125    });
126    Verdict { flagged, reason }
127}
128
129fn extract_reason(text: &str) -> Option<String> {
130    let key = "\"reason\"";
131    let i = text.find(key)?;
132    let after = &text[i + key.len()..];
133    let colon = after.find(':')?;
134    let rest = after[colon + 1..].trim_start();
135    let rest = rest.strip_prefix('"')?;
136    let end = rest.find('"')?;
137    Some(rest[..end].to_string())
138}
139
140/// Render conversation history into a plain-text block for a judge prompt,
141/// keeping role labels so the judge can reason about grounding.
142pub fn render_contents(contents: &[Content]) -> String {
143    let mut out = String::new();
144    for content in contents {
145        let role = match content.role {
146            Some(Role::User) => "user",
147            Some(Role::Model) => "model",
148            _ => "system",
149        };
150        for part in &content.parts {
151            if let Part::Text { text } = part {
152                if !text.is_empty() {
153                    out.push_str(role);
154                    out.push_str(": ");
155                    out.push_str(text);
156                    out.push('\n');
157                }
158            }
159        }
160    }
161    out
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn parses_violation_true() {
170        let v = parse_verdict(r#"{"violation": true, "reason": "contains slur"}"#);
171        assert!(v.flagged);
172        assert_eq!(v.reason, "contains slur");
173    }
174
175    #[test]
176    fn parses_violation_false() {
177        let v = parse_verdict("Sure! {\"violation\": false, \"reason\": \"all good\"}");
178        assert!(!v.flagged);
179        assert_eq!(v.reason, "all good");
180    }
181
182    #[test]
183    fn falls_back_to_labels() {
184        assert!(parse_verdict("Verdict: INVALID").flagged);
185        assert!(!parse_verdict("looks valid to me").flagged);
186    }
187}