gemini_adk_rs/evaluation/
response_evaluator.rs

1//! Response evaluator — evaluates final response quality.
2//!
3//! Compares the agent's final response text against expected output
4//! using configurable matching strategies.
5
6use async_trait::async_trait;
7
8use super::eval_case::Invocation;
9use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
10use super::evaluator::{EvalError, Evaluator};
11
12/// Strategy for comparing actual vs. expected responses.
13#[derive(Debug, Clone, Copy, Default)]
14pub enum MatchStrategy {
15    /// Exact string match.
16    Exact,
17    /// Case-insensitive containment.
18    #[default]
19    Contains,
20    /// Fuzzy match using Levenshtein-like distance.
21    Fuzzy {
22        /// Minimum similarity threshold (0.0–1.0).
23        threshold: f64,
24    },
25}
26
27/// Evaluates the agent's final response against expected output.
28#[derive(Debug, Clone)]
29pub struct ResponseEvaluator {
30    strategy: MatchStrategy,
31    metric_name: String,
32}
33
34impl ResponseEvaluator {
35    /// Create a new response evaluator with the given matching strategy.
36    pub fn new(strategy: MatchStrategy) -> Self {
37        Self {
38            strategy,
39            metric_name: "response_match".into(),
40        }
41    }
42
43    /// Set a custom metric name.
44    pub fn with_metric_name(mut self, name: impl Into<String>) -> Self {
45        self.metric_name = name.into();
46        self
47    }
48
49    /// Get the final model response from an invocation.
50    fn last_model_response(inv: &Invocation) -> Option<&str> {
51        inv.turns
52            .iter()
53            .rev()
54            .find(|t| t.role == "model")
55            .map(|t| t.content.as_str())
56    }
57
58    /// Score a single pair of actual/expected responses.
59    fn score_pair(&self, actual: &str, expected: &str) -> (f64, String) {
60        match self.strategy {
61            MatchStrategy::Exact => {
62                if actual == expected {
63                    (1.0, "Exact match".into())
64                } else {
65                    (0.0, "No exact match".into())
66                }
67            }
68            MatchStrategy::Contains => {
69                let actual_lower = actual.to_lowercase();
70                let expected_lower = expected.to_lowercase();
71                if actual_lower.contains(&expected_lower) || expected_lower.contains(&actual_lower)
72                {
73                    (1.0, "Contains match".into())
74                } else {
75                    (0.0, "No containment match".into())
76                }
77            }
78            MatchStrategy::Fuzzy { threshold } => {
79                let similarity = string_similarity(actual, expected);
80                if similarity >= threshold {
81                    (similarity, format!("Fuzzy match: {similarity:.2}"))
82                } else {
83                    (
84                        similarity,
85                        format!("Below threshold {threshold:.2}: {similarity:.2}"),
86                    )
87                }
88            }
89        }
90    }
91}
92
93impl Default for ResponseEvaluator {
94    fn default() -> Self {
95        Self::new(MatchStrategy::default())
96    }
97}
98
99#[async_trait]
100impl Evaluator for ResponseEvaluator {
101    async fn evaluate(
102        &self,
103        actual: &[Invocation],
104        expected: Option<&[Invocation]>,
105    ) -> Result<EvalResult, EvalError> {
106        let expected = expected.ok_or_else(|| {
107            EvalError::InvalidInput("ResponseEvaluator requires expected invocations".into())
108        })?;
109
110        let mut per_invocation = Vec::new();
111        let mut total_score = 0.0;
112
113        for (i, actual_inv) in actual.iter().enumerate() {
114            let actual_resp = Self::last_model_response(actual_inv).unwrap_or("");
115            let expected_resp = expected
116                .get(i)
117                .and_then(|e| Self::last_model_response(e))
118                .unwrap_or("");
119
120            let (score, explanation) = self.score_pair(actual_resp, expected_resp);
121            total_score += score;
122
123            per_invocation.push(PerInvocationResult {
124                invocation_id: if actual_inv.id.is_empty() {
125                    format!("inv-{i}")
126                } else {
127                    actual_inv.id.clone()
128                },
129                score,
130                explanation: Some(explanation),
131            });
132        }
133
134        let overall_score = if actual.is_empty() {
135            0.0
136        } else {
137            total_score / actual.len() as f64
138        };
139
140        Ok(EvalResult {
141            overall_score,
142            metrics: vec![EvalMetric {
143                name: self.metric_name.clone(),
144                score: overall_score,
145                per_invocation,
146            }],
147        })
148    }
149}
150
151/// Simple character-based similarity (normalized Levenshtein-like).
152fn string_similarity(a: &str, b: &str) -> f64 {
153    if a.is_empty() && b.is_empty() {
154        return 1.0;
155    }
156    let max_len = a.len().max(b.len()) as f64;
157    if max_len == 0.0 {
158        return 1.0;
159    }
160
161    let distance = levenshtein_distance(a, b) as f64;
162    1.0 - (distance / max_len)
163}
164
165/// Compute Levenshtein edit distance.
166fn levenshtein_distance(a: &str, b: &str) -> usize {
167    let a_chars: Vec<char> = a.chars().collect();
168    let b_chars: Vec<char> = b.chars().collect();
169    let m = a_chars.len();
170    let n = b_chars.len();
171
172    let mut dp = vec![vec![0usize; n + 1]; m + 1];
173
174    for (i, row) in dp.iter_mut().enumerate().take(m + 1) {
175        row[0] = i;
176    }
177    for (j, val) in dp[0].iter_mut().enumerate().take(n + 1) {
178        *val = j;
179    }
180
181    for i in 1..=m {
182        for j in 1..=n {
183            let cost = if a_chars[i - 1] == b_chars[j - 1] {
184                0
185            } else {
186                1
187            };
188            dp[i][j] = (dp[i - 1][j] + 1)
189                .min(dp[i][j - 1] + 1)
190                .min(dp[i - 1][j - 1] + cost);
191        }
192    }
193
194    dp[m][n]
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::evaluation::eval_case::InvocationTurn;
201
202    fn make_invocation(model_response: &str) -> Invocation {
203        Invocation {
204            id: String::new(),
205            turns: vec![
206                InvocationTurn {
207                    role: "user".into(),
208                    content: "What is 2+2?".into(),
209                    tool_calls: vec![],
210                    tool_results: vec![],
211                },
212                InvocationTurn {
213                    role: "model".into(),
214                    content: model_response.into(),
215                    tool_calls: vec![],
216                    tool_results: vec![],
217                },
218            ],
219            metadata: serde_json::Value::Null,
220        }
221    }
222
223    #[tokio::test]
224    async fn exact_match() {
225        let evaluator = ResponseEvaluator::new(MatchStrategy::Exact);
226        let actual = vec![make_invocation("4")];
227        let expected = vec![make_invocation("4")];
228        let result = evaluator.evaluate(&actual, Some(&expected)).await.unwrap();
229        assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
230    }
231
232    #[tokio::test]
233    async fn exact_mismatch() {
234        let evaluator = ResponseEvaluator::new(MatchStrategy::Exact);
235        let actual = vec![make_invocation("four")];
236        let expected = vec![make_invocation("4")];
237        let result = evaluator.evaluate(&actual, Some(&expected)).await.unwrap();
238        assert!((result.overall_score - 0.0).abs() < f64::EPSILON);
239    }
240
241    #[tokio::test]
242    async fn contains_match() {
243        let evaluator = ResponseEvaluator::new(MatchStrategy::Contains);
244        let actual = vec![make_invocation("The answer is 4")];
245        let expected = vec![make_invocation("4")];
246        let result = evaluator.evaluate(&actual, Some(&expected)).await.unwrap();
247        assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
248    }
249
250    #[tokio::test]
251    async fn fuzzy_match() {
252        let evaluator = ResponseEvaluator::new(MatchStrategy::Fuzzy { threshold: 0.5 });
253        let actual = vec![make_invocation("hello world")];
254        let expected = vec![make_invocation("hello worl")];
255        let result = evaluator.evaluate(&actual, Some(&expected)).await.unwrap();
256        assert!(result.overall_score > 0.5);
257    }
258
259    #[tokio::test]
260    async fn requires_expected() {
261        let evaluator = ResponseEvaluator::default();
262        let actual = vec![make_invocation("test")];
263        let result = evaluator.evaluate(&actual, None).await;
264        assert!(result.is_err());
265    }
266
267    #[test]
268    fn levenshtein_identical() {
269        assert_eq!(levenshtein_distance("abc", "abc"), 0);
270    }
271
272    #[test]
273    fn levenshtein_one_edit() {
274        assert_eq!(levenshtein_distance("abc", "ab"), 1);
275    }
276
277    #[test]
278    fn similarity_identical() {
279        assert!((string_similarity("hello", "hello") - 1.0).abs() < f64::EPSILON);
280    }
281
282    #[test]
283    fn similarity_empty() {
284        assert!((string_similarity("", "") - 1.0).abs() < f64::EPSILON);
285    }
286}