gemini_adk_rs/
state.rs

1//! Typed key-value state container for agents.
2//!
3//! Supports optional delta tracking for transactional state management
4//! and prefix-scoped accessors for namespace isolation.
5
6use std::collections::{HashMap, VecDeque};
7use std::marker::PhantomData;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::SystemTime;
11
12use dashmap::DashMap;
13use serde_json::Value;
14
15const DEFAULT_MUTATION_JOURNAL_CAPACITY: usize = 1024;
16
17/// A compile-time typed state key that eliminates typo bugs and type mismatches.
18///
19/// Create as a const and use with `State::get_key()` / `State::set_key()`:
20///
21/// ```rust,ignore
22/// const TURN_COUNT: StateKey<u32> = StateKey::new("session:turn_count");
23/// const SENTIMENT: StateKey<String> = StateKey::new("derived:sentiment");
24///
25/// state.set_key(&TURN_COUNT, 5);
26/// let count: Option<u32> = state.get_key(&TURN_COUNT);
27/// ```
28pub struct StateKey<T> {
29    key: &'static str,
30    _phantom: PhantomData<fn() -> T>,
31}
32
33impl<T> StateKey<T> {
34    /// Create a new typed state key.
35    pub const fn new(key: &'static str) -> Self {
36        Self {
37            key,
38            _phantom: PhantomData,
39        }
40    }
41
42    /// The string key.
43    pub const fn key(&self) -> &'static str {
44        self.key
45    }
46}
47
48/// Where a state mutation came from.
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum StateMutationOrigin {
51    /// Regular `State::set` or prefixed state write.
52    Set,
53    /// Direct committed-store write that bypasses delta tracking.
54    SetCommitted,
55    /// Removal of a single key.
56    Remove,
57    /// Removal caused by clearing a prefix.
58    ClearPrefix,
59    /// Delta changes committed into the base state.
60    Commit,
61}
62
63/// A single state mutation recorded in the bounded mutation journal.
64#[derive(Debug, Clone, PartialEq)]
65pub struct StateMutation {
66    /// Monotonic sequence number assigned when the mutation was recorded.
67    pub sequence: u64,
68    /// State key that changed.
69    pub key: String,
70    /// Value before the mutation, or `None` when the key did not exist.
71    pub old: Option<Value>,
72    /// Value after the mutation, or `None` when the key was removed.
73    pub new: Option<Value>,
74    /// Operation that recorded the mutation.
75    pub origin: StateMutationOrigin,
76    /// Wall-clock time at which the mutation was recorded.
77    pub timestamp: SystemTime,
78    /// Whether the mutation was written to a delta-tracked view.
79    pub delta: bool,
80}
81
82/// A concurrent, type-safe state container that agents read from and write to.
83///
84/// By default, `set()` writes directly to the inner store. When delta tracking
85/// is enabled via `with_delta_tracking()`, writes go to a separate delta map
86/// that can be committed or rolled back.
87#[derive(Debug, Clone)]
88pub struct State {
89    inner: Arc<DashMap<String, Value>>,
90    delta: Arc<DashMap<String, Value>>,
91    mutations: Arc<std::sync::Mutex<VecDeque<StateMutation>>>,
92    next_mutation_sequence: Arc<AtomicU64>,
93    mutation_capacity: usize,
94    track_delta: bool,
95}
96
97impl Default for State {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl State {
104    /// Create a new empty state container.
105    pub fn new() -> Self {
106        Self {
107            inner: Arc::new(DashMap::new()),
108            delta: Arc::new(DashMap::new()),
109            mutations: Arc::new(std::sync::Mutex::new(VecDeque::new())),
110            next_mutation_sequence: Arc::new(AtomicU64::new(1)),
111            mutation_capacity: DEFAULT_MUTATION_JOURNAL_CAPACITY,
112            track_delta: false,
113        }
114    }
115
116    /// Create a new State with delta tracking enabled.
117    /// Writes go to the delta map; reads check delta first, then inner.
118    pub fn with_delta_tracking(&self) -> State {
119        State {
120            inner: self.inner.clone(),
121            delta: Arc::new(DashMap::new()),
122            mutations: self.mutations.clone(),
123            next_mutation_sequence: self.next_mutation_sequence.clone(),
124            mutation_capacity: self.mutation_capacity,
125            track_delta: true,
126        }
127    }
128
129    /// Get a value by key, attempting to deserialize to the requested type.
130    /// When delta tracking is enabled, checks delta first, then inner.
131    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
132        self.get_raw(key)
133            .and_then(|v| serde_json::from_value(v).ok())
134    }
135
136    /// Borrow a value by key without cloning, applying `f` to the reference.
137    ///
138    /// This is the zero-copy alternative to `get_raw()`. The closure receives
139    /// a `&Value` directly from the DashMap ref-guard, avoiding the
140    /// `Value::clone()` + `serde_json::from_value()` overhead of `get()`.
141    ///
142    /// Lookup order: delta (if tracking) → inner → derived fallback.
143    pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
144    where
145        F: FnOnce(&Value) -> R,
146    {
147        if self.track_delta {
148            if let Some(ref_multi) = self.delta.get(key) {
149                return Some(f(ref_multi.value()));
150            }
151        }
152        if let Some(ref_multi) = self.inner.get(key) {
153            return Some(f(ref_multi.value()));
154        }
155        if !key.contains(':') {
156            let mut derived_key = String::with_capacity(8 + key.len());
157            use std::fmt::Write;
158            let _ = write!(derived_key, "derived:{}", key);
159            if self.track_delta {
160                if let Some(ref_multi) = self.delta.get(&derived_key) {
161                    return Some(f(ref_multi.value()));
162                }
163            }
164            if let Some(ref_multi) = self.inner.get(&derived_key) {
165                return Some(f(ref_multi.value()));
166            }
167        }
168        None
169    }
170
171    /// Get a raw JSON value by key.
172    /// When delta tracking is enabled, checks delta first, then inner.
173    /// If the key is not found and doesn't contain a prefix, also checks `derived:{key}`
174    /// as a transparent fallback for computed variables.
175    pub fn get_raw(&self, key: &str) -> Option<Value> {
176        if self.track_delta {
177            if let Some(v) = self.delta.get(key) {
178                return Some(v.value().clone());
179            }
180        }
181        if let Some(v) = self.inner.get(key) {
182            return Some(v.value().clone());
183        }
184        // Transparent derived fallback: if key has no prefix, check derived:{key}
185        if !key.contains(':') {
186            use std::fmt::Write;
187            let mut derived_key = String::with_capacity(8 + key.len());
188            let _ = write!(derived_key, "derived:{}", key);
189            if self.track_delta {
190                if let Some(v) = self.delta.get(&derived_key) {
191                    return Some(v.value().clone());
192                }
193            }
194            return self.inner.get(&derived_key).map(|v| v.value().clone());
195        }
196        None
197    }
198
199    /// Get a typed value using a `StateKey<T>`.
200    pub fn get_key<T: serde::de::DeserializeOwned>(&self, key: &StateKey<T>) -> Option<T> {
201        self.get(key.key())
202    }
203
204    /// Set a typed value using a `StateKey<T>`.
205    pub fn set_key<T: serde::Serialize>(&self, key: &StateKey<T>, value: T) {
206        self.set(key.key(), value);
207    }
208
209    /// Zero-copy borrow using a `StateKey<T>`.
210    pub fn with_key<T, F, R>(&self, key: &StateKey<T>, f: F) -> Option<R>
211    where
212        F: FnOnce(&Value) -> R,
213    {
214        self.with(key.key(), f)
215    }
216
217    /// Set a value by key.
218    /// When delta tracking is enabled, writes to delta instead of inner.
219    pub fn set(&self, key: impl Into<String>, value: impl serde::Serialize) {
220        let key = key.into();
221        let v = serde_json::to_value(value).expect("value must be serializable");
222        let old = self.get_raw(&key);
223        if self.track_delta {
224            self.delta.insert(key.clone(), v.clone());
225            self.record_mutation(key, old, Some(v), StateMutationOrigin::Set);
226        } else {
227            self.inner.insert(key.clone(), v.clone());
228            self.record_mutation(key, old, Some(v), StateMutationOrigin::Set);
229        }
230    }
231
232    /// Set a value directly in the committed store, bypassing delta tracking.
233    pub fn set_committed(&self, key: impl Into<String>, value: impl serde::Serialize) {
234        let key = key.into();
235        let v = serde_json::to_value(value).expect("value must be serializable");
236        let old = self.inner.insert(key.clone(), v.clone());
237        self.record_mutation(key, old, Some(v), StateMutationOrigin::SetCommitted);
238    }
239
240    /// Atomically read-modify-write a value.
241    ///
242    /// If the key doesn't exist, `default` is used as the initial value.
243    /// The function `f` receives the current value and returns the new value.
244    /// Returns the new value after modification.
245    pub fn modify<T, F>(&self, key: &str, default: T, f: F) -> T
246    where
247        T: serde::Serialize + serde::de::DeserializeOwned,
248        F: FnOnce(T) -> T,
249    {
250        // Read current value from whichever store has it
251        let current: T = self.get(key).unwrap_or(default);
252        let new_val = f(current);
253        self.set(key, &new_val);
254        new_val
255    }
256
257    /// Check if a key exists (in delta or inner).
258    pub fn contains(&self, key: &str) -> bool {
259        if self.track_delta && self.delta.contains_key(key) {
260            return true;
261        }
262        self.inner.contains_key(key)
263    }
264
265    /// Remove a key.
266    pub fn remove(&self, key: &str) -> Option<Value> {
267        if self.track_delta {
268            // Remove from delta if present, but also check inner
269            let from_delta = self.delta.remove(key).map(|(_, v)| v);
270            let from_inner = self.inner.remove(key).map(|(_, v)| v);
271            let removed = from_delta.or(from_inner);
272            if let Some(ref old) = removed {
273                self.record_mutation(
274                    key.to_string(),
275                    Some(old.clone()),
276                    None,
277                    StateMutationOrigin::Remove,
278                );
279            }
280            removed
281        } else {
282            let removed = self.inner.remove(key).map(|(_, v)| v);
283            if let Some(ref old) = removed {
284                self.record_mutation(
285                    key.to_string(),
286                    Some(old.clone()),
287                    None,
288                    StateMutationOrigin::Remove,
289                );
290            }
291            removed
292        }
293    }
294
295    /// Get all keys (from both inner and delta when tracking).
296    pub fn keys(&self) -> Vec<String> {
297        if !self.track_delta || self.delta.is_empty() {
298            return self.inner.iter().map(|r| r.key().clone()).collect();
299        }
300        let mut seen =
301            std::collections::HashSet::with_capacity(self.inner.len() + self.delta.len());
302        let mut keys = Vec::with_capacity(self.inner.len() + self.delta.len());
303        for entry in self.inner.iter() {
304            let key = entry.key().clone();
305            seen.insert(key.clone());
306            keys.push(key);
307        }
308        for entry in self.delta.iter() {
309            let key = entry.key().clone();
310            if seen.insert(key.clone()) {
311                keys.push(key);
312            }
313        }
314        keys
315    }
316
317    /// Create a new State containing only the specified keys.
318    pub fn pick(&self, keys: &[&str]) -> State {
319        let new = State::new();
320        for key in keys {
321            if let Some(v) = self.get_raw(key) {
322                new.set(*key, v);
323            }
324        }
325        new
326    }
327
328    /// Merge another state into this one (other's values overwrite on conflict).
329    pub fn merge(&self, other: &State) {
330        for entry in other.inner.iter() {
331            self.set(entry.key().clone(), entry.value().clone());
332        }
333    }
334
335    /// Rename a key.
336    pub fn rename(&self, from: &str, to: &str) {
337        if let Some(v) = self.remove(from) {
338            if self.track_delta {
339                self.set(to.to_string(), v);
340            } else {
341                self.set(to.to_string(), v);
342            }
343        }
344    }
345
346    // ── Delta methods ──────────────────────────────────────────────────────
347
348    /// Whether delta tracking is enabled.
349    pub fn is_tracking_delta(&self) -> bool {
350        self.track_delta
351    }
352
353    /// Whether there are uncommitted delta changes.
354    pub fn has_delta(&self) -> bool {
355        self.track_delta && !self.delta.is_empty()
356    }
357
358    /// Get a snapshot of the current delta.
359    pub fn delta(&self) -> HashMap<String, Value> {
360        self.delta
361            .iter()
362            .map(|entry| (entry.key().clone(), entry.value().clone()))
363            .collect()
364    }
365
366    /// Commit delta changes into the inner store, then clear the delta.
367    pub fn commit(&self) {
368        for entry in self.delta.iter() {
369            let key = entry.key().clone();
370            let value = entry.value().clone();
371            let old = self.inner.insert(key.clone(), value.clone());
372            self.record_mutation_with_delta(
373                key,
374                old,
375                Some(value),
376                StateMutationOrigin::Commit,
377                false,
378            );
379        }
380        self.delta.clear();
381    }
382
383    /// Discard all uncommitted delta changes.
384    pub fn rollback(&self) {
385        self.delta.clear();
386    }
387
388    // ── Prefix accessors ───────────────────────────────────────────────────
389
390    /// Access state with the `app:` prefix scope.
391    pub fn app(&self) -> PrefixedState<'_> {
392        PrefixedState {
393            state: self,
394            prefix: "app:",
395        }
396    }
397
398    /// Access state with the `user:` prefix scope.
399    pub fn user(&self) -> PrefixedState<'_> {
400        PrefixedState {
401            state: self,
402            prefix: "user:",
403        }
404    }
405
406    /// Access state with the `temp:` prefix scope.
407    pub fn temp(&self) -> PrefixedState<'_> {
408        PrefixedState {
409            state: self,
410            prefix: "temp:",
411        }
412    }
413
414    /// Access state with the `session:` prefix scope (auto-tracked signals).
415    pub fn session(&self) -> PrefixedState<'_> {
416        PrefixedState {
417            state: self,
418            prefix: "session:",
419        }
420    }
421
422    /// Access state with the `turn:` prefix scope (reset each turn).
423    pub fn turn(&self) -> PrefixedState<'_> {
424        PrefixedState {
425            state: self,
426            prefix: "turn:",
427        }
428    }
429
430    /// Access state with the `bg:` prefix scope (background tasks).
431    pub fn bg(&self) -> PrefixedState<'_> {
432        PrefixedState {
433            state: self,
434            prefix: "bg:",
435        }
436    }
437
438    /// Access read-only state with the `derived:` prefix scope (computed vars only).
439    pub fn derived(&self) -> ReadOnlyPrefixedState<'_> {
440        ReadOnlyPrefixedState {
441            state: self,
442            prefix: "derived:",
443        }
444    }
445
446    // ── Utility methods ───────────────────────────────────────────────────
447
448    /// Snapshot the values of specific keys. Returns HashMap of key -> current value.
449    /// Used by watchers to capture state before mutations.
450    pub fn snapshot_values(&self, keys: &[&str]) -> HashMap<String, Value> {
451        keys.iter()
452            .filter_map(|&k| self.get_raw(k).map(|v| (k.to_string(), v)))
453            .collect()
454    }
455
456    /// Diff current state against a previous snapshot.
457    /// Returns Vec of (key, old_value, new_value) for keys that changed.
458    pub fn diff_values(
459        &self,
460        prev: &HashMap<String, Value>,
461        keys: &[&str],
462    ) -> Vec<(String, Value, Value)> {
463        keys.iter()
464            .filter_map(|&k| {
465                let old = prev.get(k);
466                let new = self.get_raw(k);
467                match (old, new) {
468                    (Some(o), Some(n)) if o != &n => Some((k.to_string(), o.clone(), n)),
469                    (None, Some(n)) => Some((k.to_string(), Value::Null, n)),
470                    (Some(o), None) => Some((k.to_string(), o.clone(), Value::Null)),
471                    _ => None,
472                }
473            })
474            .collect()
475    }
476
477    /// Export all state as a HashMap (for persistence/serialization).
478    pub fn to_hashmap(&self) -> std::collections::HashMap<String, serde_json::Value> {
479        self.inner
480            .iter()
481            .map(|entry| (entry.key().clone(), entry.value().clone()))
482            .collect()
483    }
484
485    /// Restore state from a HashMap (for persistence/deserialization).
486    pub fn from_hashmap(&self, map: std::collections::HashMap<String, serde_json::Value>) {
487        for (key, value) in map {
488            self.set_committed(key, value);
489        }
490    }
491
492    /// Remove all keys with the given prefix.
493    pub fn clear_prefix(&self, prefix: &str) {
494        let keys_to_remove: Vec<String> = self
495            .inner
496            .iter()
497            .filter(|entry| entry.key().starts_with(prefix))
498            .map(|entry| entry.key().clone())
499            .collect();
500        for key in keys_to_remove {
501            if let Some((_, old)) = self.inner.remove(&key) {
502                self.record_mutation(key, Some(old), None, StateMutationOrigin::ClearPrefix);
503            }
504        }
505        if self.track_delta {
506            let delta_keys: Vec<String> = self
507                .delta
508                .iter()
509                .filter(|entry| entry.key().starts_with(prefix))
510                .map(|entry| entry.key().clone())
511                .collect();
512            for key in delta_keys {
513                if let Some((_, old)) = self.delta.remove(&key) {
514                    self.record_mutation(key, Some(old), None, StateMutationOrigin::ClearPrefix);
515                }
516            }
517        }
518    }
519
520    /// Return a snapshot of recent state mutations.
521    pub fn recent_mutations(&self) -> Vec<StateMutation> {
522        self.mutations
523            .lock()
524            .expect("state mutation journal poisoned")
525            .iter()
526            .cloned()
527            .collect()
528    }
529
530    /// Return the current monotonic cursor for the mutation journal.
531    pub fn mutation_cursor(&self) -> u64 {
532        self.next_mutation_sequence.load(Ordering::Relaxed) - 1
533    }
534
535    /// Return mutations appended after a previously captured cursor.
536    pub fn mutations_since(&self, cursor: u64) -> Vec<StateMutation> {
537        let mutations = self
538            .mutations
539            .lock()
540            .expect("state mutation journal poisoned");
541        mutations
542            .iter()
543            .filter(|mutation| mutation.sequence > cursor)
544            .cloned()
545            .collect()
546    }
547
548    /// Drain and return all recorded state mutations.
549    pub fn drain_mutations(&self) -> Vec<StateMutation> {
550        self.mutations
551            .lock()
552            .expect("state mutation journal poisoned")
553            .drain(..)
554            .collect()
555    }
556
557    fn record_mutation(
558        &self,
559        key: String,
560        old: Option<Value>,
561        new: Option<Value>,
562        origin: StateMutationOrigin,
563    ) {
564        self.record_mutation_with_delta(key, old, new, origin, self.track_delta);
565    }
566
567    fn record_mutation_with_delta(
568        &self,
569        key: String,
570        old: Option<Value>,
571        new: Option<Value>,
572        origin: StateMutationOrigin,
573        delta: bool,
574    ) {
575        let mut mutations = self
576            .mutations
577            .lock()
578            .expect("state mutation journal poisoned");
579        if mutations.len() >= self.mutation_capacity {
580            mutations.pop_front();
581        }
582        let sequence = self.next_mutation_sequence.fetch_add(1, Ordering::Relaxed);
583        mutations.push_back(StateMutation {
584            sequence,
585            key,
586            old,
587            new,
588            origin,
589            timestamp: SystemTime::now(),
590            delta,
591        });
592    }
593}
594
595/// A borrowed view of state that automatically prepends a prefix to all keys.
596pub struct PrefixedState<'a> {
597    state: &'a State,
598    prefix: &'static str,
599}
600
601impl<'a> PrefixedState<'a> {
602    fn prefixed_key(&self, key: &str) -> String {
603        format!("{}{}", self.prefix, key)
604    }
605
606    /// Get a value by key (with prefix applied).
607    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
608        self.state.get(&self.prefixed_key(key))
609    }
610
611    /// Get a raw JSON value by key (with prefix applied).
612    pub fn get_raw(&self, key: &str) -> Option<Value> {
613        self.state.get_raw(&self.prefixed_key(key))
614    }
615
616    /// Zero-copy borrow a value by key (with prefix applied).
617    pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
618    where
619        F: FnOnce(&Value) -> R,
620    {
621        self.state.with(&self.prefixed_key(key), f)
622    }
623
624    /// Set a value by key (with prefix applied).
625    pub fn set(&self, key: impl AsRef<str>, value: impl serde::Serialize) {
626        self.state.set(self.prefixed_key(key.as_ref()), value);
627    }
628
629    /// Check if a key exists (with prefix applied).
630    pub fn contains(&self, key: &str) -> bool {
631        self.state.contains(&self.prefixed_key(key))
632    }
633
634    /// Remove a key (with prefix applied).
635    pub fn remove(&self, key: &str) -> Option<Value> {
636        self.state.remove(&self.prefixed_key(key))
637    }
638
639    /// Get all keys within this prefix scope (prefix stripped from results).
640    pub fn keys(&self) -> Vec<String> {
641        self.state
642            .keys()
643            .into_iter()
644            .filter_map(|k| k.strip_prefix(self.prefix).map(|s| s.to_string()))
645            .collect()
646    }
647}
648
649/// A borrowed, read-only view of state that automatically prepends a prefix to all keys.
650///
651/// Unlike `PrefixedState`, this does not expose `set()` or `remove()` methods,
652/// making it suitable for computed/derived state that user code should not mutate.
653pub struct ReadOnlyPrefixedState<'a> {
654    state: &'a State,
655    prefix: &'static str,
656}
657
658impl<'a> ReadOnlyPrefixedState<'a> {
659    fn prefixed_key(&self, key: &str) -> String {
660        format!("{}{}", self.prefix, key)
661    }
662
663    /// Get a value by key (with prefix applied).
664    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
665        self.state.get(&self.prefixed_key(key))
666    }
667
668    /// Get a raw JSON value by key (with prefix applied).
669    pub fn get_raw(&self, key: &str) -> Option<Value> {
670        self.state.get_raw(&self.prefixed_key(key))
671    }
672
673    /// Zero-copy borrow a value by key (with prefix applied).
674    pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
675    where
676        F: FnOnce(&Value) -> R,
677    {
678        self.state.with(&self.prefixed_key(key), f)
679    }
680
681    /// Check if a key exists (with prefix applied).
682    pub fn contains(&self, key: &str) -> bool {
683        self.state.contains(&self.prefixed_key(key))
684    }
685
686    /// Get all keys within this prefix scope (prefix stripped from results).
687    pub fn keys(&self) -> Vec<String> {
688        self.state
689            .keys()
690            .into_iter()
691            .filter_map(|k| k.strip_prefix(self.prefix).map(|s| s.to_string()))
692            .collect()
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    #[test]
701    fn set_and_get_string() {
702        let state = State::new();
703        state.set("name", "Alice");
704        assert_eq!(state.get::<String>("name"), Some("Alice".to_string()));
705    }
706
707    #[test]
708    fn set_and_get_json() {
709        let state = State::new();
710        state.set("data", serde_json::json!({"temp": 22}));
711        let v: Value = state.get("data").unwrap();
712        assert_eq!(v["temp"], 22);
713    }
714
715    #[test]
716    fn pick_subset() {
717        let state = State::new();
718        state.set("a", 1);
719        state.set("b", 2);
720        state.set("c", 3);
721        let picked = state.pick(&["a", "c"]);
722        assert!(picked.contains("a"));
723        assert!(!picked.contains("b"));
724        assert!(picked.contains("c"));
725    }
726
727    #[test]
728    fn merge_states() {
729        let s1 = State::new();
730        s1.set("a", 1);
731        let s2 = State::new();
732        s2.set("b", 2);
733        s1.merge(&s2);
734        assert!(s1.contains("a"));
735        assert!(s1.contains("b"));
736    }
737
738    #[test]
739    fn rename_key() {
740        let state = State::new();
741        state.set("old", "value");
742        state.rename("old", "new");
743        assert!(!state.contains("old"));
744        assert_eq!(state.get::<String>("new"), Some("value".to_string()));
745    }
746
747    #[test]
748    fn remove_returns_value() {
749        let state = State::new();
750        state.set("key", 42);
751        let removed = state.remove("key");
752        assert!(removed.is_some());
753        assert!(!state.contains("key"));
754    }
755
756    #[test]
757    fn get_missing_returns_none() {
758        let state = State::new();
759        assert_eq!(state.get::<String>("nope"), None);
760    }
761
762    // ── Delta tracking tests ──────────────────────────────────────────────
763
764    #[test]
765    fn delta_tracking_writes_to_delta() {
766        let state = State::new();
767        state.set("committed", "yes");
768
769        let tracked = state.with_delta_tracking();
770        tracked.set("new_key", "new_value");
771
772        // New key visible through tracked state
773        assert_eq!(
774            tracked.get::<String>("new_key"),
775            Some("new_value".to_string())
776        );
777        // But NOT visible in original (non-delta) state's inner
778        assert!(!state.contains("new_key"));
779        // Committed key still visible through tracked state
780        assert_eq!(tracked.get::<String>("committed"), Some("yes".to_string()));
781    }
782
783    #[test]
784    fn delta_has_delta_reports_correctly() {
785        let state = State::new();
786        let tracked = state.with_delta_tracking();
787        assert!(!tracked.has_delta());
788
789        tracked.set("key", "val");
790        assert!(tracked.has_delta());
791    }
792
793    #[test]
794    fn delta_commit_merges_to_inner() {
795        let state = State::new();
796        let tracked = state.with_delta_tracking();
797        tracked.set("key", "val");
798        assert!(!state.contains("key"));
799
800        tracked.commit();
801        // Now visible in original state
802        assert_eq!(state.get::<String>("key"), Some("val".to_string()));
803        assert!(!tracked.has_delta());
804    }
805
806    #[test]
807    fn delta_rollback_discards_changes() {
808        let state = State::new();
809        let tracked = state.with_delta_tracking();
810        tracked.set("key", "val");
811        assert!(tracked.has_delta());
812
813        tracked.rollback();
814        assert!(!tracked.has_delta());
815        assert!(!state.contains("key"));
816        assert!(!tracked.contains("key"));
817    }
818
819    #[test]
820    fn delta_snapshot() {
821        let state = State::new();
822        let tracked = state.with_delta_tracking();
823        tracked.set("a", 1);
824        tracked.set("b", 2);
825
826        let snapshot = tracked.delta();
827        assert_eq!(snapshot.len(), 2);
828        assert!(snapshot.contains_key("a"));
829        assert!(snapshot.contains_key("b"));
830    }
831
832    #[test]
833    fn set_committed_bypasses_delta() {
834        let state = State::new();
835        let tracked = state.with_delta_tracking();
836        tracked.set_committed("direct", "value");
837
838        // Visible immediately in inner
839        assert_eq!(state.get::<String>("direct"), Some("value".to_string()));
840        // Not in delta
841        assert!(!tracked.has_delta());
842        // Still visible through tracked (reads inner too)
843        assert_eq!(tracked.get::<String>("direct"), Some("value".to_string()));
844    }
845
846    #[test]
847    fn mutation_journal_records_set_and_remove() {
848        let state = State::new();
849        state.set("key", "first");
850        state.set("key", "second");
851        state.remove("key");
852
853        let mutations = state.recent_mutations();
854        assert_eq!(mutations.len(), 3);
855        assert_eq!(mutations[0].key, "key");
856        assert_eq!(mutations[0].old, None);
857        assert_eq!(mutations[0].new, Some(serde_json::json!("first")));
858        assert_eq!(mutations[0].origin, StateMutationOrigin::Set);
859
860        assert_eq!(mutations[1].old, Some(serde_json::json!("first")));
861        assert_eq!(mutations[1].new, Some(serde_json::json!("second")));
862
863        assert_eq!(mutations[2].old, Some(serde_json::json!("second")));
864        assert_eq!(mutations[2].new, None);
865        assert_eq!(mutations[2].origin, StateMutationOrigin::Remove);
866    }
867
868    #[test]
869    fn mutation_journal_is_shared_with_delta_tracking() {
870        let state = State::new();
871        state.set("committed", "yes");
872
873        let tracked = state.with_delta_tracking();
874        tracked.set("committed", "maybe");
875        tracked.commit();
876
877        let mutations = state.recent_mutations();
878        assert_eq!(mutations.len(), 3);
879        assert_eq!(mutations[1].key, "committed");
880        assert_eq!(mutations[1].old, Some(serde_json::json!("yes")));
881        assert_eq!(mutations[1].new, Some(serde_json::json!("maybe")));
882        assert_eq!(mutations[1].origin, StateMutationOrigin::Set);
883        assert!(mutations[1].delta);
884
885        assert_eq!(mutations[2].origin, StateMutationOrigin::Commit);
886        assert!(!mutations[2].delta);
887    }
888
889    #[test]
890    fn drain_mutations_clears_journal() {
891        let state = State::new();
892        state.set("a", 1);
893        state.set("b", 2);
894
895        let drained = state.drain_mutations();
896        assert_eq!(drained.len(), 2);
897        assert!(state.recent_mutations().is_empty());
898    }
899
900    #[test]
901    fn mutation_cursor_reads_only_later_changes() {
902        let state = State::new();
903        state.set("before", 1);
904        let cursor = state.mutation_cursor();
905
906        state.set("after", 2);
907        state.remove("before");
908
909        let mutations = state.mutations_since(cursor);
910        assert_eq!(mutations.len(), 2);
911        assert_eq!(mutations[0].key, "after");
912        assert_eq!(mutations[1].key, "before");
913    }
914
915    #[test]
916    fn no_delta_tracking_preserves_existing_behavior() {
917        let state = State::new();
918        assert!(!state.is_tracking_delta());
919        state.set("key", "val");
920        assert_eq!(state.get::<String>("key"), Some("val".to_string()));
921        assert!(!state.has_delta());
922    }
923
924    // ── Prefix tests ──────────────────────────────────────────────────────
925
926    #[test]
927    fn prefix_app_set_and_get() {
928        let state = State::new();
929        state.app().set("flag", true);
930
931        // Accessible via prefix accessor
932        assert_eq!(state.app().get::<bool>("flag"), Some(true));
933        // Also accessible via raw key
934        assert_eq!(state.get::<bool>("app:flag"), Some(true));
935    }
936
937    #[test]
938    fn prefix_user_set_and_get() {
939        let state = State::new();
940        state.user().set("name", "Alice");
941        assert_eq!(
942            state.user().get::<String>("name"),
943            Some("Alice".to_string())
944        );
945        assert_eq!(state.get::<String>("user:name"), Some("Alice".to_string()));
946    }
947
948    #[test]
949    fn prefix_temp_set_and_get() {
950        let state = State::new();
951        state.temp().set("scratch", 42);
952        assert_eq!(state.temp().get::<i32>("scratch"), Some(42));
953    }
954
955    #[test]
956    fn prefix_contains_and_remove() {
957        let state = State::new();
958        state.app().set("x", 1);
959        assert!(state.app().contains("x"));
960        state.app().remove("x");
961        assert!(!state.app().contains("x"));
962    }
963
964    #[test]
965    fn prefix_keys() {
966        let state = State::new();
967        state.app().set("a", 1);
968        state.app().set("b", 2);
969        state.user().set("c", 3);
970
971        let app_keys = state.app().keys();
972        assert_eq!(app_keys.len(), 2);
973        assert!(app_keys.contains(&"a".to_string()));
974        assert!(app_keys.contains(&"b".to_string()));
975
976        let user_keys = state.user().keys();
977        assert_eq!(user_keys.len(), 1);
978        assert!(user_keys.contains(&"c".to_string()));
979    }
980
981    #[test]
982    fn prefix_with_delta_tracking() {
983        let state = State::new();
984        let tracked = state.with_delta_tracking();
985        tracked.app().set("flag", true);
986
987        // Visible in tracked state via prefix
988        assert_eq!(tracked.app().get::<bool>("flag"), Some(true));
989        // In delta, not committed
990        assert!(tracked.has_delta());
991        assert!(!state.contains("app:flag"));
992
993        tracked.commit();
994        assert_eq!(state.get::<bool>("app:flag"), Some(true));
995    }
996
997    // ── New prefix accessor tests ────────────────────────────────────────
998
999    #[test]
1000    fn prefix_session_set_and_get() {
1001        let state = State::new();
1002        state.session().set("turn_count", 5);
1003        assert_eq!(state.session().get::<i32>("turn_count"), Some(5));
1004        assert_eq!(state.get::<i32>("session:turn_count"), Some(5));
1005    }
1006
1007    #[test]
1008    fn prefix_turn_set_and_get() {
1009        let state = State::new();
1010        state.turn().set("transcript", "hello");
1011        assert_eq!(
1012            state.turn().get::<String>("transcript"),
1013            Some("hello".to_string())
1014        );
1015        assert_eq!(
1016            state.get::<String>("turn:transcript"),
1017            Some("hello".to_string())
1018        );
1019    }
1020
1021    #[test]
1022    fn prefix_bg_set_and_get() {
1023        let state = State::new();
1024        state.bg().set("task_id", "abc-123");
1025        assert_eq!(
1026            state.bg().get::<String>("task_id"),
1027            Some("abc-123".to_string())
1028        );
1029        assert_eq!(
1030            state.get::<String>("bg:task_id"),
1031            Some("abc-123".to_string())
1032        );
1033    }
1034
1035    #[test]
1036    fn prefix_session_contains_and_remove() {
1037        let state = State::new();
1038        state.session().set("x", 1);
1039        assert!(state.session().contains("x"));
1040        state.session().remove("x");
1041        assert!(!state.session().contains("x"));
1042    }
1043
1044    #[test]
1045    fn prefix_turn_keys() {
1046        let state = State::new();
1047        state.turn().set("a", 1);
1048        state.turn().set("b", 2);
1049        state.session().set("c", 3);
1050
1051        let turn_keys = state.turn().keys();
1052        assert_eq!(turn_keys.len(), 2);
1053        assert!(turn_keys.contains(&"a".to_string()));
1054        assert!(turn_keys.contains(&"b".to_string()));
1055    }
1056
1057    // ── ReadOnlyPrefixedState (derived) tests ────────────────────────────
1058
1059    #[test]
1060    fn derived_read_only_get() {
1061        let state = State::new();
1062        // Write via raw key (simulating ComputedRegistry)
1063        state.set("derived:sentiment", "positive");
1064        assert_eq!(
1065            state.derived().get::<String>("sentiment"),
1066            Some("positive".to_string())
1067        );
1068    }
1069
1070    #[test]
1071    fn derived_read_only_get_raw() {
1072        let state = State::new();
1073        state.set("derived:score", serde_json::json!(0.95));
1074        let raw = state.derived().get_raw("score");
1075        assert!(raw.is_some());
1076        assert_eq!(raw.unwrap(), serde_json::json!(0.95));
1077    }
1078
1079    #[test]
1080    fn derived_read_only_contains() {
1081        let state = State::new();
1082        state.set("derived:exists", true);
1083        assert!(state.derived().contains("exists"));
1084        assert!(!state.derived().contains("missing"));
1085    }
1086
1087    #[test]
1088    fn derived_read_only_keys() {
1089        let state = State::new();
1090        state.set("derived:a", 1);
1091        state.set("derived:b", 2);
1092        state.set("app:c", 3);
1093
1094        let derived_keys = state.derived().keys();
1095        assert_eq!(derived_keys.len(), 2);
1096        assert!(derived_keys.contains(&"a".to_string()));
1097        assert!(derived_keys.contains(&"b".to_string()));
1098    }
1099
1100    #[test]
1101    fn derived_missing_key_returns_none() {
1102        let state = State::new();
1103        assert_eq!(state.derived().get::<String>("nope"), None);
1104        assert_eq!(state.derived().get_raw("nope"), None);
1105    }
1106
1107    // ── snapshot_values tests ────────────────────────────────────────────
1108
1109    #[test]
1110    fn snapshot_values_captures_existing_keys() {
1111        let state = State::new();
1112        state.set("a", 1);
1113        state.set("b", "hello");
1114        state.set("c", true);
1115
1116        let snap = state.snapshot_values(&["a", "b", "missing"]);
1117        assert_eq!(snap.len(), 2);
1118        assert_eq!(snap.get("a"), Some(&serde_json::json!(1)));
1119        assert_eq!(snap.get("b"), Some(&serde_json::json!("hello")));
1120        assert!(!snap.contains_key("missing"));
1121    }
1122
1123    #[test]
1124    fn snapshot_values_empty_keys() {
1125        let state = State::new();
1126        state.set("a", 1);
1127        let snap = state.snapshot_values(&[]);
1128        assert!(snap.is_empty());
1129    }
1130
1131    // ── diff_values tests ────────────────────────────────────────────────
1132
1133    #[test]
1134    fn diff_values_detects_changed_value() {
1135        let state = State::new();
1136        state.set("x", 1);
1137        let snap = state.snapshot_values(&["x"]);
1138
1139        state.set("x", 2);
1140        let diffs = state.diff_values(&snap, &["x"]);
1141        assert_eq!(diffs.len(), 1);
1142        assert_eq!(diffs[0].0, "x");
1143        assert_eq!(diffs[0].1, serde_json::json!(1));
1144        assert_eq!(diffs[0].2, serde_json::json!(2));
1145    }
1146
1147    #[test]
1148    fn diff_values_detects_new_key() {
1149        let state = State::new();
1150        let snap = state.snapshot_values(&["y"]);
1151
1152        state.set("y", "new");
1153        let diffs = state.diff_values(&snap, &["y"]);
1154        assert_eq!(diffs.len(), 1);
1155        assert_eq!(diffs[0].0, "y");
1156        assert_eq!(diffs[0].1, Value::Null);
1157        assert_eq!(diffs[0].2, serde_json::json!("new"));
1158    }
1159
1160    #[test]
1161    fn diff_values_detects_removed_key() {
1162        let state = State::new();
1163        state.set("z", 42);
1164        let snap = state.snapshot_values(&["z"]);
1165
1166        state.remove("z");
1167        let diffs = state.diff_values(&snap, &["z"]);
1168        assert_eq!(diffs.len(), 1);
1169        assert_eq!(diffs[0].0, "z");
1170        assert_eq!(diffs[0].1, serde_json::json!(42));
1171        assert_eq!(diffs[0].2, Value::Null);
1172    }
1173
1174    #[test]
1175    fn diff_values_no_change() {
1176        let state = State::new();
1177        state.set("stable", 10);
1178        let snap = state.snapshot_values(&["stable"]);
1179
1180        // No mutation
1181        let diffs = state.diff_values(&snap, &["stable"]);
1182        assert!(diffs.is_empty());
1183    }
1184
1185    #[test]
1186    fn diff_values_multiple_keys_mixed_changes() {
1187        let state = State::new();
1188        state.set("a", 1);
1189        state.set("b", 2);
1190        let snap = state.snapshot_values(&["a", "b", "c"]);
1191
1192        state.set("a", 10); // changed
1193                            // b unchanged
1194        state.set("c", 3); // new
1195
1196        let diffs = state.diff_values(&snap, &["a", "b", "c"]);
1197        assert_eq!(diffs.len(), 2); // a changed, c new; b unchanged
1198        let diff_keys: Vec<&str> = diffs.iter().map(|(k, _, _)| k.as_str()).collect();
1199        assert!(diff_keys.contains(&"a"));
1200        assert!(diff_keys.contains(&"c"));
1201    }
1202
1203    // ── clear_prefix tests ───────────────────────────────────────────────
1204
1205    #[test]
1206    fn clear_prefix_removes_matching_keys() {
1207        let state = State::new();
1208        state.set("turn:a", 1);
1209        state.set("turn:b", 2);
1210        state.set("app:c", 3);
1211        state.set("session:d", 4);
1212
1213        state.clear_prefix("turn:");
1214        assert!(!state.contains("turn:a"));
1215        assert!(!state.contains("turn:b"));
1216        assert!(state.contains("app:c"));
1217        assert!(state.contains("session:d"));
1218    }
1219
1220    #[test]
1221    fn clear_prefix_no_matching_keys_is_noop() {
1222        let state = State::new();
1223        state.set("app:x", 1);
1224        state.clear_prefix("turn:");
1225        assert!(state.contains("app:x"));
1226    }
1227
1228    #[test]
1229    fn clear_prefix_also_clears_delta() {
1230        let state = State::new();
1231        state.set("turn:committed", 1);
1232        let tracked = state.with_delta_tracking();
1233        tracked.set("turn:delta_val", 2);
1234
1235        // Both committed and delta have turn: keys
1236        assert!(tracked.contains("turn:committed"));
1237        assert!(tracked.contains("turn:delta_val"));
1238
1239        tracked.clear_prefix("turn:");
1240        assert!(!tracked.contains("turn:committed"));
1241        assert!(!tracked.contains("turn:delta_val"));
1242    }
1243
1244    #[test]
1245    fn clear_prefix_via_turn_accessor() {
1246        let state = State::new();
1247        state.turn().set("x", 1);
1248        state.turn().set("y", 2);
1249        state.app().set("z", 3);
1250
1251        state.clear_prefix("turn:");
1252        assert!(state.turn().keys().is_empty());
1253        assert!(state.app().contains("z"));
1254    }
1255
1256    // ── modify() tests ──────────────────────────────────────────────────
1257
1258    #[test]
1259    fn modify_increment_existing() {
1260        let state = State::new();
1261        state.set("count", 5u32);
1262        let result = state.modify("count", 0u32, |n| n + 1);
1263        assert_eq!(result, 6);
1264        assert_eq!(state.get::<u32>("count"), Some(6));
1265    }
1266
1267    #[test]
1268    fn modify_uses_default_when_missing() {
1269        let state = State::new();
1270        let result = state.modify("new_count", 0u32, |n| n + 1);
1271        assert_eq!(result, 1);
1272        assert_eq!(state.get::<u32>("new_count"), Some(1));
1273    }
1274
1275    #[test]
1276    fn modify_with_delta_tracking() {
1277        let state = State::new();
1278        state.set("x", 10u32);
1279        let tracked = state.with_delta_tracking();
1280        let result = tracked.modify("x", 0u32, |n| n * 2);
1281        assert_eq!(result, 20);
1282        // Written to delta, not committed
1283        assert_eq!(tracked.get::<u32>("x"), Some(20));
1284        assert_eq!(state.get::<u32>("x"), Some(10)); // original unchanged
1285    }
1286
1287    // ── derived fallback tests ──────────────────────────────────────────
1288
1289    #[test]
1290    fn get_falls_back_to_derived_prefix() {
1291        let state = State::new();
1292        state.set("derived:risk", 0.85);
1293        // Access without prefix — should find derived:risk
1294        assert_eq!(state.get::<f64>("risk"), Some(0.85));
1295    }
1296
1297    #[test]
1298    fn get_prefers_direct_key_over_derived() {
1299        let state = State::new();
1300        state.set("score", 1.0);
1301        state.set("derived:score", 0.5);
1302        // Direct key should win
1303        assert_eq!(state.get::<f64>("score"), Some(1.0));
1304    }
1305
1306    #[test]
1307    fn get_derived_fallback_skipped_for_prefixed_keys() {
1308        let state = State::new();
1309        state.set("derived:risk", 0.85);
1310        // Prefixed key should NOT trigger fallback
1311        assert_eq!(state.get::<f64>("app:risk"), None);
1312    }
1313
1314    #[test]
1315    fn get_derived_fallback_with_delta_tracking() {
1316        let state = State::new();
1317        let tracked = state.with_delta_tracking();
1318        tracked.set("derived:computed_val", 42);
1319        assert_eq!(tracked.get::<i32>("computed_val"), Some(42));
1320    }
1321
1322    // ── with() zero-copy borrow tests ──────────────────────────────────
1323
1324    #[test]
1325    fn with_reads_from_inner() {
1326        let state = State::new();
1327        state.set("name", "Alice");
1328        let len = state.with("name", |v| v.as_str().unwrap().len());
1329        assert_eq!(len, Some(5));
1330    }
1331
1332    #[test]
1333    fn with_reads_from_delta_first() {
1334        let state = State::new();
1335        state.set("x", 1);
1336        let tracked = state.with_delta_tracking();
1337        tracked.set("x", 99);
1338        let val = tracked.with("x", |v| v.as_i64().unwrap());
1339        assert_eq!(val, Some(99));
1340    }
1341
1342    #[test]
1343    fn with_falls_back_to_inner_when_not_in_delta() {
1344        let state = State::new();
1345        state.set("committed", "yes");
1346        let tracked = state.with_delta_tracking();
1347        let val = tracked.with("committed", |v| v.as_str().unwrap().to_string());
1348        assert_eq!(val, Some("yes".to_string()));
1349    }
1350
1351    #[test]
1352    fn with_falls_back_to_derived() {
1353        let state = State::new();
1354        state.set("derived:risk", 0.85);
1355        let val = state.with("risk", |v| v.as_f64().unwrap());
1356        assert_eq!(val, Some(0.85));
1357    }
1358
1359    #[test]
1360    fn with_derived_fallback_skipped_for_prefixed() {
1361        let state = State::new();
1362        state.set("derived:risk", 0.85);
1363        let val = state.with("app:risk", |v| v.as_f64().unwrap());
1364        assert_eq!(val, None);
1365    }
1366
1367    #[test]
1368    fn with_returns_none_for_missing() {
1369        let state = State::new();
1370        let val = state.with("missing", |v| v.clone());
1371        assert_eq!(val, None);
1372    }
1373
1374    #[test]
1375    fn with_on_prefixed_state() {
1376        let state = State::new();
1377        state.app().set("flag", true);
1378        let val = state.app().with("flag", |v| v.as_bool().unwrap());
1379        assert_eq!(val, Some(true));
1380    }
1381
1382    #[test]
1383    fn with_on_read_only_prefixed_state() {
1384        let state = State::new();
1385        state.set("derived:score", serde_json::json!(0.95));
1386        let val = state.derived().with("score", |v| v.as_f64().unwrap());
1387        assert_eq!(val, Some(0.95));
1388    }
1389
1390    // ── StateKey typed key tests ───────────────────────────────────────
1391
1392    const TURN_COUNT: StateKey<u32> = StateKey::new("session:turn_count");
1393    const NAME: StateKey<String> = StateKey::new("user:name");
1394
1395    #[test]
1396    fn state_key_get_and_set() {
1397        let state = State::new();
1398        state.set_key(&TURN_COUNT, 5);
1399        assert_eq!(state.get_key(&TURN_COUNT), Some(5));
1400    }
1401
1402    #[test]
1403    fn state_key_get_missing() {
1404        let state = State::new();
1405        assert_eq!(state.get_key(&TURN_COUNT), None);
1406    }
1407
1408    #[test]
1409    fn state_key_string_type() {
1410        let state = State::new();
1411        state.set_key(&NAME, "Alice".to_string());
1412        assert_eq!(state.get_key(&NAME), Some("Alice".to_string()));
1413    }
1414
1415    #[test]
1416    fn state_key_with() {
1417        let state = State::new();
1418        state.set_key(&TURN_COUNT, 42);
1419        let val = state.with_key(&TURN_COUNT, |v| v.as_u64().unwrap());
1420        assert_eq!(val, Some(42));
1421    }
1422
1423    #[test]
1424    fn state_key_interop_with_raw() {
1425        let state = State::new();
1426        state.set_key(&TURN_COUNT, 10);
1427        // Can also read via raw key
1428        assert_eq!(state.get::<u32>("session:turn_count"), Some(10));
1429    }
1430}