1use std::collections::BTreeMap;
22use std::sync::Arc;
23use std::time::Instant;
24
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27
28use gemini_adk_rs::flow::{Enforcement, FlowExplanation};
29use gemini_adk_rs::live::{TranscriptTurn, TurnExtractor};
30use gemini_adk_rs::state::State;
31
32use crate::conversation::{CompiledConversation, FlowStack};
33
34struct BoundExtractor {
35 extractor: Arc<dyn TurnExtractor>,
36 fields: Vec<(String, String)>,
38}
39
40pub struct Sim {
42 stack: FlowStack,
43 extractors: Vec<BoundExtractor>,
44 state: State,
45 turn_no: u32,
46 pending_tools: Vec<(String, u32)>,
48}
49
50impl Sim {
51 pub fn new(convo: &CompiledConversation, mode: Enforcement) -> Self {
53 let extractors = convo
54 .all_extractors()
55 .into_iter()
56 .map(|e| BoundExtractor {
57 fields: e.field_state_keys(),
58 extractor: e.into_extractor(),
59 })
60 .collect();
61 Self {
62 stack: convo.stack(mode),
63 extractors,
64 state: State::new(),
65 turn_no: 0,
66 pending_tools: Vec::new(),
67 }
68 }
69
70 pub fn set(&self, key: impl Into<String>, value: impl Serialize) -> &Self {
73 let _ = self.state.set(key, value);
74 self
75 }
76
77 pub async fn user(&mut self, utterance: &str) -> &mut Self {
80 let window = [TranscriptTurn {
81 turn_number: self.turn_no,
82 user: utterance.to_string(),
83 model: String::new(),
84 tool_calls: Vec::new(),
85 timestamp: Instant::now(),
86 }];
87 for bound in &self.extractors {
88 if let Ok(Value::Object(obj)) = bound
89 .extractor
90 .extract_with_state(&window, &self.state)
91 .await
92 {
93 for (name, key) in &bound.fields {
94 if let Some(v) = obj.get(name) {
95 if !v.is_null() {
96 let _ = self.state.set(key.clone(), v.clone());
97 }
98 }
99 }
100 }
101 }
102 self.advance();
103 self
104 }
105
106 pub fn turn(&mut self) -> &mut Self {
108 self.advance();
109 self
110 }
111
112 pub fn tool_ok(&mut self, tool: &str) -> &mut Self {
115 self.stack.on_tool_ok(tool, &self.state);
116 self.advance();
117 self
118 }
119
120 pub fn schedule_tool(&mut self, tool: impl Into<String>, after: u32) -> &mut Self {
122 self.pending_tools
123 .push((tool.into(), self.turn_no + after.max(1)));
124 self
125 }
126
127 fn advance(&mut self) {
128 self.turn_no += 1;
129 let due: Vec<String> = self
131 .pending_tools
132 .iter()
133 .filter(|(_, at)| *at <= self.turn_no)
134 .map(|(t, _)| t.clone())
135 .collect();
136 self.pending_tools.retain(|(_, at)| *at > self.turn_no);
137 for tool in due {
138 self.stack.on_tool_ok(&tool, &self.state);
139 }
140 self.stack.on_turn(&self.state);
141 }
142
143 pub fn active(&self) -> Vec<String> {
145 self.stack.explain(&self.state).active
146 }
147
148 pub fn active_overlay(&self) -> Option<&str> {
150 self.stack.active_overlay()
151 }
152
153 pub fn allowed(&self, tool: &str) -> bool {
155 self.stack.admits_tool(tool, &self.state).is_ok()
156 }
157
158 pub fn denied(&self) -> BTreeMap<String, String> {
160 self.stack.explain(&self.state).blocked_tools
161 }
162
163 pub fn is_complete(&self) -> bool {
165 self.stack.is_complete()
166 }
167
168 pub fn slot<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
170 self.state.get(key)
171 }
172
173 pub fn explain(&self) -> FlowExplanation {
175 self.stack.explain(&self.state)
176 }
177
178 pub fn state(&self) -> &State {
180 &self.state
181 }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186#[serde(rename_all = "snake_case")]
187pub enum SimStep {
188 User(String),
190 Set {
192 key: String,
194 value: Value,
196 },
197 ToolOk(String),
199 ScheduleTool {
201 tool: String,
203 after: u32,
205 },
206 Turn,
208 ExpectActive(Vec<String>),
210 ExpectDenied(String),
212 ExpectAllowed(String),
214 ExpectSlot {
216 key: String,
218 value: Value,
220 },
221 ExpectComplete,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct Scenario {
229 pub name: String,
231 pub steps: Vec<SimStep>,
233}
234
235impl Scenario {
236 pub async fn run(&self, convo: &CompiledConversation, mode: Enforcement) -> Result<(), String> {
239 let mut sim = Sim::new(convo, mode);
240 for (i, step) in self.steps.iter().enumerate() {
241 let fail = |msg: String| Err(format!("[{}] step {i} ({step:?}): {msg}", self.name));
242 match step {
243 SimStep::User(text) => {
244 sim.user(text).await;
245 }
246 SimStep::Set { key, value } => {
247 sim.set(key.clone(), value.clone());
248 }
249 SimStep::ToolOk(tool) => {
250 sim.tool_ok(tool);
251 }
252 SimStep::ScheduleTool { tool, after } => {
253 sim.schedule_tool(tool.clone(), *after);
254 }
255 SimStep::Turn => {
256 sim.turn();
257 }
258 SimStep::ExpectActive(expected) => {
259 let active = sim.active();
260 for e in expected {
261 if !active.contains(e) {
262 return fail(format!("expected active '{e}', got {active:?}"));
263 }
264 }
265 }
266 SimStep::ExpectDenied(tool) => {
267 if sim.allowed(tool) {
268 return fail(format!("expected '{tool}' denied, but it was admitted"));
269 }
270 }
271 SimStep::ExpectAllowed(tool) => {
272 if !sim.allowed(tool) {
273 let why = sim.denied().get(tool).cloned().unwrap_or_default();
274 return fail(format!("expected '{tool}' allowed, but denied: {why}"));
275 }
276 }
277 SimStep::ExpectSlot { key, value } => {
278 let got = sim.state().get_raw(key);
279 if got.as_ref() != Some(value) {
280 return fail(format!("expected slot '{key}' = {value}, got {got:?}"));
281 }
282 }
283 SimStep::ExpectComplete => {
284 if !sim.is_complete() {
285 return fail("expected conversation complete".into());
286 }
287 }
288 }
289 }
290 Ok(())
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::conversation::Conversation;
298 use gemini_adk_rs::flow::Guard;
299 use gemini_adk_rs::frame::{Frame, FrameSpec, SlotRecognizer, SlotSpec};
300
301 struct Booking;
302 impl Frame for Booking {
303 fn frame() -> FrameSpec {
304 FrameSpec {
305 name: "booking".into(),
306 slots: vec![SlotSpec {
307 recognizer: Some(SlotRecognizer::IntegerNear(vec!["people".into()])),
308 ..SlotSpec::new("party_size")
309 }],
310 }
311 }
312 }
313
314 fn booking() -> CompiledConversation {
315 Conversation::new("booking")
316 .stage("collect")
317 .collect_frame::<Booking>()
318 .next("confirm", Guard::captured(["party_size"]))
319 .stage("confirm")
320 .commit("book", Guard::is_true("user_confirmed"))
321 .next("done", Guard::called_ok("book"))
322 .stage("done")
323 .terminal()
324 .require(["done"])
325 .compile()
326 .expect("compiles")
327 }
328
329 #[tokio::test]
330 async fn fake_user_fills_slots_and_gates_commit() {
331 let convo = booking();
332 let mut sim = Sim::new(&convo, Enforcement::Enforce);
333
334 assert!(sim.active().contains(&"collect".to_string()));
335 assert!(!sim.allowed("book"));
336
337 sim.user("a table for 4 people").await;
339 assert_eq!(sim.slot::<u32>("party_size"), Some(4));
340 assert!(sim.active().contains(&"confirm".to_string()));
341
342 assert!(!sim.allowed("book"));
344 sim.set("user_confirmed", true);
345 sim.turn();
346 assert!(sim.allowed("book"));
347
348 sim.tool_ok("book");
349 assert!(sim.is_complete());
350 }
351
352 #[tokio::test]
353 async fn scenario_runs_and_round_trips() {
354 let scenario = Scenario {
355 name: "happy_path".into(),
356 steps: vec![
357 SimStep::ExpectActive(vec!["collect".into()]),
358 SimStep::ExpectDenied("book".into()),
359 SimStep::User("party of 4 people".into()),
360 SimStep::ExpectSlot {
361 key: "party_size".into(),
362 value: serde_json::json!(4),
363 },
364 SimStep::ExpectActive(vec!["confirm".into()]),
365 SimStep::Set {
366 key: "user_confirmed".into(),
367 value: serde_json::json!(true),
368 },
369 SimStep::Turn,
370 SimStep::ExpectAllowed("book".into()),
371 SimStep::ToolOk("book".into()),
372 SimStep::ExpectComplete,
373 ],
374 };
375
376 scenario
377 .run(&booking(), Enforcement::Enforce)
378 .await
379 .expect("scenario passes");
380
381 let json = serde_json::to_string(&scenario).unwrap();
383 let back: Scenario = serde_json::from_str(&json).unwrap();
384 back.run(&booking(), Enforcement::Enforce)
385 .await
386 .expect("round-tripped scenario passes");
387 }
388
389 #[tokio::test]
390 async fn scenario_reports_failed_expectation() {
391 let scenario = Scenario {
392 name: "bad".into(),
393 steps: vec![SimStep::ExpectComplete], };
395 let err = scenario
396 .run(&booking(), Enforcement::Enforce)
397 .await
398 .unwrap_err();
399 assert!(err.contains("expected conversation complete"));
400 }
401
402 #[tokio::test]
403 async fn tool_latency_resolves_after_delay() {
404 let convo = booking();
405 let mut sim = Sim::new(&convo, Enforcement::Enforce);
406 sim.user("4 people").await;
407 sim.set("user_confirmed", true);
408 sim.turn();
409 sim.schedule_tool("book", 2);
411 assert!(!sim.is_complete());
412 sim.turn();
413 assert!(!sim.is_complete());
414 sim.turn();
415 assert!(sim.is_complete());
416 }
417}