1use std::collections::HashMap;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::state::State;
13
14pub struct ComputedVar {
21 pub key: String,
23 pub dependencies: Vec<String>,
25 pub compute: Arc<dyn Fn(&State) -> Option<Value> + Send + Sync>,
27}
28
29pub struct ComputedRegistry {
35 vars: Vec<ComputedVar>,
37 dep_index: HashMap<String, Vec<usize>>,
40}
41
42impl Default for ComputedRegistry {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl ComputedRegistry {
49 pub fn new() -> Self {
51 Self {
52 vars: Vec::new(),
53 dep_index: HashMap::new(),
54 }
55 }
56
57 pub fn register(&mut self, var: ComputedVar) {
60 if let Some(pos) = self.vars.iter().position(|v| v.key == var.key) {
61 self.vars[pos] = var; } else {
63 self.vars.push(var);
64 }
65 self.topo_sort_or_panic();
66 self.rebuild_dep_index();
67 }
68
69 pub fn recompute(&self, state: &State) -> Vec<String> {
72 let mut changed = Vec::new();
73 for var in &self.vars {
74 if let Some(new_val) = (var.compute)(state) {
75 let derived_key = format!("derived:{}", var.key);
76 let old_val = state.get_raw(&derived_key);
77 let did_change = old_val.as_ref() != Some(&new_val);
78 state.set(&derived_key, new_val);
79 if did_change {
80 changed.push(var.key.clone());
81 }
82 }
83 }
84 changed
85 }
86
87 pub fn recompute_affected(&self, state: &State, changed_keys: &[String]) -> Vec<String> {
93 let mut visited = vec![false; self.vars.len()];
95 let mut affected_set = Vec::new();
96
97 let mut work_keys: Vec<String> = changed_keys.to_vec();
99
100 while let Some(key) = work_keys.pop() {
101 if let Some(indices) = self.dep_index.get(&key) {
103 for &idx in indices {
104 if !visited[idx] {
105 visited[idx] = true;
106 affected_set.push(idx);
107 work_keys.push(self.vars[idx].key.clone());
110 }
111 }
112 }
113 }
114
115 affected_set.sort_unstable();
117
118 let mut changed = Vec::new();
119 for idx in affected_set {
120 let var = &self.vars[idx];
121 if let Some(new_val) = (var.compute)(state) {
122 let derived_key = format!("derived:{}", var.key);
123 let old_val = state.get_raw(&derived_key);
124 let did_change = old_val.as_ref() != Some(&new_val);
125 state.set(&derived_key, new_val);
126 if did_change {
127 changed.push(var.key.clone());
128 }
129 }
130 }
131 changed
132 }
133
134 pub fn validate(&self) -> Result<(), String> {
137 let n = self.vars.len();
139 if n == 0 {
140 return Ok(());
141 }
142
143 let key_to_idx: HashMap<&str, usize> = self
144 .vars
145 .iter()
146 .enumerate()
147 .map(|(i, v)| (v.key.as_str(), i))
148 .collect();
149
150 let mut in_degree = vec![0usize; n];
151 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
152
153 for (i, var) in self.vars.iter().enumerate() {
154 for dep in &var.dependencies {
155 if let Some(&dep_idx) = key_to_idx.get(dep.as_str()) {
156 adj[dep_idx].push(i);
157 in_degree[i] += 1;
158 }
159 }
161 }
162
163 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
164 let mut visited = 0usize;
165
166 while let Some(node) = queue.pop() {
167 visited += 1;
168 for &neighbor in &adj[node] {
169 in_degree[neighbor] -= 1;
170 if in_degree[neighbor] == 0 {
171 queue.push(neighbor);
172 }
173 }
174 }
175
176 if visited == n {
177 Ok(())
178 } else {
179 let cycle_vars: Vec<&str> = (0..n)
181 .filter(|&i| in_degree[i] > 0)
182 .map(|i| self.vars[i].key.as_str())
183 .collect();
184 Err(format!(
185 "Cycle detected among computed variables: {:?}",
186 cycle_vars
187 ))
188 }
189 }
190
191 pub fn len(&self) -> usize {
193 self.vars.len()
194 }
195
196 pub fn is_empty(&self) -> bool {
198 self.vars.is_empty()
199 }
200
201 fn topo_sort_or_panic(&mut self) {
206 let n = self.vars.len();
207
208 for var in &self.vars {
210 if var.dependencies.contains(&var.key) {
211 panic!(
212 "Cycle detected among computed variables: {:?}",
213 vec![var.key.as_str()]
214 );
215 }
216 }
217
218 if n <= 1 {
219 return;
220 }
221
222 let key_to_idx: HashMap<&str, usize> = self
224 .vars
225 .iter()
226 .enumerate()
227 .map(|(i, v)| (v.key.as_str(), i))
228 .collect();
229
230 let mut in_degree = vec![0usize; n];
233 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
234
235 for (i, var) in self.vars.iter().enumerate() {
236 for dep in &var.dependencies {
237 if let Some(&dep_idx) = key_to_idx.get(dep.as_str()) {
238 adj[dep_idx].push(i);
239 in_degree[i] += 1;
240 }
241 }
242 }
243
244 let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
246 let mut order: Vec<usize> = Vec::with_capacity(n);
247
248 while let Some(node) = queue.pop() {
249 order.push(node);
250 for &neighbor in &adj[node] {
251 in_degree[neighbor] -= 1;
252 if in_degree[neighbor] == 0 {
253 queue.push(neighbor);
254 }
255 }
256 }
257
258 if order.len() != n {
259 let cycle_vars: Vec<&str> = (0..n)
260 .filter(|&i| in_degree[i] > 0)
261 .map(|i| self.vars[i].key.as_str())
262 .collect();
263 panic!("Cycle detected among computed variables: {:?}", cycle_vars);
264 }
265
266 let mut slots: Vec<Option<ComputedVar>> = self.vars.drain(..).map(Some).collect();
269 for &idx in &order {
270 if let Some(var) = slots[idx].take() {
271 self.vars.push(var);
272 }
273 }
274 }
275
276 fn rebuild_dep_index(&mut self) {
278 self.dep_index.clear();
279 for (i, var) in self.vars.iter().enumerate() {
280 for dep in &var.dependencies {
281 self.dep_index.entry(dep.clone()).or_default().push(i);
282 }
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use serde_json::json;
291
292 #[test]
295 fn single_var_register_and_recompute() {
296 let mut registry = ComputedRegistry::new();
297 registry.register(ComputedVar {
298 key: "doubled".into(),
299 dependencies: vec!["app:count".into()],
300 compute: Arc::new(|state| {
301 let count: i64 = state.get("app:count")?;
302 Some(json!(count * 2))
303 }),
304 });
305
306 let state = State::new();
307 state.set("app:count", 5);
308
309 let changed = registry.recompute(&state);
310 assert_eq!(changed, vec!["doubled"]);
311 assert_eq!(state.get::<i64>("derived:doubled"), Some(10));
312 }
313
314 #[test]
317 fn dependency_ordering() {
318 let mut registry = ComputedRegistry::new();
319
320 registry.register(ComputedVar {
322 key: "derived_from_base".into(),
323 dependencies: vec!["base".into()],
324 compute: Arc::new(|state| {
325 let base: i64 = state.get("derived:base")?;
326 Some(json!(base + 100))
327 }),
328 });
329
330 registry.register(ComputedVar {
332 key: "base".into(),
333 dependencies: vec!["app:input".into()],
334 compute: Arc::new(|state| {
335 let input: i64 = state.get("app:input")?;
336 Some(json!(input * 2))
337 }),
338 });
339
340 let state = State::new();
341 state.set("app:input", 3);
342
343 let changed = registry.recompute(&state);
344 assert_eq!(state.get::<i64>("derived:base"), Some(6));
346 assert_eq!(state.get::<i64>("derived:derived_from_base"), Some(106));
347 assert!(changed.contains(&"base".to_string()));
348 assert!(changed.contains(&"derived_from_base".to_string()));
349 }
350
351 #[test]
354 #[should_panic(expected = "Cycle detected")]
355 fn cycle_detection_panics() {
356 let mut registry = ComputedRegistry::new();
357 registry.register(ComputedVar {
358 key: "a".into(),
359 dependencies: vec!["b".into()],
360 compute: Arc::new(|_| Some(json!(1))),
361 });
362 registry.register(ComputedVar {
363 key: "b".into(),
364 dependencies: vec!["a".into()],
365 compute: Arc::new(|_| Some(json!(2))),
366 });
367 }
368
369 #[test]
372 fn recompute_returns_only_changed_keys() {
373 let mut registry = ComputedRegistry::new();
374 registry.register(ComputedVar {
375 key: "level".into(),
376 dependencies: vec!["app:score".into()],
377 compute: Arc::new(|state| {
378 let score: f64 = state.get("app:score")?;
379 if score > 0.5 {
380 Some(json!("high"))
381 } else {
382 Some(json!("low"))
383 }
384 }),
385 });
386
387 let state = State::new();
388 state.set("app:score", 0.8);
389
390 let changed = registry.recompute(&state);
392 assert_eq!(changed, vec!["level"]);
393
394 let changed = registry.recompute(&state);
396 assert!(changed.is_empty());
397
398 state.set("app:score", 0.2);
400 let changed = registry.recompute(&state);
401 assert_eq!(changed, vec!["level"]);
402 assert_eq!(
403 state.get::<String>("derived:level"),
404 Some("low".to_string())
405 );
406 }
407
408 #[test]
411 fn recompute_affected_only_recomputes_affected() {
412 let call_count_a = Arc::new(std::sync::atomic::AtomicUsize::new(0));
413 let call_count_b = Arc::new(std::sync::atomic::AtomicUsize::new(0));
414
415 let cc_a = call_count_a.clone();
416 let cc_b = call_count_b.clone();
417
418 let mut registry = ComputedRegistry::new();
419 registry.register(ComputedVar {
420 key: "from_x".into(),
421 dependencies: vec!["app:x".into()],
422 compute: Arc::new(move |state| {
423 cc_a.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
424 let x: i64 = state.get("app:x")?;
425 Some(json!(x + 1))
426 }),
427 });
428 registry.register(ComputedVar {
429 key: "from_y".into(),
430 dependencies: vec!["app:y".into()],
431 compute: Arc::new(move |state| {
432 cc_b.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
433 let y: i64 = state.get("app:y")?;
434 Some(json!(y + 1))
435 }),
436 });
437
438 let state = State::new();
439 state.set("app:x", 10);
440 state.set("app:y", 20);
441
442 let changed = registry.recompute_affected(&state, &["app:x".into()]);
444 assert_eq!(changed, vec!["from_x"]);
445 assert_eq!(call_count_a.load(std::sync::atomic::Ordering::SeqCst), 1);
446 assert_eq!(call_count_b.load(std::sync::atomic::Ordering::SeqCst), 0);
447
448 assert_eq!(state.get::<i64>("derived:from_x"), Some(11));
449 assert_eq!(state.get_raw("derived:from_y"), None);
451 }
452
453 #[test]
456 fn validate_catches_cycles() {
457 let mut registry = ComputedRegistry::new();
458 registry.vars.push(ComputedVar {
460 key: "x".into(),
461 dependencies: vec!["y".into()],
462 compute: Arc::new(|_| Some(json!(1))),
463 });
464 registry.vars.push(ComputedVar {
465 key: "y".into(),
466 dependencies: vec!["x".into()],
467 compute: Arc::new(|_| Some(json!(2))),
468 });
469
470 let result = registry.validate();
471 assert!(result.is_err());
472 let msg = result.unwrap_err();
473 assert!(msg.contains("Cycle detected"));
474 }
475
476 #[test]
479 fn validate_succeeds_on_valid_graph() {
480 let mut registry = ComputedRegistry::new();
481 registry.register(ComputedVar {
482 key: "a".into(),
483 dependencies: vec!["app:input".into()],
484 compute: Arc::new(|_| Some(json!(1))),
485 });
486 registry.register(ComputedVar {
487 key: "b".into(),
488 dependencies: vec!["a".into()],
489 compute: Arc::new(|_| Some(json!(2))),
490 });
491
492 assert!(registry.validate().is_ok());
493 }
494
495 #[test]
498 fn compute_returning_none_skips_write() {
499 let mut registry = ComputedRegistry::new();
500 registry.register(ComputedVar {
501 key: "maybe".into(),
502 dependencies: vec!["app:flag".into()],
503 compute: Arc::new(|state| {
504 let flag: bool = state.get("app:flag")?;
505 if flag {
506 Some(json!("yes"))
507 } else {
508 None
509 }
510 }),
511 });
512
513 let state = State::new();
514 let changed = registry.recompute(&state);
516 assert!(changed.is_empty());
517 assert_eq!(state.get_raw("derived:maybe"), None);
518
519 state.set("app:flag", false);
521 let changed = registry.recompute(&state);
522 assert!(changed.is_empty());
523 assert_eq!(state.get_raw("derived:maybe"), None);
524
525 state.set("app:flag", true);
527 let changed = registry.recompute(&state);
528 assert_eq!(changed, vec!["maybe"]);
529 assert_eq!(
530 state.get::<String>("derived:maybe"),
531 Some("yes".to_string())
532 );
533 }
534
535 #[test]
538 fn diamond_dependency() {
539 let mut registry = ComputedRegistry::new();
547
548 registry.register(ComputedVar {
549 key: "d".into(),
550 dependencies: vec!["app:root".into()],
551 compute: Arc::new(|state| {
552 let root: i64 = state.get("app:root")?;
553 Some(json!(root))
554 }),
555 });
556
557 registry.register(ComputedVar {
558 key: "a".into(),
559 dependencies: vec!["d".into()],
560 compute: Arc::new(|state| {
561 let d: i64 = state.get("derived:d")?;
562 Some(json!(d + 10))
563 }),
564 });
565
566 registry.register(ComputedVar {
567 key: "b".into(),
568 dependencies: vec!["d".into()],
569 compute: Arc::new(|state| {
570 let d: i64 = state.get("derived:d")?;
571 Some(json!(d + 20))
572 }),
573 });
574
575 registry.register(ComputedVar {
576 key: "c".into(),
577 dependencies: vec!["a".into(), "b".into()],
578 compute: Arc::new(|state| {
579 let a: i64 = state.get("derived:a")?;
580 let b: i64 = state.get("derived:b")?;
581 Some(json!(a + b))
582 }),
583 });
584
585 let state = State::new();
586 state.set("app:root", 1);
587
588 let changed = registry.recompute(&state);
589 assert_eq!(state.get::<i64>("derived:d"), Some(1));
590 assert_eq!(state.get::<i64>("derived:a"), Some(11));
591 assert_eq!(state.get::<i64>("derived:b"), Some(21));
592 assert_eq!(state.get::<i64>("derived:c"), Some(32));
593 assert_eq!(changed.len(), 4);
594 }
595
596 #[test]
599 fn empty_registry_recompute_returns_empty() {
600 let registry = ComputedRegistry::new();
601 let state = State::new();
602 let changed = registry.recompute(&state);
603 assert!(changed.is_empty());
604 }
605
606 #[test]
609 fn len_and_is_empty() {
610 let mut registry = ComputedRegistry::new();
611 assert!(registry.is_empty());
612 assert_eq!(registry.len(), 0);
613
614 registry.register(ComputedVar {
615 key: "x".into(),
616 dependencies: vec![],
617 compute: Arc::new(|_| Some(json!(1))),
618 });
619 assert!(!registry.is_empty());
620 assert_eq!(registry.len(), 1);
621 }
622
623 #[test]
626 fn recompute_affected_diamond() {
627 let mut registry = ComputedRegistry::new();
628
629 registry.register(ComputedVar {
630 key: "root_derived".into(),
631 dependencies: vec!["app:root".into()],
632 compute: Arc::new(|state| {
633 let r: i64 = state.get("app:root")?;
634 Some(json!(r * 10))
635 }),
636 });
637
638 registry.register(ComputedVar {
639 key: "leaf".into(),
640 dependencies: vec!["root_derived".into()],
641 compute: Arc::new(|state| {
642 let rd: i64 = state.get("derived:root_derived")?;
643 Some(json!(rd + 5))
644 }),
645 });
646
647 let state = State::new();
648 state.set("app:root", 2);
649
650 registry.recompute(&state);
652 assert_eq!(state.get::<i64>("derived:root_derived"), Some(20));
653 assert_eq!(state.get::<i64>("derived:leaf"), Some(25));
654
655 state.set("app:root", 3);
657 let changed = registry.recompute_affected(&state, &["app:root".into()]);
658 assert!(changed.contains(&"root_derived".to_string()));
660 assert_eq!(state.get::<i64>("derived:root_derived"), Some(30));
661 assert!(changed.contains(&"leaf".to_string()));
664 assert_eq!(state.get::<i64>("derived:leaf"), Some(35));
665 }
666
667 #[test]
670 fn validate_empty_registry() {
671 let registry = ComputedRegistry::new();
672 assert!(registry.validate().is_ok());
673 }
674
675 #[test]
678 #[should_panic(expected = "Cycle detected")]
679 fn self_cycle_panics() {
680 let mut registry = ComputedRegistry::new();
681 registry.register(ComputedVar {
682 key: "self_ref".into(),
683 dependencies: vec!["self_ref".into()],
684 compute: Arc::new(|_| Some(json!(1))),
685 });
686 }
687}