gemini_adk_rs/evaluation/
trajectory_evaluator.rs

1//! Trajectory evaluator — evaluates the tool call trajectory of agent invocations.
2//!
3//! Compares the sequence of tool calls made by the agent against expected trajectories.
4
5use async_trait::async_trait;
6
7use super::eval_case::Invocation;
8use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
9use super::evaluator::{EvalError, Evaluator};
10
11/// Evaluates the tool-call trajectory of agent invocations.
12///
13/// Compares actual tool calls (names and order) against expected tool calls
14/// to assess whether the agent followed the correct reasoning path.
15#[derive(Debug, Clone)]
16pub struct TrajectoryEvaluator {
17    /// Whether to enforce strict ordering of tool calls.
18    pub strict_order: bool,
19    metric_name: String,
20}
21
22impl TrajectoryEvaluator {
23    /// Create a new trajectory evaluator.
24    pub fn new(strict_order: bool) -> Self {
25        Self {
26            strict_order,
27            metric_name: "trajectory_match".into(),
28        }
29    }
30
31    /// Set a custom metric name.
32    pub fn with_metric_name(mut self, name: impl Into<String>) -> Self {
33        self.metric_name = name.into();
34        self
35    }
36
37    /// Extract tool call names from an invocation's turns.
38    fn extract_tool_names(inv: &Invocation) -> Vec<String> {
39        inv.turns
40            .iter()
41            .flat_map(|turn| {
42                turn.tool_calls
43                    .iter()
44                    .filter_map(|tc| tc.get("name").and_then(|n| n.as_str()).map(String::from))
45            })
46            .collect()
47    }
48
49    /// Score trajectory match between actual and expected tool call sequences.
50    fn score_trajectory(&self, actual: &[String], expected: &[String]) -> (f64, String) {
51        if expected.is_empty() && actual.is_empty() {
52            return (1.0, "Both empty — trivially matching".into());
53        }
54
55        if expected.is_empty() {
56            return (1.0, "No expected tools — any trajectory acceptable".into());
57        }
58
59        if self.strict_order {
60            // Longest common subsequence ratio
61            let lcs_len = lcs_length(actual, expected);
62            let max_len = actual.len().max(expected.len());
63            let score = if max_len == 0 {
64                1.0
65            } else {
66                lcs_len as f64 / max_len as f64
67            };
68            (
69                score,
70                format!(
71                    "Strict order: LCS {lcs_len}/{max_len} (actual={}, expected={})",
72                    actual.len(),
73                    expected.len()
74                ),
75            )
76        } else {
77            // Set-based: how many expected tools were called
78            let expected_set: std::collections::HashSet<&str> =
79                expected.iter().map(|s| s.as_str()).collect();
80            let actual_set: std::collections::HashSet<&str> =
81                actual.iter().map(|s| s.as_str()).collect();
82
83            let intersection = expected_set.intersection(&actual_set).count();
84            let union = expected_set.union(&actual_set).count();
85            let score = if union == 0 {
86                1.0
87            } else {
88                intersection as f64 / union as f64
89            };
90            (
91                score,
92                format!("Set match: {intersection}/{union} tools overlap"),
93            )
94        }
95    }
96}
97
98impl Default for TrajectoryEvaluator {
99    fn default() -> Self {
100        Self::new(true)
101    }
102}
103
104#[async_trait]
105impl Evaluator for TrajectoryEvaluator {
106    async fn evaluate(
107        &self,
108        actual: &[Invocation],
109        expected: Option<&[Invocation]>,
110    ) -> Result<EvalResult, EvalError> {
111        let expected = expected.ok_or_else(|| {
112            EvalError::InvalidInput("TrajectoryEvaluator requires expected invocations".into())
113        })?;
114
115        let mut per_invocation = Vec::new();
116        let mut total_score = 0.0;
117
118        for (i, actual_inv) in actual.iter().enumerate() {
119            let actual_tools = Self::extract_tool_names(actual_inv);
120            let expected_tools = expected
121                .get(i)
122                .map(Self::extract_tool_names)
123                .unwrap_or_default();
124
125            let (score, explanation) = self.score_trajectory(&actual_tools, &expected_tools);
126            total_score += score;
127
128            per_invocation.push(PerInvocationResult {
129                invocation_id: if actual_inv.id.is_empty() {
130                    format!("inv-{i}")
131                } else {
132                    actual_inv.id.clone()
133                },
134                score,
135                explanation: Some(explanation),
136            });
137        }
138
139        let overall_score = if actual.is_empty() {
140            0.0
141        } else {
142            total_score / actual.len() as f64
143        };
144
145        Ok(EvalResult {
146            overall_score,
147            metrics: vec![EvalMetric {
148                name: self.metric_name.clone(),
149                score: overall_score,
150                per_invocation,
151            }],
152        })
153    }
154}
155
156/// Compute length of longest common subsequence.
157fn lcs_length(a: &[String], b: &[String]) -> usize {
158    let m = a.len();
159    let n = b.len();
160    let mut dp = vec![vec![0usize; n + 1]; m + 1];
161
162    for i in 1..=m {
163        for j in 1..=n {
164            if a[i - 1] == b[j - 1] {
165                dp[i][j] = dp[i - 1][j - 1] + 1;
166            } else {
167                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
168            }
169        }
170    }
171
172    dp[m][n]
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::evaluation::eval_case::InvocationTurn;
179    use serde_json::json;
180
181    fn make_invocation_with_tools(tool_names: &[&str]) -> Invocation {
182        Invocation {
183            id: String::new(),
184            turns: vec![InvocationTurn {
185                role: "model".into(),
186                content: String::new(),
187                tool_calls: tool_names
188                    .iter()
189                    .map(|name| json!({"name": name, "args": {}}))
190                    .collect(),
191                tool_results: vec![],
192            }],
193            metadata: serde_json::Value::Null,
194        }
195    }
196
197    #[tokio::test]
198    async fn strict_order_perfect_match() {
199        let eval = TrajectoryEvaluator::new(true);
200        let actual = vec![make_invocation_with_tools(&["search", "lookup"])];
201        let expected = vec![make_invocation_with_tools(&["search", "lookup"])];
202        let result = eval.evaluate(&actual, Some(&expected)).await.unwrap();
203        assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
204    }
205
206    #[tokio::test]
207    async fn set_match_unordered() {
208        let eval = TrajectoryEvaluator::new(false);
209        let actual = vec![make_invocation_with_tools(&["lookup", "search"])];
210        let expected = vec![make_invocation_with_tools(&["search", "lookup"])];
211        let result = eval.evaluate(&actual, Some(&expected)).await.unwrap();
212        assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
213    }
214
215    #[tokio::test]
216    async fn partial_match() {
217        let eval = TrajectoryEvaluator::new(false);
218        let actual = vec![make_invocation_with_tools(&["search"])];
219        let expected = vec![make_invocation_with_tools(&["search", "lookup"])];
220        let result = eval.evaluate(&actual, Some(&expected)).await.unwrap();
221        assert!(result.overall_score > 0.0);
222        assert!(result.overall_score < 1.0);
223    }
224
225    #[tokio::test]
226    async fn requires_expected() {
227        let eval = TrajectoryEvaluator::default();
228        let actual = vec![make_invocation_with_tools(&["search"])];
229        assert!(eval.evaluate(&actual, None).await.is_err());
230    }
231
232    #[test]
233    fn lcs_identical() {
234        let a: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
235        let b: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
236        assert_eq!(lcs_length(&a, &b), 3);
237    }
238
239    #[test]
240    fn lcs_different() {
241        let a: Vec<String> = vec!["a".into(), "b".into()];
242        let b: Vec<String> = vec!["c".into(), "d".into()];
243        assert_eq!(lcs_length(&a, &b), 0);
244    }
245}