gemini_adk_rs/evaluation/
evalset_parser.rs

1//! `.evalset.json` parser — load golden evaluation datasets.
2//!
3//! Parses the upstream ADK golden dataset format into typed Rust structures
4//! that can be fed into evaluators.
5
6use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9
10use super::evaluator::EvalError;
11
12// ---------------------------------------------------------------------------
13// Wire types — match the upstream `.evalset.json` schema
14// ---------------------------------------------------------------------------
15
16/// Top-level structure of a `.evalset.json` file.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EvalSetFile {
19    /// Name of the evaluation set.
20    pub name: String,
21    /// The evaluation cases.
22    pub eval_cases: Vec<EvalCaseFile>,
23}
24
25/// A single evaluation case within an eval set file.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct EvalCaseFile {
28    /// Unique identifier for this eval case.
29    pub eval_id: String,
30    /// The multi-turn conversation to evaluate.
31    pub conversation: Vec<InvocationFile>,
32}
33
34/// A single invocation (turn pair) in the eval case conversation.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct InvocationFile {
37    /// Unique identifier for this invocation.
38    pub invocation_id: String,
39    /// The user's input content.
40    pub user_content: String,
41    /// Expected tool uses for this invocation.
42    #[serde(default)]
43    pub expected_tool_use: Vec<ExpectedToolUse>,
44    /// Expected final response text (if any).
45    #[serde(default)]
46    pub expected_response: Option<String>,
47    /// Intermediate data recorded during this invocation.
48    #[serde(default)]
49    pub intermediate_data: Option<IntermediateData>,
50}
51
52/// An expected tool call within an invocation.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ExpectedToolUse {
55    /// Name of the tool expected to be called.
56    pub tool_name: String,
57    /// Expected input arguments to the tool.
58    pub tool_input: serde_json::Value,
59}
60
61/// Intermediate data captured during agent execution.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct IntermediateData {
64    /// Tool uses that actually occurred.
65    #[serde(default)]
66    pub tool_uses: Vec<ToolUseRecord>,
67    /// Intermediate text responses from the model.
68    #[serde(default)]
69    pub intermediate_responses: Vec<String>,
70}
71
72/// Record of a tool use that occurred during execution.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ToolUseRecord {
75    /// Name of the tool that was called.
76    pub tool_name: String,
77    /// Input arguments passed to the tool.
78    pub tool_input: serde_json::Value,
79    /// Output returned by the tool.
80    pub tool_output: serde_json::Value,
81}
82
83// ---------------------------------------------------------------------------
84// Parsing functions
85// ---------------------------------------------------------------------------
86
87/// Parse an `.evalset.json` file from disk.
88///
89/// # Errors
90///
91/// Returns `EvalError::Io` if the file cannot be read, or
92/// `EvalError::Parse` if the JSON is invalid.
93pub fn parse_evalset(path: &Path) -> Result<EvalSetFile, EvalError> {
94    let contents = std::fs::read_to_string(path).map_err(|e| {
95        EvalError::Io(format!(
96            "Failed to read evalset file {}: {e}",
97            path.display()
98        ))
99    })?;
100    parse_evalset_str(&contents)
101}
102
103/// Parse an `.evalset.json` from a raw JSON string.
104///
105/// # Errors
106///
107/// Returns `EvalError::Parse` if the JSON is invalid.
108pub fn parse_evalset_str(json: &str) -> Result<EvalSetFile, EvalError> {
109    serde_json::from_str(json).map_err(|e| EvalError::Parse(format!("Invalid evalset JSON: {e}")))
110}
111
112// ---------------------------------------------------------------------------
113// Conversion helpers — turn file types into evaluator types
114// ---------------------------------------------------------------------------
115
116impl EvalSetFile {
117    /// Convert this file representation into evaluator-compatible [`super::Invocation`] pairs.
118    ///
119    /// Returns `(actual_invocations, expected_invocations)` for each eval case.
120    /// Actual invocations are built from `intermediate_data` when present,
121    /// falling back to user content only. Expected invocations are built from
122    /// `expected_tool_use` and `expected_response`.
123    pub fn to_eval_pairs(&self) -> Vec<(Vec<super::Invocation>, Vec<super::Invocation>)> {
124        self.eval_cases
125            .iter()
126            .map(|case| {
127                let mut actual_invocations = Vec::new();
128                let mut expected_invocations = Vec::new();
129
130                for inv in &case.conversation {
131                    // Build the actual invocation from intermediate data
132                    let mut actual_turns = vec![super::InvocationTurn {
133                        role: "user".into(),
134                        content: inv.user_content.clone(),
135                        tool_calls: vec![],
136                        tool_results: vec![],
137                    }];
138
139                    if let Some(ref data) = inv.intermediate_data {
140                        // Add tool call turns from intermediate data
141                        for tu in &data.tool_uses {
142                            actual_turns.push(super::InvocationTurn {
143                                role: "model".into(),
144                                content: String::new(),
145                                tool_calls: vec![serde_json::json!({
146                                    "name": tu.tool_name,
147                                    "args": tu.tool_input,
148                                })],
149                                tool_results: vec![tu.tool_output.clone()],
150                            });
151                        }
152                        // Add intermediate responses
153                        for resp in &data.intermediate_responses {
154                            actual_turns.push(super::InvocationTurn {
155                                role: "model".into(),
156                                content: resp.clone(),
157                                tool_calls: vec![],
158                                tool_results: vec![],
159                            });
160                        }
161                    }
162
163                    actual_invocations.push(super::Invocation {
164                        id: inv.invocation_id.clone(),
165                        turns: actual_turns,
166                        metadata: serde_json::Value::Null,
167                    });
168
169                    // Build the expected invocation
170                    let mut expected_turns = vec![super::InvocationTurn {
171                        role: "user".into(),
172                        content: inv.user_content.clone(),
173                        tool_calls: vec![],
174                        tool_results: vec![],
175                    }];
176
177                    // Add expected tool calls
178                    if !inv.expected_tool_use.is_empty() {
179                        expected_turns.push(super::InvocationTurn {
180                            role: "model".into(),
181                            content: String::new(),
182                            tool_calls: inv
183                                .expected_tool_use
184                                .iter()
185                                .map(|tu| {
186                                    serde_json::json!({
187                                        "name": tu.tool_name,
188                                        "args": tu.tool_input,
189                                    })
190                                })
191                                .collect(),
192                            tool_results: vec![],
193                        });
194                    }
195
196                    // Add expected response
197                    if let Some(ref resp) = inv.expected_response {
198                        expected_turns.push(super::InvocationTurn {
199                            role: "model".into(),
200                            content: resp.clone(),
201                            tool_calls: vec![],
202                            tool_results: vec![],
203                        });
204                    }
205
206                    expected_invocations.push(super::Invocation {
207                        id: inv.invocation_id.clone(),
208                        turns: expected_turns,
209                        metadata: serde_json::Value::Null,
210                    });
211                }
212
213                (actual_invocations, expected_invocations)
214            })
215            .collect()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    const SAMPLE_EVALSET: &str = r#"{
224        "name": "weather-agent-eval",
225        "eval_cases": [
226            {
227                "eval_id": "case-1",
228                "conversation": [
229                    {
230                        "invocation_id": "inv-1",
231                        "user_content": "What's the weather in London?",
232                        "expected_tool_use": [
233                            {
234                                "tool_name": "get_weather",
235                                "tool_input": {"city": "London"}
236                            }
237                        ],
238                        "expected_response": "The weather in London is 15°C and cloudy.",
239                        "intermediate_data": {
240                            "tool_uses": [
241                                {
242                                    "tool_name": "get_weather",
243                                    "tool_input": {"city": "London"},
244                                    "tool_output": {"temp": 15, "condition": "cloudy"}
245                                }
246                            ],
247                            "intermediate_responses": ["Let me check the weather for London."]
248                        }
249                    },
250                    {
251                        "invocation_id": "inv-2",
252                        "user_content": "And in Paris?",
253                        "expected_tool_use": [
254                            {
255                                "tool_name": "get_weather",
256                                "tool_input": {"city": "Paris"}
257                            }
258                        ],
259                        "expected_response": null
260                    }
261                ]
262            }
263        ]
264    }"#;
265
266    #[test]
267    fn parse_valid_evalset() {
268        let evalset = parse_evalset_str(SAMPLE_EVALSET).unwrap();
269        assert_eq!(evalset.name, "weather-agent-eval");
270        assert_eq!(evalset.eval_cases.len(), 1);
271
272        let case = &evalset.eval_cases[0];
273        assert_eq!(case.eval_id, "case-1");
274        assert_eq!(case.conversation.len(), 2);
275
276        let inv1 = &case.conversation[0];
277        assert_eq!(inv1.invocation_id, "inv-1");
278        assert_eq!(inv1.user_content, "What's the weather in London?");
279        assert_eq!(inv1.expected_tool_use.len(), 1);
280        assert_eq!(inv1.expected_tool_use[0].tool_name, "get_weather");
281        assert!(inv1.expected_response.is_some());
282        assert!(inv1.intermediate_data.is_some());
283
284        let data = inv1.intermediate_data.as_ref().unwrap();
285        assert_eq!(data.tool_uses.len(), 1);
286        assert_eq!(data.intermediate_responses.len(), 1);
287    }
288
289    #[test]
290    fn parse_invalid_json() {
291        let result = parse_evalset_str("not json");
292        assert!(result.is_err());
293        let err = result.unwrap_err();
294        assert!(err.to_string().contains("Invalid evalset JSON"));
295    }
296
297    #[test]
298    fn to_eval_pairs_converts_correctly() {
299        let evalset = parse_evalset_str(SAMPLE_EVALSET).unwrap();
300        let pairs = evalset.to_eval_pairs();
301        assert_eq!(pairs.len(), 1);
302
303        let (actual, expected) = &pairs[0];
304        assert_eq!(actual.len(), 2);
305        assert_eq!(expected.len(), 2);
306
307        // First invocation should have user turn + tool call turn + intermediate response
308        assert_eq!(actual[0].id, "inv-1");
309        assert_eq!(actual[0].turns.len(), 3); // user + tool call + intermediate response
310        assert_eq!(actual[0].turns[0].role, "user");
311        assert_eq!(actual[0].turns[1].role, "model");
312        assert!(!actual[0].turns[1].tool_calls.is_empty());
313
314        // Expected should have user turn + tool call turn + response turn
315        assert_eq!(expected[0].turns.len(), 3); // user + expected tool + expected response
316    }
317
318    #[test]
319    fn minimal_evalset() {
320        let json = r#"{
321            "name": "minimal",
322            "eval_cases": [{
323                "eval_id": "c1",
324                "conversation": [{
325                    "invocation_id": "i1",
326                    "user_content": "hello"
327                }]
328            }]
329        }"#;
330        let evalset = parse_evalset_str(json).unwrap();
331        assert_eq!(
332            evalset.eval_cases[0].conversation[0]
333                .expected_tool_use
334                .len(),
335            0
336        );
337        assert!(evalset.eval_cases[0].conversation[0]
338            .expected_response
339            .is_none());
340        assert!(evalset.eval_cases[0].conversation[0]
341            .intermediate_data
342            .is_none());
343    }
344}