gemini_adk_rs/evaluation/
user_simulator_evaluator.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use super::eval_case::Invocation;
12use super::eval_result::{EvalMetric, EvalResult, PerInvocationResult};
13use super::evaluator::{EvalError, Evaluator};
14use crate::llm::BaseLlm;
15
16pub struct UserSimulatorEvaluator {
24 judge_model: Option<String>,
26 stop_signal: Option<String>,
28 llm: Option<Arc<dyn BaseLlm>>,
30}
31
32impl UserSimulatorEvaluator {
33 pub fn new() -> Self {
35 Self {
36 judge_model: None,
37 stop_signal: None,
38 llm: None,
39 }
40 }
41
42 pub fn with_stop_signal(mut self, signal: impl Into<String>) -> Self {
44 self.stop_signal = Some(signal.into());
45 self
46 }
47
48 pub fn with_judge_model(mut self, model: impl Into<String>) -> Self {
50 self.judge_model = Some(model.into());
51 self
52 }
53
54 pub fn with_llm(mut self, llm: Arc<dyn BaseLlm>) -> Self {
56 self.llm = Some(llm);
57 self
58 }
59
60 fn build_prompt(&self, inv: &Invocation) -> String {
62 let mut prompt = String::from(
63 "You are an expert evaluator assessing USER SIMULATOR FIDELITY.\n\n\
64 A user simulator was used to generate the user-side of a multi-turn \
65 conversation with an AI agent. Your task is to evaluate the quality \
66 of the simulated user messages.\n\n\
67 Evaluate on these criteria:\n\
68 1. REALISM: Do the simulated user messages sound like a real human?\n\
69 2. COHERENCE: Does the simulated user maintain a consistent persona and goal?\n\
70 3. COVERAGE: Does the simulation adequately exercise the agent's capabilities?\n\
71 4. PROGRESSION: Does the conversation progress naturally toward resolution?\n",
72 );
73
74 if let Some(ref signal) = self.stop_signal {
75 prompt.push_str(&format!(
76 "5. TERMINATION: Was the stop signal \"{signal}\" used appropriately?\n"
77 ));
78 }
79
80 prompt.push_str("\nCONVERSATION:\n");
81 for turn in &inv.turns {
82 prompt.push_str(&format!("[{}]: {}\n", turn.role, turn.content));
83 }
84
85 prompt.push_str(
86 "\nRespond with ONLY a JSON object:\n\
87 {\"realism\": <float 0-1>, \
88 \"coherence\": <float 0-1>, \
89 \"coverage\": <float 0-1>, \
90 \"progression\": <float 0-1>, \
91 \"overall_score\": <float 0-1>, \
92 \"explanation\": \"<text>\"}\n",
93 );
94
95 prompt
96 }
97
98 fn parse_response(text: &str) -> (f64, String) {
100 if let Some(result) = try_parse_response(text) {
101 return result;
102 }
103
104 if let Some(start) = text.find('{') {
106 if let Some(end) = text[start..].rfind('}') {
107 let json_str = &text[start..=start + end];
108 if let Some(result) = try_parse_response(json_str) {
109 return result;
110 }
111 }
112 }
113
114 (
115 0.0,
116 format!("Failed to parse simulator judge response: {text}"),
117 )
118 }
119
120 fn heuristic_score(&self, inv: &Invocation) -> (f64, String) {
125 let mut score = 1.0;
126 let mut issues = Vec::new();
127
128 let user_turns: Vec<&str> = inv
129 .turns
130 .iter()
131 .filter(|t| t.role == "user")
132 .map(|t| t.content.as_str())
133 .collect();
134
135 if user_turns.is_empty() {
136 return (0.0, "No user turns in conversation".into());
137 }
138
139 let empty_count = user_turns.iter().filter(|t| t.trim().is_empty()).count();
141 if empty_count > 0 {
142 score -= 0.2 * empty_count as f64;
143 issues.push(format!("{empty_count} empty user messages"));
144 }
145
146 let mut prev = "";
148 let mut repeat_count = 0;
149 for msg in &user_turns {
150 if *msg == prev && !msg.is_empty() {
151 repeat_count += 1;
152 }
153 prev = msg;
154 }
155 if repeat_count > 0 {
156 score -= 0.15 * repeat_count as f64;
157 issues.push(format!("{repeat_count} consecutive repeated messages"));
158 }
159
160 let mut last_role = "";
162 let mut alternation_violations = 0;
163 for turn in &inv.turns {
164 if turn.role == last_role && turn.role == "user" {
165 alternation_violations += 1;
166 }
167 last_role = &turn.role;
168 }
169 if alternation_violations > 0 {
170 score -= 0.1 * alternation_violations as f64;
171 issues.push(format!(
172 "{alternation_violations} turn alternation violations"
173 ));
174 }
175
176 if let Some(ref signal) = self.stop_signal {
178 let has_stop = user_turns.iter().any(|t| t.contains(signal.as_str()));
179 let last_user_has_stop = user_turns
180 .last()
181 .map(|t| t.contains(signal.as_str()))
182 .unwrap_or(false);
183
184 if has_stop && !last_user_has_stop {
185 score -= 0.2;
186 issues.push("Stop signal used in non-final user turn".into());
187 }
188 }
189
190 score = score.clamp(0.0, 1.0);
191
192 let explanation = if issues.is_empty() {
193 "Heuristic check passed — no structural issues detected".into()
194 } else {
195 format!("Heuristic issues: {}", issues.join("; "))
196 };
197
198 (score, explanation)
199 }
200}
201
202impl Default for UserSimulatorEvaluator {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208fn try_parse_response(text: &str) -> Option<(f64, String)> {
210 let v: serde_json::Value = serde_json::from_str(text).ok()?;
211
212 let score = if let Some(overall) = v["overall_score"].as_f64() {
213 overall.clamp(0.0, 1.0)
214 } else {
215 let sub_scores = ["realism", "coherence", "coverage", "progression"];
217 let (sum, count) = sub_scores
218 .iter()
219 .filter_map(|k| v[k].as_f64())
220 .fold((0.0, 0), |(s, c), v| (s + v.clamp(0.0, 1.0), c + 1));
221 if count == 0 {
222 return None;
223 }
224 sum / count as f64
225 };
226
227 let explanation = v["explanation"]
228 .as_str()
229 .unwrap_or("No explanation")
230 .to_string();
231
232 Some((score, explanation))
233}
234
235#[async_trait]
236impl Evaluator for UserSimulatorEvaluator {
237 async fn evaluate(
238 &self,
239 actual: &[Invocation],
240 _expected: Option<&[Invocation]>,
241 ) -> Result<EvalResult, EvalError> {
242 let mut per_invocation = Vec::new();
243 let mut total_score = 0.0;
244
245 let use_llm = self.llm.is_some();
246
247 for (i, actual_inv) in actual.iter().enumerate() {
248 let (score, explanation) = if use_llm {
249 let llm = self.llm.as_ref().unwrap();
250 let prompt = self.build_prompt(actual_inv);
251 let request = crate::llm::LlmRequest::from_text(&prompt);
252 let response = llm
253 .generate(request)
254 .await
255 .map_err(|e| EvalError::Llm(e.to_string()))?;
256 Self::parse_response(&response.text())
257 } else {
258 self.heuristic_score(actual_inv)
260 };
261
262 total_score += score;
263
264 per_invocation.push(PerInvocationResult {
265 invocation_id: if actual_inv.id.is_empty() {
266 format!("inv-{i}")
267 } else {
268 actual_inv.id.clone()
269 },
270 score,
271 explanation: Some(explanation),
272 });
273 }
274
275 let overall_score = if actual.is_empty() {
276 0.0
277 } else {
278 total_score / actual.len() as f64
279 };
280
281 Ok(EvalResult {
282 overall_score,
283 metrics: vec![EvalMetric {
284 name: "user_simulator_fidelity".into(),
285 score: overall_score,
286 per_invocation,
287 }],
288 })
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use crate::evaluation::eval_case::InvocationTurn;
296
297 fn make_conversation(turns: &[(&str, &str)]) -> Invocation {
298 Invocation {
299 id: String::new(),
300 turns: turns
301 .iter()
302 .map(|(role, content)| InvocationTurn {
303 role: role.to_string(),
304 content: content.to_string(),
305 tool_calls: vec![],
306 tool_results: vec![],
307 })
308 .collect(),
309 metadata: serde_json::Value::Null,
310 }
311 }
312
313 #[tokio::test]
314 async fn heuristic_good_conversation() {
315 let eval = UserSimulatorEvaluator::new();
316 let inv = make_conversation(&[
317 ("user", "What is the weather?"),
318 ("model", "It's sunny."),
319 ("user", "Thanks!"),
320 ]);
321 let result = eval.evaluate(&[inv], None).await.unwrap();
322 assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
323 }
324
325 #[tokio::test]
326 async fn heuristic_detects_empty_messages() {
327 let eval = UserSimulatorEvaluator::new();
328 let inv = make_conversation(&[
329 ("user", ""),
330 ("model", "I didn't understand."),
331 ("user", "Hello"),
332 ]);
333 let result = eval.evaluate(&[inv], None).await.unwrap();
334 assert!(result.overall_score < 1.0);
335 }
336
337 #[tokio::test]
338 async fn heuristic_detects_repetition() {
339 let eval = UserSimulatorEvaluator::new();
340 let inv = make_conversation(&[
341 ("user", "Hello"),
342 ("model", "Hi!"),
343 ("user", "Hello"),
344 ("model", "Hi again!"),
345 ("user", "Hello"),
346 ]);
347 let result = eval.evaluate(&[inv], None).await.unwrap();
348 assert!(result.overall_score < 1.0);
349 }
350
351 #[tokio::test]
352 async fn heuristic_stop_signal_ok() {
353 let eval = UserSimulatorEvaluator::new().with_stop_signal("[DONE]");
354 let inv = make_conversation(&[
355 ("user", "Check the weather"),
356 ("model", "It's 22C."),
357 ("user", "Thanks [DONE]"),
358 ]);
359 let result = eval.evaluate(&[inv], None).await.unwrap();
360 assert!((result.overall_score - 1.0).abs() < f64::EPSILON);
361 }
362
363 #[tokio::test]
364 async fn heuristic_stop_signal_misplaced() {
365 let eval = UserSimulatorEvaluator::new().with_stop_signal("[DONE]");
366 let inv = make_conversation(&[
367 ("user", "Check the weather [DONE]"),
368 ("model", "It's 22C."),
369 ("user", "Wait actually..."),
370 ]);
371 let result = eval.evaluate(&[inv], None).await.unwrap();
372 assert!(result.overall_score < 1.0);
373 }
374
375 #[tokio::test]
376 async fn empty_invocations() {
377 let eval = UserSimulatorEvaluator::new();
378 let result = eval.evaluate(&[], None).await.unwrap();
379 assert!((result.overall_score - 0.0).abs() < f64::EPSILON);
380 }
381
382 #[tokio::test]
383 async fn no_user_turns() {
384 let eval = UserSimulatorEvaluator::new();
385 let inv = make_conversation(&[("model", "Hello!")]);
386 let result = eval.evaluate(&[inv], None).await.unwrap();
387 assert!((result.overall_score - 0.0).abs() < f64::EPSILON);
388 }
389
390 #[test]
391 fn parse_valid_response() {
392 let json = r#"{"realism": 0.9, "coherence": 0.8, "coverage": 0.7, "progression": 0.6, "overall_score": 0.75, "explanation": "Good"}"#;
393 let (score, explanation) = UserSimulatorEvaluator::parse_response(json);
394 assert!((score - 0.75).abs() < f64::EPSILON);
395 assert_eq!(explanation, "Good");
396 }
397
398 #[test]
399 fn parse_sub_scores_only() {
400 let json = r#"{"realism": 0.8, "coherence": 0.6}"#;
401 let (score, _) = UserSimulatorEvaluator::parse_response(json);
402 assert!((score - 0.7).abs() < f64::EPSILON);
403 }
404
405 #[test]
406 fn parse_invalid() {
407 let (score, explanation) = UserSimulatorEvaluator::parse_response("not json");
408 assert!((score - 0.0).abs() < f64::EPSILON);
409 assert!(explanation.contains("Failed to parse"));
410 }
411
412 #[test]
413 fn default_impl() {
414 let eval = UserSimulatorEvaluator::default();
415 assert!(eval.judge_model.is_none());
416 assert!(eval.stop_signal.is_none());
417 }
418
419 #[test]
420 fn builder_methods() {
421 let eval = UserSimulatorEvaluator::new()
422 .with_judge_model("gemini-2.0-flash")
423 .with_stop_signal("[END]");
424 assert_eq!(eval.judge_model.as_deref(), Some("gemini-2.0-flash"));
425 assert_eq!(eval.stop_signal.as_deref(), Some("[END]"));
426 }
427
428 #[test]
429 fn build_prompt_includes_conversation() {
430 let eval = UserSimulatorEvaluator::new().with_stop_signal("[DONE]");
431 let inv = make_conversation(&[("user", "Hello"), ("model", "Hi!")]);
432 let prompt = eval.build_prompt(&inv);
433 assert!(prompt.contains("USER SIMULATOR FIDELITY"));
434 assert!(prompt.contains("[user]: Hello"));
435 assert!(prompt.contains("[model]: Hi!"));
436 assert!(prompt.contains("[DONE]"));
437 }
438}