gemini_adk_rs/evaluation/
response_evaluator.rs1use async_trait::async_trait;
7
8use super::eval_case::Invocation;
9use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
10use super::evaluator::{EvalError, Evaluator};
11
12#[derive(Debug, Clone, Copy, Default)]
14pub enum MatchStrategy {
15 Exact,
17 #[default]
19 Contains,
20 Fuzzy {
22 threshold: f64,
24 },
25}
26
27#[derive(Debug, Clone)]
29pub struct ResponseEvaluator {
30 strategy: MatchStrategy,
31 metric_name: String,
32}
33
34impl ResponseEvaluator {
35 pub fn new(strategy: MatchStrategy) -> Self {
37 Self {
38 strategy,
39 metric_name: "response_match".into(),
40 }
41 }
42
43 pub fn with_metric_name(mut self, name: impl Into<String>) -> Self {
45 self.metric_name = name.into();
46 self
47 }
48
49 fn last_model_response(inv: &Invocation) -> Option<&str> {
51 inv.turns
52 .iter()
53 .rev()
54 .find(|t| t.role == "model")
55 .map(|t| t.content.as_str())
56 }
57
58 fn score_pair(&self, actual: &str, expected: &str) -> (f64, String) {
60 match self.strategy {
61 MatchStrategy::Exact => {
62 if actual == expected {
63 (1.0, "Exact match".into())
64 } else {
65 (0.0, "No exact match".into())
66 }
67 }
68 MatchStrategy::Contains => {
69 let actual_lower = actual.to_lowercase();
70 let expected_lower = expected.to_lowercase();
71 if actual_lower.contains(&expected_lower) || expected_lower.contains(&actual_lower)
72 {
73 (1.0, "Contains match".into())
74 } else {
75 (0.0, "No containment match".into())
76 }
77 }
78 MatchStrategy::Fuzzy { threshold } => {
79 let similarity = string_similarity(actual, expected);
80 if similarity >= threshold {
81 (similarity, format!("Fuzzy match: {similarity:.2}"))
82 } else {
83 (
84 similarity,
85 format!("Below threshold {threshold:.2}: {similarity:.2}"),
86 )
87 }
88 }
89 }
90 }
91}
92
93impl Default for ResponseEvaluator {
94 fn default() -> Self {
95 Self::new(MatchStrategy::default())
96 }
97}
98
99#[async_trait]
100impl Evaluator for ResponseEvaluator {
101 async fn evaluate(
102 &self,
103 actual: &[Invocation],
104 expected: Option<&[Invocation]>,
105 ) -> Result<EvalResult, EvalError> {
106 let expected = expected.ok_or_else(|| {
107 EvalError::InvalidInput("ResponseEvaluator requires expected invocations".into())
108 })?;
109
110 let mut per_invocation = Vec::new();
111 let mut total_score = 0.0;
112
113 for (i, actual_inv) in actual.iter().enumerate() {
114 let actual_resp = Self::last_model_response(actual_inv).unwrap_or("");
115 let expected_resp = expected
116 .get(i)
117 .and_then(|e| Self::last_model_response(e))
118 .unwrap_or("");
119
120 let (score, explanation) = self.score_pair(actual_resp, expected_resp);
121 total_score += score;
122
123 per_invocation.push(PerInvocationResult {
124 invocation_id: if actual_inv.id.is_empty() {
125 format!("inv-{i}")
126 } else {
127 actual_inv.id.clone()
128 },
129 score,
130 explanation: Some(explanation),
131 });
132 }
133
134 let overall_score = if actual.is_empty() {
135 0.0
136 } else {
137 total_score / actual.len() as f64
138 };
139
140 Ok(EvalResult {
141 overall_score,
142 metrics: vec![EvalMetric {
143 name: self.metric_name.clone(),
144 score: overall_score,
145 per_invocation,
146 }],
147 })
148 }
149}
150
151fn string_similarity(a: &str, b: &str) -> f64 {
153 if a.is_empty() && b.is_empty() {
154 return 1.0;
155 }
156 let max_len = a.len().max(b.len()) as f64;
157 if max_len == 0.0 {
158 return 1.0;
159 }
160
161 let distance = levenshtein_distance(a, b) as f64;
162 1.0 - (distance / max_len)
163}
164
165fn levenshtein_distance(a: &str, b: &str) -> usize {
167 let a_chars: Vec<char> = a.chars().collect();
168 let b_chars: Vec<char> = b.chars().collect();
169 let m = a_chars.len();
170 let n = b_chars.len();
171
172 let mut dp = vec![vec![0usize; n + 1]; m + 1];
173
174 for (i, row) in dp.iter_mut().enumerate().take(m + 1) {
175 row[0] = i;
176 }
177 for (j, val) in dp[0].iter_mut().enumerate().take(n + 1) {
178 *val = j;
179 }
180
181 for i in 1..=m {
182 for j in 1..=n {
183 let cost = if a_chars[i - 1] == b_chars[j - 1] {
184 0
185 } else {
186 1
187 };
188 dp[i][j] = (dp[i - 1][j] + 1)
189 .min(dp[i][j - 1] + 1)
190 .min(dp[i - 1][j - 1] + cost);
191 }
192 }
193
194 dp[m][n]
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::evaluation::eval_case::InvocationTurn;
201
202 fn make_invocation(model_response: &str) -> Invocation {
203 Invocation {
204 id: String::new(),
205 turns: vec![
206 InvocationTurn {
207 role: "user".into(),
208 content: "What is 2+2?".into(),
209 tool_calls: vec![],
210 tool_results: vec![],
211 },
212 InvocationTurn {
213 role: "model".into(),
214 content: model_response.into(),
215 tool_calls: vec![],
216 tool_results: vec![],
217 },
218 ],
219 metadata: serde_json::Value::Null,
220 }
221 }
222
223 #[tokio::test]
224 async fn exact_match() {
225 let evaluator = ResponseEvaluator::new(MatchStrategy::Exact);
226 let actual = vec![make_invocation("4")];
227 let expected = vec![make_invocation("4")];
228 let result = evaluator.evaluate(&actual, Some(&expected)).await.unwrap();
229 assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
230 }
231
232 #[tokio::test]
233 async fn exact_mismatch() {
234 let evaluator = ResponseEvaluator::new(MatchStrategy::Exact);
235 let actual = vec![make_invocation("four")];
236 let expected = vec![make_invocation("4")];
237 let result = evaluator.evaluate(&actual, Some(&expected)).await.unwrap();
238 assert!((result.overall_score - 0.0).abs() < f64::EPSILON);
239 }
240
241 #[tokio::test]
242 async fn contains_match() {
243 let evaluator = ResponseEvaluator::new(MatchStrategy::Contains);
244 let actual = vec![make_invocation("The answer is 4")];
245 let expected = vec![make_invocation("4")];
246 let result = evaluator.evaluate(&actual, Some(&expected)).await.unwrap();
247 assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
248 }
249
250 #[tokio::test]
251 async fn fuzzy_match() {
252 let evaluator = ResponseEvaluator::new(MatchStrategy::Fuzzy { threshold: 0.5 });
253 let actual = vec![make_invocation("hello world")];
254 let expected = vec![make_invocation("hello worl")];
255 let result = evaluator.evaluate(&actual, Some(&expected)).await.unwrap();
256 assert!(result.overall_score > 0.5);
257 }
258
259 #[tokio::test]
260 async fn requires_expected() {
261 let evaluator = ResponseEvaluator::default();
262 let actual = vec![make_invocation("test")];
263 let result = evaluator.evaluate(&actual, None).await;
264 assert!(result.is_err());
265 }
266
267 #[test]
268 fn levenshtein_identical() {
269 assert_eq!(levenshtein_distance("abc", "abc"), 0);
270 }
271
272 #[test]
273 fn levenshtein_one_edit() {
274 assert_eq!(levenshtein_distance("abc", "ab"), 1);
275 }
276
277 #[test]
278 fn similarity_identical() {
279 assert!((string_similarity("hello", "hello") - 1.0).abs() < f64::EPSILON);
280 }
281
282 #[test]
283 fn similarity_empty() {
284 assert!((string_similarity("", "") - 1.0).abs() < f64::EPSILON);
285 }
286}