gemini_adk_rs/live/
computed.rs

1//! Computed (derived) state variables with dependency-ordered evaluation.
2//!
3//! Computed variables are pure functions of other state keys. The [`ComputedRegistry`]
4//! maintains a topologically sorted list of [`ComputedVar`]s so that dependencies are
5//! always evaluated before the variables that depend on them.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9
10use serde_json::Value;
11
12use crate::state::State;
13
14use super::contract::ComputedContract;
15
16/// Closure computing a derived value from current state.
17pub type ComputeFn = Arc<dyn Fn(&State) -> Option<Value> + Send + Sync>;
18
19/// A computed state variable: a pure function of other state keys.
20///
21/// The `compute` closure receives the full [`State`] and returns an optional
22/// [`Value`]. When it returns `Some(value)`, the result is written to
23/// `derived:{key}` in state. When it returns `None`, the key is skipped
24/// (no write, no change detection).
25pub struct ComputedVar {
26    /// The state key this computed variable writes to (prefixed with `derived:`).
27    pub key: String,
28    /// State keys this variable depends on.
29    pub dependencies: Vec<String>,
30    /// Closure that computes the derived value from current state.
31    pub compute: ComputeFn,
32}
33
34/// Registry of computed variables with dependency-ordered evaluation.
35///
36/// Variables are kept in topological order: if var A depends on var B, then B
37/// appears before A in the internal list. This invariant is maintained at
38/// registration time using Kahn's algorithm.
39pub struct ComputedRegistry {
40    /// Topologically sorted computed variables.
41    vars: Vec<ComputedVar>,
42    /// Maps a state key to the indices (into `vars`) of computed variables
43    /// that list that key as a dependency.
44    dep_index: HashMap<String, Vec<usize>>,
45}
46
47impl Default for ComputedRegistry {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl ComputedRegistry {
54    /// Create an empty registry.
55    pub fn new() -> Self {
56        Self {
57            vars: Vec::new(),
58            dep_index: HashMap::new(),
59        }
60    }
61
62    /// Register a computed variable. Re-sorts the internal list and rebuilds
63    /// the dependency index. **Panics** if the new variable introduces a cycle.
64    pub fn register(&mut self, var: ComputedVar) {
65        if let Some(pos) = self.vars.iter().position(|v| v.key == var.key) {
66            self.vars[pos] = var; // replace existing
67        } else {
68            self.vars.push(var);
69        }
70        self.topo_sort_or_panic();
71        self.rebuild_dep_index();
72    }
73
74    /// Recompute all variables in dependency order. Returns the keys whose
75    /// derived values actually changed (old != new).
76    pub fn recompute(&self, state: &State) -> Vec<String> {
77        let mut changed = Vec::new();
78        for var in &self.vars {
79            if let Some(new_val) = (var.compute)(state) {
80                let derived_key = format!("derived:{}", var.key);
81                let old_val = state.get_raw(&derived_key);
82                let did_change = old_val.as_ref() != Some(&new_val);
83                let _ = state.set(&derived_key, new_val);
84                if did_change {
85                    changed.push(var.key.clone());
86                }
87            }
88        }
89        changed
90    }
91
92    /// Recompute only the variables affected by the given changed keys.
93    /// Uses the dependency index for O(1) lookup of affected variables, then
94    /// evaluates them in topological order. Transitively propagates: if a
95    /// computed var changes, its dependents are also scheduled for recomputation.
96    /// Returns keys that actually changed.
97    pub fn recompute_affected(&self, state: &State, changed_keys: &[String]) -> Vec<String> {
98        // Collect indices of affected vars transitively (deduplicated via bitmap).
99        let mut visited = vec![false; self.vars.len()];
100        let mut affected_set = Vec::new();
101
102        // Seed the work queue with the initial changed keys.
103        let mut work_keys: Vec<String> = changed_keys.to_vec();
104
105        while let Some(key) = work_keys.pop() {
106            // Look up vars that depend on this key directly.
107            if let Some(indices) = self.dep_index.get(&key) {
108                for &idx in indices {
109                    if !visited[idx] {
110                        visited[idx] = true;
111                        affected_set.push(idx);
112                        // This computed var's output (derived:<key>) might be
113                        // a dependency of other vars, so enqueue it.
114                        work_keys.push(self.vars[idx].key.clone());
115                    }
116                }
117            }
118        }
119
120        // Sort by topological order (indices are already in topo order).
121        affected_set.sort_unstable();
122
123        let mut changed = Vec::new();
124        for idx in affected_set {
125            let var = &self.vars[idx];
126            if let Some(new_val) = (var.compute)(state) {
127                let derived_key = format!("derived:{}", var.key);
128                let old_val = state.get_raw(&derived_key);
129                let did_change = old_val.as_ref() != Some(&new_val);
130                let _ = state.set(&derived_key, new_val);
131                if did_change {
132                    changed.push(var.key.clone());
133                }
134            }
135        }
136        changed
137    }
138
139    /// Validate the dependency graph. Returns `Ok(())` if there are no cycles,
140    /// or `Err(message)` describing the problem.
141    pub fn validate(&self) -> Result<(), String> {
142        // Build adjacency from the current vars and run Kahn's algorithm.
143        let n = self.vars.len();
144        if n == 0 {
145            return Ok(());
146        }
147
148        let key_to_idx: HashMap<&str, usize> = self
149            .vars
150            .iter()
151            .enumerate()
152            .map(|(i, v)| (v.key.as_str(), i))
153            .collect();
154
155        let mut in_degree = vec![0usize; n];
156        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
157
158        for (i, var) in self.vars.iter().enumerate() {
159            for dep in &var.dependencies {
160                if let Some(&dep_idx) = key_to_idx.get(dep.as_str()) {
161                    adj[dep_idx].push(i);
162                    in_degree[i] += 1;
163                }
164                // External dependencies (not in registry) are fine — ignore them.
165            }
166        }
167
168        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
169        let mut visited = 0usize;
170
171        while let Some(node) = queue.pop() {
172            visited += 1;
173            for &neighbor in &adj[node] {
174                in_degree[neighbor] -= 1;
175                if in_degree[neighbor] == 0 {
176                    queue.push(neighbor);
177                }
178            }
179        }
180
181        if visited == n {
182            Ok(())
183        } else {
184            // Find the vars involved in the cycle.
185            let cycle_vars: Vec<&str> = (0..n)
186                .filter(|&i| in_degree[i] > 0)
187                .map(|i| self.vars[i].key.as_str())
188                .collect();
189            Err(format!(
190                "Cycle detected among computed variables: {:?}",
191                cycle_vars
192            ))
193        }
194    }
195
196    /// Returns the number of registered computed variables.
197    pub fn len(&self) -> usize {
198        self.vars.len()
199    }
200
201    /// Returns true if no computed variables are registered.
202    pub fn is_empty(&self) -> bool {
203        self.vars.is_empty()
204    }
205
206    /// Return serializable contract metadata for all computed variables.
207    pub fn describe(&self) -> Vec<ComputedContract> {
208        self.vars
209            .iter()
210            .map(|var| ComputedContract {
211                key: var.key.clone(),
212                dependencies: var.dependencies.clone(),
213            })
214            .collect()
215    }
216
217    // ── Internal helpers ──────────────────────────────────────────────────
218
219    /// Topologically sort `self.vars` in-place using Kahn's algorithm.
220    /// Panics if a cycle is detected (including self-cycles).
221    fn topo_sort_or_panic(&mut self) {
222        let n = self.vars.len();
223
224        // Check for self-cycles (a var depending on itself).
225        for var in &self.vars {
226            if var.dependencies.contains(&var.key) {
227                panic!(
228                    "Cycle detected among computed variables: {:?}",
229                    vec![var.key.as_str()]
230                );
231            }
232        }
233
234        if n <= 1 {
235            return;
236        }
237
238        // Map computed-var keys to their current index.
239        let key_to_idx: HashMap<&str, usize> = self
240            .vars
241            .iter()
242            .enumerate()
243            .map(|(i, v)| (v.key.as_str(), i))
244            .collect();
245
246        // Build adjacency list and in-degree array.
247        // Edge dep_idx -> i means "dep must come before i".
248        let mut in_degree = vec![0usize; n];
249        let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
250
251        for (i, var) in self.vars.iter().enumerate() {
252            for dep in &var.dependencies {
253                if let Some(&dep_idx) = key_to_idx.get(dep.as_str()) {
254                    adj[dep_idx].push(i);
255                    in_degree[i] += 1;
256                }
257            }
258        }
259
260        // Kahn's algorithm.
261        let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
262        let mut order: Vec<usize> = Vec::with_capacity(n);
263
264        while let Some(node) = queue.pop() {
265            order.push(node);
266            for &neighbor in &adj[node] {
267                in_degree[neighbor] -= 1;
268                if in_degree[neighbor] == 0 {
269                    queue.push(neighbor);
270                }
271            }
272        }
273
274        if order.len() != n {
275            let cycle_vars: Vec<&str> = (0..n)
276                .filter(|&i| in_degree[i] > 0)
277                .map(|i| self.vars[i].key.as_str())
278                .collect();
279            panic!("Cycle detected among computed variables: {:?}", cycle_vars);
280        }
281
282        // Reorder vars according to topological sort.
283        // Use Option wrapping for safe index-based extraction.
284        let mut slots: Vec<Option<ComputedVar>> = self.vars.drain(..).map(Some).collect();
285        for &idx in &order {
286            if let Some(var) = slots[idx].take() {
287                self.vars.push(var);
288            }
289        }
290    }
291
292    /// Rebuild the `dep_index` mapping from dependency keys to var indices.
293    fn rebuild_dep_index(&mut self) {
294        self.dep_index.clear();
295        for (i, var) in self.vars.iter().enumerate() {
296            for dep in &var.dependencies {
297                self.dep_index.entry(dep.clone()).or_default().push(i);
298            }
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use serde_json::json;
307
308    // ── 1. Single var register + recompute ──────────────────────────────
309
310    #[test]
311    fn single_var_register_and_recompute() {
312        let mut registry = ComputedRegistry::new();
313        registry.register(ComputedVar {
314            key: "doubled".into(),
315            dependencies: vec!["app:count".into()],
316            compute: Arc::new(|state| {
317                let count: i64 = state.get("app:count")?;
318                Some(json!(count * 2))
319            }),
320        });
321
322        let state = State::new();
323        let _ = state.set("app:count", 5);
324
325        let changed = registry.recompute(&state);
326        assert_eq!(changed, vec!["doubled"]);
327        assert_eq!(state.get::<i64>("derived:doubled"), Some(10));
328    }
329
330    // ── 2. Dependency ordering (B depends on A) ────────────────────────
331
332    #[test]
333    fn dependency_ordering() {
334        let mut registry = ComputedRegistry::new();
335
336        // Register B first (depends on derived:base).
337        registry.register(ComputedVar {
338            key: "derived_from_base".into(),
339            dependencies: vec!["base".into()],
340            compute: Arc::new(|state| {
341                let base: i64 = state.get("derived:base")?;
342                Some(json!(base + 100))
343            }),
344        });
345
346        // Register A (base, no internal deps).
347        registry.register(ComputedVar {
348            key: "base".into(),
349            dependencies: vec!["app:input".into()],
350            compute: Arc::new(|state| {
351                let input: i64 = state.get("app:input")?;
352                Some(json!(input * 2))
353            }),
354        });
355
356        let state = State::new();
357        let _ = state.set("app:input", 3);
358
359        let changed = registry.recompute(&state);
360        // base should be computed first (6), then derived_from_base (106).
361        assert_eq!(state.get::<i64>("derived:base"), Some(6));
362        assert_eq!(state.get::<i64>("derived:derived_from_base"), Some(106));
363        assert!(changed.contains(&"base".to_string()));
364        assert!(changed.contains(&"derived_from_base".to_string()));
365    }
366
367    // ── 3. Cycle detection (panic) ─────────────────────────────────────
368
369    #[test]
370    #[should_panic(expected = "Cycle detected")]
371    fn cycle_detection_panics() {
372        let mut registry = ComputedRegistry::new();
373        registry.register(ComputedVar {
374            key: "a".into(),
375            dependencies: vec!["b".into()],
376            compute: Arc::new(|_| Some(json!(1))),
377        });
378        registry.register(ComputedVar {
379            key: "b".into(),
380            dependencies: vec!["a".into()],
381            compute: Arc::new(|_| Some(json!(2))),
382        });
383    }
384
385    // ── 4. Recompute returns only keys that changed ────────────────────
386
387    #[test]
388    fn recompute_returns_only_changed_keys() {
389        let mut registry = ComputedRegistry::new();
390        registry.register(ComputedVar {
391            key: "level".into(),
392            dependencies: vec!["app:score".into()],
393            compute: Arc::new(|state| {
394                let score: f64 = state.get("app:score")?;
395                if score > 0.5 {
396                    Some(json!("high"))
397                } else {
398                    Some(json!("low"))
399                }
400            }),
401        });
402
403        let state = State::new();
404        let _ = state.set("app:score", 0.8);
405
406        // First recompute: level is new, so it changed.
407        let changed = registry.recompute(&state);
408        assert_eq!(changed, vec!["level"]);
409
410        // Second recompute with same input: no change.
411        let changed = registry.recompute(&state);
412        assert!(changed.is_empty());
413
414        // Change input so derived value changes.
415        let _ = state.set("app:score", 0.2);
416        let changed = registry.recompute(&state);
417        assert_eq!(changed, vec!["level"]);
418        assert_eq!(
419            state.get::<String>("derived:level"),
420            Some("low".to_string())
421        );
422    }
423
424    // ── 5. recompute_affected only recomputes affected vars ────────────
425
426    #[test]
427    fn recompute_affected_only_recomputes_affected() {
428        let call_count_a = Arc::new(std::sync::atomic::AtomicUsize::new(0));
429        let call_count_b = Arc::new(std::sync::atomic::AtomicUsize::new(0));
430
431        let cc_a = call_count_a.clone();
432        let cc_b = call_count_b.clone();
433
434        let mut registry = ComputedRegistry::new();
435        registry.register(ComputedVar {
436            key: "from_x".into(),
437            dependencies: vec!["app:x".into()],
438            compute: Arc::new(move |state| {
439                cc_a.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
440                let x: i64 = state.get("app:x")?;
441                Some(json!(x + 1))
442            }),
443        });
444        registry.register(ComputedVar {
445            key: "from_y".into(),
446            dependencies: vec!["app:y".into()],
447            compute: Arc::new(move |state| {
448                cc_b.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
449                let y: i64 = state.get("app:y")?;
450                Some(json!(y + 1))
451            }),
452        });
453
454        let state = State::new();
455        let _ = state.set("app:x", 10);
456        let _ = state.set("app:y", 20);
457
458        // Only app:x changed — should only recompute from_x.
459        let changed = registry.recompute_affected(&state, &["app:x".into()]);
460        assert_eq!(changed, vec!["from_x"]);
461        assert_eq!(call_count_a.load(std::sync::atomic::Ordering::SeqCst), 1);
462        assert_eq!(call_count_b.load(std::sync::atomic::Ordering::SeqCst), 0);
463
464        assert_eq!(state.get::<i64>("derived:from_x"), Some(11));
465        // from_y was not computed, so derived:from_y should not exist.
466        assert_eq!(state.get_raw("derived:from_y"), None);
467    }
468
469    // ── 6. validate catches cycles ─────────────────────────────────────
470
471    #[test]
472    fn validate_catches_cycles() {
473        let mut registry = ComputedRegistry::new();
474        // Manually push vars without going through register (which would panic).
475        registry.vars.push(ComputedVar {
476            key: "x".into(),
477            dependencies: vec!["y".into()],
478            compute: Arc::new(|_| Some(json!(1))),
479        });
480        registry.vars.push(ComputedVar {
481            key: "y".into(),
482            dependencies: vec!["x".into()],
483            compute: Arc::new(|_| Some(json!(2))),
484        });
485
486        let result = registry.validate();
487        assert!(result.is_err());
488        let msg = result.unwrap_err();
489        assert!(msg.contains("Cycle detected"));
490    }
491
492    // ── 7. validate succeeds on valid graph ────────────────────────────
493
494    #[test]
495    fn validate_succeeds_on_valid_graph() {
496        let mut registry = ComputedRegistry::new();
497        registry.register(ComputedVar {
498            key: "a".into(),
499            dependencies: vec!["app:input".into()],
500            compute: Arc::new(|_| Some(json!(1))),
501        });
502        registry.register(ComputedVar {
503            key: "b".into(),
504            dependencies: vec!["a".into()],
505            compute: Arc::new(|_| Some(json!(2))),
506        });
507
508        assert!(registry.validate().is_ok());
509    }
510
511    // ── 8. Compute returning None skips write ──────────────────────────
512
513    #[test]
514    fn compute_returning_none_skips_write() {
515        let mut registry = ComputedRegistry::new();
516        registry.register(ComputedVar {
517            key: "maybe".into(),
518            dependencies: vec!["app:flag".into()],
519            compute: Arc::new(|state| {
520                let flag: bool = state.get("app:flag")?;
521                if flag {
522                    Some(json!("yes"))
523                } else {
524                    None
525                }
526            }),
527        });
528
529        let state = State::new();
530        // app:flag not set → get returns None → compute returns None.
531        let changed = registry.recompute(&state);
532        assert!(changed.is_empty());
533        assert_eq!(state.get_raw("derived:maybe"), None);
534
535        // Set flag to false → compute returns None.
536        let _ = state.set("app:flag", false);
537        let changed = registry.recompute(&state);
538        assert!(changed.is_empty());
539        assert_eq!(state.get_raw("derived:maybe"), None);
540
541        // Set flag to true → compute returns Some.
542        let _ = state.set("app:flag", true);
543        let changed = registry.recompute(&state);
544        assert_eq!(changed, vec!["maybe"]);
545        assert_eq!(
546            state.get::<String>("derived:maybe"),
547            Some("yes".to_string())
548        );
549    }
550
551    // ── 9. Diamond dependency ──────────────────────────────────────────
552
553    #[test]
554    fn diamond_dependency() {
555        // D is the root. A and B depend on D. C depends on A and B.
556        //
557        //     D
558        //    / \
559        //   A   B
560        //    \ /
561        //     C
562        let mut registry = ComputedRegistry::new();
563
564        registry.register(ComputedVar {
565            key: "d".into(),
566            dependencies: vec!["app:root".into()],
567            compute: Arc::new(|state| {
568                let root: i64 = state.get("app:root")?;
569                Some(json!(root))
570            }),
571        });
572
573        registry.register(ComputedVar {
574            key: "a".into(),
575            dependencies: vec!["d".into()],
576            compute: Arc::new(|state| {
577                let d: i64 = state.get("derived:d")?;
578                Some(json!(d + 10))
579            }),
580        });
581
582        registry.register(ComputedVar {
583            key: "b".into(),
584            dependencies: vec!["d".into()],
585            compute: Arc::new(|state| {
586                let d: i64 = state.get("derived:d")?;
587                Some(json!(d + 20))
588            }),
589        });
590
591        registry.register(ComputedVar {
592            key: "c".into(),
593            dependencies: vec!["a".into(), "b".into()],
594            compute: Arc::new(|state| {
595                let a: i64 = state.get("derived:a")?;
596                let b: i64 = state.get("derived:b")?;
597                Some(json!(a + b))
598            }),
599        });
600
601        let state = State::new();
602        let _ = state.set("app:root", 1);
603
604        let changed = registry.recompute(&state);
605        assert_eq!(state.get::<i64>("derived:d"), Some(1));
606        assert_eq!(state.get::<i64>("derived:a"), Some(11));
607        assert_eq!(state.get::<i64>("derived:b"), Some(21));
608        assert_eq!(state.get::<i64>("derived:c"), Some(32));
609        assert_eq!(changed.len(), 4);
610    }
611
612    // ── 10. Empty registry recompute returns empty vec ─────────────────
613
614    #[test]
615    fn empty_registry_recompute_returns_empty() {
616        let registry = ComputedRegistry::new();
617        let state = State::new();
618        let changed = registry.recompute(&state);
619        assert!(changed.is_empty());
620    }
621
622    // ── Additional: len / is_empty ─────────────────────────────────────
623
624    #[test]
625    fn len_and_is_empty() {
626        let mut registry = ComputedRegistry::new();
627        assert!(registry.is_empty());
628        assert_eq!(registry.len(), 0);
629
630        registry.register(ComputedVar {
631            key: "x".into(),
632            dependencies: vec![],
633            compute: Arc::new(|_| Some(json!(1))),
634        });
635        assert!(!registry.is_empty());
636        assert_eq!(registry.len(), 1);
637    }
638
639    // ── Additional: recompute_affected with diamond ────────────────────
640
641    #[test]
642    fn recompute_affected_diamond() {
643        let mut registry = ComputedRegistry::new();
644
645        registry.register(ComputedVar {
646            key: "root_derived".into(),
647            dependencies: vec!["app:root".into()],
648            compute: Arc::new(|state| {
649                let r: i64 = state.get("app:root")?;
650                Some(json!(r * 10))
651            }),
652        });
653
654        registry.register(ComputedVar {
655            key: "leaf".into(),
656            dependencies: vec!["root_derived".into()],
657            compute: Arc::new(|state| {
658                let rd: i64 = state.get("derived:root_derived")?;
659                Some(json!(rd + 5))
660            }),
661        });
662
663        let state = State::new();
664        let _ = state.set("app:root", 2);
665
666        // First full recompute to populate.
667        registry.recompute(&state);
668        assert_eq!(state.get::<i64>("derived:root_derived"), Some(20));
669        assert_eq!(state.get::<i64>("derived:leaf"), Some(25));
670
671        // Now change root, use recompute_affected.
672        let _ = state.set("app:root", 3);
673        let changed = registry.recompute_affected(&state, &["app:root".into()]);
674        // root_derived should be recomputed (depends on app:root).
675        assert!(changed.contains(&"root_derived".to_string()));
676        assert_eq!(state.get::<i64>("derived:root_derived"), Some(30));
677        // leaf depends on root_derived — it should be picked up via
678        // the dep_index entry for "root_derived".
679        assert!(changed.contains(&"leaf".to_string()));
680        assert_eq!(state.get::<i64>("derived:leaf"), Some(35));
681    }
682
683    // ── Additional: validate on empty registry ─────────────────────────
684
685    #[test]
686    fn validate_empty_registry() {
687        let registry = ComputedRegistry::new();
688        assert!(registry.validate().is_ok());
689    }
690
691    // ── Additional: self-cycle ──────────────────────────────────────────
692
693    #[test]
694    #[should_panic(expected = "Cycle detected")]
695    fn self_cycle_panics() {
696        let mut registry = ComputedRegistry::new();
697        registry.register(ComputedVar {
698            key: "self_ref".into(),
699            dependencies: vec!["self_ref".into()],
700            compute: Arc::new(|_| Some(json!(1))),
701        });
702    }
703}