gemini_adk_rs/evaluation/
user_simulator_evaluator.rs

1//! User simulator evaluator — assess multi-turn simulation fidelity.
2//!
3//! Evaluates how well a user simulator (used in automated multi-turn testing)
4//! follows its assigned persona, stays on topic, and produces realistic
5//! user messages that effectively exercise the agent under test.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use super::eval_case::Invocation;
12use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
13use super::evaluator::{EvalError, Evaluator};
14use crate::llm::BaseLlm;
15
16/// Evaluates the fidelity of a user simulator in multi-turn conversations.
17///
18/// Assesses whether simulated user messages are:
19/// - Realistic and coherent
20/// - Following the assigned persona/scenario
21/// - Providing adequate coverage of the test scenario
22/// - Properly using the stop signal when the conversation should end
23pub struct UserSimulatorEvaluator {
24    /// Optional override for the judge model.
25    judge_model: Option<String>,
26    /// The stop signal token that ends simulation (e.g., "[DONE]").
27    stop_signal: Option<String>,
28    /// Optional LLM for performing evaluations.
29    llm: Option<Arc<dyn BaseLlm>>,
30}
31
32impl UserSimulatorEvaluator {
33    /// Create a new user simulator evaluator.
34    pub fn new() -> Self {
35        Self {
36            judge_model: None,
37            stop_signal: None,
38            llm: None,
39        }
40    }
41
42    /// Set the stop signal that the simulator uses to end conversations.
43    pub fn with_stop_signal(mut self, signal: impl Into<String>) -> Self {
44        self.stop_signal = Some(signal.into());
45        self
46    }
47
48    /// Set an override judge model name.
49    pub fn with_judge_model(mut self, model: impl Into<String>) -> Self {
50        self.judge_model = Some(model.into());
51        self
52    }
53
54    /// Provide an LLM instance for performing evaluations.
55    pub fn with_llm(mut self, llm: Arc<dyn BaseLlm>) -> Self {
56        self.llm = Some(llm);
57        self
58    }
59
60    /// Build the evaluation prompt for a simulated conversation.
61    fn build_prompt(&self, inv: &Invocation) -> String {
62        let mut prompt = String::from(
63            "You are an expert evaluator assessing USER SIMULATOR FIDELITY.\n\n\
64             A user simulator was used to generate the user-side of a multi-turn \
65             conversation with an AI agent. Your task is to evaluate the quality \
66             of the simulated user messages.\n\n\
67             Evaluate on these criteria:\n\
68             1. REALISM: Do the simulated user messages sound like a real human?\n\
69             2. COHERENCE: Does the simulated user maintain a consistent persona and goal?\n\
70             3. COVERAGE: Does the simulation adequately exercise the agent's capabilities?\n\
71             4. PROGRESSION: Does the conversation progress naturally toward resolution?\n",
72        );
73
74        if let Some(ref signal) = self.stop_signal {
75            prompt.push_str(&format!(
76                "5. TERMINATION: Was the stop signal \"{signal}\" used appropriately?\n"
77            ));
78        }
79
80        prompt.push_str("\nCONVERSATION:\n");
81        for turn in &inv.turns {
82            prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
83        }
84
85        prompt.push_str(
86            "\nRespond with ONLY a JSON object:\n\
87             {\"realism\": <float 0-1>, \
88             \"coherence\": <float 0-1>, \
89             \"coverage\": <float 0-1>, \
90             \"progression\": <float 0-1>, \
91             \"overall_score\": <float 0-1>, \
92             \"explanation\": \"<text>\"}\n",
93        );
94
95        prompt
96    }
97
98    /// Parse the judge response.
99    fn parse_response(text: &str) -> (f64, String) {
100        if let Some(result) = try_parse_response(text) {
101            return result;
102        }
103
104        // Try to find JSON embedded in text
105        if let Some(start) = text.find('{') {
106            if let Some(end) = text[start..].rfind('}') {
107                let json_str = &text[start..=start + end];
108                if let Some(result) = try_parse_response(json_str) {
109                    return result;
110                }
111            }
112        }
113
114        (
115            0.0,
116            format!("Failed to parse simulator judge response: {text}"),
117        )
118    }
119
120    /// Perform heuristic scoring without an LLM.
121    ///
122    /// Checks basic conversation structure: turn alternation, non-empty
123    /// messages, reasonable lengths, and proper stop signal usage.
124    fn heuristic_score(&self, inv: &Invocation) -> (f64, String) {
125        let mut score = 1.0;
126        let mut issues = Vec::new();
127
128        let user_turns: Vec<&str> = inv
129            .turns
130            .iter()
131            .filter(|t| t.role == "user")
132            .map(|t| t.content.as_str())
133            .collect();
134
135        if user_turns.is_empty() {
136            return (0.0, "No user turns in conversation".into());
137        }
138
139        // Check for empty user messages
140        let empty_count = user_turns.iter().filter(|t| t.trim().is_empty()).count();
141        if empty_count > 0 {
142            score -= 0.2 * empty_count as f64;
143            issues.push(format!("{empty_count} empty user messages"));
144        }
145
146        // Check for very short repetitive messages
147        let mut prev = "";
148        let mut repeat_count = 0;
149        for msg in &user_turns {
150            if *msg == prev && !msg.is_empty() {
151                repeat_count += 1;
152            }
153            prev = msg;
154        }
155        if repeat_count > 0 {
156            score -= 0.15 * repeat_count as f64;
157            issues.push(format!("{repeat_count} consecutive repeated messages"));
158        }
159
160        // Check turn alternation (user/model/user/model)
161        let mut last_role = "";
162        let mut alternation_violations = 0;
163        for turn in &inv.turns {
164            if turn.role == last_role && turn.role == "user" {
165                alternation_violations += 1;
166            }
167            last_role = &turn.role;
168        }
169        if alternation_violations > 0 {
170            score -= 0.1 * alternation_violations as f64;
171            issues.push(format!(
172                "{alternation_violations} turn alternation violations"
173            ));
174        }
175
176        // Check stop signal usage if configured
177        if let Some(ref signal) = self.stop_signal {
178            let has_stop = user_turns.iter().any(|t| t.contains(signal.as_str()));
179            let last_user_has_stop = user_turns
180                .last()
181                .map(|t| t.contains(signal.as_str()))
182                .unwrap_or(false);
183
184            if has_stop && !last_user_has_stop {
185                score -= 0.2;
186                issues.push("Stop signal used in non-final user turn".into());
187            }
188        }
189
190        score = score.clamp(0.0, 1.0);
191
192        let explanation = if issues.is_empty() {
193            "Heuristic check passed — no structural issues detected".into()
194        } else {
195            format!("Heuristic issues: {}", issues.join("; "))
196        };
197
198        (score, explanation)
199    }
200}
201
202impl Default for UserSimulatorEvaluator {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208/// Try to parse a JSON string into a score and explanation.
209fn try_parse_response(text: &str) -> Option<(f64, String)> {
210    let v: serde_json::Value = serde_json::from_str(text).ok()?;
211
212    let score = if let Some(overall) = v["overall_score"].as_f64() {
213        overall.clamp(0.0, 1.0)
214    } else {
215        // Average sub-scores
216        let sub_scores = ["realism", "coherence", "coverage", "progression"];
217        let (sum, count) = sub_scores
218            .iter()
219            .filter_map(|k| v[k].as_f64())
220            .fold((0.0, 0), |(s, c), v| (s + v.clamp(0.0, 1.0), c + 1));
221        if count == 0 {
222            return None;
223        }
224        sum / count as f64
225    };
226
227    let explanation = v["explanation"]
228        .as_str()
229        .unwrap_or("No explanation")
230        .to_string();
231
232    Some((score, explanation))
233}
234
235#[async_trait]
236impl Evaluator for UserSimulatorEvaluator {
237    async fn evaluate(
238        &self,
239        actual: &[Invocation],
240        _expected: Option<&[Invocation]>,
241    ) -> Result<EvalResult, EvalError> {
242        let mut per_invocation = Vec::new();
243        let mut total_score = 0.0;
244
245        let use_llm = self.llm.is_some();
246
247        for (i, actual_inv) in actual.iter().enumerate() {
248            let (score, explanation) = if use_llm {
249                let llm = self.llm.as_ref().unwrap();
250                let prompt = self.build_prompt(actual_inv);
251                let request = crate::llm::LlmRequest::from_text(&prompt);
252                let response = llm
253                    .generate(request)
254                    .await
255                    .map_err(|e| EvalError::Llm(e.to_string()))?;
256                Self::parse_response(&response.text())
257            } else {
258                // Fall back to heuristic scoring
259                self.heuristic_score(actual_inv)
260            };
261
262            total_score += score;
263
264            per_invocation.push(PerInvocationResult {
265                invocation_id: if actual_inv.id.is_empty() {
266                    format!("inv-{i}")
267                } else {
268                    actual_inv.id.clone()
269                },
270                score,
271                explanation: Some(explanation),
272            });
273        }
274
275        let overall_score = if actual.is_empty() {
276            0.0
277        } else {
278            total_score / actual.len() as f64
279        };
280
281        Ok(EvalResult {
282            overall_score,
283            metrics: vec![EvalMetric {
284                name: "user_simulator_fidelity".into(),
285                score: overall_score,
286                per_invocation,
287            }],
288        })
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use crate::evaluation::eval_case::InvocationTurn;
296
297    fn make_conversation(turns: &[(&str, &str)]) -> Invocation {
298        Invocation {
299            id: String::new(),
300            turns: turns
301                .iter()
302                .map(|(role, content)| InvocationTurn {
303                    role: role.to_string(),
304                    content: content.to_string(),
305                    tool_calls: vec![],
306                    tool_results: vec![],
307                })
308                .collect(),
309            metadata: serde_json::Value::Null,
310        }
311    }
312
313    #[tokio::test]
314    async fn heuristic_good_conversation() {
315        let eval = UserSimulatorEvaluator::new();
316        let inv = make_conversation(&[
317            ("user", "What is the weather?"),
318            ("model", "It's sunny."),
319            ("user", "Thanks!"),
320        ]);
321        let result = eval.evaluate(&[inv], None).await.unwrap();
322        assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
323    }
324
325    #[tokio::test]
326    async fn heuristic_detects_empty_messages() {
327        let eval = UserSimulatorEvaluator::new();
328        let inv = make_conversation(&[
329            ("user", ""),
330            ("model", "I didn't understand."),
331            ("user", "Hello"),
332        ]);
333        let result = eval.evaluate(&[inv], None).await.unwrap();
334        assert!(result.overall_score < 1.0);
335    }
336
337    #[tokio::test]
338    async fn heuristic_detects_repetition() {
339        let eval = UserSimulatorEvaluator::new();
340        let inv = make_conversation(&[
341            ("user", "Hello"),
342            ("model", "Hi!"),
343            ("user", "Hello"),
344            ("model", "Hi again!"),
345            ("user", "Hello"),
346        ]);
347        let result = eval.evaluate(&[inv], None).await.unwrap();
348        assert!(result.overall_score < 1.0);
349    }
350
351    #[tokio::test]
352    async fn heuristic_stop_signal_ok() {
353        let eval = UserSimulatorEvaluator::new().with_stop_signal("[DONE]");
354        let inv = make_conversation(&[
355            ("user", "Check the weather"),
356            ("model", "It's 22C."),
357            ("user", "Thanks [DONE]"),
358        ]);
359        let result = eval.evaluate(&[inv], None).await.unwrap();
360        assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
361    }
362
363    #[tokio::test]
364    async fn heuristic_stop_signal_misplaced() {
365        let eval = UserSimulatorEvaluator::new().with_stop_signal("[DONE]");
366        let inv = make_conversation(&[
367            ("user", "Check the weather [DONE]"),
368            ("model", "It's 22C."),
369            ("user", "Wait actually..."),
370        ]);
371        let result = eval.evaluate(&[inv], None).await.unwrap();
372        assert!(result.overall_score < 1.0);
373    }
374
375    #[tokio::test]
376    async fn empty_invocations() {
377        let eval = UserSimulatorEvaluator::new();
378        let result = eval.evaluate(&[], None).await.unwrap();
379        assert!((result.overall_score - 0.0).abs() < f64::EPSILON);
380    }
381
382    #[tokio::test]
383    async fn no_user_turns() {
384        let eval = UserSimulatorEvaluator::new();
385        let inv = make_conversation(&[("model", "Hello!")]);
386        let result = eval.evaluate(&[inv], None).await.unwrap();
387        assert!((result.overall_score - 0.0).abs() < f64::EPSILON);
388    }
389
390    #[test]
391    fn parse_valid_response() {
392        let json = r#"{"realism": 0.9, "coherence": 0.8, "coverage": 0.7, "progression": 0.6, "overall_score": 0.75, "explanation": "Good"}"#;
393        let (score, explanation) = UserSimulatorEvaluator::parse_response(json);
394        assert!((score - 0.75).abs() < f64::EPSILON);
395        assert_eq!(explanation, "Good");
396    }
397
398    #[test]
399    fn parse_sub_scores_only() {
400        let json = r#"{"realism": 0.8, "coherence": 0.6}"#;
401        let (score, _) = UserSimulatorEvaluator::parse_response(json);
402        assert!((score - 0.7).abs() < f64::EPSILON);
403    }
404
405    #[test]
406    fn parse_invalid() {
407        let (score, explanation) = UserSimulatorEvaluator::parse_response("not json");
408        assert!((score - 0.0).abs() < f64::EPSILON);
409        assert!(explanation.contains("Failed to parse"));
410    }
411
412    #[test]
413    fn default_impl() {
414        let eval = UserSimulatorEvaluator::default();
415        assert!(eval.judge_model.is_none());
416        assert!(eval.stop_signal.is_none());
417    }
418
419    #[test]
420    fn builder_methods() {
421        let eval = UserSimulatorEvaluator::new()
422            .with_judge_model("gemini-2.0-flash")
423            .with_stop_signal("[END]");
424        assert_eq!(eval.judge_model.as_deref(), Some("gemini-2.0-flash"));
425        assert_eq!(eval.stop_signal.as_deref(), Some("[END]"));
426    }
427
428    #[test]
429    fn build_prompt_includes_conversation() {
430        let eval = UserSimulatorEvaluator::new().with_stop_signal("[DONE]");
431        let inv = make_conversation(&[("user", "Hello"), ("model", "Hi!")]);
432        let prompt = eval.build_prompt(&inv);
433        assert!(prompt.contains("USER SIMULATOR FIDELITY"));
434        assert!(prompt.contains("[user]: Hello"));
435        assert!(prompt.contains("[model]: Hi!"));
436        assert!(prompt.contains("[DONE]"));
437    }
438}