gemini_adk_rs/live/
computed.rs

1//! Computed (derived) state variables with dependency-ordered evaluation.
2//!
3//! Computed variables are pure functions of other state keys. The [`ComputedRegistry`]
4//! maintains a topologically sorted list of [`ComputedVar`]s so that dependencies are
5//! always evaluated before the variables that depend on them.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::state::State;
13
14/// A computed state variable: a pure function of other state keys.
15///
16/// The `compute` closure receives the full [`State`] and returns an optional
17/// [`Value`]. When it returns `Some(value)`, the result is written to
18/// `derived:{key}` in state. When it returns `None`, the key is skipped
19/// (no write, no change detection).
20pub struct ComputedVar {
21    /// The state key this computed variable writes to (prefixed with `derived:`).
22    pub key: String,
23    /// State keys this variable depends on.
24    pub dependencies: Vec<String>,
25    /// Closure that computes the derived value from current state.
26    pub compute: Arc<dyn Fn(&State) -> Option<Value> + Send + Sync>,
27}
28
29/// Registry of computed variables with dependency-ordered evaluation.
30///
31/// Variables are kept in topological order: if var A depends on var B, then B
32/// appears before A in the internal list. This invariant is maintained at
33/// registration time using Kahn's algorithm.
34pub struct ComputedRegistry {
35    /// Topologically sorted computed variables.
36    vars: Vec<ComputedVar>,
37    /// Maps a state key to the indices (into `vars`) of computed variables
38    /// that list that key as a dependency.
39    dep_index: HashMap<String, Vec<usize>>,
40}
41
42impl Default for ComputedRegistry {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl ComputedRegistry {
49    /// Create an empty registry.
50    pub fn new() -> Self {
51        Self {
52            vars: Vec::new(),
53            dep_index: HashMap::new(),
54        }
55    }
56
57    /// Register a computed variable. Re-sorts the internal list and rebuilds
58    /// the dependency index. **Panics** if the new variable introduces a cycle.
59    pub fn register(&mut self, var: ComputedVar) {
60        if let Some(pos) = self.vars.iter().position(|v| v.key == var.key) {
61            self.vars[pos] = var; // replace existing
62        } else {
63            self.vars.push(var);
64        }
65        self.topo_sort_or_panic();
66        self.rebuild_dep_index();
67    }
68
69    /// Recompute all variables in dependency order. Returns the keys whose
70    /// derived values actually changed (old != new).
71    pub fn recompute(&self, state: &State) -> Vec<String> {
72        let mut changed = Vec::new();
73        for var in &self.vars {
74            if let Some(new_val) = (var.compute)(state) {
75                let derived_key = format!("derived:{}", var.key);
76                let old_val = state.get_raw(&derived_key);
77                let did_change = old_val.as_ref() != Some(&new_val);
78                state.set(&derived_key, new_val);
79                if did_change {
80                    changed.push(var.key.clone());
81                }
82            }
83        }
84        changed
85    }
86
87    /// Recompute only the variables affected by the given changed keys.
88    /// Uses the dependency index for O(1) lookup of affected variables, then
89    /// evaluates them in topological order. Transitively propagates: if a
90    /// computed var changes, its dependents are also scheduled for recomputation.
91    /// Returns keys that actually changed.
92    pub fn recompute_affected(&self, state: &State, changed_keys: &[String]) -> Vec<String> {
93        // Collect indices of affected vars transitively (deduplicated via bitmap).
94        let mut visited = vec![false; self.vars.len()];
95        let mut affected_set = Vec::new();
96
97        // Seed the work queue with the initial changed keys.
98        let mut work_keys: Vec<String> = changed_keys.to_vec();
99
100        while let Some(key) = work_keys.pop() {
101            // Look up vars that depend on this key directly.
102            if let Some(indices) = self.dep_index.get(&key) {
103                for &idx in indices {
104                    if !visited[idx] {
105                        visited[idx] = true;
106                        affected_set.push(idx);
107                        // This computed var's output (derived:<key>) might be
108                        // a dependency of other vars, so enqueue it.
109                        work_keys.push(self.vars[idx].key.clone());
110                    }
111                }
112            }
113        }
114
115        // Sort by topological order (indices are already in topo order).
116        affected_set.sort_unstable();
117
118        let mut changed = Vec::new();
119        for idx in affected_set {
120            let var = &self.vars[idx];
121            if let Some(new_val) = (var.compute)(state) {
122                let derived_key = format!("derived:{}", var.key);
123                let old_val = state.get_raw(&derived_key);
124                let did_change = old_val.as_ref() != Some(&new_val);
125                state.set(&derived_key, new_val);
126                if did_change {
127                    changed.push(var.key.clone());
128                }
129            }
130        }
131        changed
132    }
133
134    /// Validate the dependency graph. Returns `Ok(())` if there are no cycles,
135    /// or `Err(message)` describing the problem.
136    pub fn validate(&self) -> Result<(), String> {
137        // Build adjacency from the current vars and run Kahn's algorithm.
138        let n = self.vars.len();
139        if n == 0 {
140            return Ok(());
141        }
142
143        let key_to_idx: HashMap<&str, usize> = self
144            .vars
145            .iter()
146            .enumerate()
147            .map(|(i, v)| (v.key.as_str(), i))
148            .collect();
149
150        let mut in_degree = vec![0usize; n];
151        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
152
153        for (i, var) in self.vars.iter().enumerate() {
154            for dep in &var.dependencies {
155                if let Some(&dep_idx) = key_to_idx.get(dep.as_str()) {
156                    adj[dep_idx].push(i);
157                    in_degree[i] += 1;
158                }
159                // External dependencies (not in registry) are fine — ignore them.
160            }
161        }
162
163        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
164        let mut visited = 0usize;
165
166        while let Some(node) = queue.pop() {
167            visited += 1;
168            for &neighbor in &adj[node] {
169                in_degree[neighbor] -= 1;
170                if in_degree[neighbor] == 0 {
171                    queue.push(neighbor);
172                }
173            }
174        }
175
176        if visited == n {
177            Ok(())
178        } else {
179            // Find the vars involved in the cycle.
180            let cycle_vars: Vec<&str> = (0..n)
181                .filter(|&i| in_degree[i] > 0)
182                .map(|i| self.vars[i].key.as_str())
183                .collect();
184            Err(format!(
185                "Cycle detected among computed variables: {:?}",
186                cycle_vars
187            ))
188        }
189    }
190
191    /// Returns the number of registered computed variables.
192    pub fn len(&self) -> usize {
193        self.vars.len()
194    }
195
196    /// Returns true if no computed variables are registered.
197    pub fn is_empty(&self) -> bool {
198        self.vars.is_empty()
199    }
200
201    // ── Internal helpers ──────────────────────────────────────────────────
202
203    /// Topologically sort `self.vars` in-place using Kahn's algorithm.
204    /// Panics if a cycle is detected (including self-cycles).
205    fn topo_sort_or_panic(&mut self) {
206        let n = self.vars.len();
207
208        // Check for self-cycles (a var depending on itself).
209        for var in &self.vars {
210            if var.dependencies.contains(&var.key) {
211                panic!(
212                    "Cycle detected among computed variables: {:?}",
213                    vec![var.key.as_str()]
214                );
215            }
216        }
217
218        if n <= 1 {
219            return;
220        }
221
222        // Map computed-var keys to their current index.
223        let key_to_idx: HashMap<&str, usize> = self
224            .vars
225            .iter()
226            .enumerate()
227            .map(|(i, v)| (v.key.as_str(), i))
228            .collect();
229
230        // Build adjacency list and in-degree array.
231        // Edge dep_idx -> i means "dep must come before i".
232        let mut in_degree = vec![0usize; n];
233        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
234
235        for (i, var) in self.vars.iter().enumerate() {
236            for dep in &var.dependencies {
237                if let Some(&dep_idx) = key_to_idx.get(dep.as_str()) {
238                    adj[dep_idx].push(i);
239                    in_degree[i] += 1;
240                }
241            }
242        }
243
244        // Kahn's algorithm.
245        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
246        let mut order: Vec<usize> = Vec::with_capacity(n);
247
248        while let Some(node) = queue.pop() {
249            order.push(node);
250            for &neighbor in &adj[node] {
251                in_degree[neighbor] -= 1;
252                if in_degree[neighbor] == 0 {
253                    queue.push(neighbor);
254                }
255            }
256        }
257
258        if order.len() != n {
259            let cycle_vars: Vec<&str> = (0..n)
260                .filter(|&i| in_degree[i] > 0)
261                .map(|i| self.vars[i].key.as_str())
262                .collect();
263            panic!("Cycle detected among computed variables: {:?}", cycle_vars);
264        }
265
266        // Reorder vars according to topological sort.
267        // Use Option wrapping for safe index-based extraction.
268        let mut slots: Vec<Option<ComputedVar>> = self.vars.drain(..).map(Some).collect();
269        for &idx in &order {
270            if let Some(var) = slots[idx].take() {
271                self.vars.push(var);
272            }
273        }
274    }
275
276    /// Rebuild the `dep_index` mapping from dependency keys to var indices.
277    fn rebuild_dep_index(&mut self) {
278        self.dep_index.clear();
279        for (i, var) in self.vars.iter().enumerate() {
280            for dep in &var.dependencies {
281                self.dep_index.entry(dep.clone()).or_default().push(i);
282            }
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use serde_json::json;
291
292    // ── 1. Single var register + recompute ──────────────────────────────
293
294    #[test]
295    fn single_var_register_and_recompute() {
296        let mut registry = ComputedRegistry::new();
297        registry.register(ComputedVar {
298            key: "doubled".into(),
299            dependencies: vec!["app:count".into()],
300            compute: Arc::new(|state| {
301                let count: i64 = state.get("app:count")?;
302                Some(json!(count * 2))
303            }),
304        });
305
306        let state = State::new();
307        state.set("app:count", 5);
308
309        let changed = registry.recompute(&state);
310        assert_eq!(changed, vec!["doubled"]);
311        assert_eq!(state.get::<i64>("derived:doubled"), Some(10));
312    }
313
314    // ── 2. Dependency ordering (B depends on A) ────────────────────────
315
316    #[test]
317    fn dependency_ordering() {
318        let mut registry = ComputedRegistry::new();
319
320        // Register B first (depends on derived:base).
321        registry.register(ComputedVar {
322            key: "derived_from_base".into(),
323            dependencies: vec!["base".into()],
324            compute: Arc::new(|state| {
325                let base: i64 = state.get("derived:base")?;
326                Some(json!(base + 100))
327            }),
328        });
329
330        // Register A (base, no internal deps).
331        registry.register(ComputedVar {
332            key: "base".into(),
333            dependencies: vec!["app:input".into()],
334            compute: Arc::new(|state| {
335                let input: i64 = state.get("app:input")?;
336                Some(json!(input * 2))
337            }),
338        });
339
340        let state = State::new();
341        state.set("app:input", 3);
342
343        let changed = registry.recompute(&state);
344        // base should be computed first (6), then derived_from_base (106).
345        assert_eq!(state.get::<i64>("derived:base"), Some(6));
346        assert_eq!(state.get::<i64>("derived:derived_from_base"), Some(106));
347        assert!(changed.contains(&"base".to_string()));
348        assert!(changed.contains(&"derived_from_base".to_string()));
349    }
350
351    // ── 3. Cycle detection (panic) ─────────────────────────────────────
352
353    #[test]
354    #[should_panic(expected = "Cycle detected")]
355    fn cycle_detection_panics() {
356        let mut registry = ComputedRegistry::new();
357        registry.register(ComputedVar {
358            key: "a".into(),
359            dependencies: vec!["b".into()],
360            compute: Arc::new(|_| Some(json!(1))),
361        });
362        registry.register(ComputedVar {
363            key: "b".into(),
364            dependencies: vec!["a".into()],
365            compute: Arc::new(|_| Some(json!(2))),
366        });
367    }
368
369    // ── 4. Recompute returns only keys that changed ────────────────────
370
371    #[test]
372    fn recompute_returns_only_changed_keys() {
373        let mut registry = ComputedRegistry::new();
374        registry.register(ComputedVar {
375            key: "level".into(),
376            dependencies: vec!["app:score".into()],
377            compute: Arc::new(|state| {
378                let score: f64 = state.get("app:score")?;
379                if score > 0.5 {
380                    Some(json!("high"))
381                } else {
382                    Some(json!("low"))
383                }
384            }),
385        });
386
387        let state = State::new();
388        state.set("app:score", 0.8);
389
390        // First recompute: level is new, so it changed.
391        let changed = registry.recompute(&state);
392        assert_eq!(changed, vec!["level"]);
393
394        // Second recompute with same input: no change.
395        let changed = registry.recompute(&state);
396        assert!(changed.is_empty());
397
398        // Change input so derived value changes.
399        state.set("app:score", 0.2);
400        let changed = registry.recompute(&state);
401        assert_eq!(changed, vec!["level"]);
402        assert_eq!(
403            state.get::<String>("derived:level"),
404            Some("low".to_string())
405        );
406    }
407
408    // ── 5. recompute_affected only recomputes affected vars ────────────
409
410    #[test]
411    fn recompute_affected_only_recomputes_affected() {
412        let call_count_a = Arc::new(std::sync::atomic::AtomicUsize::new(0));
413        let call_count_b = Arc::new(std::sync::atomic::AtomicUsize::new(0));
414
415        let cc_a = call_count_a.clone();
416        let cc_b = call_count_b.clone();
417
418        let mut registry = ComputedRegistry::new();
419        registry.register(ComputedVar {
420            key: "from_x".into(),
421            dependencies: vec!["app:x".into()],
422            compute: Arc::new(move |state| {
423                cc_a.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
424                let x: i64 = state.get("app:x")?;
425                Some(json!(x + 1))
426            }),
427        });
428        registry.register(ComputedVar {
429            key: "from_y".into(),
430            dependencies: vec!["app:y".into()],
431            compute: Arc::new(move |state| {
432                cc_b.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
433                let y: i64 = state.get("app:y")?;
434                Some(json!(y + 1))
435            }),
436        });
437
438        let state = State::new();
439        state.set("app:x", 10);
440        state.set("app:y", 20);
441
442        // Only app:x changed — should only recompute from_x.
443        let changed = registry.recompute_affected(&state, &["app:x".into()]);
444        assert_eq!(changed, vec!["from_x"]);
445        assert_eq!(call_count_a.load(std::sync::atomic::Ordering::SeqCst), 1);
446        assert_eq!(call_count_b.load(std::sync::atomic::Ordering::SeqCst), 0);
447
448        assert_eq!(state.get::<i64>("derived:from_x"), Some(11));
449        // from_y was not computed, so derived:from_y should not exist.
450        assert_eq!(state.get_raw("derived:from_y"), None);
451    }
452
453    // ── 6. validate catches cycles ─────────────────────────────────────
454
455    #[test]
456    fn validate_catches_cycles() {
457        let mut registry = ComputedRegistry::new();
458        // Manually push vars without going through register (which would panic).
459        registry.vars.push(ComputedVar {
460            key: "x".into(),
461            dependencies: vec!["y".into()],
462            compute: Arc::new(|_| Some(json!(1))),
463        });
464        registry.vars.push(ComputedVar {
465            key: "y".into(),
466            dependencies: vec!["x".into()],
467            compute: Arc::new(|_| Some(json!(2))),
468        });
469
470        let result = registry.validate();
471        assert!(result.is_err());
472        let msg = result.unwrap_err();
473        assert!(msg.contains("Cycle detected"));
474    }
475
476    // ── 7. validate succeeds on valid graph ────────────────────────────
477
478    #[test]
479    fn validate_succeeds_on_valid_graph() {
480        let mut registry = ComputedRegistry::new();
481        registry.register(ComputedVar {
482            key: "a".into(),
483            dependencies: vec!["app:input".into()],
484            compute: Arc::new(|_| Some(json!(1))),
485        });
486        registry.register(ComputedVar {
487            key: "b".into(),
488            dependencies: vec!["a".into()],
489            compute: Arc::new(|_| Some(json!(2))),
490        });
491
492        assert!(registry.validate().is_ok());
493    }
494
495    // ── 8. Compute returning None skips write ──────────────────────────
496
497    #[test]
498    fn compute_returning_none_skips_write() {
499        let mut registry = ComputedRegistry::new();
500        registry.register(ComputedVar {
501            key: "maybe".into(),
502            dependencies: vec!["app:flag".into()],
503            compute: Arc::new(|state| {
504                let flag: bool = state.get("app:flag")?;
505                if flag {
506                    Some(json!("yes"))
507                } else {
508                    None
509                }
510            }),
511        });
512
513        let state = State::new();
514        // app:flag not set → get returns None → compute returns None.
515        let changed = registry.recompute(&state);
516        assert!(changed.is_empty());
517        assert_eq!(state.get_raw("derived:maybe"), None);
518
519        // Set flag to false → compute returns None.
520        state.set("app:flag", false);
521        let changed = registry.recompute(&state);
522        assert!(changed.is_empty());
523        assert_eq!(state.get_raw("derived:maybe"), None);
524
525        // Set flag to true → compute returns Some.
526        state.set("app:flag", true);
527        let changed = registry.recompute(&state);
528        assert_eq!(changed, vec!["maybe"]);
529        assert_eq!(
530            state.get::<String>("derived:maybe"),
531            Some("yes".to_string())
532        );
533    }
534
535    // ── 9. Diamond dependency ──────────────────────────────────────────
536
537    #[test]
538    fn diamond_dependency() {
539        // D is the root. A and B depend on D. C depends on A and B.
540        //
541        //     D
542        //    / \
543        //   A   B
544        //    \ /
545        //     C
546        let mut registry = ComputedRegistry::new();
547
548        registry.register(ComputedVar {
549            key: "d".into(),
550            dependencies: vec!["app:root".into()],
551            compute: Arc::new(|state| {
552                let root: i64 = state.get("app:root")?;
553                Some(json!(root))
554            }),
555        });
556
557        registry.register(ComputedVar {
558            key: "a".into(),
559            dependencies: vec!["d".into()],
560            compute: Arc::new(|state| {
561                let d: i64 = state.get("derived:d")?;
562                Some(json!(d + 10))
563            }),
564        });
565
566        registry.register(ComputedVar {
567            key: "b".into(),
568            dependencies: vec!["d".into()],
569            compute: Arc::new(|state| {
570                let d: i64 = state.get("derived:d")?;
571                Some(json!(d + 20))
572            }),
573        });
574
575        registry.register(ComputedVar {
576            key: "c".into(),
577            dependencies: vec!["a".into(), "b".into()],
578            compute: Arc::new(|state| {
579                let a: i64 = state.get("derived:a")?;
580                let b: i64 = state.get("derived:b")?;
581                Some(json!(a + b))
582            }),
583        });
584
585        let state = State::new();
586        state.set("app:root", 1);
587
588        let changed = registry.recompute(&state);
589        assert_eq!(state.get::<i64>("derived:d"), Some(1));
590        assert_eq!(state.get::<i64>("derived:a"), Some(11));
591        assert_eq!(state.get::<i64>("derived:b"), Some(21));
592        assert_eq!(state.get::<i64>("derived:c"), Some(32));
593        assert_eq!(changed.len(), 4);
594    }
595
596    // ── 10. Empty registry recompute returns empty vec ─────────────────
597
598    #[test]
599    fn empty_registry_recompute_returns_empty() {
600        let registry = ComputedRegistry::new();
601        let state = State::new();
602        let changed = registry.recompute(&state);
603        assert!(changed.is_empty());
604    }
605
606    // ── Additional: len / is_empty ─────────────────────────────────────
607
608    #[test]
609    fn len_and_is_empty() {
610        let mut registry = ComputedRegistry::new();
611        assert!(registry.is_empty());
612        assert_eq!(registry.len(), 0);
613
614        registry.register(ComputedVar {
615            key: "x".into(),
616            dependencies: vec![],
617            compute: Arc::new(|_| Some(json!(1))),
618        });
619        assert!(!registry.is_empty());
620        assert_eq!(registry.len(), 1);
621    }
622
623    // ── Additional: recompute_affected with diamond ────────────────────
624
625    #[test]
626    fn recompute_affected_diamond() {
627        let mut registry = ComputedRegistry::new();
628
629        registry.register(ComputedVar {
630            key: "root_derived".into(),
631            dependencies: vec!["app:root".into()],
632            compute: Arc::new(|state| {
633                let r: i64 = state.get("app:root")?;
634                Some(json!(r * 10))
635            }),
636        });
637
638        registry.register(ComputedVar {
639            key: "leaf".into(),
640            dependencies: vec!["root_derived".into()],
641            compute: Arc::new(|state| {
642                let rd: i64 = state.get("derived:root_derived")?;
643                Some(json!(rd + 5))
644            }),
645        });
646
647        let state = State::new();
648        state.set("app:root", 2);
649
650        // First full recompute to populate.
651        registry.recompute(&state);
652        assert_eq!(state.get::<i64>("derived:root_derived"), Some(20));
653        assert_eq!(state.get::<i64>("derived:leaf"), Some(25));
654
655        // Now change root, use recompute_affected.
656        state.set("app:root", 3);
657        let changed = registry.recompute_affected(&state, &["app:root".into()]);
658        // root_derived should be recomputed (depends on app:root).
659        assert!(changed.contains(&"root_derived".to_string()));
660        assert_eq!(state.get::<i64>("derived:root_derived"), Some(30));
661        // leaf depends on root_derived — it should be picked up via
662        // the dep_index entry for "root_derived".
663        assert!(changed.contains(&"leaf".to_string()));
664        assert_eq!(state.get::<i64>("derived:leaf"), Some(35));
665    }
666
667    // ── Additional: validate on empty registry ─────────────────────────
668
669    #[test]
670    fn validate_empty_registry() {
671        let registry = ComputedRegistry::new();
672        assert!(registry.validate().is_ok());
673    }
674
675    // ── Additional: self-cycle ──────────────────────────────────────────
676
677    #[test]
678    #[should_panic(expected = "Cycle detected")]
679    fn self_cycle_panics() {
680        let mut registry = ComputedRegistry::new();
681        registry.register(ComputedVar {
682            key: "self_ref".into(),
683            dependencies: vec!["self_ref".into()],
684            compute: Arc::new(|_| Some(json!(1))),
685        });
686    }
687}