gemini_adk_rs/evaluation/
hallucination_evaluator.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use super::eval_case::Invocation;
12use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
13use super::evaluator::{EvalError, Evaluator};
14use crate::llm::BaseLlm;
15
16pub struct HallucinationEvaluator {
21 judge_model: Option<String>,
23 evaluate_intermediate: bool,
25 llm: Option<Arc<dyn BaseLlm>>,
27}
28
29impl HallucinationEvaluator {
30 pub fn new() -> Self {
32 Self {
33 judge_model: None,
34 evaluate_intermediate: false,
35 llm: None,
36 }
37 }
38
39 pub fn with_intermediate(mut self, eval: bool) -> Self {
41 self.evaluate_intermediate = eval;
42 self
43 }
44
45 pub fn with_judge_model(mut self, model: impl Into<String>) -> Self {
47 self.judge_model = Some(model.into());
48 self
49 }
50
51 pub fn with_llm(mut self, llm: Arc<dyn BaseLlm>) -> Self {
53 self.llm = Some(llm);
54 self
55 }
56
57 fn extract_context(inv: &Invocation) -> String {
62 let mut context = String::new();
63
64 for turn in &inv.turns {
65 match turn.role.as_str() {
66 "user" => {
67 context.push_str(&format!("USER INPUT: {}\n", turn.content));
68 }
69 "model" if !turn.tool_results.is_empty() => {
70 for result in &turn.tool_results {
71 context.push_str(&format!("TOOL RESULT: {}\n", result));
72 }
73 }
74 _ => {}
75 }
76 }
77
78 context
79 }
80
81 fn extract_responses(inv: &Invocation, include_intermediate: bool) -> Vec<String> {
83 let model_turns: Vec<&str> = inv
84 .turns
85 .iter()
86 .filter(|t| t.role == "model" && !t.content.is_empty())
87 .map(|t| t.content.as_str())
88 .collect();
89
90 if include_intermediate {
91 model_turns.into_iter().map(String::from).collect()
92 } else {
93 model_turns
95 .last()
96 .map(|s| vec![s.to_string()])
97 .unwrap_or_default()
98 }
99 }
100
101 fn build_prompt(context: &str, response: &str) -> String {
103 format!(
104 "You are an expert evaluator assessing GROUNDEDNESS (absence of hallucination).\n\n\
105 Your task: determine whether the agent's response is fully supported by the \
106 provided context. A response is grounded if every factual claim it makes can \
107 be traced back to information in the context.\n\n\
108 GROUNDING CONTEXT (source of truth):\n\
109 {context}\n\n\
110 AGENT RESPONSE TO EVALUATE:\n\
111 {response}\n\n\
112 Scoring guide:\n\
113 - 1.0: Every claim is directly supported by the context\n\
114 - 0.75: Most claims are supported, minor unsupported details\n\
115 - 0.5: Mix of supported and unsupported claims\n\
116 - 0.25: Mostly unsupported claims with some grounded elements\n\
117 - 0.0: Entirely fabricated or contradicts the context\n\n\
118 Respond with ONLY a JSON object:\n\
119 {{\"score\": <float>, \"hallucinated_claims\": [\"<claim1>\", ...], \"explanation\": \"<text>\"}}"
120 )
121 }
122
123 fn parse_response(text: &str) -> (f64, String) {
125 if let Ok(v) = serde_json::from_str::<serde_json::Value>(text) {
127 return extract_score_and_explanation(&v);
128 }
129
130 if let Some(start) = text.find('{') {
132 if let Some(end) = text[start..].rfind('}') {
133 let json_str = &text[start..=start + end];
134 if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
135 return extract_score_and_explanation(&v);
136 }
137 }
138 }
139
140 (
141 0.0,
142 format!("Failed to parse hallucination judge response: {text}"),
143 )
144 }
145}
146
147impl Default for HallucinationEvaluator {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153fn extract_score_and_explanation(v: &serde_json::Value) -> (f64, String) {
155 let score = v["score"].as_f64().unwrap_or(0.0).clamp(0.0, 1.0);
156
157 let mut explanation = v["explanation"]
158 .as_str()
159 .unwrap_or("No explanation")
160 .to_string();
161
162 if let Some(claims) = v["hallucinated_claims"].as_array() {
164 let claim_strs: Vec<&str> = claims.iter().filter_map(|c| c.as_str()).collect();
165 if !claim_strs.is_empty() {
166 explanation.push_str(&format!(
167 " | Hallucinated claims: {}",
168 claim_strs.join("; ")
169 ));
170 }
171 }
172
173 (score, explanation)
174}
175
176#[async_trait]
177impl Evaluator for HallucinationEvaluator {
178 async fn evaluate(
179 &self,
180 actual: &[Invocation],
181 _expected: Option<&[Invocation]>,
182 ) -> Result<EvalResult, EvalError> {
183 let llm = self
184 .llm
185 .as_ref()
186 .ok_or_else(|| EvalError::Llm("HallucinationEvaluator requires an LLM instance — call .with_llm() before evaluating".into()))?;
187
188 let mut per_invocation = Vec::new();
189 let mut total_score = 0.0;
190
191 for (i, actual_inv) in actual.iter().enumerate() {
192 let context = Self::extract_context(actual_inv);
193 let responses = Self::extract_responses(actual_inv, self.evaluate_intermediate);
194
195 if responses.is_empty() {
196 per_invocation.push(PerInvocationResult {
198 invocation_id: inv_id(actual_inv, i),
199 score: 1.0,
200 explanation: Some("No model responses to evaluate".into()),
201 });
202 total_score += 1.0;
203 continue;
204 }
205
206 let mut resp_total = 0.0;
208 let mut explanations = Vec::new();
209
210 for response in &responses {
211 let prompt = Self::build_prompt(&context, response);
212 let request = crate::llm::LlmRequest::from_text(&prompt);
213 let llm_response = llm
214 .generate(request)
215 .await
216 .map_err(|e| EvalError::Llm(e.to_string()))?;
217
218 let (score, explanation) = Self::parse_response(&llm_response.text());
219 resp_total += score;
220 explanations.push(explanation);
221 }
222
223 let avg_score = resp_total / responses.len() as f64;
224 total_score += avg_score;
225
226 per_invocation.push(PerInvocationResult {
227 invocation_id: inv_id(actual_inv, i),
228 score: avg_score,
229 explanation: Some(explanations.join(" | ")),
230 });
231 }
232
233 let overall_score = if actual.is_empty() {
234 0.0
235 } else {
236 total_score / actual.len() as f64
237 };
238
239 Ok(EvalResult {
240 overall_score,
241 metrics: vec![EvalMetric {
242 name: "groundedness".into(),
243 score: overall_score,
244 per_invocation,
245 }],
246 })
247 }
248}
249
250fn inv_id(inv: &Invocation, index: usize) -> String {
252 if inv.id.is_empty() {
253 format!("inv-{index}")
254 } else {
255 inv.id.clone()
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::evaluation::eval_case::InvocationTurn;
263
264 #[test]
265 fn extract_context_includes_user_and_tools() {
266 let inv = Invocation {
267 id: "test".into(),
268 turns: vec![
269 InvocationTurn {
270 role: "user".into(),
271 content: "What is the weather?".into(),
272 tool_calls: vec![],
273 tool_results: vec![],
274 },
275 InvocationTurn {
276 role: "model".into(),
277 content: String::new(),
278 tool_calls: vec![serde_json::json!({"name": "get_weather"})],
279 tool_results: vec![serde_json::json!({"temp": 22})],
280 },
281 InvocationTurn {
282 role: "model".into(),
283 content: "It's 22 degrees.".into(),
284 tool_calls: vec![],
285 tool_results: vec![],
286 },
287 ],
288 metadata: serde_json::Value::Null,
289 };
290
291 let context = HallucinationEvaluator::extract_context(&inv);
292 assert!(context.contains("What is the weather?"));
293 assert!(context.contains("22"));
294 }
295
296 #[test]
297 fn extract_responses_final_only() {
298 let inv = Invocation {
299 id: "test".into(),
300 turns: vec![
301 InvocationTurn {
302 role: "model".into(),
303 content: "first".into(),
304 tool_calls: vec![],
305 tool_results: vec![],
306 },
307 InvocationTurn {
308 role: "model".into(),
309 content: "second".into(),
310 tool_calls: vec![],
311 tool_results: vec![],
312 },
313 ],
314 metadata: serde_json::Value::Null,
315 };
316
317 let responses = HallucinationEvaluator::extract_responses(&inv, false);
318 assert_eq!(responses.len(), 1);
319 assert_eq!(responses[0], "second");
320 }
321
322 #[test]
323 fn extract_responses_all() {
324 let inv = Invocation {
325 id: "test".into(),
326 turns: vec![
327 InvocationTurn {
328 role: "model".into(),
329 content: "first".into(),
330 tool_calls: vec![],
331 tool_results: vec![],
332 },
333 InvocationTurn {
334 role: "model".into(),
335 content: "second".into(),
336 tool_calls: vec![],
337 tool_results: vec![],
338 },
339 ],
340 metadata: serde_json::Value::Null,
341 };
342
343 let responses = HallucinationEvaluator::extract_responses(&inv, true);
344 assert_eq!(responses.len(), 2);
345 }
346
347 #[test]
348 fn parse_valid_response() {
349 let json = r#"{"score": 0.9, "hallucinated_claims": [], "explanation": "Well grounded"}"#;
350 let (score, explanation) = HallucinationEvaluator::parse_response(json);
351 assert!((score - 0.9).abs() < f64::EPSILON);
352 assert!(explanation.contains("Well grounded"));
353 }
354
355 #[test]
356 fn parse_response_with_claims() {
357 let json = r#"{"score": 0.5, "hallucinated_claims": ["temp was 25 not 22"], "explanation": "Partial"}"#;
358 let (score, explanation) = HallucinationEvaluator::parse_response(json);
359 assert!((score - 0.5).abs() < f64::EPSILON);
360 assert!(explanation.contains("temp was 25 not 22"));
361 }
362
363 #[test]
364 fn parse_invalid() {
365 let (score, explanation) = HallucinationEvaluator::parse_response("garbage");
366 assert!((score - 0.0).abs() < f64::EPSILON);
367 assert!(explanation.contains("Failed to parse"));
368 }
369
370 #[test]
371 fn default_impl() {
372 let eval = HallucinationEvaluator::default();
373 assert!(!eval.evaluate_intermediate);
374 assert!(eval.judge_model.is_none());
375 }
376
377 #[test]
378 fn builder_methods() {
379 let eval = HallucinationEvaluator::new()
380 .with_intermediate(true)
381 .with_judge_model("gemini-2.0-flash");
382 assert!(eval.evaluate_intermediate);
383 assert_eq!(eval.judge_model.as_deref(), Some("gemini-2.0-flash"));
384 }
385
386 #[test]
387 fn build_prompt_structure() {
388 let prompt = HallucinationEvaluator::build_prompt(
389 "USER INPUT: What is 2+2?\nTOOL RESULT: {\"answer\": 4}",
390 "The answer is 4.",
391 );
392 assert!(prompt.contains("GROUNDEDNESS"));
393 assert!(prompt.contains("GROUNDING CONTEXT"));
394 assert!(prompt.contains("What is 2+2?"));
395 assert!(prompt.contains("The answer is 4."));
396 }
397}