1use std::collections::HashMap;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::state::State;
13
14use super::contract::ComputedContract;
15
16pub type ComputeFn = Arc<dyn Fn(&State) -> Option<Value> + Send + Sync>;
18
19pub struct ComputedVar {
26 pub key: String,
28 pub dependencies: Vec<String>,
30 pub compute: ComputeFn,
32}
33
34pub struct ComputedRegistry {
40 vars: Vec<ComputedVar>,
42 dep_index: HashMap<String, Vec<usize>>,
45}
46
47impl Default for ComputedRegistry {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl ComputedRegistry {
54 pub fn new() -> Self {
56 Self {
57 vars: Vec::new(),
58 dep_index: HashMap::new(),
59 }
60 }
61
62 pub fn register(&mut self, var: ComputedVar) {
65 if let Some(pos) = self.vars.iter().position(|v| v.key == var.key) {
66 self.vars[pos] = var; } else {
68 self.vars.push(var);
69 }
70 self.topo_sort_or_panic();
71 self.rebuild_dep_index();
72 }
73
74 pub fn recompute(&self, state: &State) -> Vec<String> {
77 let mut changed = Vec::new();
78 for var in &self.vars {
79 if let Some(new_val) = (var.compute)(state) {
80 let derived_key = format!("derived:{}", var.key);
81 let old_val = state.get_raw(&derived_key);
82 let did_change = old_val.as_ref() != Some(&new_val);
83 let _ = state.set(&derived_key, new_val);
84 if did_change {
85 changed.push(var.key.clone());
86 }
87 }
88 }
89 changed
90 }
91
92 pub fn recompute_affected(&self, state: &State, changed_keys: &[String]) -> Vec<String> {
98 let mut visited = vec![false; self.vars.len()];
100 let mut affected_set = Vec::new();
101
102 let mut work_keys: Vec<String> = changed_keys.to_vec();
104
105 while let Some(key) = work_keys.pop() {
106 if let Some(indices) = self.dep_index.get(&key) {
108 for &idx in indices {
109 if !visited[idx] {
110 visited[idx] = true;
111 affected_set.push(idx);
112 work_keys.push(self.vars[idx].key.clone());
115 }
116 }
117 }
118 }
119
120 affected_set.sort_unstable();
122
123 let mut changed = Vec::new();
124 for idx in affected_set {
125 let var = &self.vars[idx];
126 if let Some(new_val) = (var.compute)(state) {
127 let derived_key = format!("derived:{}", var.key);
128 let old_val = state.get_raw(&derived_key);
129 let did_change = old_val.as_ref() != Some(&new_val);
130 let _ = state.set(&derived_key, new_val);
131 if did_change {
132 changed.push(var.key.clone());
133 }
134 }
135 }
136 changed
137 }
138
139 pub fn validate(&self) -> Result<(), String> {
142 let n = self.vars.len();
144 if n == 0 {
145 return Ok(());
146 }
147
148 let key_to_idx: HashMap<&str, usize> = self
149 .vars
150 .iter()
151 .enumerate()
152 .map(|(i, v)| (v.key.as_str(), i))
153 .collect();
154
155 let mut in_degree = vec![0usize; n];
156 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
157
158 for (i, var) in self.vars.iter().enumerate() {
159 for dep in &var.dependencies {
160 if let Some(&dep_idx) = key_to_idx.get(dep.as_str()) {
161 adj[dep_idx].push(i);
162 in_degree[i] += 1;
163 }
164 }
166 }
167
168 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
169 let mut visited = 0usize;
170
171 while let Some(node) = queue.pop() {
172 visited += 1;
173 for &neighbor in &adj[node] {
174 in_degree[neighbor] -= 1;
175 if in_degree[neighbor] == 0 {
176 queue.push(neighbor);
177 }
178 }
179 }
180
181 if visited == n {
182 Ok(())
183 } else {
184 let cycle_vars: Vec<&str> = (0..n)
186 .filter(|&i| in_degree[i] > 0)
187 .map(|i| self.vars[i].key.as_str())
188 .collect();
189 Err(format!(
190 "Cycle detected among computed variables: {:?}",
191 cycle_vars
192 ))
193 }
194 }
195
196 pub fn len(&self) -> usize {
198 self.vars.len()
199 }
200
201 pub fn is_empty(&self) -> bool {
203 self.vars.is_empty()
204 }
205
206 pub fn describe(&self) -> Vec<ComputedContract> {
208 self.vars
209 .iter()
210 .map(|var| ComputedContract {
211 key: var.key.clone(),
212 dependencies: var.dependencies.clone(),
213 })
214 .collect()
215 }
216
217 fn topo_sort_or_panic(&mut self) {
222 let n = self.vars.len();
223
224 for var in &self.vars {
226 if var.dependencies.contains(&var.key) {
227 panic!(
228 "Cycle detected among computed variables: {:?}",
229 vec![var.key.as_str()]
230 );
231 }
232 }
233
234 if n <= 1 {
235 return;
236 }
237
238 let key_to_idx: HashMap<&str, usize> = self
240 .vars
241 .iter()
242 .enumerate()
243 .map(|(i, v)| (v.key.as_str(), i))
244 .collect();
245
246 let mut in_degree = vec![0usize; n];
249 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
250
251 for (i, var) in self.vars.iter().enumerate() {
252 for dep in &var.dependencies {
253 if let Some(&dep_idx) = key_to_idx.get(dep.as_str()) {
254 adj[dep_idx].push(i);
255 in_degree[i] += 1;
256 }
257 }
258 }
259
260 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
262 let mut order: Vec<usize> = Vec::with_capacity(n);
263
264 while let Some(node) = queue.pop() {
265 order.push(node);
266 for &neighbor in &adj[node] {
267 in_degree[neighbor] -= 1;
268 if in_degree[neighbor] == 0 {
269 queue.push(neighbor);
270 }
271 }
272 }
273
274 if order.len() != n {
275 let cycle_vars: Vec<&str> = (0..n)
276 .filter(|&i| in_degree[i] > 0)
277 .map(|i| self.vars[i].key.as_str())
278 .collect();
279 panic!("Cycle detected among computed variables: {:?}", cycle_vars);
280 }
281
282 let mut slots: Vec<Option<ComputedVar>> = self.vars.drain(..).map(Some).collect();
285 for &idx in &order {
286 if let Some(var) = slots[idx].take() {
287 self.vars.push(var);
288 }
289 }
290 }
291
292 fn rebuild_dep_index(&mut self) {
294 self.dep_index.clear();
295 for (i, var) in self.vars.iter().enumerate() {
296 for dep in &var.dependencies {
297 self.dep_index.entry(dep.clone()).or_default().push(i);
298 }
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use serde_json::json;
307
308 #[test]
311 fn single_var_register_and_recompute() {
312 let mut registry = ComputedRegistry::new();
313 registry.register(ComputedVar {
314 key: "doubled".into(),
315 dependencies: vec!["app:count".into()],
316 compute: Arc::new(|state| {
317 let count: i64 = state.get("app:count")?;
318 Some(json!(count * 2))
319 }),
320 });
321
322 let state = State::new();
323 let _ = state.set("app:count", 5);
324
325 let changed = registry.recompute(&state);
326 assert_eq!(changed, vec!["doubled"]);
327 assert_eq!(state.get::<i64>("derived:doubled"), Some(10));
328 }
329
330 #[test]
333 fn dependency_ordering() {
334 let mut registry = ComputedRegistry::new();
335
336 registry.register(ComputedVar {
338 key: "derived_from_base".into(),
339 dependencies: vec!["base".into()],
340 compute: Arc::new(|state| {
341 let base: i64 = state.get("derived:base")?;
342 Some(json!(base + 100))
343 }),
344 });
345
346 registry.register(ComputedVar {
348 key: "base".into(),
349 dependencies: vec!["app:input".into()],
350 compute: Arc::new(|state| {
351 let input: i64 = state.get("app:input")?;
352 Some(json!(input * 2))
353 }),
354 });
355
356 let state = State::new();
357 let _ = state.set("app:input", 3);
358
359 let changed = registry.recompute(&state);
360 assert_eq!(state.get::<i64>("derived:base"), Some(6));
362 assert_eq!(state.get::<i64>("derived:derived_from_base"), Some(106));
363 assert!(changed.contains(&"base".to_string()));
364 assert!(changed.contains(&"derived_from_base".to_string()));
365 }
366
367 #[test]
370 #[should_panic(expected = "Cycle detected")]
371 fn cycle_detection_panics() {
372 let mut registry = ComputedRegistry::new();
373 registry.register(ComputedVar {
374 key: "a".into(),
375 dependencies: vec!["b".into()],
376 compute: Arc::new(|_| Some(json!(1))),
377 });
378 registry.register(ComputedVar {
379 key: "b".into(),
380 dependencies: vec!["a".into()],
381 compute: Arc::new(|_| Some(json!(2))),
382 });
383 }
384
385 #[test]
388 fn recompute_returns_only_changed_keys() {
389 let mut registry = ComputedRegistry::new();
390 registry.register(ComputedVar {
391 key: "level".into(),
392 dependencies: vec!["app:score".into()],
393 compute: Arc::new(|state| {
394 let score: f64 = state.get("app:score")?;
395 if score > 0.5 {
396 Some(json!("high"))
397 } else {
398 Some(json!("low"))
399 }
400 }),
401 });
402
403 let state = State::new();
404 let _ = state.set("app:score", 0.8);
405
406 let changed = registry.recompute(&state);
408 assert_eq!(changed, vec!["level"]);
409
410 let changed = registry.recompute(&state);
412 assert!(changed.is_empty());
413
414 let _ = state.set("app:score", 0.2);
416 let changed = registry.recompute(&state);
417 assert_eq!(changed, vec!["level"]);
418 assert_eq!(
419 state.get::<String>("derived:level"),
420 Some("low".to_string())
421 );
422 }
423
424 #[test]
427 fn recompute_affected_only_recomputes_affected() {
428 let call_count_a = Arc::new(std::sync::atomic::AtomicUsize::new(0));
429 let call_count_b = Arc::new(std::sync::atomic::AtomicUsize::new(0));
430
431 let cc_a = call_count_a.clone();
432 let cc_b = call_count_b.clone();
433
434 let mut registry = ComputedRegistry::new();
435 registry.register(ComputedVar {
436 key: "from_x".into(),
437 dependencies: vec!["app:x".into()],
438 compute: Arc::new(move |state| {
439 cc_a.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
440 let x: i64 = state.get("app:x")?;
441 Some(json!(x + 1))
442 }),
443 });
444 registry.register(ComputedVar {
445 key: "from_y".into(),
446 dependencies: vec!["app:y".into()],
447 compute: Arc::new(move |state| {
448 cc_b.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
449 let y: i64 = state.get("app:y")?;
450 Some(json!(y + 1))
451 }),
452 });
453
454 let state = State::new();
455 let _ = state.set("app:x", 10);
456 let _ = state.set("app:y", 20);
457
458 let changed = registry.recompute_affected(&state, &["app:x".into()]);
460 assert_eq!(changed, vec!["from_x"]);
461 assert_eq!(call_count_a.load(std::sync::atomic::Ordering::SeqCst), 1);
462 assert_eq!(call_count_b.load(std::sync::atomic::Ordering::SeqCst), 0);
463
464 assert_eq!(state.get::<i64>("derived:from_x"), Some(11));
465 assert_eq!(state.get_raw("derived:from_y"), None);
467 }
468
469 #[test]
472 fn validate_catches_cycles() {
473 let mut registry = ComputedRegistry::new();
474 registry.vars.push(ComputedVar {
476 key: "x".into(),
477 dependencies: vec!["y".into()],
478 compute: Arc::new(|_| Some(json!(1))),
479 });
480 registry.vars.push(ComputedVar {
481 key: "y".into(),
482 dependencies: vec!["x".into()],
483 compute: Arc::new(|_| Some(json!(2))),
484 });
485
486 let result = registry.validate();
487 assert!(result.is_err());
488 let msg = result.unwrap_err();
489 assert!(msg.contains("Cycle detected"));
490 }
491
492 #[test]
495 fn validate_succeeds_on_valid_graph() {
496 let mut registry = ComputedRegistry::new();
497 registry.register(ComputedVar {
498 key: "a".into(),
499 dependencies: vec!["app:input".into()],
500 compute: Arc::new(|_| Some(json!(1))),
501 });
502 registry.register(ComputedVar {
503 key: "b".into(),
504 dependencies: vec!["a".into()],
505 compute: Arc::new(|_| Some(json!(2))),
506 });
507
508 assert!(registry.validate().is_ok());
509 }
510
511 #[test]
514 fn compute_returning_none_skips_write() {
515 let mut registry = ComputedRegistry::new();
516 registry.register(ComputedVar {
517 key: "maybe".into(),
518 dependencies: vec!["app:flag".into()],
519 compute: Arc::new(|state| {
520 let flag: bool = state.get("app:flag")?;
521 if flag {
522 Some(json!("yes"))
523 } else {
524 None
525 }
526 }),
527 });
528
529 let state = State::new();
530 let changed = registry.recompute(&state);
532 assert!(changed.is_empty());
533 assert_eq!(state.get_raw("derived:maybe"), None);
534
535 let _ = state.set("app:flag", false);
537 let changed = registry.recompute(&state);
538 assert!(changed.is_empty());
539 assert_eq!(state.get_raw("derived:maybe"), None);
540
541 let _ = state.set("app:flag", true);
543 let changed = registry.recompute(&state);
544 assert_eq!(changed, vec!["maybe"]);
545 assert_eq!(
546 state.get::<String>("derived:maybe"),
547 Some("yes".to_string())
548 );
549 }
550
551 #[test]
554 fn diamond_dependency() {
555 let mut registry = ComputedRegistry::new();
563
564 registry.register(ComputedVar {
565 key: "d".into(),
566 dependencies: vec!["app:root".into()],
567 compute: Arc::new(|state| {
568 let root: i64 = state.get("app:root")?;
569 Some(json!(root))
570 }),
571 });
572
573 registry.register(ComputedVar {
574 key: "a".into(),
575 dependencies: vec!["d".into()],
576 compute: Arc::new(|state| {
577 let d: i64 = state.get("derived:d")?;
578 Some(json!(d + 10))
579 }),
580 });
581
582 registry.register(ComputedVar {
583 key: "b".into(),
584 dependencies: vec!["d".into()],
585 compute: Arc::new(|state| {
586 let d: i64 = state.get("derived:d")?;
587 Some(json!(d + 20))
588 }),
589 });
590
591 registry.register(ComputedVar {
592 key: "c".into(),
593 dependencies: vec!["a".into(), "b".into()],
594 compute: Arc::new(|state| {
595 let a: i64 = state.get("derived:a")?;
596 let b: i64 = state.get("derived:b")?;
597 Some(json!(a + b))
598 }),
599 });
600
601 let state = State::new();
602 let _ = state.set("app:root", 1);
603
604 let changed = registry.recompute(&state);
605 assert_eq!(state.get::<i64>("derived:d"), Some(1));
606 assert_eq!(state.get::<i64>("derived:a"), Some(11));
607 assert_eq!(state.get::<i64>("derived:b"), Some(21));
608 assert_eq!(state.get::<i64>("derived:c"), Some(32));
609 assert_eq!(changed.len(), 4);
610 }
611
612 #[test]
615 fn empty_registry_recompute_returns_empty() {
616 let registry = ComputedRegistry::new();
617 let state = State::new();
618 let changed = registry.recompute(&state);
619 assert!(changed.is_empty());
620 }
621
622 #[test]
625 fn len_and_is_empty() {
626 let mut registry = ComputedRegistry::new();
627 assert!(registry.is_empty());
628 assert_eq!(registry.len(), 0);
629
630 registry.register(ComputedVar {
631 key: "x".into(),
632 dependencies: vec![],
633 compute: Arc::new(|_| Some(json!(1))),
634 });
635 assert!(!registry.is_empty());
636 assert_eq!(registry.len(), 1);
637 }
638
639 #[test]
642 fn recompute_affected_diamond() {
643 let mut registry = ComputedRegistry::new();
644
645 registry.register(ComputedVar {
646 key: "root_derived".into(),
647 dependencies: vec!["app:root".into()],
648 compute: Arc::new(|state| {
649 let r: i64 = state.get("app:root")?;
650 Some(json!(r * 10))
651 }),
652 });
653
654 registry.register(ComputedVar {
655 key: "leaf".into(),
656 dependencies: vec!["root_derived".into()],
657 compute: Arc::new(|state| {
658 let rd: i64 = state.get("derived:root_derived")?;
659 Some(json!(rd + 5))
660 }),
661 });
662
663 let state = State::new();
664 let _ = state.set("app:root", 2);
665
666 registry.recompute(&state);
668 assert_eq!(state.get::<i64>("derived:root_derived"), Some(20));
669 assert_eq!(state.get::<i64>("derived:leaf"), Some(25));
670
671 let _ = state.set("app:root", 3);
673 let changed = registry.recompute_affected(&state, &["app:root".into()]);
674 assert!(changed.contains(&"root_derived".to_string()));
676 assert_eq!(state.get::<i64>("derived:root_derived"), Some(30));
677 assert!(changed.contains(&"leaf".to_string()));
680 assert_eq!(state.get::<i64>("derived:leaf"), Some(35));
681 }
682
683 #[test]
686 fn validate_empty_registry() {
687 let registry = ComputedRegistry::new();
688 assert!(registry.validate().is_ok());
689 }
690
691 #[test]
694 #[should_panic(expected = "Cycle detected")]
695 fn self_cycle_panics() {
696 let mut registry = ComputedRegistry::new();
697 registry.register(ComputedVar {
698 key: "self_ref".into(),
699 dependencies: vec!["self_ref".into()],
700 compute: Arc::new(|_| Some(json!(1))),
701 });
702 }
703}