gemini_adk_rs/evaluation/
evalset_parser.rs1use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9
10use super::evaluator::EvalError;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct EvalSetFile {
19 pub name: String,
21 pub eval_cases: Vec<EvalCaseFile>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct EvalCaseFile {
28 pub eval_id: String,
30 pub conversation: Vec<InvocationFile>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct InvocationFile {
37 pub invocation_id: String,
39 pub user_content: String,
41 #[serde(default)]
43 pub expected_tool_use: Vec<ExpectedToolUse>,
44 #[serde(default)]
46 pub expected_response: Option<String>,
47 #[serde(default)]
49 pub intermediate_data: Option<IntermediateData>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ExpectedToolUse {
55 pub tool_name: String,
57 pub tool_input: serde_json::Value,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct IntermediateData {
64 #[serde(default)]
66 pub tool_uses: Vec<ToolUseRecord>,
67 #[serde(default)]
69 pub intermediate_responses: Vec<String>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ToolUseRecord {
75 pub tool_name: String,
77 pub tool_input: serde_json::Value,
79 pub tool_output: serde_json::Value,
81}
82
83pub fn parse_evalset(path: &Path) -> Result<EvalSetFile, EvalError> {
94 let contents = std::fs::read_to_string(path).map_err(|e| {
95 EvalError::Io(format!(
96 "Failed to read evalset file {}: {e}",
97 path.display()
98 ))
99 })?;
100 parse_evalset_str(&contents)
101}
102
103pub fn parse_evalset_str(json: &str) -> Result<EvalSetFile, EvalError> {
109 serde_json::from_str(json).map_err(|e| EvalError::Parse(format!("Invalid evalset JSON: {e}")))
110}
111
112impl EvalSetFile {
117 pub fn to_eval_pairs(&self) -> Vec<(Vec<super::Invocation>, Vec<super::Invocation>)> {
124 self.eval_cases
125 .iter()
126 .map(|case| {
127 let mut actual_invocations = Vec::new();
128 let mut expected_invocations = Vec::new();
129
130 for inv in &case.conversation {
131 let mut actual_turns = vec![super::InvocationTurn {
133 role: "user".into(),
134 content: inv.user_content.clone(),
135 tool_calls: vec![],
136 tool_results: vec![],
137 }];
138
139 if let Some(ref data) = inv.intermediate_data {
140 for tu in &data.tool_uses {
142 actual_turns.push(super::InvocationTurn {
143 role: "model".into(),
144 content: String::new(),
145 tool_calls: vec![serde_json::json!({
146 "name": tu.tool_name,
147 "args": tu.tool_input,
148 })],
149 tool_results: vec![tu.tool_output.clone()],
150 });
151 }
152 for resp in &data.intermediate_responses {
154 actual_turns.push(super::InvocationTurn {
155 role: "model".into(),
156 content: resp.clone(),
157 tool_calls: vec![],
158 tool_results: vec![],
159 });
160 }
161 }
162
163 actual_invocations.push(super::Invocation {
164 id: inv.invocation_id.clone(),
165 turns: actual_turns,
166 metadata: serde_json::Value::Null,
167 });
168
169 let mut expected_turns = vec![super::InvocationTurn {
171 role: "user".into(),
172 content: inv.user_content.clone(),
173 tool_calls: vec![],
174 tool_results: vec![],
175 }];
176
177 if !inv.expected_tool_use.is_empty() {
179 expected_turns.push(super::InvocationTurn {
180 role: "model".into(),
181 content: String::new(),
182 tool_calls: inv
183 .expected_tool_use
184 .iter()
185 .map(|tu| {
186 serde_json::json!({
187 "name": tu.tool_name,
188 "args": tu.tool_input,
189 })
190 })
191 .collect(),
192 tool_results: vec![],
193 });
194 }
195
196 if let Some(ref resp) = inv.expected_response {
198 expected_turns.push(super::InvocationTurn {
199 role: "model".into(),
200 content: resp.clone(),
201 tool_calls: vec![],
202 tool_results: vec![],
203 });
204 }
205
206 expected_invocations.push(super::Invocation {
207 id: inv.invocation_id.clone(),
208 turns: expected_turns,
209 metadata: serde_json::Value::Null,
210 });
211 }
212
213 (actual_invocations, expected_invocations)
214 })
215 .collect()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 const SAMPLE_EVALSET: &str = r#"{
224 "name": "weather-agent-eval",
225 "eval_cases": [
226 {
227 "eval_id": "case-1",
228 "conversation": [
229 {
230 "invocation_id": "inv-1",
231 "user_content": "What's the weather in London?",
232 "expected_tool_use": [
233 {
234 "tool_name": "get_weather",
235 "tool_input": {"city": "London"}
236 }
237 ],
238 "expected_response": "The weather in London is 15°C and cloudy.",
239 "intermediate_data": {
240 "tool_uses": [
241 {
242 "tool_name": "get_weather",
243 "tool_input": {"city": "London"},
244 "tool_output": {"temp": 15, "condition": "cloudy"}
245 }
246 ],
247 "intermediate_responses": ["Let me check the weather for London."]
248 }
249 },
250 {
251 "invocation_id": "inv-2",
252 "user_content": "And in Paris?",
253 "expected_tool_use": [
254 {
255 "tool_name": "get_weather",
256 "tool_input": {"city": "Paris"}
257 }
258 ],
259 "expected_response": null
260 }
261 ]
262 }
263 ]
264 }"#;
265
266 #[test]
267 fn parse_valid_evalset() {
268 let evalset = parse_evalset_str(SAMPLE_EVALSET).unwrap();
269 assert_eq!(evalset.name, "weather-agent-eval");
270 assert_eq!(evalset.eval_cases.len(), 1);
271
272 let case = &evalset.eval_cases[0];
273 assert_eq!(case.eval_id, "case-1");
274 assert_eq!(case.conversation.len(), 2);
275
276 let inv1 = &case.conversation[0];
277 assert_eq!(inv1.invocation_id, "inv-1");
278 assert_eq!(inv1.user_content, "What's the weather in London?");
279 assert_eq!(inv1.expected_tool_use.len(), 1);
280 assert_eq!(inv1.expected_tool_use[0].tool_name, "get_weather");
281 assert!(inv1.expected_response.is_some());
282 assert!(inv1.intermediate_data.is_some());
283
284 let data = inv1.intermediate_data.as_ref().unwrap();
285 assert_eq!(data.tool_uses.len(), 1);
286 assert_eq!(data.intermediate_responses.len(), 1);
287 }
288
289 #[test]
290 fn parse_invalid_json() {
291 let result = parse_evalset_str("not json");
292 assert!(result.is_err());
293 let err = result.unwrap_err();
294 assert!(err.to_string().contains("Invalid evalset JSON"));
295 }
296
297 #[test]
298 fn to_eval_pairs_converts_correctly() {
299 let evalset = parse_evalset_str(SAMPLE_EVALSET).unwrap();
300 let pairs = evalset.to_eval_pairs();
301 assert_eq!(pairs.len(), 1);
302
303 let (actual, expected) = &pairs[0];
304 assert_eq!(actual.len(), 2);
305 assert_eq!(expected.len(), 2);
306
307 assert_eq!(actual[0].id, "inv-1");
309 assert_eq!(actual[0].turns.len(), 3); assert_eq!(actual[0].turns[0].role, "user");
311 assert_eq!(actual[0].turns[1].role, "model");
312 assert!(!actual[0].turns[1].tool_calls.is_empty());
313
314 assert_eq!(expected[0].turns.len(), 3); }
317
318 #[test]
319 fn minimal_evalset() {
320 let json = r#"{
321 "name": "minimal",
322 "eval_cases": [{
323 "eval_id": "c1",
324 "conversation": [{
325 "invocation_id": "i1",
326 "user_content": "hello"
327 }]
328 }]
329 }"#;
330 let evalset = parse_evalset_str(json).unwrap();
331 assert_eq!(
332 evalset.eval_cases[0].conversation[0]
333 .expected_tool_use
334 .len(),
335 0
336 );
337 assert!(evalset.eval_cases[0].conversation[0]
338 .expected_response
339 .is_none());
340 assert!(evalset.eval_cases[0].conversation[0]
341 .intermediate_data
342 .is_none());
343 }
344}