gemini_adk_rs/evaluation/
hallucination_evaluator.rs

1//! Hallucination evaluator — check groundedness of agent responses.
2//!
3//! Evaluates whether agent responses are grounded in the provided context
4//! (tool results, user input, conversation history) or contain fabricated
5//! information.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use super::eval_case::Invocation;
12use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
13use super::evaluator::{EvalError, Evaluator};
14use crate::llm::BaseLlm;
15
16/// Evaluates whether agent responses are grounded (not hallucinated).
17///
18/// Uses an LLM-as-judge to assess whether the model's claims are supported
19/// by the conversation context, tool outputs, and provided information.
20pub struct HallucinationEvaluator {
21    /// Optional override for the judge model.
22    judge_model: Option<String>,
23    /// Whether to also evaluate intermediate responses (not just the final one).
24    evaluate_intermediate: bool,
25    /// Optional LLM for performing evaluations.
26    llm: Option<Arc<dyn BaseLlm>>,
27}
28
29impl HallucinationEvaluator {
30    /// Create a new hallucination evaluator.
31    pub fn new() -> Self {
32        Self {
33            judge_model: None,
34            evaluate_intermediate: false,
35            llm: None,
36        }
37    }
38
39    /// Set whether to evaluate intermediate responses in addition to the final response.
40    pub fn with_intermediate(mut self, eval: bool) -> Self {
41        self.evaluate_intermediate = eval;
42        self
43    }
44
45    /// Set an override judge model name.
46    pub fn with_judge_model(mut self, model: impl Into<String>) -> Self {
47        self.judge_model = Some(model.into());
48        self
49    }
50
51    /// Provide an LLM instance for performing evaluations.
52    pub fn with_llm(mut self, llm: Arc<dyn BaseLlm>) -> Self {
53        self.llm = Some(llm);
54        self
55    }
56
57    /// Extract grounding context from an invocation.
58    ///
59    /// Collects user inputs and tool results as the "source of truth" that
60    /// model responses should be grounded in.
61    fn extract_context(inv: &Invocation) -> String {
62        let mut context = String::new();
63
64        for turn in &inv.turns {
65            match turn.role.as_str() {
66                "user" => {
67                    context.push_str(&format!("USER INPUT: {}\n", turn.content));
68                }
69                "model" if !turn.tool_results.is_empty() => {
70                    for result in &turn.tool_results {
71                        context.push_str(&format!("TOOL RESULT: {}\n", result));
72                    }
73                }
74                _ => {}
75            }
76        }
77
78        context
79    }
80
81    /// Extract model responses to evaluate for groundedness.
82    fn extract_responses(inv: &Invocation, include_intermediate: bool) -> Vec<String> {
83        let model_turns: Vec<&str> = inv
84            .turns
85            .iter()
86            .filter(|t| t.role == "model" && !t.content.is_empty())
87            .map(|t| t.content.as_str())
88            .collect();
89
90        if include_intermediate {
91            model_turns.into_iter().map(String::from).collect()
92        } else {
93            // Only the last model response
94            model_turns
95                .last()
96                .map(|s| vec![s.to_string()])
97                .unwrap_or_default()
98        }
99    }
100
101    /// Build the groundedness evaluation prompt.
102    fn build_prompt(context: &str, response: &str) -> String {
103        format!(
104            "You are an expert evaluator assessing GROUNDEDNESS (absence of hallucination).\n\n\
105             Your task: determine whether the agent's response is fully supported by the \
106             provided context. A response is grounded if every factual claim it makes can \
107             be traced back to information in the context.\n\n\
108             GROUNDING CONTEXT (source of truth):\n\
109             {context}\n\n\
110             AGENT RESPONSE TO EVALUATE:\n\
111             {response}\n\n\
112             Scoring guide:\n\
113             - 1.0: Every claim is directly supported by the context\n\
114             - 0.75: Most claims are supported, minor unsupported details\n\
115             - 0.5: Mix of supported and unsupported claims\n\
116             - 0.25: Mostly unsupported claims with some grounded elements\n\
117             - 0.0: Entirely fabricated or contradicts the context\n\n\
118             Respond with ONLY a JSON object:\n\
119             {{\"score\": <float>, \"hallucinated_claims\": [\"<claim1>\", ...], \"explanation\": \"<text>\"}}"
120        )
121    }
122
123    /// Parse the judge response for a groundedness score.
124    fn parse_response(text: &str) -> (f64, String) {
125        // Try direct JSON parse
126        if let Ok(v) = serde_json::from_str::<serde_json::Value>(text) {
127            return extract_score_and_explanation(&v);
128        }
129
130        // Try finding embedded JSON
131        if let Some(start) = text.find('{') {
132            if let Some(end) = text[start..].rfind('}') {
133                let json_str = &text[start..=start + end];
134                if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
135                    return extract_score_and_explanation(&v);
136                }
137            }
138        }
139
140        (
141            0.0,
142            format!("Failed to parse hallucination judge response: {text}"),
143        )
144    }
145}
146
147impl Default for HallucinationEvaluator {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153/// Extract score and explanation from a parsed JSON value.
154fn extract_score_and_explanation(v: &serde_json::Value) -> (f64, String) {
155    let score = v["score"].as_f64().unwrap_or(0.0).clamp(0.0, 1.0);
156
157    let mut explanation = v["explanation"]
158        .as_str()
159        .unwrap_or("No explanation")
160        .to_string();
161
162    // Append hallucinated claims if present
163    if let Some(claims) = v["hallucinated_claims"].as_array() {
164        let claim_strs: Vec<&str> = claims.iter().filter_map(|c| c.as_str()).collect();
165        if !claim_strs.is_empty() {
166            explanation.push_str(&format!(
167                " | Hallucinated claims: {}",
168                claim_strs.join("; ")
169            ));
170        }
171    }
172
173    (score, explanation)
174}
175
176#[async_trait]
177impl Evaluator for HallucinationEvaluator {
178    async fn evaluate(
179        &self,
180        actual: &[Invocation],
181        _expected: Option<&[Invocation]>,
182    ) -> Result<EvalResult, EvalError> {
183        let llm = self
184            .llm
185            .as_ref()
186            .ok_or_else(|| EvalError::Llm("HallucinationEvaluator requires an LLM instance — call .with_llm() before evaluating".into()))?;
187
188        let mut per_invocation = Vec::new();
189        let mut total_score = 0.0;
190
191        for (i, actual_inv) in actual.iter().enumerate() {
192            let context = Self::extract_context(actual_inv);
193            let responses = Self::extract_responses(actual_inv, self.evaluate_intermediate);
194
195            if responses.is_empty() {
196                // No model responses to evaluate — trivially grounded
197                per_invocation.push(PerInvocationResult {
198                    invocation_id: inv_id(actual_inv, i),
199                    score: 1.0,
200                    explanation: Some("No model responses to evaluate".into()),
201                });
202                total_score += 1.0;
203                continue;
204            }
205
206            // Evaluate each response and average
207            let mut resp_total = 0.0;
208            let mut explanations = Vec::new();
209
210            for response in &responses {
211                let prompt = Self::build_prompt(&context, response);
212                let request = crate::llm::LlmRequest::from_text(&prompt);
213                let llm_response = llm
214                    .generate(request)
215                    .await
216                    .map_err(|e| EvalError::Llm(e.to_string()))?;
217
218                let (score, explanation) = Self::parse_response(&llm_response.text());
219                resp_total += score;
220                explanations.push(explanation);
221            }
222
223            let avg_score = resp_total / responses.len() as f64;
224            total_score += avg_score;
225
226            per_invocation.push(PerInvocationResult {
227                invocation_id: inv_id(actual_inv, i),
228                score: avg_score,
229                explanation: Some(explanations.join(" | ")),
230            });
231        }
232
233        let overall_score = if actual.is_empty() {
234            0.0
235        } else {
236            total_score / actual.len() as f64
237        };
238
239        Ok(EvalResult {
240            overall_score,
241            metrics: vec![EvalMetric {
242                name: "groundedness".into(),
243                score: overall_score,
244                per_invocation,
245            }],
246        })
247    }
248}
249
250/// Helper to get a meaningful invocation ID.
251fn inv_id(inv: &Invocation, index: usize) -> String {
252    if inv.id.is_empty() {
253        format!("inv-{index}")
254    } else {
255        inv.id.clone()
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::evaluation::eval_case::InvocationTurn;
263
264    #[test]
265    fn extract_context_includes_user_and_tools() {
266        let inv = Invocation {
267            id: "test".into(),
268            turns: vec![
269                InvocationTurn {
270                    role: "user".into(),
271                    content: "What is the weather?".into(),
272                    tool_calls: vec![],
273                    tool_results: vec![],
274                },
275                InvocationTurn {
276                    role: "model".into(),
277                    content: String::new(),
278                    tool_calls: vec![serde_json::json!({"name": "get_weather"})],
279                    tool_results: vec![serde_json::json!({"temp": 22})],
280                },
281                InvocationTurn {
282                    role: "model".into(),
283                    content: "It's 22 degrees.".into(),
284                    tool_calls: vec![],
285                    tool_results: vec![],
286                },
287            ],
288            metadata: serde_json::Value::Null,
289        };
290
291        let context = HallucinationEvaluator::extract_context(&inv);
292        assert!(context.contains("What is the weather?"));
293        assert!(context.contains("22"));
294    }
295
296    #[test]
297    fn extract_responses_final_only() {
298        let inv = Invocation {
299            id: "test".into(),
300            turns: vec![
301                InvocationTurn {
302                    role: "model".into(),
303                    content: "first".into(),
304                    tool_calls: vec![],
305                    tool_results: vec![],
306                },
307                InvocationTurn {
308                    role: "model".into(),
309                    content: "second".into(),
310                    tool_calls: vec![],
311                    tool_results: vec![],
312                },
313            ],
314            metadata: serde_json::Value::Null,
315        };
316
317        let responses = HallucinationEvaluator::extract_responses(&inv, false);
318        assert_eq!(responses.len(), 1);
319        assert_eq!(responses[0], "second");
320    }
321
322    #[test]
323    fn extract_responses_all() {
324        let inv = Invocation {
325            id: "test".into(),
326            turns: vec![
327                InvocationTurn {
328                    role: "model".into(),
329                    content: "first".into(),
330                    tool_calls: vec![],
331                    tool_results: vec![],
332                },
333                InvocationTurn {
334                    role: "model".into(),
335                    content: "second".into(),
336                    tool_calls: vec![],
337                    tool_results: vec![],
338                },
339            ],
340            metadata: serde_json::Value::Null,
341        };
342
343        let responses = HallucinationEvaluator::extract_responses(&inv, true);
344        assert_eq!(responses.len(), 2);
345    }
346
347    #[test]
348    fn parse_valid_response() {
349        let json = r#"{"score": 0.9, "hallucinated_claims": [], "explanation": "Well grounded"}"#;
350        let (score, explanation) = HallucinationEvaluator::parse_response(json);
351        assert!((score - 0.9).abs() < f64::EPSILON);
352        assert!(explanation.contains("Well grounded"));
353    }
354
355    #[test]
356    fn parse_response_with_claims() {
357        let json = r#"{"score": 0.5, "hallucinated_claims": ["temp was 25 not 22"], "explanation": "Partial"}"#;
358        let (score, explanation) = HallucinationEvaluator::parse_response(json);
359        assert!((score - 0.5).abs() < f64::EPSILON);
360        assert!(explanation.contains("temp was 25 not 22"));
361    }
362
363    #[test]
364    fn parse_invalid() {
365        let (score, explanation) = HallucinationEvaluator::parse_response("garbage");
366        assert!((score - 0.0).abs() < f64::EPSILON);
367        assert!(explanation.contains("Failed to parse"));
368    }
369
370    #[test]
371    fn default_impl() {
372        let eval = HallucinationEvaluator::default();
373        assert!(!eval.evaluate_intermediate);
374        assert!(eval.judge_model.is_none());
375    }
376
377    #[test]
378    fn builder_methods() {
379        let eval = HallucinationEvaluator::new()
380            .with_intermediate(true)
381            .with_judge_model("gemini-2.0-flash");
382        assert!(eval.evaluate_intermediate);
383        assert_eq!(eval.judge_model.as_deref(), Some("gemini-2.0-flash"));
384    }
385
386    #[test]
387    fn build_prompt_structure() {
388        let prompt = HallucinationEvaluator::build_prompt(
389            "USER INPUT: What is 2+2?\nTOOL RESULT: {\"answer\": 4}",
390            "The answer is 4.",
391        );
392        assert!(prompt.contains("GROUNDEDNESS"));
393        assert!(prompt.contains("GROUNDING CONTEXT"));
394        assert!(prompt.contains("What is 2+2?"));
395        assert!(prompt.contains("The answer is 4."));
396    }
397}