gemini_adk_rs/evaluation/
rubric_evaluator.rs

1//! Rubric-based evaluator — evaluate agent responses against rubric criteria.
2//!
3//! Uses an LLM-as-judge to score agent outputs against one or more free-text
4//! rubric criteria. Supports both final-response quality and tool-use quality
5//! evaluation modes.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use super::eval_case::Invocation;
12use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
13use super::evaluator::{EvalError, Evaluator};
14use crate::llm::BaseLlm;
15
16/// Evaluation mode for rubric evaluation.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum RubricMode {
19    /// Evaluate the final response quality.
20    FinalResponse,
21    /// Evaluate tool use quality (selection, arguments, sequencing).
22    ToolUse,
23}
24
25/// Evaluator that scores agent outputs against free-text rubric criteria
26/// using an LLM as judge.
27pub struct RubricEvaluator {
28    /// The rubric criteria to evaluate against.
29    rubrics: Vec<String>,
30    /// Optional override for the judge model.
31    judge_model: Option<String>,
32    /// The evaluation mode (response vs tool use).
33    mode: RubricMode,
34    /// Optional LLM for performing evaluations.
35    llm: Option<Arc<dyn BaseLlm>>,
36}
37
38impl RubricEvaluator {
39    /// Create a new rubric evaluator with the given rubric criteria.
40    pub fn new(rubrics: Vec<String>) -> Self {
41        Self {
42            rubrics,
43            judge_model: None,
44            mode: RubricMode::FinalResponse,
45            llm: None,
46        }
47    }
48
49    /// Create a rubric evaluator for final response quality.
50    ///
51    /// Uses the `rubric_based_final_response_quality_v1` evaluation strategy.
52    pub fn for_response(rubrics: Vec<String>) -> Self {
53        Self {
54            rubrics,
55            judge_model: None,
56            mode: RubricMode::FinalResponse,
57            llm: None,
58        }
59    }
60
61    /// Create a rubric evaluator for tool use quality.
62    ///
63    /// Uses the `rubric_based_tool_use_quality_v1` evaluation strategy.
64    pub fn for_tool_use(rubrics: Vec<String>) -> Self {
65        Self {
66            rubrics,
67            judge_model: None,
68            mode: RubricMode::ToolUse,
69            llm: None,
70        }
71    }
72
73    /// Set an override judge model name.
74    pub fn with_judge_model(mut self, model: impl Into<String>) -> Self {
75        self.judge_model = Some(model.into());
76        self
77    }
78
79    /// Provide an LLM instance for performing evaluations.
80    pub fn with_llm(mut self, llm: Arc<dyn BaseLlm>) -> Self {
81        self.llm = Some(llm);
82        self
83    }
84
85    /// Build the evaluation prompt for a single invocation.
86    fn build_prompt(&self, actual: &Invocation, expected: Option<&Invocation>) -> String {
87        let mode_label = match self.mode {
88            RubricMode::FinalResponse => "FINAL RESPONSE QUALITY",
89            RubricMode::ToolUse => "TOOL USE QUALITY",
90        };
91
92        let mut prompt = format!(
93            "You are an expert evaluator assessing {mode_label}.\n\n\
94             Score the agent's performance on a scale of 0.0 to 1.0 for EACH rubric criterion.\n\n"
95        );
96
97        // Add rubrics
98        prompt.push_str("RUBRIC CRITERIA:\n");
99        for (i, rubric) in self.rubrics.iter().enumerate() {
100            prompt.push_str(&format!("{}. {}\n", i + 1, rubric));
101        }
102        prompt.push('\n');
103
104        // Add actual conversation
105        prompt.push_str("ACTUAL AGENT CONVERSATION:\n");
106        for turn in &actual.turns {
107            prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
108            if !turn.tool_calls.is_empty() {
109                prompt.push_str(&format!(
110                    "  Tool calls: {}\n",
111                    serde_json::json!(turn.tool_calls)
112                ));
113            }
114            if !turn.tool_results.is_empty() {
115                prompt.push_str(&format!(
116                    "  Tool results: {}\n",
117                    serde_json::json!(turn.tool_results)
118                ));
119            }
120        }
121
122        // Add expected conversation if available
123        if let Some(expected) = expected {
124            prompt.push_str("\nEXPECTED CONVERSATION:\n");
125            for turn in &expected.turns {
126                prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
127                if !turn.tool_calls.is_empty() {
128                    prompt.push_str(&format!(
129                        "  Tool calls: {}\n",
130                        serde_json::json!(turn.tool_calls)
131                    ));
132                }
133            }
134        }
135
136        prompt.push_str(
137            "\nRespond with ONLY a JSON object:\n\
138             {\"scores\": [<float per rubric criterion>], \
139             \"overall_score\": <float average>, \
140             \"explanation\": \"<text>\"}\n",
141        );
142
143        prompt
144    }
145
146    /// Parse the LLM judge response to extract rubric scores.
147    fn parse_response(text: &str, num_rubrics: usize) -> (f64, String) {
148        // Try full JSON parse first
149        if let Some((score, explanation)) = try_parse_json(text) {
150            return (score, explanation);
151        }
152
153        // Try to find JSON embedded in text
154        if let Some(start) = text.find('{') {
155            if let Some(end) = text[start..].rfind('}') {
156                let json_str = &text[start..=start + end];
157                if let Some((score, explanation)) = try_parse_json(json_str) {
158                    return (score, explanation);
159                }
160            }
161        }
162
163        // Fallback: try to find individual scores
164        let _ = num_rubrics; // Used in full implementation
165        (
166            0.0,
167            format!("Failed to parse rubric judge response: {text}"),
168        )
169    }
170}
171
172/// Try to parse a JSON string into a score and explanation.
173fn try_parse_json(text: &str) -> Option<(f64, String)> {
174    let v: serde_json::Value = serde_json::from_str(text).ok()?;
175
176    let score = if let Some(overall) = v["overall_score"].as_f64() {
177        overall.clamp(0.0, 1.0)
178    } else if let Some(scores) = v["scores"].as_array() {
179        let sum: f64 = scores
180            .iter()
181            .filter_map(|s| s.as_f64())
182            .map(|s| s.clamp(0.0, 1.0))
183            .sum();
184        let count = scores.len().max(1) as f64;
185        sum / count
186    } else {
187        return None;
188    };
189
190    let explanation = v["explanation"]
191        .as_str()
192        .unwrap_or("No explanation")
193        .to_string();
194
195    Some((score, explanation))
196}
197
198#[async_trait]
199impl Evaluator for RubricEvaluator {
200    async fn evaluate(
201        &self,
202        actual: &[Invocation],
203        expected: Option<&[Invocation]>,
204    ) -> Result<EvalResult, EvalError> {
205        let llm = self.llm.as_ref().ok_or_else(|| {
206            EvalError::Llm(
207                "RubricEvaluator requires an LLM instance — call .with_llm() before evaluating"
208                    .into(),
209            )
210        })?;
211
212        let mut per_invocation = Vec::new();
213        let mut total_score = 0.0;
214
215        for (i, actual_inv) in actual.iter().enumerate() {
216            let expected_inv = expected.and_then(|e| e.get(i));
217            let prompt = self.build_prompt(actual_inv, expected_inv);
218
219            let request = crate::llm::LlmRequest::from_text(&prompt);
220            let response = llm
221                .generate(request)
222                .await
223                .map_err(|e| EvalError::Llm(e.to_string()))?;
224
225            let (score, explanation) = Self::parse_response(&response.text(), self.rubrics.len());
226            total_score += score;
227
228            per_invocation.push(PerInvocationResult {
229                invocation_id: if actual_inv.id.is_empty() {
230                    format!("inv-{i}")
231                } else {
232                    actual_inv.id.clone()
233                },
234                score,
235                explanation: Some(explanation),
236            });
237        }
238
239        let overall_score = if actual.is_empty() {
240            0.0
241        } else {
242            total_score / actual.len() as f64
243        };
244
245        let metric_name = match self.mode {
246            RubricMode::FinalResponse => "rubric_based_final_response_quality_v1",
247            RubricMode::ToolUse => "rubric_based_tool_use_quality_v1",
248        };
249
250        Ok(EvalResult {
251            overall_score,
252            metrics: vec![EvalMetric {
253                name: metric_name.into(),
254                score: overall_score,
255                per_invocation,
256            }],
257        })
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn parse_valid_response() {
267        let json = r#"{"scores": [0.8, 0.9], "overall_score": 0.85, "explanation": "Good"}"#;
268        let (score, explanation) = RubricEvaluator::parse_response(json, 2);
269        assert!((score - 0.85).abs() < f64::EPSILON);
270        assert_eq!(explanation, "Good");
271    }
272
273    #[test]
274    fn parse_scores_only() {
275        let json = r#"{"scores": [0.8, 0.6]}"#;
276        let (score, _) = RubricEvaluator::parse_response(json, 2);
277        assert!((score - 0.7).abs() < f64::EPSILON);
278    }
279
280    #[test]
281    fn parse_embedded_json() {
282        let text = r#"Here is my evaluation: {"overall_score": 0.9, "explanation": "Great"}"#;
283        let (score, _) = RubricEvaluator::parse_response(text, 1);
284        assert!((score - 0.9).abs() < f64::EPSILON);
285    }
286
287    #[test]
288    fn parse_invalid() {
289        let (score, explanation) = RubricEvaluator::parse_response("no json here", 1);
290        assert!((score - 0.0).abs() < f64::EPSILON);
291        assert!(explanation.contains("Failed to parse"));
292    }
293
294    #[test]
295    fn for_response_mode() {
296        let eval = RubricEvaluator::for_response(vec!["Accuracy".into()]);
297        assert_eq!(eval.mode, RubricMode::FinalResponse);
298    }
299
300    #[test]
301    fn for_tool_use_mode() {
302        let eval = RubricEvaluator::for_tool_use(vec!["Tool selection".into()]);
303        assert_eq!(eval.mode, RubricMode::ToolUse);
304    }
305
306    #[test]
307    fn build_prompt_includes_rubrics() {
308        use crate::evaluation::eval_case::InvocationTurn;
309
310        let eval = RubricEvaluator::new(vec![
311            "Is the response accurate?".into(),
312            "Is it well-formatted?".into(),
313        ]);
314        let inv = Invocation {
315            id: "test".into(),
316            turns: vec![InvocationTurn {
317                role: "user".into(),
318                content: "Hello".into(),
319                tool_calls: vec![],
320                tool_results: vec![],
321            }],
322            metadata: serde_json::Value::Null,
323        };
324        let prompt = eval.build_prompt(&inv, None);
325        assert!(prompt.contains("Is the response accurate?"));
326        assert!(prompt.contains("Is it well-formatted?"));
327        assert!(prompt.contains("FINAL RESPONSE QUALITY"));
328    }
329
330    #[test]
331    fn with_judge_model() {
332        let eval = RubricEvaluator::new(vec!["test".into()]).with_judge_model("gemini-2.0-flash");
333        assert_eq!(eval.judge_model.as_deref(), Some("gemini-2.0-flash"));
334    }
335
336    #[test]
337    fn score_clamped() {
338        let json = r#"{"overall_score": 1.5, "explanation": "Over"}"#;
339        let (score, _) = RubricEvaluator::parse_response(json, 1);
340        assert!((score - 1.0).abs() < f64::EPSILON);
341    }
342}