gemini_adk_rs/live/
needs.rs1use std::collections::HashMap;
9
10use crate::state::State;
11
12pub const DEFAULT_NUDGE_AFTER: u32 = 3;
14pub const DEFAULT_ESCALATE_AFTER: u32 = 6;
16
17#[derive(Debug, Clone)]
19pub struct RepairConfig {
20 pub nudge_after: u32,
22 pub escalate_after: u32,
24}
25
26impl Default for RepairConfig {
27 fn default() -> Self {
28 Self {
29 nudge_after: DEFAULT_NUDGE_AFTER,
30 escalate_after: DEFAULT_ESCALATE_AFTER,
31 }
32 }
33}
34
35impl RepairConfig {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn nudge_after(mut self, n: u32) -> Self {
43 self.nudge_after = n;
44 self
45 }
46
47 pub fn escalate_after(mut self, n: u32) -> Self {
49 self.escalate_after = n;
50 self
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum RepairAction {
57 None,
59 Nudge {
61 unfulfilled: Vec<String>,
63 attempt: u32,
65 },
66 Escalate {
68 unfulfilled: Vec<String>,
70 },
71}
72
73pub struct NeedsFulfillment {
75 stall_count: HashMap<String, u32>,
77 config: RepairConfig,
79}
80
81impl NeedsFulfillment {
82 pub fn new(config: RepairConfig) -> Self {
84 Self {
85 stall_count: HashMap::new(),
86 config,
87 }
88 }
89
90 pub fn evaluate(&mut self, phase: &str, needs: &[String], state: &State) -> RepairAction {
94 let unfulfilled: Vec<String> = needs
95 .iter()
96 .filter(|key| state.get_raw(key).is_none())
97 .cloned()
98 .collect();
99
100 if unfulfilled.is_empty() {
101 self.stall_count.remove(phase);
102 return RepairAction::None;
103 }
104
105 let count = self.stall_count.entry(phase.to_string()).or_insert(0);
106 *count += 1;
107
108 if *count >= self.config.escalate_after {
109 RepairAction::Escalate { unfulfilled }
110 } else if *count >= self.config.nudge_after {
111 RepairAction::Nudge {
112 unfulfilled,
113 attempt: *count - self.config.nudge_after + 1,
114 }
115 } else {
116 RepairAction::None
117 }
118 }
119
120 pub fn reset(&mut self, phase: &str) {
122 self.stall_count.remove(phase);
123 }
124
125 pub fn reset_all(&mut self) {
127 self.stall_count.clear();
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn no_action_when_needs_fulfilled() {
137 let state = State::new();
138 state.set("customer_id", "C123");
139 state.set("account_number", "A456");
140
141 let mut nf = NeedsFulfillment::new(RepairConfig::default());
142 let action = nf.evaluate(
143 "gather",
144 &["customer_id".into(), "account_number".into()],
145 &state,
146 );
147 assert_eq!(action, RepairAction::None);
148 }
149
150 #[test]
151 fn no_action_before_threshold() {
152 let state = State::new();
153 let mut nf = NeedsFulfillment::new(RepairConfig::default());
154
155 for _ in 0..2 {
157 let action = nf.evaluate("gather", &["customer_id".into()], &state);
158 assert_eq!(action, RepairAction::None);
159 }
160 }
161
162 #[test]
163 fn nudge_at_threshold() {
164 let state = State::new();
165 let mut nf = NeedsFulfillment::new(RepairConfig::default());
166
167 for _ in 0..2 {
169 nf.evaluate("gather", &["customer_id".into()], &state);
170 }
171
172 let action = nf.evaluate("gather", &["customer_id".into()], &state);
174 assert!(matches!(action, RepairAction::Nudge { attempt: 1, .. }));
175 }
176
177 #[test]
178 fn escalation_at_threshold() {
179 let state = State::new();
180 let mut nf = NeedsFulfillment::new(RepairConfig::default());
181
182 for _ in 0..5 {
183 nf.evaluate("gather", &["customer_id".into()], &state);
184 }
185
186 let action = nf.evaluate("gather", &["customer_id".into()], &state);
188 assert!(matches!(action, RepairAction::Escalate { .. }));
189 }
190
191 #[test]
192 fn fulfilling_need_resets_counter() {
193 let state = State::new();
194 let mut nf = NeedsFulfillment::new(RepairConfig::default());
195
196 for _ in 0..2 {
198 nf.evaluate("gather", &["customer_id".into()], &state);
199 }
200
201 state.set("customer_id", "C123");
203 let action = nf.evaluate("gather", &["customer_id".into()], &state);
204 assert_eq!(action, RepairAction::None);
205
206 state.remove("customer_id");
208 let action = nf.evaluate("gather", &["customer_id".into()], &state);
209 assert_eq!(action, RepairAction::None); }
211
212 #[test]
213 fn custom_thresholds() {
214 let state = State::new();
215 let mut nf = NeedsFulfillment::new(RepairConfig::new().nudge_after(1).escalate_after(2));
216
217 let action = nf.evaluate("gather", &["x".into()], &state);
219 assert!(matches!(action, RepairAction::Nudge { attempt: 1, .. }));
220
221 let action = nf.evaluate("gather", &["x".into()], &state);
223 assert!(matches!(action, RepairAction::Escalate { .. }));
224 }
225}