1use serde::{Deserialize, Serialize};
43use serde_json::Value;
44use std::collections::{BTreeMap, BTreeSet};
45use std::future::Future;
46use std::pin::Pin;
47use std::sync::Arc;
48use std::time::Duration;
49
50use gemini_adk_rs::extract::Extract;
51use gemini_adk_rs::flow::{
52 CompiledFlow, Enforcement, Flow, FlowErrors, FlowExplanation, FlowMonitor, Guard, Pred,
53};
54use gemini_adk_rs::frame::{Frame, FrameSpec};
55use gemini_adk_rs::state::State;
56
57type SlotFetch =
59 Arc<dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value, String>> + Send>> + Send + Sync>;
60
61#[derive(Clone)]
65struct StageResolver {
66 stage: String,
67 name: String,
68 args: Vec<String>,
69 ttl: Option<Duration>,
70 fetch: SlotFetch,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TransitionSpec {
76 pub to: String,
78 pub when: Guard,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct CommitSpec {
85 pub tool: String,
87 pub when: Guard,
89}
90
91#[derive(Debug, Clone, Default, Serialize, Deserialize)]
93pub struct StageSpec {
94 pub id: String,
96 #[serde(default, skip_serializing_if = "Option::is_none")]
98 pub say: Option<String>,
99 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub ground: Option<String>,
102 #[serde(default, skip_serializing_if = "Vec::is_empty")]
104 pub collect: Vec<String>,
105 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub frame: Option<FrameSpec>,
109 #[serde(default, skip_serializing_if = "Vec::is_empty")]
111 pub allow: Vec<String>,
112 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub done: Option<Guard>,
115 #[serde(default, skip_serializing_if = "Option::is_none")]
117 pub commit: Option<CommitSpec>,
118 #[serde(default, skip_serializing_if = "Vec::is_empty")]
120 pub next: Vec<TransitionSpec>,
121 #[serde(default, skip_serializing_if = "Vec::is_empty")]
123 pub after: Vec<String>,
124 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
126 pub terminal: bool,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
129 pub repair: Option<RepairPolicy>,
130}
131
132fn escalate_flag(stage: &str) -> String {
134 format!("repair:{stage}:escalate")
135}
136
137fn reprompt_flag(stage: &str) -> String {
139 format!("repair:{stage}:reprompt")
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
144#[serde(rename_all = "snake_case")]
145pub enum Resume {
146 #[default]
148 Previous,
149 Restart,
151 Terminate,
153}
154
155fn default_reprompt_after() -> u32 {
156 2
157}
158fn default_escalate_after() -> u32 {
159 4
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct RepairPolicy {
169 #[serde(default = "default_reprompt_after")]
171 pub reprompt_after: u32,
172 #[serde(default = "default_escalate_after")]
174 pub escalate_after: u32,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub escalate_to: Option<String>,
178}
179
180impl Default for RepairPolicy {
181 fn default() -> Self {
182 Self {
183 reprompt_after: default_reprompt_after(),
184 escalate_after: default_escalate_after(),
185 escalate_to: None,
186 }
187 }
188}
189
190impl RepairPolicy {
191 pub fn new(reprompt_after: u32, escalate_after: u32) -> Self {
193 Self {
194 reprompt_after,
195 escalate_after,
196 escalate_to: None,
197 }
198 }
199
200 pub fn escalate_to(mut self, stage: impl Into<String>) -> Self {
202 self.escalate_to = Some(stage.into());
203 self
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct OverlaySpec {
211 pub name: String,
213 pub trigger: Guard,
215 #[serde(default)]
217 pub stages: Vec<StageSpec>,
218 #[serde(default, skip_serializing_if = "Vec::is_empty")]
220 pub require: Vec<String>,
221 #[serde(default)]
223 pub resume: Resume,
224}
225
226#[derive(Debug, Clone, Default, Serialize, Deserialize)]
229pub struct ConversationSpec {
230 pub name: String,
232 #[serde(default)]
234 pub stages: Vec<StageSpec>,
235 #[serde(default, skip_serializing_if = "Vec::is_empty")]
237 pub require: Vec<String>,
238 #[serde(default, skip_serializing_if = "Vec::is_empty")]
240 pub overlays: Vec<OverlaySpec>,
241 #[serde(default, skip_serializing_if = "Vec::is_empty")]
243 pub policies: Vec<crate::policy::Policy>,
244}
245
246#[derive(Debug)]
248pub enum ConversationError {
249 Empty,
251 Spec(String),
253 Flow(Vec<String>),
255 Compile(FlowErrors),
257}
258
259impl std::fmt::Display for ConversationError {
260 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
261 match self {
262 ConversationError::Empty => write!(f, "conversation has no stages"),
263 ConversationError::Spec(m) => write!(f, "conversation spec error: {m}"),
264 ConversationError::Flow(errs) => {
265 write!(f, "lowered flow is invalid: {}", errs.join("; "))
266 }
267 ConversationError::Compile(e) => write!(f, "lowered flow failed to compile: {e}"),
268 }
269 }
270}
271
272impl std::error::Error for ConversationError {}
273
274#[derive(Clone)]
276pub struct CompiledOverlay {
277 pub name: String,
279 pub trigger: Guard,
281 pub flow: CompiledFlow,
283 pub extractors: Vec<Extract>,
285 pub resume: Resume,
287}
288
289#[derive(Clone)]
292pub struct CompiledConversation {
293 flow: CompiledFlow,
294 extractors: Vec<Extract>,
295 overlays: Vec<CompiledOverlay>,
296 repair: BTreeMap<String, RepairPolicy>,
297 policies: Vec<crate::policy::Policy>,
298 spec: ConversationSpec,
299}
300
301impl std::fmt::Debug for CompiledConversation {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 f.debug_struct("CompiledConversation")
306 .field("flow", &self.flow)
307 .field("extractors", &self.extractors.len())
308 .field("overlays", &self.overlays.len())
309 .field("policies", &self.policies)
310 .field("spec", &self.spec)
311 .finish()
312 }
313}
314
315impl CompiledConversation {
316 pub fn flow(&self) -> &CompiledFlow {
318 &self.flow
319 }
320 pub fn extractors(&self) -> &[Extract] {
323 &self.extractors
324 }
325 pub fn overlays(&self) -> &[CompiledOverlay] {
327 &self.overlays
328 }
329 pub fn policies(&self) -> &[crate::policy::Policy] {
331 &self.policies
332 }
333 pub fn redacted_fields(&self) -> BTreeSet<String> {
335 self.policies
336 .iter()
337 .flat_map(|p| p.redacted_keys().iter().cloned())
338 .collect()
339 }
340 pub fn all_extractors(&self) -> Vec<Extract> {
343 let mut all = self.extractors.clone();
344 for ov in &self.overlays {
345 all.extend(ov.extractors.iter().cloned());
346 }
347 all
348 }
349 pub fn stack(&self, mode: Enforcement) -> FlowStack {
352 FlowStack::new(self, mode)
353 }
354 pub fn spec(&self) -> &ConversationSpec {
356 &self.spec
357 }
358 pub fn to_mermaid(&self) -> String {
360 self.flow.to_mermaid()
361 }
362 pub fn monitor(&self, mode: Enforcement) -> FlowMonitor {
364 FlowMonitor::compiled(self.flow.clone(), mode)
365 }
366}
367
368#[derive(Clone, Default)]
373pub struct Conversation {
374 spec: ConversationSpec,
375 resolvers: Vec<StageResolver>,
376 current_overlay: Option<usize>,
379}
380
381impl Conversation {
382 pub fn new(name: impl Into<String>) -> Self {
384 Self {
385 spec: ConversationSpec {
386 name: name.into(),
387 ..Default::default()
388 },
389 resolvers: Vec::new(),
390 current_overlay: None,
391 }
392 }
393
394 pub fn stage(mut self, id: impl Into<String>) -> Self {
397 let stage = StageSpec {
398 id: id.into(),
399 ..Default::default()
400 };
401 match self.current_overlay {
402 Some(i) => self.spec.overlays[i].stages.push(stage),
403 None => self.spec.stages.push(stage),
404 }
405 self
406 }
407
408 pub fn overlay(mut self, name: impl Into<String>) -> Self {
412 self.spec.overlays.push(OverlaySpec {
413 name: name.into(),
414 trigger: Guard::is_true("__overlay_never_triggers__"),
416 stages: Vec::new(),
417 require: Vec::new(),
418 resume: Resume::Previous,
419 });
420 self.current_overlay = Some(self.spec.overlays.len() - 1);
421 self
422 }
423
424 pub fn trigger(mut self, guard: Guard) -> Self {
426 if let Some(i) = self.current_overlay {
427 self.spec.overlays[i].trigger = guard;
428 }
429 self
430 }
431
432 pub fn resume(mut self, resume: Resume) -> Self {
434 if let Some(i) = self.current_overlay {
435 self.spec.overlays[i].resume = resume;
436 }
437 self
438 }
439
440 pub fn done_overlay(mut self) -> Self {
443 self.current_overlay = None;
444 self
445 }
446
447 pub fn add_stage(mut self, stage: StageSpec) -> Self {
451 match self.current_overlay {
452 Some(i) => self.spec.overlays[i].stages.push(stage),
453 None => self.spec.stages.push(stage),
454 }
455 self
456 }
457
458 pub fn add_overlay(mut self, overlay: OverlaySpec) -> Self {
461 self.spec.overlays.push(overlay);
462 self.current_overlay = None;
463 self
464 }
465
466 pub fn policy(mut self, policy: impl Into<crate::policy::Policy>) -> Self {
468 self.spec.policies.push(policy.into());
469 self
470 }
471
472 fn current(&mut self) -> &mut StageSpec {
473 let stages = match self.current_overlay {
474 Some(i) => &mut self.spec.overlays[i].stages,
475 None => &mut self.spec.stages,
476 };
477 stages
478 .last_mut()
479 .expect("call .stage(..) before configuring a stage")
480 }
481
482 pub fn say(mut self, text: impl Into<String>) -> Self {
484 self.current().say = Some(text.into());
485 self
486 }
487
488 pub fn ground(mut self, template: impl Into<String>) -> Self {
490 self.current().ground = Some(template.into());
491 self
492 }
493
494 pub fn collect<I, S>(mut self, fields: I) -> Self
496 where
497 I: IntoIterator<Item = S>,
498 S: Into<String>,
499 {
500 self.current().collect = fields.into_iter().map(Into::into).collect();
501 self
502 }
503
504 pub fn collect_frame<F: Frame>(mut self) -> Self {
509 let spec = F::frame();
510 let stage = self.current();
511 stage.collect = spec.slot_keys();
512 stage.frame = Some(spec);
513 self
514 }
515
516 pub fn allow<I, S>(mut self, tools: I) -> Self
518 where
519 I: IntoIterator<Item = S>,
520 S: Into<String>,
521 {
522 self.current().allow = tools.into_iter().map(Into::into).collect();
523 self
524 }
525
526 pub fn done(mut self, guard: Guard) -> Self {
528 self.current().done = Some(guard);
529 self
530 }
531
532 pub fn commit(mut self, tool: impl Into<String>, when: Guard) -> Self {
534 self.current().commit = Some(CommitSpec {
535 tool: tool.into(),
536 when,
537 });
538 self
539 }
540
541 pub fn next(mut self, to: impl Into<String>, when: Guard) -> Self {
543 self.current().next.push(TransitionSpec {
544 to: to.into(),
545 when,
546 });
547 self
548 }
549
550 pub fn resolve_slot<I, S, F, Fut>(
559 mut self,
560 name: impl Into<String>,
561 args: I,
562 ttl: Option<Duration>,
563 fetch: F,
564 ) -> Self
565 where
566 I: IntoIterator<Item = S>,
567 S: Into<String>,
568 F: Fn(Value) -> Fut + Send + Sync + 'static,
569 Fut: Future<Output = Result<Value, String>> + Send + 'static,
570 {
571 let name = name.into();
572 let stage = self.current().id.clone();
573 if !self.current().collect.contains(&name) {
574 self.current().collect.push(name.clone());
575 }
576 let fetch = Arc::new(fetch);
577 self.resolvers.push(StageResolver {
578 stage,
579 name,
580 args: args.into_iter().map(Into::into).collect(),
581 ttl,
582 fetch: Arc::new(move |v| {
583 let fetch = fetch.clone();
584 Box::pin(async move { fetch(v).await })
585 }),
586 });
587 self
588 }
589
590 pub fn after(mut self, dep: impl Into<String>) -> Self {
592 self.current().after.push(dep.into());
593 self
594 }
595
596 pub fn terminal(mut self) -> Self {
598 self.current().terminal = true;
599 self
600 }
601
602 pub fn repair(mut self, policy: RepairPolicy) -> Self {
604 self.current().repair = Some(policy);
605 self
606 }
607
608 pub fn require<I, S>(mut self, steps: I) -> Self
611 where
612 I: IntoIterator<Item = S>,
613 S: Into<String>,
614 {
615 let req: Vec<String> = steps.into_iter().map(Into::into).collect();
616 match self.current_overlay {
617 Some(i) => self.spec.overlays[i].require = req,
618 None => self.spec.require = req,
619 }
620 self
621 }
622
623 pub fn spec(&self) -> &ConversationSpec {
625 &self.spec
626 }
627
628 pub fn into_spec(self) -> ConversationSpec {
630 self.spec
631 }
632
633 pub fn from_spec(spec: ConversationSpec) -> Result<CompiledConversation, ConversationError> {
636 compile_spec(spec, Vec::new())
637 }
638
639 pub fn compile(self) -> Result<CompiledConversation, ConversationError> {
641 compile_spec(self.spec, self.resolvers)
642 }
643}
644
645impl crate::live::Live {
646 pub fn converse(self, convo: &CompiledConversation) -> Self {
660 let mut live = self.govern_compiled(convo.flow().clone());
661 for extract in convo.all_extractors() {
662 live = live.extract_record(extract);
663 }
664 live
665 }
666
667 pub fn converse_observe(self, convo: &CompiledConversation) -> Self {
670 let mut live = self.observe_compiled(convo.flow().clone());
671 for extract in convo.all_extractors() {
672 live = live.extract_record(extract);
673 }
674 live
675 }
676}
677
678struct ActiveOverlay {
680 name: String,
681 monitor: FlowMonitor,
682 resume: Resume,
683}
684
685pub struct FlowStack {
693 main_flow: CompiledFlow,
694 main: FlowMonitor,
695 mode: Enforcement,
696 overlays: Vec<CompiledOverlay>,
697 active: Option<ActiveOverlay>,
698 terminated: bool,
699 repair: BTreeMap<String, RepairPolicy>,
701 active_turns: BTreeMap<String, u32>,
703}
704
705impl FlowStack {
706 fn new(convo: &CompiledConversation, mode: Enforcement) -> Self {
707 Self {
708 main_flow: convo.flow.clone(),
709 main: FlowMonitor::compiled(convo.flow.clone(), mode),
710 mode,
711 overlays: convo.overlays.clone(),
712 active: None,
713 terminated: false,
714 repair: convo.repair.clone(),
715 active_turns: BTreeMap::new(),
716 }
717 }
718
719 fn apply_repair(&mut self, state: &State) {
723 if self.repair.is_empty() {
724 return;
725 }
726 let active: BTreeSet<String> = self.main.explain(state).active.into_iter().collect();
727 let left: Vec<String> = self
729 .active_turns
730 .keys()
731 .filter(|k| !active.contains(*k))
732 .cloned()
733 .collect();
734 for stage in left {
735 self.active_turns.remove(&stage);
736 let _ = state.set(reprompt_flag(&stage), false);
737 let _ = state.set(escalate_flag(&stage), false);
738 }
739 for stage in &active {
740 let count = self.active_turns.entry(stage.clone()).or_insert(0);
741 *count += 1;
742 if let Some(rp) = self.repair.get(stage) {
743 if *count >= rp.reprompt_after {
744 let _ = state.set(reprompt_flag(stage), true);
745 }
746 if *count >= rp.escalate_after {
747 let _ = state.set(escalate_flag(stage), true);
748 }
749 }
750 }
751 }
752
753 pub fn current(&self) -> &FlowMonitor {
755 self.active.as_ref().map_or(&self.main, |a| &a.monitor)
756 }
757
758 pub fn active_overlay(&self) -> Option<&str> {
760 self.active.as_ref().map(|a| a.name.as_str())
761 }
762
763 pub fn is_complete(&self) -> bool {
766 self.terminated || (self.active.is_none() && self.main.is_complete())
767 }
768
769 fn triggered(&self, state: &State) -> Option<usize> {
771 self.overlays
772 .iter()
773 .position(|ov| self.main.eval(&ov.trigger, state))
774 }
775
776 pub fn on_turn(&mut self, state: &State) {
780 if self.terminated {
781 return;
782 }
783 match &mut self.active {
784 Some(active) => {
785 active.monitor.on_turn(state);
786 if active.monitor.is_complete() {
787 let resume = active.resume;
788 self.active = None;
789 match resume {
790 Resume::Previous => {}
792 Resume::Restart => {
793 self.main = FlowMonitor::compiled(self.main_flow.clone(), self.mode);
794 }
795 Resume::Terminate => self.terminated = true,
796 }
797 }
798 }
799 None => {
800 if let Some(idx) = self.triggered(state) {
801 let ov = &self.overlays[idx];
802 let mut monitor = FlowMonitor::compiled(ov.flow.clone(), self.mode);
803 monitor.on_turn(state);
805 if monitor.is_complete() {
806 match ov.resume {
807 Resume::Previous => {}
808 Resume::Restart => {
809 self.main =
810 FlowMonitor::compiled(self.main_flow.clone(), self.mode);
811 }
812 Resume::Terminate => self.terminated = true,
813 }
814 } else {
815 self.active = Some(ActiveOverlay {
816 name: ov.name.clone(),
817 monitor,
818 resume: ov.resume,
819 });
820 }
821 } else {
822 self.apply_repair(state);
825 self.main.on_turn(state);
826 }
827 }
828 }
829 }
830
831 pub fn on_tool_ok(&mut self, tool: &str, state: &State) {
833 match &mut self.active {
834 Some(active) => active.monitor.on_tool_ok(tool, state),
835 None => self.main.on_tool_ok(tool, state),
836 }
837 }
838
839 pub fn admits_tool(&self, tool: &str, state: &State) -> Result<(), String> {
841 self.current().admits_tool(tool, state)
842 }
843
844 pub fn explain(&self, state: &State) -> FlowExplanation {
846 self.current().explain(state)
847 }
848}
849
850fn is_always(g: &Guard) -> bool {
851 matches!(g, Guard::Spec(Pred::Always))
852}
853
854fn any_of(guards: Vec<Guard>) -> Option<Guard> {
856 if guards.is_empty() {
857 return None;
858 }
859 if guards.iter().any(is_always) {
860 return Some(Guard::always());
861 }
862 if guards.len() == 1 {
863 return guards.into_iter().next();
864 }
865 Some(Guard::any(guards))
866}
867
868fn lower_flow(stages: &[StageSpec], require: &[String]) -> Result<CompiledFlow, ConversationError> {
871 if stages.is_empty() {
872 return Err(ConversationError::Empty);
873 }
874 let ids: BTreeSet<&str> = stages.iter().map(|s| s.id.as_str()).collect();
875 if ids.len() != stages.len() {
876 return Err(ConversationError::Spec("duplicate stage ids".into()));
877 }
878 for s in stages {
879 for t in &s.next {
880 if !ids.contains(t.to.as_str()) {
881 return Err(ConversationError::Spec(format!(
882 "stage '{}' transitions to unknown stage '{}'",
883 s.id, t.to
884 )));
885 }
886 }
887 for d in &s.after {
888 if !ids.contains(d.as_str()) {
889 return Err(ConversationError::Spec(format!(
890 "stage '{}' depends on unknown stage '{}'",
891 s.id, d
892 )));
893 }
894 }
895 if let Some(target) = s.repair.as_ref().and_then(|r| r.escalate_to.as_ref()) {
896 if !ids.contains(target.as_str()) {
897 return Err(ConversationError::Spec(format!(
898 "stage '{}' escalates to unknown stage '{}'",
899 s.id, target
900 )));
901 }
902 }
903 }
904 for r in require {
905 if !ids.contains(r.as_str()) {
906 return Err(ConversationError::Spec(format!(
907 "require references unknown stage '{r}'"
908 )));
909 }
910 }
911
912 let mut incoming: BTreeMap<&str, Vec<(&str, Guard)>> = BTreeMap::new();
914 for s in stages {
915 for t in &s.next {
916 incoming
917 .entry(t.to.as_str())
918 .or_default()
919 .push((s.id.as_str(), t.when.clone()));
920 }
921 if let Some(target) = s.repair.as_ref().and_then(|r| r.escalate_to.as_ref()) {
923 incoming
924 .entry(target.as_str())
925 .or_default()
926 .push((s.id.as_str(), Guard::is_true(escalate_flag(&s.id))));
927 }
928 }
929
930 let mut fb = Flow::new();
931 for s in stages {
932 fb = fb.step(&s.id);
933
934 let mut deps: BTreeSet<&str> = s.after.iter().map(String::as_str).collect();
935 if let Some(inc) = incoming.get(s.id.as_str()) {
936 for (src, _) in inc {
937 deps.insert(src);
938 }
939 }
940 for d in deps {
941 fb = fb.after(d);
942 }
943
944 if let Some(inc) = incoming.get(s.id.as_str()) {
945 if let Some(gate) = any_of(inc.iter().map(|(_, w)| w.clone()).collect()) {
946 fb = fb.gate(gate);
947 }
948 }
949
950 if let Some(say) = &s.say {
951 fb = fb.posture(say.clone());
952 }
953 if let Some(ground) = &s.ground {
954 fb = fb.ground(ground.clone());
955 }
956
957 let mut allow: Vec<String> = s.allow.clone();
958 if let Some(c) = &s.commit {
959 if !allow.contains(&c.tool) {
960 allow.push(c.tool.clone());
961 }
962 }
963 if !allow.is_empty() {
964 fb = fb.allow(allow);
965 }
966 if let Some(c) = &s.commit {
967 fb = fb.commit(&c.tool, c.when.clone());
968 }
969
970 if s.terminal {
971 fb = fb.terminal();
972 } else {
973 let done = stage_completion(s).ok_or_else(|| {
974 ConversationError::Spec(format!(
975 "non-terminal stage '{}' has no completion (add collect, next, or done)",
976 s.id
977 ))
978 })?;
979 fb = fb.done(done);
980 }
981 }
982
983 if !require.is_empty() {
984 fb = fb.require(require.to_vec());
985 }
986
987 let flow = fb.build().map_err(ConversationError::Flow)?;
988 flow.compile().map_err(ConversationError::Compile)
989}
990
991fn frame_extractors(stages: &[StageSpec]) -> Vec<Extract> {
993 stages
994 .iter()
995 .filter_map(|s| s.frame.as_ref().and_then(FrameSpec::to_extract))
996 .collect()
997}
998
999fn compile_spec(
1000 mut spec: ConversationSpec,
1001 resolvers: Vec<StageResolver>,
1002) -> Result<CompiledConversation, ConversationError> {
1003 for policy in spec.policies.clone() {
1006 if let crate::policy::Policy::SafetyHandoff { intents } = policy {
1007 if let Some(trigger) = any_of(
1008 intents
1009 .iter()
1010 .map(|i| Guard::is_true(format!("intent:{i}")))
1011 .collect(),
1012 ) {
1013 spec.overlays.push(OverlaySpec {
1014 name: "safety".into(),
1015 trigger,
1016 stages: vec![StageSpec {
1017 id: "safety_handoff".into(),
1018 say: Some("Safety concern detected — hand off to a human now.".into()),
1019 terminal: true,
1020 ..Default::default()
1021 }],
1022 require: Vec::new(),
1023 resume: Resume::Terminate,
1024 });
1025 }
1026 }
1027 }
1028
1029 let flow = lower_flow(&spec.stages, &spec.require)?;
1031
1032 let ids: BTreeSet<&str> = spec.stages.iter().map(|s| s.id.as_str()).collect();
1034 for r in &resolvers {
1035 if !ids.contains(r.stage.as_str()) {
1036 return Err(ConversationError::Spec(format!(
1037 "resolver for slot '{}' references unknown stage '{}'",
1038 r.name, r.stage
1039 )));
1040 }
1041 }
1042
1043 let mut extractors = frame_extractors(&spec.stages);
1045 let mut by_stage: BTreeMap<&str, Vec<&StageResolver>> = BTreeMap::new();
1046 for r in &resolvers {
1047 by_stage.entry(r.stage.as_str()).or_default().push(r);
1048 }
1049 for (stage, binds) in by_stage {
1050 let mut builder = Extract::record(format!("{}__{}_resolve", spec.name, stage));
1051 for r in binds {
1052 let fetch = r.fetch.clone();
1053 builder = builder.field_resolve(r.name.clone(), r.args.clone(), r.ttl, move |args| {
1054 let fetch = fetch.clone();
1055 async move { fetch(args).await }
1056 });
1057 }
1058 extractors.push(builder.build());
1059 }
1060
1061 let mut overlays = Vec::with_capacity(spec.overlays.len());
1065 for ov in &spec.overlays {
1066 let require = if ov.require.is_empty() {
1067 ov.stages
1068 .iter()
1069 .filter(|s| s.terminal)
1070 .map(|s| s.id.clone())
1071 .collect()
1072 } else {
1073 ov.require.clone()
1074 };
1075 let ov_flow = lower_flow(&ov.stages, &require)?;
1076 overlays.push(CompiledOverlay {
1077 name: ov.name.clone(),
1078 trigger: ov.trigger.clone(),
1079 flow: ov_flow,
1080 extractors: frame_extractors(&ov.stages),
1081 resume: ov.resume,
1082 });
1083 }
1084
1085 let repair = spec
1087 .stages
1088 .iter()
1089 .filter_map(|s| s.repair.clone().map(|p| (s.id.clone(), p)))
1090 .collect();
1091
1092 let policies = spec.policies.clone();
1093
1094 Ok(CompiledConversation {
1095 flow,
1096 extractors,
1097 overlays,
1098 repair,
1099 policies,
1100 spec,
1101 })
1102}
1103
1104fn stage_completion(s: &StageSpec) -> Option<Guard> {
1109 let base = if let Some(g) = &s.done {
1110 Some(g.clone())
1111 } else if !s.collect.is_empty() {
1112 Some(Guard::captured(s.collect.clone()))
1113 } else {
1114 any_of(s.next.iter().map(|t| t.when.clone()).collect())
1115 };
1116 if s.repair
1117 .as_ref()
1118 .and_then(|r| r.escalate_to.as_ref())
1119 .is_some()
1120 {
1121 let esc = Guard::is_true(escalate_flag(&s.id));
1122 return Some(match base {
1123 Some(b) => Guard::any(vec![b, esc]),
1124 None => esc,
1125 });
1126 }
1127 base
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132 use super::*;
1133 use gemini_adk_rs::flow::Enforcement;
1134 use gemini_adk_rs::state::State;
1135
1136 fn booking() -> CompiledConversation {
1137 Conversation::new("booking")
1138 .stage("collect")
1139 .say("Help the user book a table.")
1140 .collect(["party_size", "slot"])
1141 .next("check", Guard::captured(["party_size", "slot"]))
1142 .stage("check")
1143 .ground("Party of {party_size} at {slot}.")
1144 .next("confirm", Guard::is_true("availability_ok"))
1145 .stage("confirm")
1146 .commit("book", Guard::is_true("user_confirmed"))
1147 .next("done", Guard::called_ok("book"))
1148 .stage("done")
1149 .terminal()
1150 .require(["done"])
1151 .compile()
1152 .expect("booking compiles")
1153 }
1154
1155 #[test]
1156 fn compiles_to_a_governed_flow() {
1157 let convo = booking();
1158 assert!(convo.flow().tool_policy().tools.contains("book"));
1160 assert_eq!(convo.flow().flow().steps.len(), 4);
1161 }
1162
1163 #[test]
1164 fn lowered_flow_enforces_stage_order_and_commit() {
1165 let convo = booking();
1166 let mut mon = convo.monitor(Enforcement::Enforce);
1167 let state = State::new();
1168
1169 let ex = mon.explain(&state);
1171 assert!(ex.active.contains(&"collect".to_string()));
1172 assert!(ex.blocked_tools.contains_key("book"));
1173
1174 let _ = state.set("party_size", 4u8);
1176 let _ = state.set("slot", "tomorrow 7pm");
1177 mon.on_turn(&state);
1178 assert!(mon.explain(&state).active.contains(&"check".to_string()));
1179
1180 let _ = state.set("availability_ok", true);
1182 mon.on_turn(&state);
1183 assert!(mon.admits_tool("book", &state).is_err());
1184
1185 let _ = state.set("user_confirmed", true);
1187 assert!(mon.admits_tool("book", &state).is_ok());
1188 mon.on_tool_ok("book", &state);
1189 mon.on_turn(&state);
1190 assert!(mon.is_complete());
1191 }
1192
1193 #[test]
1194 fn spec_round_trips_through_json() {
1195 let spec = booking().spec().clone();
1196 let json = serde_json::to_string(&spec).expect("serialize spec");
1197 let back: ConversationSpec = serde_json::from_str(&json).expect("deserialize spec");
1198 let recompiled = Conversation::from_spec(back).expect("recompile from spec");
1199 assert_eq!(recompiled.flow().flow().steps.len(), 4);
1200 }
1201
1202 #[test]
1203 fn collect_frame_uses_frame_slot_keys() {
1204 use gemini_adk_rs::frame::{Frame, FrameSpec, SlotSpec};
1205
1206 struct Booking;
1207 impl Frame for Booking {
1208 fn frame() -> FrameSpec {
1209 FrameSpec {
1210 name: "booking".into(),
1211 slots: vec![SlotSpec::new("party_size"), SlotSpec::new("slot")],
1212 }
1213 }
1214 }
1215
1216 let convo = Conversation::new("b")
1217 .stage("collect")
1218 .collect_frame::<Booking>()
1219 .next("done", Guard::captured(["party_size", "slot"]))
1220 .stage("done")
1221 .terminal()
1222 .compile()
1223 .expect("compiles");
1224
1225 let mut mon = convo.monitor(Enforcement::Enforce);
1227 let state = State::new();
1228 assert!(mon.explain(&state).active.contains(&"collect".to_string()));
1229 let _ = state.set("party_size", 2u8);
1230 let _ = state.set("slot", "noon");
1231 mon.on_turn(&state);
1232 assert!(mon.marking().done.contains("collect"));
1234 assert!(mon.marking().done.contains("done"));
1235 }
1236
1237 #[tokio::test]
1238 async fn collect_frame_extractor_fills_and_scores_slots() {
1239 use gemini_adk_rs::frame::{Frame, FrameSpec, SlotRecognizer, SlotSpec};
1240 use gemini_adk_rs::live::TranscriptTurn;
1241
1242 struct Order;
1243 impl Frame for Order {
1244 fn frame() -> FrameSpec {
1245 FrameSpec {
1246 name: "order".into(),
1247 slots: vec![SlotSpec {
1248 recognizer: Some(SlotRecognizer::OneOf(vec![
1249 "pizza".into(),
1250 "salad".into(),
1251 ])),
1252 ..SlotSpec::new("item")
1253 }],
1254 }
1255 }
1256 }
1257
1258 let convo = Conversation::new("o")
1259 .stage("collect")
1260 .collect_frame::<Order>()
1261 .next("done", Guard::captured(["item"]))
1262 .stage("done")
1263 .terminal()
1264 .compile()
1265 .expect("compiles");
1266
1267 assert_eq!(convo.extractors().len(), 1);
1269 let extractor = convo.extractors()[0].clone().into_extractor();
1270
1271 let state = State::new();
1274 let window = vec![TranscriptTurn {
1275 turn_number: 0,
1276 user: "I'd like a large PIZZA".into(),
1277 model: String::new(),
1278 tool_calls: Vec::new(),
1279 timestamp: std::time::Instant::now(),
1280 }];
1281 let out = extractor.extract_with_state(&window, &state).await.unwrap();
1282 assert_eq!(out.get("item").and_then(|v| v.as_str()), Some("pizza"));
1283
1284 let ev = state.evidence("item");
1285 assert_eq!(ev.source.as_deref(), Some("extraction"));
1286 assert!(ev.confidence.unwrap() > 0.0);
1287 }
1288
1289 #[tokio::test]
1290 async fn validate_rejects_out_of_range_recognized_values() {
1291 use gemini_adk_rs::frame::{Frame, FrameSpec, SlotRecognizer, SlotSpec, SlotValidator};
1292 use gemini_adk_rs::live::TranscriptTurn;
1293
1294 struct Party;
1295 impl Frame for Party {
1296 fn frame() -> FrameSpec {
1297 FrameSpec {
1298 name: "party".into(),
1299 slots: vec![SlotSpec {
1300 recognizer: Some(SlotRecognizer::Integer),
1301 validate: Some(SlotValidator::Range {
1302 min: Some(1.0),
1303 max: Some(12.0),
1304 }),
1305 ..SlotSpec::new("party_size")
1306 }],
1307 }
1308 }
1309 }
1310
1311 let convo = Conversation::new("p")
1312 .stage("collect")
1313 .collect_frame::<Party>()
1314 .next("done", Guard::captured(["party_size"]))
1315 .stage("done")
1316 .terminal()
1317 .compile()
1318 .expect("compiles");
1319 let extractor = convo.extractors()[0].clone().into_extractor();
1320
1321 let run = |text: &str| {
1322 let extractor = extractor.clone();
1323 let text = text.to_string();
1324 async move {
1325 let state = State::new();
1326 let window = vec![TranscriptTurn {
1327 turn_number: 0,
1328 user: text,
1329 model: String::new(),
1330 tool_calls: Vec::new(),
1331 timestamp: std::time::Instant::now(),
1332 }];
1333 let out = extractor.extract_with_state(&window, &state).await.unwrap();
1334 out.get("party_size").cloned()
1335 }
1336 };
1337
1338 assert_eq!(run("a table for 4").await, Some(serde_json::json!(4)));
1340 assert_eq!(run("a table for 40").await, None);
1341 }
1342
1343 #[tokio::test]
1344 async fn resolve_slot_fills_from_async_fetch() {
1345 use gemini_adk_rs::live::TranscriptTurn;
1346
1347 let convo = Conversation::new("c")
1348 .stage("check")
1349 .resolve_slot("availability", ["party_size"], None, |args| async move {
1350 let n = args.get("party_size").and_then(|v| v.as_i64()).unwrap_or(0);
1352 Ok(serde_json::json!(n <= 8))
1353 })
1354 .next("done", Guard::is_set("availability"))
1355 .stage("done")
1356 .terminal()
1357 .compile()
1358 .expect("compiles");
1359
1360 assert_eq!(convo.extractors().len(), 1);
1362 let extractor = convo.extractors()[0].clone().into_extractor();
1363
1364 let state = State::new();
1365 let _ = state.set("party_size", 4i64);
1366 let window = vec![TranscriptTurn {
1367 turn_number: 0,
1368 user: "any".into(),
1369 model: String::new(),
1370 tool_calls: Vec::new(),
1371 timestamp: std::time::Instant::now(),
1372 }];
1373 let out = extractor.extract_with_state(&window, &state).await.unwrap();
1374 assert_eq!(out.get("availability"), Some(&serde_json::json!(true)));
1375 }
1376
1377 #[test]
1378 fn converse_registers_flow_and_extractors() {
1379 use gemini_adk_rs::frame::{Frame, FrameSpec, SlotRecognizer, SlotSpec};
1381
1382 struct Order;
1383 impl Frame for Order {
1384 fn frame() -> FrameSpec {
1385 FrameSpec {
1386 name: "order".into(),
1387 slots: vec![SlotSpec {
1388 recognizer: Some(SlotRecognizer::OneOf(vec!["pizza".into()])),
1389 ..SlotSpec::new("item")
1390 }],
1391 }
1392 }
1393 }
1394 let convo = Conversation::new("o")
1395 .stage("collect")
1396 .collect_frame::<Order>()
1397 .next("done", Guard::captured(["item"]))
1398 .stage("done")
1399 .terminal()
1400 .compile()
1401 .expect("compiles");
1402
1403 let _live = crate::live::Live::builder().converse(&convo);
1405 }
1406
1407 #[test]
1408 fn overlay_suspends_main_then_resumes_previous() {
1409 let convo = Conversation::new("support")
1412 .stage("a")
1413 .next("b", Guard::is_true("a_done"))
1414 .stage("b")
1415 .terminal()
1416 .overlay("faq")
1417 .trigger(Guard::is_true("intent:faq"))
1418 .stage("answer")
1420 .done(Guard::is_true("faq_answered"))
1421 .next("faq_end", Guard::is_true("faq_answered"))
1422 .stage("faq_end")
1423 .terminal()
1424 .resume(Resume::Previous)
1425 .done_overlay()
1426 .compile()
1427 .expect("compiles");
1428
1429 assert_eq!(convo.overlays().len(), 1);
1430 let mut stack = convo.stack(Enforcement::Enforce);
1431 let state = State::new();
1432
1433 assert!(stack.explain(&state).active.contains(&"a".to_string()));
1435 assert!(stack.active_overlay().is_none());
1436
1437 let _ = state.set("intent:faq", true);
1439 stack.on_turn(&state);
1440 assert_eq!(stack.active_overlay(), Some("faq"));
1441
1442 let _ = state.set("faq_answered", true);
1444 let _ = state.set("intent:faq", false);
1445 stack.on_turn(&state);
1446 assert!(stack.active_overlay().is_none());
1447
1448 assert!(stack.explain(&state).active.contains(&"a".to_string()));
1450
1451 let _ = state.set("a_done", true);
1453 stack.on_turn(&state);
1454 assert!(stack.current().marking().done.contains("a"));
1455 }
1456
1457 #[test]
1458 fn overlay_spec_round_trips_through_json() {
1459 let spec = Conversation::new("s")
1460 .stage("main")
1461 .terminal()
1462 .overlay("cancel")
1463 .trigger(Guard::is_true("intent:cancel"))
1464 .stage("confirm")
1465 .terminal()
1466 .resume(Resume::Terminate)
1467 .done_overlay()
1468 .into_spec();
1469 let json = serde_json::to_string(&spec).unwrap();
1470 let back: ConversationSpec = serde_json::from_str(&json).unwrap();
1471 assert_eq!(back.overlays.len(), 1);
1472 assert_eq!(back.overlays[0].resume, Resume::Terminate);
1473 assert!(Conversation::from_spec(back).is_ok());
1475 }
1476
1477 #[tokio::test]
1478 async fn safety_policy_terminates_on_intent() {
1479 use crate::policy::Policy;
1480 use crate::simulation::Sim;
1481
1482 let convo = Conversation::new("support")
1483 .policy(Policy::safety_handoff(["self_harm", "abuse"]))
1484 .policy(Policy::redact(["card_number"]))
1485 .stage("triage")
1486 .next("resolve", Guard::is_true("triaged"))
1487 .stage("resolve")
1488 .terminal()
1489 .require(["resolve"])
1490 .compile()
1491 .expect("compiles");
1492
1493 assert!(convo.redacted_fields().contains("card_number"));
1495 assert!(convo.overlays().iter().any(|o| o.name == "safety"));
1497
1498 let mut sim = Sim::new(&convo, Enforcement::Enforce);
1499 assert!(sim.active().contains(&"triage".to_string()));
1500 assert!(!sim.is_complete());
1501
1502 sim.set("intent:abuse", true);
1504 sim.turn();
1505 assert!(sim.is_complete());
1506 }
1507
1508 #[tokio::test]
1509 async fn repair_reprompts_then_escalates_to_handoff() {
1510 use crate::simulation::Sim;
1511
1512 let convo = Conversation::new("support")
1515 .stage("collect")
1516 .done(Guard::is_true("info"))
1517 .next("done", Guard::is_true("info"))
1518 .repair(RepairPolicy::new(2, 3).escalate_to("handoff"))
1519 .stage("done")
1520 .terminal()
1521 .stage("handoff")
1523 .done(Guard::is_true("handoff_complete"))
1524 .compile()
1525 .expect("compiles");
1526
1527 let mut sim = Sim::new(&convo, Enforcement::Enforce);
1528 assert!(sim.active().contains(&"collect".to_string()));
1529
1530 sim.turn();
1532 assert_eq!(sim.slot::<bool>("repair:collect:reprompt"), None);
1533 sim.turn();
1535 assert_eq!(sim.slot::<bool>("repair:collect:reprompt"), Some(true));
1536 assert!(sim.active().contains(&"collect".to_string()));
1537 sim.turn();
1539 assert_eq!(sim.slot::<bool>("repair:collect:escalate"), Some(true));
1540 assert!(sim.active().contains(&"handoff".to_string()));
1541 assert!(!sim.active().contains(&"collect".to_string()));
1542 }
1543
1544 #[tokio::test]
1545 async fn repair_signal_clears_when_stage_satisfied() {
1546 use crate::simulation::Sim;
1547
1548 let convo = Conversation::new("s")
1549 .stage("collect")
1550 .done(Guard::is_true("info"))
1551 .next("done", Guard::is_true("info"))
1552 .repair(RepairPolicy::new(1, 9))
1553 .stage("done")
1554 .terminal()
1555 .require(["done"])
1556 .compile()
1557 .expect("compiles");
1558
1559 let mut sim = Sim::new(&convo, Enforcement::Enforce);
1560 sim.turn(); assert_eq!(sim.slot::<bool>("repair:collect:reprompt"), Some(true));
1562
1563 sim.set("info", true);
1566 sim.turn();
1567 sim.turn();
1568 assert_eq!(sim.slot::<bool>("repair:collect:reprompt"), Some(false));
1569 assert!(sim.is_complete());
1570 }
1571
1572 #[test]
1573 fn rejects_transition_to_unknown_stage() {
1574 let err = Conversation::new("x")
1575 .stage("a")
1576 .next("ghost", Guard::always())
1577 .stage("b")
1578 .terminal()
1579 .compile()
1580 .expect_err("unknown target must fail");
1581 assert!(matches!(err, ConversationError::Spec(_)));
1582 }
1583
1584 #[test]
1585 fn rejects_unguarded_commit_via_flow_compile() {
1586 let err = Conversation::new("x")
1588 .stage("s")
1589 .commit("pay", Guard::always())
1590 .done(Guard::called_ok("pay"))
1591 .next("done", Guard::called_ok("pay"))
1592 .stage("done")
1593 .terminal()
1594 .compile()
1595 .expect_err("unguarded commit must fail");
1596 assert!(matches!(err, ConversationError::Compile(_)));
1597 }
1598}