gemini_adk_rs/evaluation/
rubric_evaluator.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use super::eval_case::Invocation;
12use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
13use super::evaluator::{EvalError, Evaluator};
14use crate::llm::BaseLlm;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum RubricMode {
19 FinalResponse,
21 ToolUse,
23}
24
25pub struct RubricEvaluator {
28 rubrics: Vec<String>,
30 judge_model: Option<String>,
32 mode: RubricMode,
34 llm: Option<Arc<dyn BaseLlm>>,
36}
37
38impl RubricEvaluator {
39 pub fn new(rubrics: Vec<String>) -> Self {
41 Self {
42 rubrics,
43 judge_model: None,
44 mode: RubricMode::FinalResponse,
45 llm: None,
46 }
47 }
48
49 pub fn for_response(rubrics: Vec<String>) -> Self {
53 Self {
54 rubrics,
55 judge_model: None,
56 mode: RubricMode::FinalResponse,
57 llm: None,
58 }
59 }
60
61 pub fn for_tool_use(rubrics: Vec<String>) -> Self {
65 Self {
66 rubrics,
67 judge_model: None,
68 mode: RubricMode::ToolUse,
69 llm: None,
70 }
71 }
72
73 pub fn with_judge_model(mut self, model: impl Into<String>) -> Self {
75 self.judge_model = Some(model.into());
76 self
77 }
78
79 pub fn with_llm(mut self, llm: Arc<dyn BaseLlm>) -> Self {
81 self.llm = Some(llm);
82 self
83 }
84
85 fn build_prompt(&self, actual: &Invocation, expected: Option<&Invocation>) -> String {
87 let mode_label = match self.mode {
88 RubricMode::FinalResponse => "FINAL RESPONSE QUALITY",
89 RubricMode::ToolUse => "TOOL USE QUALITY",
90 };
91
92 let mut prompt = format!(
93 "You are an expert evaluator assessing {mode_label}.\n\n\
94 Score the agent's performance on a scale of 0.0 to 1.0 for EACH rubric criterion.\n\n"
95 );
96
97 prompt.push_str("RUBRIC CRITERIA:\n");
99 for (i, rubric) in self.rubrics.iter().enumerate() {
100 prompt.push_str(&format!("{}. {}\n", i + 1, rubric));
101 }
102 prompt.push('\n');
103
104 prompt.push_str("ACTUAL AGENT CONVERSATION:\n");
106 for turn in &actual.turns {
107 prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
108 if !turn.tool_calls.is_empty() {
109 prompt.push_str(&format!(
110 " Tool calls: {}\n",
111 serde_json::json!(turn.tool_calls)
112 ));
113 }
114 if !turn.tool_results.is_empty() {
115 prompt.push_str(&format!(
116 " Tool results: {}\n",
117 serde_json::json!(turn.tool_results)
118 ));
119 }
120 }
121
122 if let Some(expected) = expected {
124 prompt.push_str("\nEXPECTED CONVERSATION:\n");
125 for turn in &expected.turns {
126 prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
127 if !turn.tool_calls.is_empty() {
128 prompt.push_str(&format!(
129 " Tool calls: {}\n",
130 serde_json::json!(turn.tool_calls)
131 ));
132 }
133 }
134 }
135
136 prompt.push_str(
137 "\nRespond with ONLY a JSON object:\n\
138 {\"scores\": [<float per rubric criterion>], \
139 \"overall_score\": <float average>, \
140 \"explanation\": \"<text>\"}\n",
141 );
142
143 prompt
144 }
145
146 fn parse_response(text: &str, num_rubrics: usize) -> (f64, String) {
148 if let Some((score, explanation)) = try_parse_json(text) {
150 return (score, explanation);
151 }
152
153 if let Some(start) = text.find('{') {
155 if let Some(end) = text[start..].rfind('}') {
156 let json_str = &text[start..=start + end];
157 if let Some((score, explanation)) = try_parse_json(json_str) {
158 return (score, explanation);
159 }
160 }
161 }
162
163 let _ = num_rubrics; (
166 0.0,
167 format!("Failed to parse rubric judge response: {text}"),
168 )
169 }
170}
171
172fn try_parse_json(text: &str) -> Option<(f64, String)> {
174 let v: serde_json::Value = serde_json::from_str(text).ok()?;
175
176 let score = if let Some(overall) = v["overall_score"].as_f64() {
177 overall.clamp(0.0, 1.0)
178 } else if let Some(scores) = v["scores"].as_array() {
179 let sum: f64 = scores
180 .iter()
181 .filter_map(|s| s.as_f64())
182 .map(|s| s.clamp(0.0, 1.0))
183 .sum();
184 let count = scores.len().max(1) as f64;
185 sum / count
186 } else {
187 return None;
188 };
189
190 let explanation = v["explanation"]
191 .as_str()
192 .unwrap_or("No explanation")
193 .to_string();
194
195 Some((score, explanation))
196}
197
198#[async_trait]
199impl Evaluator for RubricEvaluator {
200 async fn evaluate(
201 &self,
202 actual: &[Invocation],
203 expected: Option<&[Invocation]>,
204 ) -> Result<EvalResult, EvalError> {
205 let llm = self.llm.as_ref().ok_or_else(|| {
206 EvalError::Llm(
207 "RubricEvaluator requires an LLM instance — call .with_llm() before evaluating"
208 .into(),
209 )
210 })?;
211
212 let mut per_invocation = Vec::new();
213 let mut total_score = 0.0;
214
215 for (i, actual_inv) in actual.iter().enumerate() {
216 let expected_inv = expected.and_then(|e| e.get(i));
217 let prompt = self.build_prompt(actual_inv, expected_inv);
218
219 let request = crate::llm::LlmRequest::from_text(&prompt);
220 let response = llm
221 .generate(request)
222 .await
223 .map_err(|e| EvalError::Llm(e.to_string()))?;
224
225 let (score, explanation) = Self::parse_response(&response.text(), self.rubrics.len());
226 total_score += score;
227
228 per_invocation.push(PerInvocationResult {
229 invocation_id: if actual_inv.id.is_empty() {
230 format!("inv-{i}")
231 } else {
232 actual_inv.id.clone()
233 },
234 score,
235 explanation: Some(explanation),
236 });
237 }
238
239 let overall_score = if actual.is_empty() {
240 0.0
241 } else {
242 total_score / actual.len() as f64
243 };
244
245 let metric_name = match self.mode {
246 RubricMode::FinalResponse => "rubric_based_final_response_quality_v1",
247 RubricMode::ToolUse => "rubric_based_tool_use_quality_v1",
248 };
249
250 Ok(EvalResult {
251 overall_score,
252 metrics: vec![EvalMetric {
253 name: metric_name.into(),
254 score: overall_score,
255 per_invocation,
256 }],
257 })
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn parse_valid_response() {
267 let json = r#"{"scores": [0.8, 0.9], "overall_score": 0.85, "explanation": "Good"}"#;
268 let (score, explanation) = RubricEvaluator::parse_response(json, 2);
269 assert!((score - 0.85).abs() < f64::EPSILON);
270 assert_eq!(explanation, "Good");
271 }
272
273 #[test]
274 fn parse_scores_only() {
275 let json = r#"{"scores": [0.8, 0.6]}"#;
276 let (score, _) = RubricEvaluator::parse_response(json, 2);
277 assert!((score - 0.7).abs() < f64::EPSILON);
278 }
279
280 #[test]
281 fn parse_embedded_json() {
282 let text = r#"Here is my evaluation: {"overall_score": 0.9, "explanation": "Great"}"#;
283 let (score, _) = RubricEvaluator::parse_response(text, 1);
284 assert!((score - 0.9).abs() < f64::EPSILON);
285 }
286
287 #[test]
288 fn parse_invalid() {
289 let (score, explanation) = RubricEvaluator::parse_response("no json here", 1);
290 assert!((score - 0.0).abs() < f64::EPSILON);
291 assert!(explanation.contains("Failed to parse"));
292 }
293
294 #[test]
295 fn for_response_mode() {
296 let eval = RubricEvaluator::for_response(vec!["Accuracy".into()]);
297 assert_eq!(eval.mode, RubricMode::FinalResponse);
298 }
299
300 #[test]
301 fn for_tool_use_mode() {
302 let eval = RubricEvaluator::for_tool_use(vec!["Tool selection".into()]);
303 assert_eq!(eval.mode, RubricMode::ToolUse);
304 }
305
306 #[test]
307 fn build_prompt_includes_rubrics() {
308 use crate::evaluation::eval_case::InvocationTurn;
309
310 let eval = RubricEvaluator::new(vec![
311 "Is the response accurate?".into(),
312 "Is it well-formatted?".into(),
313 ]);
314 let inv = Invocation {
315 id: "test".into(),
316 turns: vec![InvocationTurn {
317 role: "user".into(),
318 content: "Hello".into(),
319 tool_calls: vec![],
320 tool_results: vec![],
321 }],
322 metadata: serde_json::Value::Null,
323 };
324 let prompt = eval.build_prompt(&inv, None);
325 assert!(prompt.contains("Is the response accurate?"));
326 assert!(prompt.contains("Is it well-formatted?"));
327 assert!(prompt.contains("FINAL RESPONSE QUALITY"));
328 }
329
330 #[test]
331 fn with_judge_model() {
332 let eval = RubricEvaluator::new(vec!["test".into()]).with_judge_model("gemini-2.0-flash");
333 assert_eq!(eval.judge_model.as_deref(), Some("gemini-2.0-flash"));
334 }
335
336 #[test]
337 fn score_clamped() {
338 let json = r#"{"overall_score": 1.5, "explanation": "Over"}"#;
339 let (score, _) = RubricEvaluator::parse_response(json, 1);
340 assert!((score - 1.0).abs() < f64::EPSILON);
341 }
342}