gemini_adk_rs/evaluation/
llm_as_judge.rs1use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use super::eval_case::Invocation;
10use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
11use super::evaluator::{EvalError, Evaluator};
12use crate::llm::BaseLlm;
13
14#[derive(Debug, Clone)]
16pub struct LlmAsJudgeConfig {
17 pub rubric: String,
19 pub metric_name: String,
21}
22
23impl Default for LlmAsJudgeConfig {
24 fn default() -> Self {
25 Self {
26 rubric: "Evaluate the quality and correctness of the agent's response.".into(),
27 metric_name: "llm_judge_score".into(),
28 }
29 }
30}
31
32pub struct LlmAsJudge {
37 llm: Arc<dyn BaseLlm>,
38 config: LlmAsJudgeConfig,
39}
40
41impl LlmAsJudge {
42 pub fn new(llm: Arc<dyn BaseLlm>, config: LlmAsJudgeConfig) -> Self {
44 Self { llm, config }
45 }
46
47 fn build_prompt(&self, actual: &Invocation, expected: Option<&Invocation>) -> String {
49 let mut prompt = format!(
50 "You are an expert evaluator. Score the agent's response on a scale of 0.0 to 1.0.\n\n\
51 Rubric: {}\n\n\
52 Actual conversation:\n",
53 self.config.rubric
54 );
55
56 for turn in &actual.turns {
57 prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
58 }
59
60 if let Some(expected) = expected {
61 prompt.push_str("\nExpected conversation:\n");
62 for turn in &expected.turns {
63 prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
64 }
65 }
66
67 prompt.push_str(
68 "\nRespond with ONLY a JSON object: {\"score\": <float>, \"explanation\": \"<text>\"}",
69 );
70
71 prompt
72 }
73}
74
75#[async_trait]
76impl Evaluator for LlmAsJudge {
77 async fn evaluate(
78 &self,
79 actual: &[Invocation],
80 expected: Option<&[Invocation]>,
81 ) -> Result<EvalResult, EvalError> {
82 let mut per_invocation = Vec::new();
83 let mut total_score = 0.0;
84
85 for (i, actual_inv) in actual.iter().enumerate() {
86 let expected_inv = expected.and_then(|e| e.get(i));
87 let prompt = self.build_prompt(actual_inv, expected_inv);
88
89 let request = crate::llm::LlmRequest::from_text(&prompt);
90 let response = self
91 .llm
92 .generate(request)
93 .await
94 .map_err(|e| EvalError::Llm(e.to_string()))?;
95
96 let (score, explanation) = parse_judge_response(&response.text());
98 total_score += score;
99
100 per_invocation.push(PerInvocationResult {
101 invocation_id: if actual_inv.id.is_empty() {
102 format!("inv-{}", i)
103 } else {
104 actual_inv.id.clone()
105 },
106 score,
107 explanation: Some(explanation),
108 });
109 }
110
111 let overall_score = if actual.is_empty() {
112 0.0
113 } else {
114 total_score / actual.len() as f64
115 };
116
117 Ok(EvalResult {
118 overall_score,
119 metrics: vec![EvalMetric {
120 name: self.config.metric_name.clone(),
121 score: overall_score,
122 per_invocation,
123 }],
124 })
125 }
126}
127
128fn parse_judge_response(text: &str) -> (f64, String) {
130 if let Ok(v) = serde_json::from_str::<serde_json::Value>(text) {
132 let score = v["score"].as_f64().unwrap_or(0.0).clamp(0.0, 1.0);
133 let explanation = v["explanation"]
134 .as_str()
135 .unwrap_or("No explanation")
136 .to_string();
137 return (score, explanation);
138 }
139
140 if let Some(start) = text.find('{') {
142 if let Some(end) = text[start..].rfind('}') {
143 let json_str = &text[start..=start + end];
144 if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
145 let score = v["score"].as_f64().unwrap_or(0.0).clamp(0.0, 1.0);
146 let explanation = v["explanation"]
147 .as_str()
148 .unwrap_or("No explanation")
149 .to_string();
150 return (score, explanation);
151 }
152 }
153 }
154
155 (0.0, format!("Failed to parse judge response: {text}"))
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[test]
163 fn parse_valid_json_response() {
164 let (score, explanation) =
165 parse_judge_response(r#"{"score": 0.85, "explanation": "Good response"}"#);
166 assert!((score - 0.85).abs() < f64::EPSILON);
167 assert_eq!(explanation, "Good response");
168 }
169
170 #[test]
171 fn parse_json_in_text() {
172 let (score, _) = parse_judge_response(
173 r#"Here is my evaluation: {"score": 0.7, "explanation": "Decent"}"#,
174 );
175 assert!((score - 0.7).abs() < f64::EPSILON);
176 }
177
178 #[test]
179 fn parse_invalid_response() {
180 let (score, explanation) = parse_judge_response("This is just text");
181 assert!((score - 0.0).abs() < f64::EPSILON);
182 assert!(explanation.contains("Failed to parse"));
183 }
184
185 #[test]
186 fn score_clamped_to_valid_range() {
187 let (score, _) = parse_judge_response(r#"{"score": 1.5, "explanation": "Over"}"#);
188 assert!((score - 1.0).abs() < f64::EPSILON);
189 }
190
191 #[test]
192 fn default_config() {
193 let config = LlmAsJudgeConfig::default();
194 assert_eq!(config.metric_name, "llm_judge_score");
195 assert!(!config.rubric.is_empty());
196 }
197
198 #[test]
199 fn build_prompt_includes_rubric() {
200 use crate::evaluation::eval_case::InvocationTurn;
201
202 struct DummyLlm;
203 #[async_trait]
204 impl BaseLlm for DummyLlm {
205 fn model_id(&self) -> &str {
206 "dummy"
207 }
208 async fn generate(
209 &self,
210 _req: crate::llm::LlmRequest,
211 ) -> Result<crate::llm::LlmResponse, crate::llm::LlmError> {
212 unreachable!()
213 }
214 }
215
216 let judge = LlmAsJudge::new(Arc::new(DummyLlm), LlmAsJudgeConfig::default());
217 let inv = Invocation {
218 id: "test".into(),
219 turns: vec![InvocationTurn {
220 role: "user".into(),
221 content: "Hello".into(),
222 tool_calls: vec![],
223 tool_results: vec![],
224 }],
225 metadata: serde_json::Value::Null,
226 };
227 let prompt = judge.build_prompt(&inv, None);
228 assert!(prompt.contains("Rubric:"));
229 assert!(prompt.contains("[user]: Hello"));
230 }
231}