gemini_adk_rs/evaluation/
test_config.rs1use std::collections::HashMap;
6use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9
10use super::evaluator::EvalError;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TestConfig {
19 pub criteria: HashMap<String, CriterionConfig>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(untagged)]
26pub enum CriterionConfig {
27 Threshold(f64),
29 LlmJudge {
31 threshold: f64,
33 #[serde(default)]
35 judge_model: Option<String>,
36 #[serde(default)]
38 num_samples: Option<u32>,
39 },
40}
41
42impl CriterionConfig {
43 pub fn threshold(&self) -> f64 {
45 match self {
46 Self::Threshold(t) => *t,
47 Self::LlmJudge { threshold, .. } => *threshold,
48 }
49 }
50
51 pub fn passes(&self, score: f64) -> bool {
53 score >= self.threshold()
54 }
55}
56
57impl TestConfig {
58 pub fn check_all(&self, scores: &HashMap<String, f64>) -> HashMap<String, (bool, f64, f64)> {
62 self.criteria
63 .iter()
64 .map(|(name, config)| {
65 let score = scores.get(name).copied().unwrap_or(0.0);
66 let threshold = config.threshold();
67 let passed = config.passes(score);
68 (name.clone(), (passed, score, threshold))
69 })
70 .collect()
71 }
72
73 pub fn all_pass(&self, scores: &HashMap<String, f64>) -> bool {
75 self.check_all(scores)
76 .values()
77 .all(|(passed, _, _)| *passed)
78 }
79}
80
81pub fn parse_test_config(path: &Path) -> Result<TestConfig, EvalError> {
92 let contents = std::fs::read_to_string(path).map_err(|e| {
93 EvalError::Io(format!(
94 "Failed to read test config {}: {e}",
95 path.display()
96 ))
97 })?;
98 parse_test_config_str(&contents)
99}
100
101pub fn parse_test_config_str(json: &str) -> Result<TestConfig, EvalError> {
103 serde_json::from_str(json)
104 .map_err(|e| EvalError::Parse(format!("Invalid test config JSON: {e}")))
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn parse_simple_thresholds() {
113 let json = r#"{
114 "criteria": {
115 "response_quality": 0.8,
116 "tool_accuracy": 0.9
117 }
118 }"#;
119 let config = parse_test_config_str(json).unwrap();
120 assert_eq!(config.criteria.len(), 2);
121
122 let rq = &config.criteria["response_quality"];
123 assert!((rq.threshold() - 0.8).abs() < f64::EPSILON);
124 }
125
126 #[test]
127 fn parse_llm_judge_config() {
128 let json = r#"{
129 "criteria": {
130 "coherence": {
131 "threshold": 0.7,
132 "judge_model": "gemini-2.0-flash",
133 "num_samples": 3
134 }
135 }
136 }"#;
137 let config = parse_test_config_str(json).unwrap();
138 match &config.criteria["coherence"] {
139 CriterionConfig::LlmJudge {
140 threshold,
141 judge_model,
142 num_samples,
143 } => {
144 assert!((threshold - 0.7).abs() < f64::EPSILON);
145 assert_eq!(judge_model.as_deref(), Some("gemini-2.0-flash"));
146 assert_eq!(*num_samples, Some(3));
147 }
148 _ => panic!("Expected LlmJudge variant"),
149 }
150 }
151
152 #[test]
153 fn check_all_passing() {
154 let json = r#"{"criteria": {"a": 0.5, "b": 0.8}}"#;
155 let config = parse_test_config_str(json).unwrap();
156 let scores: HashMap<String, f64> =
157 [("a".into(), 0.6), ("b".into(), 0.9)].into_iter().collect();
158 assert!(config.all_pass(&scores));
159 }
160
161 #[test]
162 fn check_all_failing() {
163 let json = r#"{"criteria": {"a": 0.5, "b": 0.8}}"#;
164 let config = parse_test_config_str(json).unwrap();
165 let scores: HashMap<String, f64> =
166 [("a".into(), 0.6), ("b".into(), 0.7)].into_iter().collect();
167 assert!(!config.all_pass(&scores));
168 }
169
170 #[test]
171 fn missing_score_defaults_to_zero() {
172 let json = r#"{"criteria": {"a": 0.5}}"#;
173 let config = parse_test_config_str(json).unwrap();
174 let scores: HashMap<String, f64> = HashMap::new();
175 assert!(!config.all_pass(&scores));
176 }
177
178 #[test]
179 fn parse_invalid_json() {
180 let result = parse_test_config_str("bad");
181 assert!(result.is_err());
182 }
183}