gemini_adk_rs/evaluation/
test_config.rs

1//! Test configuration parser — load `test_config.json` for evaluation thresholds.
2//!
3//! Defines per-criterion pass/fail thresholds and optional LLM-judge configuration.
4
5use std::collections::HashMap;
6use std::path::Path;
7
8use serde::{Deserialize, Serialize};
9
10use super::evaluator::EvalError;
11
12// ---------------------------------------------------------------------------
13// Configuration types
14// ---------------------------------------------------------------------------
15
16/// Top-level test configuration loaded from `test_config.json`.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TestConfig {
19    /// Per-criterion evaluation configuration, keyed by criterion name.
20    pub criteria: HashMap<String, CriterionConfig>,
21}
22
23/// Configuration for a single evaluation criterion.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(untagged)]
26pub enum CriterionConfig {
27    /// Simple threshold: the metric score must be >= this value to pass.
28    Threshold(f64),
29    /// LLM-judge criterion with optional model and sampling configuration.
30    LlmJudge {
31        /// Minimum score threshold to pass.
32        threshold: f64,
33        /// Override the judge model (e.g., "gemini-2.0-flash").
34        #[serde(default)]
35        judge_model: Option<String>,
36        /// Number of LLM samples to average over for more stable scores.
37        #[serde(default)]
38        num_samples: Option<u32>,
39    },
40}
41
42impl CriterionConfig {
43    /// Get the threshold value regardless of variant.
44    pub fn threshold(&self) -> f64 {
45        match self {
46            Self::Threshold(t) => *t,
47            Self::LlmJudge { threshold, .. } => *threshold,
48        }
49    }
50
51    /// Check whether a score passes this criterion.
52    pub fn passes(&self, score: f64) -> bool {
53        score >= self.threshold()
54    }
55}
56
57impl TestConfig {
58    /// Check whether all criteria pass for a set of metric scores.
59    ///
60    /// Returns a map of criterion name -> (passed, score, threshold).
61    pub fn check_all(&self, scores: &HashMap<String, f64>) -> HashMap<String, (bool, f64, f64)> {
62        self.criteria
63            .iter()
64            .map(|(name, config)| {
65                let score = scores.get(name).copied().unwrap_or(0.0);
66                let threshold = config.threshold();
67                let passed = config.passes(score);
68                (name.clone(), (passed, score, threshold))
69            })
70            .collect()
71    }
72
73    /// Returns `true` if all criteria pass.
74    pub fn all_pass(&self, scores: &HashMap<String, f64>) -> bool {
75        self.check_all(scores)
76            .values()
77            .all(|(passed, _, _)| *passed)
78    }
79}
80
81// ---------------------------------------------------------------------------
82// Parsing
83// ---------------------------------------------------------------------------
84
85/// Parse a `test_config.json` file from disk.
86///
87/// # Errors
88///
89/// Returns `EvalError::Io` if the file cannot be read, or
90/// `EvalError::Parse` if the JSON is invalid.
91pub fn parse_test_config(path: &Path) -> Result<TestConfig, EvalError> {
92    let contents = std::fs::read_to_string(path).map_err(|e| {
93        EvalError::Io(format!(
94            "Failed to read test config {}: {e}",
95            path.display()
96        ))
97    })?;
98    parse_test_config_str(&contents)
99}
100
101/// Parse a `test_config.json` from a raw JSON string.
102pub fn parse_test_config_str(json: &str) -> Result<TestConfig, EvalError> {
103    serde_json::from_str(json)
104        .map_err(|e| EvalError::Parse(format!("Invalid test config JSON: {e}")))
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn parse_simple_thresholds() {
113        let json = r#"{
114            "criteria": {
115                "response_quality": 0.8,
116                "tool_accuracy": 0.9
117            }
118        }"#;
119        let config = parse_test_config_str(json).unwrap();
120        assert_eq!(config.criteria.len(), 2);
121
122        let rq = &config.criteria["response_quality"];
123        assert!((rq.threshold() - 0.8).abs() < f64::EPSILON);
124    }
125
126    #[test]
127    fn parse_llm_judge_config() {
128        let json = r#"{
129            "criteria": {
130                "coherence": {
131                    "threshold": 0.7,
132                    "judge_model": "gemini-2.0-flash",
133                    "num_samples": 3
134                }
135            }
136        }"#;
137        let config = parse_test_config_str(json).unwrap();
138        match &config.criteria["coherence"] {
139            CriterionConfig::LlmJudge {
140                threshold,
141                judge_model,
142                num_samples,
143            } => {
144                assert!((threshold - 0.7).abs() < f64::EPSILON);
145                assert_eq!(judge_model.as_deref(), Some("gemini-2.0-flash"));
146                assert_eq!(*num_samples, Some(3));
147            }
148            _ => panic!("Expected LlmJudge variant"),
149        }
150    }
151
152    #[test]
153    fn check_all_passing() {
154        let json = r#"{"criteria": {"a": 0.5, "b": 0.8}}"#;
155        let config = parse_test_config_str(json).unwrap();
156        let scores: HashMap<String, f64> =
157            [("a".into(), 0.6), ("b".into(), 0.9)].into_iter().collect();
158        assert!(config.all_pass(&scores));
159    }
160
161    #[test]
162    fn check_all_failing() {
163        let json = r#"{"criteria": {"a": 0.5, "b": 0.8}}"#;
164        let config = parse_test_config_str(json).unwrap();
165        let scores: HashMap<String, f64> =
166            [("a".into(), 0.6), ("b".into(), 0.7)].into_iter().collect();
167        assert!(!config.all_pass(&scores));
168    }
169
170    #[test]
171    fn missing_score_defaults_to_zero() {
172        let json = r#"{"criteria": {"a": 0.5}}"#;
173        let config = parse_test_config_str(json).unwrap();
174        let scores: HashMap<String, f64> = HashMap::new();
175        assert!(!config.all_pass(&scores));
176    }
177
178    #[test]
179    fn parse_invalid_json() {
180        let result = parse_test_config_str("bad");
181        assert!(result.is_err());
182    }
183}