gemini_adk_rs/evaluation/
llm_as_judge.rs

1//! LLM-as-judge evaluator — uses an LLM to grade agent responses.
2//!
3//! Mirrors ADK-Python's `llm_as_judge` evaluator.
4
5use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use super::eval_case::Invocation;
10use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
11use super::evaluator::{EvalError, Evaluator};
12use crate::llm::BaseLlm;
13
14/// Configuration for the LLM-as-judge evaluator.
15#[derive(Debug, Clone)]
16pub struct LlmAsJudgeConfig {
17    /// The rubric/criteria to evaluate against.
18    pub rubric: String,
19    /// The metric name for this evaluation.
20    pub metric_name: String,
21}
22
23impl Default for LlmAsJudgeConfig {
24    fn default() -> Self {
25        Self {
26            rubric: "Evaluate the quality and correctness of the agent's response.".into(),
27            metric_name: "llm_judge_score".into(),
28        }
29    }
30}
31
32/// Evaluator that uses an LLM to judge agent responses.
33///
34/// Sends the actual and expected invocations to an LLM along with
35/// a rubric, and parses the score from the LLM's response.
36pub struct LlmAsJudge {
37    llm: Arc<dyn BaseLlm>,
38    config: LlmAsJudgeConfig,
39}
40
41impl LlmAsJudge {
42    /// Create a new LLM-as-judge evaluator.
43    pub fn new(llm: Arc<dyn BaseLlm>, config: LlmAsJudgeConfig) -> Self {
44        Self { llm, config }
45    }
46
47    /// Build the evaluation prompt for a single invocation.
48    fn build_prompt(&self, actual: &Invocation, expected: Option<&Invocation>) -> String {
49        let mut prompt = format!(
50            "You are an expert evaluator. Score the agent's response on a scale of 0.0 to 1.0.\n\n\
51             Rubric: {}\n\n\
52             Actual conversation:\n",
53            self.config.rubric
54        );
55
56        for turn in &actual.turns {
57            prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
58        }
59
60        if let Some(expected) = expected {
61            prompt.push_str("\nExpected conversation:\n");
62            for turn in &expected.turns {
63                prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
64            }
65        }
66
67        prompt.push_str(
68            "\nRespond with ONLY a JSON object: {\"score\": <float>, \"explanation\": \"<text>\"}",
69        );
70
71        prompt
72    }
73}
74
75#[async_trait]
76impl Evaluator for LlmAsJudge {
77    async fn evaluate(
78        &self,
79        actual: &[Invocation],
80        expected: Option<&[Invocation]>,
81    ) -> Result<EvalResult, EvalError> {
82        let mut per_invocation = Vec::new();
83        let mut total_score = 0.0;
84
85        for (i, actual_inv) in actual.iter().enumerate() {
86            let expected_inv = expected.and_then(|e| e.get(i));
87            let prompt = self.build_prompt(actual_inv, expected_inv);
88
89            let request = crate::llm::LlmRequest::from_text(&prompt);
90            let response = self
91                .llm
92                .generate(request)
93                .await
94                .map_err(|e| EvalError::Llm(e.to_string()))?;
95
96            // Try to parse score from response
97            let (score, explanation) = parse_judge_response(&response.text());
98            total_score += score;
99
100            per_invocation.push(PerInvocationResult {
101                invocation_id: if actual_inv.id.is_empty() {
102                    format!("inv-{}", i)
103                } else {
104                    actual_inv.id.clone()
105                },
106                score,
107                explanation: Some(explanation),
108            });
109        }
110
111        let overall_score = if actual.is_empty() {
112            0.0
113        } else {
114            total_score / actual.len() as f64
115        };
116
117        Ok(EvalResult {
118            overall_score,
119            metrics: vec![EvalMetric {
120                name: self.config.metric_name.clone(),
121                score: overall_score,
122                per_invocation,
123            }],
124        })
125    }
126}
127
128/// Parse the LLM judge's response to extract score and explanation.
129fn parse_judge_response(text: &str) -> (f64, String) {
130    // Try to parse JSON response
131    if let Ok(v) = serde_json::from_str::<serde_json::Value>(text) {
132        let score = v["score"].as_f64().unwrap_or(0.0).clamp(0.0, 1.0);
133        let explanation = v["explanation"]
134            .as_str()
135            .unwrap_or("No explanation")
136            .to_string();
137        return (score, explanation);
138    }
139
140    // Try to find JSON in the response text
141    if let Some(start) = text.find('{') {
142        if let Some(end) = text[start..].rfind('}') {
143            let json_str = &text[start..=start + end];
144            if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
145                let score = v["score"].as_f64().unwrap_or(0.0).clamp(0.0, 1.0);
146                let explanation = v["explanation"]
147                    .as_str()
148                    .unwrap_or("No explanation")
149                    .to_string();
150                return (score, explanation);
151            }
152        }
153    }
154
155    (0.0, format!("Failed to parse judge response: {text}"))
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[test]
163    fn parse_valid_json_response() {
164        let (score, explanation) =
165            parse_judge_response(r#"{"score": 0.85, "explanation": "Good response"}"#);
166        assert!((score - 0.85).abs() < f64::EPSILON);
167        assert_eq!(explanation, "Good response");
168    }
169
170    #[test]
171    fn parse_json_in_text() {
172        let (score, _) = parse_judge_response(
173            r#"Here is my evaluation: {"score": 0.7, "explanation": "Decent"}"#,
174        );
175        assert!((score - 0.7).abs() < f64::EPSILON);
176    }
177
178    #[test]
179    fn parse_invalid_response() {
180        let (score, explanation) = parse_judge_response("This is just text");
181        assert!((score - 0.0).abs() < f64::EPSILON);
182        assert!(explanation.contains("Failed to parse"));
183    }
184
185    #[test]
186    fn score_clamped_to_valid_range() {
187        let (score, _) = parse_judge_response(r#"{"score": 1.5, "explanation": "Over"}"#);
188        assert!((score - 1.0).abs() < f64::EPSILON);
189    }
190
191    #[test]
192    fn default_config() {
193        let config = LlmAsJudgeConfig::default();
194        assert_eq!(config.metric_name, "llm_judge_score");
195        assert!(!config.rubric.is_empty());
196    }
197
198    #[test]
199    fn build_prompt_includes_rubric() {
200        use crate::evaluation::eval_case::InvocationTurn;
201
202        struct DummyLlm;
203        #[async_trait]
204        impl BaseLlm for DummyLlm {
205            fn model_id(&self) -> &str {
206                "dummy"
207            }
208            async fn generate(
209                &self,
210                _req: crate::llm::LlmRequest,
211            ) -> Result<crate::llm::LlmResponse, crate::llm::LlmError> {
212                unreachable!()
213            }
214        }
215
216        let judge = LlmAsJudge::new(Arc::new(DummyLlm), LlmAsJudgeConfig::default());
217        let inv = Invocation {
218            id: "test".into(),
219            turns: vec![InvocationTurn {
220                role: "user".into(),
221                content: "Hello".into(),
222                tool_calls: vec![],
223                tool_results: vec![],
224            }],
225            metadata: serde_json::Value::Null,
226        };
227        let prompt = judge.build_prompt(&inv, None);
228        assert!(prompt.contains("Rubric:"));
229        assert!(prompt.contains("[user]: Hello"));
230    }
231}