gemini_adk_rs/live/
needs.rs

1//! Conversation repair protocol — tracks need fulfillment and nudges.
2//!
3//! When a phase declares `needs` (state keys that must be gathered), this
4//! module tracks whether the conversation is making progress. After N turns
5//! without progress, it nudges the model via context injection. After M turns,
6//! it sets an escalation flag for phase guards to pick up.
7
8use std::collections::HashMap;
9
10use crate::state::State;
11
12/// Default turns before first nudge.
13pub const DEFAULT_NUDGE_AFTER: u32 = 3;
14/// Default turns before escalation.
15pub const DEFAULT_ESCALATE_AFTER: u32 = 6;
16
17/// Configuration for the conversation repair system.
18#[derive(Debug, Clone)]
19pub struct RepairConfig {
20    /// Turns without progress before first nudge.
21    pub nudge_after: u32,
22    /// Turns without progress before escalation flag is set.
23    pub escalate_after: u32,
24}
25
26impl Default for RepairConfig {
27    fn default() -> Self {
28        Self {
29            nudge_after: DEFAULT_NUDGE_AFTER,
30            escalate_after: DEFAULT_ESCALATE_AFTER,
31        }
32    }
33}
34
35impl RepairConfig {
36    /// Create a new config with custom thresholds.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Set the nudge threshold.
42    pub fn nudge_after(mut self, n: u32) -> Self {
43        self.nudge_after = n;
44        self
45    }
46
47    /// Set the escalation threshold.
48    pub fn escalate_after(mut self, n: u32) -> Self {
49        self.escalate_after = n;
50        self
51    }
52}
53
54/// What action the repair system recommends.
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum RepairAction {
57    /// No intervention needed.
58    None,
59    /// Nudge the model to collect missing information.
60    Nudge {
61        /// Keys that still need values.
62        unfulfilled: Vec<String>,
63        /// Which nudge attempt this is (1-based).
64        attempt: u32,
65    },
66    /// Escalation — set state flag for phase guards.
67    Escalate {
68        /// Keys that still need values.
69        unfulfilled: Vec<String>,
70    },
71}
72
73/// Tracks need fulfillment per phase and recommends repair actions.
74pub struct NeedsFulfillment {
75    /// Phase name → consecutive turns without progress.
76    stall_count: HashMap<String, u32>,
77    /// Configuration thresholds.
78    config: RepairConfig,
79}
80
81impl NeedsFulfillment {
82    /// Create with the given configuration.
83    pub fn new(config: RepairConfig) -> Self {
84        Self {
85            stall_count: HashMap::new(),
86            config,
87        }
88    }
89
90    /// Evaluate whether repair action is needed for the current phase.
91    ///
92    /// Call after extractors run. Returns the recommended action.
93    pub fn evaluate(&mut self, phase: &str, needs: &[String], state: &State) -> RepairAction {
94        let unfulfilled: Vec<String> = needs
95            .iter()
96            .filter(|key| state.get_raw(key).is_none())
97            .cloned()
98            .collect();
99
100        if unfulfilled.is_empty() {
101            self.stall_count.remove(phase);
102            return RepairAction::None;
103        }
104
105        let count = self.stall_count.entry(phase.to_string()).or_insert(0);
106        *count += 1;
107
108        if *count >= self.config.escalate_after {
109            RepairAction::Escalate { unfulfilled }
110        } else if *count >= self.config.nudge_after {
111            RepairAction::Nudge {
112                unfulfilled,
113                attempt: *count - self.config.nudge_after + 1,
114            }
115        } else {
116            RepairAction::None
117        }
118    }
119
120    /// Reset tracking for a phase (call on phase transition).
121    pub fn reset(&mut self, phase: &str) {
122        self.stall_count.remove(phase);
123    }
124
125    /// Reset all tracking.
126    pub fn reset_all(&mut self) {
127        self.stall_count.clear();
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn no_action_when_needs_fulfilled() {
137        let state = State::new();
138        state.set("customer_id", "C123");
139        state.set("account_number", "A456");
140
141        let mut nf = NeedsFulfillment::new(RepairConfig::default());
142        let action = nf.evaluate(
143            "gather",
144            &["customer_id".into(), "account_number".into()],
145            &state,
146        );
147        assert_eq!(action, RepairAction::None);
148    }
149
150    #[test]
151    fn no_action_before_threshold() {
152        let state = State::new();
153        let mut nf = NeedsFulfillment::new(RepairConfig::default());
154
155        // First 2 turns: no action (threshold is 3)
156        for _ in 0..2 {
157            let action = nf.evaluate("gather", &["customer_id".into()], &state);
158            assert_eq!(action, RepairAction::None);
159        }
160    }
161
162    #[test]
163    fn nudge_at_threshold() {
164        let state = State::new();
165        let mut nf = NeedsFulfillment::new(RepairConfig::default());
166
167        // Turns 1-2: no action
168        for _ in 0..2 {
169            nf.evaluate("gather", &["customer_id".into()], &state);
170        }
171
172        // Turn 3: nudge
173        let action = nf.evaluate("gather", &["customer_id".into()], &state);
174        assert!(matches!(action, RepairAction::Nudge { attempt: 1, .. }));
175    }
176
177    #[test]
178    fn escalation_at_threshold() {
179        let state = State::new();
180        let mut nf = NeedsFulfillment::new(RepairConfig::default());
181
182        for _ in 0..5 {
183            nf.evaluate("gather", &["customer_id".into()], &state);
184        }
185
186        // Turn 6: escalate
187        let action = nf.evaluate("gather", &["customer_id".into()], &state);
188        assert!(matches!(action, RepairAction::Escalate { .. }));
189    }
190
191    #[test]
192    fn fulfilling_need_resets_counter() {
193        let state = State::new();
194        let mut nf = NeedsFulfillment::new(RepairConfig::default());
195
196        // Stall for 2 turns
197        for _ in 0..2 {
198            nf.evaluate("gather", &["customer_id".into()], &state);
199        }
200
201        // Fulfill the need
202        state.set("customer_id", "C123");
203        let action = nf.evaluate("gather", &["customer_id".into()], &state);
204        assert_eq!(action, RepairAction::None);
205
206        // Counter should be reset — unfulfill again
207        state.remove("customer_id");
208        let action = nf.evaluate("gather", &["customer_id".into()], &state);
209        assert_eq!(action, RepairAction::None); // Turn 1 of new stall
210    }
211
212    #[test]
213    fn custom_thresholds() {
214        let state = State::new();
215        let mut nf = NeedsFulfillment::new(RepairConfig::new().nudge_after(1).escalate_after(2));
216
217        // Turn 1: nudge immediately
218        let action = nf.evaluate("gather", &["x".into()], &state);
219        assert!(matches!(action, RepairAction::Nudge { attempt: 1, .. }));
220
221        // Turn 2: escalate
222        let action = nf.evaluate("gather", &["x".into()], &state);
223        assert!(matches!(action, RepairAction::Escalate { .. }));
224    }
225}