gemini_adk_rs/evaluation/
trajectory_evaluator.rs1use async_trait::async_trait;
6
7use super::eval_case::Invocation;
8use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
9use super::evaluator::{EvalError, Evaluator};
10
11#[derive(Debug, Clone)]
16pub struct TrajectoryEvaluator {
17 pub strict_order: bool,
19 metric_name: String,
20}
21
22impl TrajectoryEvaluator {
23 pub fn new(strict_order: bool) -> Self {
25 Self {
26 strict_order,
27 metric_name: "trajectory_match".into(),
28 }
29 }
30
31 pub fn with_metric_name(mut self, name: impl Into<String>) -> Self {
33 self.metric_name = name.into();
34 self
35 }
36
37 fn extract_tool_names(inv: &Invocation) -> Vec<String> {
39 inv.turns
40 .iter()
41 .flat_map(|turn| {
42 turn.tool_calls
43 .iter()
44 .filter_map(|tc| tc.get("name").and_then(|n| n.as_str()).map(String::from))
45 })
46 .collect()
47 }
48
49 fn score_trajectory(&self, actual: &[String], expected: &[String]) -> (f64, String) {
51 if expected.is_empty() && actual.is_empty() {
52 return (1.0, "Both empty — trivially matching".into());
53 }
54
55 if expected.is_empty() {
56 return (1.0, "No expected tools — any trajectory acceptable".into());
57 }
58
59 if self.strict_order {
60 let lcs_len = lcs_length(actual, expected);
62 let max_len = actual.len().max(expected.len());
63 let score = if max_len == 0 {
64 1.0
65 } else {
66 lcs_len as f64 / max_len as f64
67 };
68 (
69 score,
70 format!(
71 "Strict order: LCS {lcs_len}/{max_len} (actual={}, expected={})",
72 actual.len(),
73 expected.len()
74 ),
75 )
76 } else {
77 let expected_set: std::collections::HashSet<&str> =
79 expected.iter().map(|s| s.as_str()).collect();
80 let actual_set: std::collections::HashSet<&str> =
81 actual.iter().map(|s| s.as_str()).collect();
82
83 let intersection = expected_set.intersection(&actual_set).count();
84 let union = expected_set.union(&actual_set).count();
85 let score = if union == 0 {
86 1.0
87 } else {
88 intersection as f64 / union as f64
89 };
90 (
91 score,
92 format!("Set match: {intersection}/{union} tools overlap"),
93 )
94 }
95 }
96}
97
98impl Default for TrajectoryEvaluator {
99 fn default() -> Self {
100 Self::new(true)
101 }
102}
103
104#[async_trait]
105impl Evaluator for TrajectoryEvaluator {
106 async fn evaluate(
107 &self,
108 actual: &[Invocation],
109 expected: Option<&[Invocation]>,
110 ) -> Result<EvalResult, EvalError> {
111 let expected = expected.ok_or_else(|| {
112 EvalError::InvalidInput("TrajectoryEvaluator requires expected invocations".into())
113 })?;
114
115 let mut per_invocation = Vec::new();
116 let mut total_score = 0.0;
117
118 for (i, actual_inv) in actual.iter().enumerate() {
119 let actual_tools = Self::extract_tool_names(actual_inv);
120 let expected_tools = expected
121 .get(i)
122 .map(Self::extract_tool_names)
123 .unwrap_or_default();
124
125 let (score, explanation) = self.score_trajectory(&actual_tools, &expected_tools);
126 total_score += score;
127
128 per_invocation.push(PerInvocationResult {
129 invocation_id: if actual_inv.id.is_empty() {
130 format!("inv-{i}")
131 } else {
132 actual_inv.id.clone()
133 },
134 score,
135 explanation: Some(explanation),
136 });
137 }
138
139 let overall_score = if actual.is_empty() {
140 0.0
141 } else {
142 total_score / actual.len() as f64
143 };
144
145 Ok(EvalResult {
146 overall_score,
147 metrics: vec![EvalMetric {
148 name: self.metric_name.clone(),
149 score: overall_score,
150 per_invocation,
151 }],
152 })
153 }
154}
155
156fn lcs_length(a: &[String], b: &[String]) -> usize {
158 let m = a.len();
159 let n = b.len();
160 let mut dp = vec![vec![0usize; n + 1]; m + 1];
161
162 for i in 1..=m {
163 for j in 1..=n {
164 if a[i - 1] == b[j - 1] {
165 dp[i][j] = dp[i - 1][j - 1] + 1;
166 } else {
167 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
168 }
169 }
170 }
171
172 dp[m][n]
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::evaluation::eval_case::InvocationTurn;
179 use serde_json::json;
180
181 fn make_invocation_with_tools(tool_names: &[&str]) -> Invocation {
182 Invocation {
183 id: String::new(),
184 turns: vec![InvocationTurn {
185 role: "model".into(),
186 content: String::new(),
187 tool_calls: tool_names
188 .iter()
189 .map(|name| json!({"name": name, "args": {}}))
190 .collect(),
191 tool_results: vec![],
192 }],
193 metadata: serde_json::Value::Null,
194 }
195 }
196
197 #[tokio::test]
198 async fn strict_order_perfect_match() {
199 let eval = TrajectoryEvaluator::new(true);
200 let actual = vec![make_invocation_with_tools(&["search", "lookup"])];
201 let expected = vec![make_invocation_with_tools(&["search", "lookup"])];
202 let result = eval.evaluate(&actual, Some(&expected)).await.unwrap();
203 assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
204 }
205
206 #[tokio::test]
207 async fn set_match_unordered() {
208 let eval = TrajectoryEvaluator::new(false);
209 let actual = vec![make_invocation_with_tools(&["lookup", "search"])];
210 let expected = vec![make_invocation_with_tools(&["search", "lookup"])];
211 let result = eval.evaluate(&actual, Some(&expected)).await.unwrap();
212 assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
213 }
214
215 #[tokio::test]
216 async fn partial_match() {
217 let eval = TrajectoryEvaluator::new(false);
218 let actual = vec![make_invocation_with_tools(&["search"])];
219 let expected = vec![make_invocation_with_tools(&["search", "lookup"])];
220 let result = eval.evaluate(&actual, Some(&expected)).await.unwrap();
221 assert!(result.overall_score > 0.0);
222 assert!(result.overall_score < 1.0);
223 }
224
225 #[tokio::test]
226 async fn requires_expected() {
227 let eval = TrajectoryEvaluator::default();
228 let actual = vec![make_invocation_with_tools(&["search"])];
229 assert!(eval.evaluate(&actual, None).await.is_err());
230 }
231
232 #[test]
233 fn lcs_identical() {
234 let a: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
235 let b: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
236 assert_eq!(lcs_length(&a, &b), 3);
237 }
238
239 #[test]
240 fn lcs_different() {
241 let a: Vec<String> = vec!["a".into(), "b".into()];
242 let b: Vec<String> = vec!["c".into(), "d".into()];
243 assert_eq!(lcs_length(&a, &b), 0);
244 }
245}