1use std::collections::{HashMap, VecDeque};
7use std::marker::PhantomData;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::SystemTime;
11
12use dashmap::DashMap;
13use serde_json::Value;
14
15const DEFAULT_MUTATION_JOURNAL_CAPACITY: usize = 1024;
16
17pub struct StateKey<T> {
29 key: &'static str,
30 _phantom: PhantomData<fn() -> T>,
31}
32
33impl<T> StateKey<T> {
34 pub const fn new(key: &'static str) -> Self {
36 Self {
37 key,
38 _phantom: PhantomData,
39 }
40 }
41
42 pub const fn key(&self) -> &'static str {
44 self.key
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum StateMutationOrigin {
51 Set,
53 SetCommitted,
55 Remove,
57 ClearPrefix,
59 Commit,
61}
62
63#[derive(Debug, Clone, PartialEq)]
65pub struct StateMutation {
66 pub sequence: u64,
68 pub key: String,
70 pub old: Option<Value>,
72 pub new: Option<Value>,
74 pub origin: StateMutationOrigin,
76 pub timestamp: SystemTime,
78 pub delta: bool,
80}
81
82#[derive(Debug, Clone)]
88pub struct State {
89 inner: Arc<DashMap<String, Value>>,
90 delta: Arc<DashMap<String, Value>>,
91 mutations: Arc<std::sync::Mutex<VecDeque<StateMutation>>>,
92 next_mutation_sequence: Arc<AtomicU64>,
93 mutation_capacity: usize,
94 track_delta: bool,
95}
96
97impl Default for State {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl State {
104 pub fn new() -> Self {
106 Self {
107 inner: Arc::new(DashMap::new()),
108 delta: Arc::new(DashMap::new()),
109 mutations: Arc::new(std::sync::Mutex::new(VecDeque::new())),
110 next_mutation_sequence: Arc::new(AtomicU64::new(1)),
111 mutation_capacity: DEFAULT_MUTATION_JOURNAL_CAPACITY,
112 track_delta: false,
113 }
114 }
115
116 pub fn with_delta_tracking(&self) -> State {
119 State {
120 inner: self.inner.clone(),
121 delta: Arc::new(DashMap::new()),
122 mutations: self.mutations.clone(),
123 next_mutation_sequence: self.next_mutation_sequence.clone(),
124 mutation_capacity: self.mutation_capacity,
125 track_delta: true,
126 }
127 }
128
129 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
132 self.get_raw(key)
133 .and_then(|v| serde_json::from_value(v).ok())
134 }
135
136 pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
144 where
145 F: FnOnce(&Value) -> R,
146 {
147 if self.track_delta {
148 if let Some(ref_multi) = self.delta.get(key) {
149 return Some(f(ref_multi.value()));
150 }
151 }
152 if let Some(ref_multi) = self.inner.get(key) {
153 return Some(f(ref_multi.value()));
154 }
155 if !key.contains(':') {
156 let mut derived_key = String::with_capacity(8 + key.len());
157 use std::fmt::Write;
158 let _ = write!(derived_key, "derived:{}", key);
159 if self.track_delta {
160 if let Some(ref_multi) = self.delta.get(&derived_key) {
161 return Some(f(ref_multi.value()));
162 }
163 }
164 if let Some(ref_multi) = self.inner.get(&derived_key) {
165 return Some(f(ref_multi.value()));
166 }
167 }
168 None
169 }
170
171 pub fn get_raw(&self, key: &str) -> Option<Value> {
176 if self.track_delta {
177 if let Some(v) = self.delta.get(key) {
178 return Some(v.value().clone());
179 }
180 }
181 if let Some(v) = self.inner.get(key) {
182 return Some(v.value().clone());
183 }
184 if !key.contains(':') {
186 use std::fmt::Write;
187 let mut derived_key = String::with_capacity(8 + key.len());
188 let _ = write!(derived_key, "derived:{}", key);
189 if self.track_delta {
190 if let Some(v) = self.delta.get(&derived_key) {
191 return Some(v.value().clone());
192 }
193 }
194 return self.inner.get(&derived_key).map(|v| v.value().clone());
195 }
196 None
197 }
198
199 pub fn get_key<T: serde::de::DeserializeOwned>(&self, key: &StateKey<T>) -> Option<T> {
201 self.get(key.key())
202 }
203
204 pub fn set_key<T: serde::Serialize>(&self, key: &StateKey<T>, value: T) {
206 self.set(key.key(), value);
207 }
208
209 pub fn with_key<T, F, R>(&self, key: &StateKey<T>, f: F) -> Option<R>
211 where
212 F: FnOnce(&Value) -> R,
213 {
214 self.with(key.key(), f)
215 }
216
217 pub fn set(&self, key: impl Into<String>, value: impl serde::Serialize) {
220 let key = key.into();
221 let v = serde_json::to_value(value).expect("value must be serializable");
222 let old = self.get_raw(&key);
223 if self.track_delta {
224 self.delta.insert(key.clone(), v.clone());
225 self.record_mutation(key, old, Some(v), StateMutationOrigin::Set);
226 } else {
227 self.inner.insert(key.clone(), v.clone());
228 self.record_mutation(key, old, Some(v), StateMutationOrigin::Set);
229 }
230 }
231
232 pub fn set_committed(&self, key: impl Into<String>, value: impl serde::Serialize) {
234 let key = key.into();
235 let v = serde_json::to_value(value).expect("value must be serializable");
236 let old = self.inner.insert(key.clone(), v.clone());
237 self.record_mutation(key, old, Some(v), StateMutationOrigin::SetCommitted);
238 }
239
240 pub fn modify<T, F>(&self, key: &str, default: T, f: F) -> T
246 where
247 T: serde::Serialize + serde::de::DeserializeOwned,
248 F: FnOnce(T) -> T,
249 {
250 let current: T = self.get(key).unwrap_or(default);
252 let new_val = f(current);
253 self.set(key, &new_val);
254 new_val
255 }
256
257 pub fn contains(&self, key: &str) -> bool {
259 if self.track_delta && self.delta.contains_key(key) {
260 return true;
261 }
262 self.inner.contains_key(key)
263 }
264
265 pub fn remove(&self, key: &str) -> Option<Value> {
267 if self.track_delta {
268 let from_delta = self.delta.remove(key).map(|(_, v)| v);
270 let from_inner = self.inner.remove(key).map(|(_, v)| v);
271 let removed = from_delta.or(from_inner);
272 if let Some(ref old) = removed {
273 self.record_mutation(
274 key.to_string(),
275 Some(old.clone()),
276 None,
277 StateMutationOrigin::Remove,
278 );
279 }
280 removed
281 } else {
282 let removed = self.inner.remove(key).map(|(_, v)| v);
283 if let Some(ref old) = removed {
284 self.record_mutation(
285 key.to_string(),
286 Some(old.clone()),
287 None,
288 StateMutationOrigin::Remove,
289 );
290 }
291 removed
292 }
293 }
294
295 pub fn keys(&self) -> Vec<String> {
297 if !self.track_delta || self.delta.is_empty() {
298 return self.inner.iter().map(|r| r.key().clone()).collect();
299 }
300 let mut seen =
301 std::collections::HashSet::with_capacity(self.inner.len() + self.delta.len());
302 let mut keys = Vec::with_capacity(self.inner.len() + self.delta.len());
303 for entry in self.inner.iter() {
304 let key = entry.key().clone();
305 seen.insert(key.clone());
306 keys.push(key);
307 }
308 for entry in self.delta.iter() {
309 let key = entry.key().clone();
310 if seen.insert(key.clone()) {
311 keys.push(key);
312 }
313 }
314 keys
315 }
316
317 pub fn pick(&self, keys: &[&str]) -> State {
319 let new = State::new();
320 for key in keys {
321 if let Some(v) = self.get_raw(key) {
322 new.set(*key, v);
323 }
324 }
325 new
326 }
327
328 pub fn merge(&self, other: &State) {
330 for entry in other.inner.iter() {
331 self.set(entry.key().clone(), entry.value().clone());
332 }
333 }
334
335 pub fn rename(&self, from: &str, to: &str) {
337 if let Some(v) = self.remove(from) {
338 if self.track_delta {
339 self.set(to.to_string(), v);
340 } else {
341 self.set(to.to_string(), v);
342 }
343 }
344 }
345
346 pub fn is_tracking_delta(&self) -> bool {
350 self.track_delta
351 }
352
353 pub fn has_delta(&self) -> bool {
355 self.track_delta && !self.delta.is_empty()
356 }
357
358 pub fn delta(&self) -> HashMap<String, Value> {
360 self.delta
361 .iter()
362 .map(|entry| (entry.key().clone(), entry.value().clone()))
363 .collect()
364 }
365
366 pub fn commit(&self) {
368 for entry in self.delta.iter() {
369 let key = entry.key().clone();
370 let value = entry.value().clone();
371 let old = self.inner.insert(key.clone(), value.clone());
372 self.record_mutation_with_delta(
373 key,
374 old,
375 Some(value),
376 StateMutationOrigin::Commit,
377 false,
378 );
379 }
380 self.delta.clear();
381 }
382
383 pub fn rollback(&self) {
385 self.delta.clear();
386 }
387
388 pub fn app(&self) -> PrefixedState<'_> {
392 PrefixedState {
393 state: self,
394 prefix: "app:",
395 }
396 }
397
398 pub fn user(&self) -> PrefixedState<'_> {
400 PrefixedState {
401 state: self,
402 prefix: "user:",
403 }
404 }
405
406 pub fn temp(&self) -> PrefixedState<'_> {
408 PrefixedState {
409 state: self,
410 prefix: "temp:",
411 }
412 }
413
414 pub fn session(&self) -> PrefixedState<'_> {
416 PrefixedState {
417 state: self,
418 prefix: "session:",
419 }
420 }
421
422 pub fn turn(&self) -> PrefixedState<'_> {
424 PrefixedState {
425 state: self,
426 prefix: "turn:",
427 }
428 }
429
430 pub fn bg(&self) -> PrefixedState<'_> {
432 PrefixedState {
433 state: self,
434 prefix: "bg:",
435 }
436 }
437
438 pub fn derived(&self) -> ReadOnlyPrefixedState<'_> {
440 ReadOnlyPrefixedState {
441 state: self,
442 prefix: "derived:",
443 }
444 }
445
446 pub fn snapshot_values(&self, keys: &[&str]) -> HashMap<String, Value> {
451 keys.iter()
452 .filter_map(|&k| self.get_raw(k).map(|v| (k.to_string(), v)))
453 .collect()
454 }
455
456 pub fn diff_values(
459 &self,
460 prev: &HashMap<String, Value>,
461 keys: &[&str],
462 ) -> Vec<(String, Value, Value)> {
463 keys.iter()
464 .filter_map(|&k| {
465 let old = prev.get(k);
466 let new = self.get_raw(k);
467 match (old, new) {
468 (Some(o), Some(n)) if o != &n => Some((k.to_string(), o.clone(), n)),
469 (None, Some(n)) => Some((k.to_string(), Value::Null, n)),
470 (Some(o), None) => Some((k.to_string(), o.clone(), Value::Null)),
471 _ => None,
472 }
473 })
474 .collect()
475 }
476
477 pub fn to_hashmap(&self) -> std::collections::HashMap<String, serde_json::Value> {
479 self.inner
480 .iter()
481 .map(|entry| (entry.key().clone(), entry.value().clone()))
482 .collect()
483 }
484
485 pub fn from_hashmap(&self, map: std::collections::HashMap<String, serde_json::Value>) {
487 for (key, value) in map {
488 self.set_committed(key, value);
489 }
490 }
491
492 pub fn clear_prefix(&self, prefix: &str) {
494 let keys_to_remove: Vec<String> = self
495 .inner
496 .iter()
497 .filter(|entry| entry.key().starts_with(prefix))
498 .map(|entry| entry.key().clone())
499 .collect();
500 for key in keys_to_remove {
501 if let Some((_, old)) = self.inner.remove(&key) {
502 self.record_mutation(key, Some(old), None, StateMutationOrigin::ClearPrefix);
503 }
504 }
505 if self.track_delta {
506 let delta_keys: Vec<String> = self
507 .delta
508 .iter()
509 .filter(|entry| entry.key().starts_with(prefix))
510 .map(|entry| entry.key().clone())
511 .collect();
512 for key in delta_keys {
513 if let Some((_, old)) = self.delta.remove(&key) {
514 self.record_mutation(key, Some(old), None, StateMutationOrigin::ClearPrefix);
515 }
516 }
517 }
518 }
519
520 pub fn recent_mutations(&self) -> Vec<StateMutation> {
522 self.mutations
523 .lock()
524 .expect("state mutation journal poisoned")
525 .iter()
526 .cloned()
527 .collect()
528 }
529
530 pub fn mutation_cursor(&self) -> u64 {
532 self.next_mutation_sequence.load(Ordering::Relaxed) - 1
533 }
534
535 pub fn mutations_since(&self, cursor: u64) -> Vec<StateMutation> {
537 let mutations = self
538 .mutations
539 .lock()
540 .expect("state mutation journal poisoned");
541 mutations
542 .iter()
543 .filter(|mutation| mutation.sequence > cursor)
544 .cloned()
545 .collect()
546 }
547
548 pub fn drain_mutations(&self) -> Vec<StateMutation> {
550 self.mutations
551 .lock()
552 .expect("state mutation journal poisoned")
553 .drain(..)
554 .collect()
555 }
556
557 fn record_mutation(
558 &self,
559 key: String,
560 old: Option<Value>,
561 new: Option<Value>,
562 origin: StateMutationOrigin,
563 ) {
564 self.record_mutation_with_delta(key, old, new, origin, self.track_delta);
565 }
566
567 fn record_mutation_with_delta(
568 &self,
569 key: String,
570 old: Option<Value>,
571 new: Option<Value>,
572 origin: StateMutationOrigin,
573 delta: bool,
574 ) {
575 let mut mutations = self
576 .mutations
577 .lock()
578 .expect("state mutation journal poisoned");
579 if mutations.len() >= self.mutation_capacity {
580 mutations.pop_front();
581 }
582 let sequence = self.next_mutation_sequence.fetch_add(1, Ordering::Relaxed);
583 mutations.push_back(StateMutation {
584 sequence,
585 key,
586 old,
587 new,
588 origin,
589 timestamp: SystemTime::now(),
590 delta,
591 });
592 }
593}
594
595pub struct PrefixedState<'a> {
597 state: &'a State,
598 prefix: &'static str,
599}
600
601impl<'a> PrefixedState<'a> {
602 fn prefixed_key(&self, key: &str) -> String {
603 format!("{}{}", self.prefix, key)
604 }
605
606 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
608 self.state.get(&self.prefixed_key(key))
609 }
610
611 pub fn get_raw(&self, key: &str) -> Option<Value> {
613 self.state.get_raw(&self.prefixed_key(key))
614 }
615
616 pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
618 where
619 F: FnOnce(&Value) -> R,
620 {
621 self.state.with(&self.prefixed_key(key), f)
622 }
623
624 pub fn set(&self, key: impl AsRef<str>, value: impl serde::Serialize) {
626 self.state.set(self.prefixed_key(key.as_ref()), value);
627 }
628
629 pub fn contains(&self, key: &str) -> bool {
631 self.state.contains(&self.prefixed_key(key))
632 }
633
634 pub fn remove(&self, key: &str) -> Option<Value> {
636 self.state.remove(&self.prefixed_key(key))
637 }
638
639 pub fn keys(&self) -> Vec<String> {
641 self.state
642 .keys()
643 .into_iter()
644 .filter_map(|k| k.strip_prefix(self.prefix).map(|s| s.to_string()))
645 .collect()
646 }
647}
648
649pub struct ReadOnlyPrefixedState<'a> {
654 state: &'a State,
655 prefix: &'static str,
656}
657
658impl<'a> ReadOnlyPrefixedState<'a> {
659 fn prefixed_key(&self, key: &str) -> String {
660 format!("{}{}", self.prefix, key)
661 }
662
663 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
665 self.state.get(&self.prefixed_key(key))
666 }
667
668 pub fn get_raw(&self, key: &str) -> Option<Value> {
670 self.state.get_raw(&self.prefixed_key(key))
671 }
672
673 pub fn with<F, R>(&self, key: &str, f: F) -> Option<R>
675 where
676 F: FnOnce(&Value) -> R,
677 {
678 self.state.with(&self.prefixed_key(key), f)
679 }
680
681 pub fn contains(&self, key: &str) -> bool {
683 self.state.contains(&self.prefixed_key(key))
684 }
685
686 pub fn keys(&self) -> Vec<String> {
688 self.state
689 .keys()
690 .into_iter()
691 .filter_map(|k| k.strip_prefix(self.prefix).map(|s| s.to_string()))
692 .collect()
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699
700 #[test]
701 fn set_and_get_string() {
702 let state = State::new();
703 state.set("name", "Alice");
704 assert_eq!(state.get::<String>("name"), Some("Alice".to_string()));
705 }
706
707 #[test]
708 fn set_and_get_json() {
709 let state = State::new();
710 state.set("data", serde_json::json!({"temp": 22}));
711 let v: Value = state.get("data").unwrap();
712 assert_eq!(v["temp"], 22);
713 }
714
715 #[test]
716 fn pick_subset() {
717 let state = State::new();
718 state.set("a", 1);
719 state.set("b", 2);
720 state.set("c", 3);
721 let picked = state.pick(&["a", "c"]);
722 assert!(picked.contains("a"));
723 assert!(!picked.contains("b"));
724 assert!(picked.contains("c"));
725 }
726
727 #[test]
728 fn merge_states() {
729 let s1 = State::new();
730 s1.set("a", 1);
731 let s2 = State::new();
732 s2.set("b", 2);
733 s1.merge(&s2);
734 assert!(s1.contains("a"));
735 assert!(s1.contains("b"));
736 }
737
738 #[test]
739 fn rename_key() {
740 let state = State::new();
741 state.set("old", "value");
742 state.rename("old", "new");
743 assert!(!state.contains("old"));
744 assert_eq!(state.get::<String>("new"), Some("value".to_string()));
745 }
746
747 #[test]
748 fn remove_returns_value() {
749 let state = State::new();
750 state.set("key", 42);
751 let removed = state.remove("key");
752 assert!(removed.is_some());
753 assert!(!state.contains("key"));
754 }
755
756 #[test]
757 fn get_missing_returns_none() {
758 let state = State::new();
759 assert_eq!(state.get::<String>("nope"), None);
760 }
761
762 #[test]
765 fn delta_tracking_writes_to_delta() {
766 let state = State::new();
767 state.set("committed", "yes");
768
769 let tracked = state.with_delta_tracking();
770 tracked.set("new_key", "new_value");
771
772 assert_eq!(
774 tracked.get::<String>("new_key"),
775 Some("new_value".to_string())
776 );
777 assert!(!state.contains("new_key"));
779 assert_eq!(tracked.get::<String>("committed"), Some("yes".to_string()));
781 }
782
783 #[test]
784 fn delta_has_delta_reports_correctly() {
785 let state = State::new();
786 let tracked = state.with_delta_tracking();
787 assert!(!tracked.has_delta());
788
789 tracked.set("key", "val");
790 assert!(tracked.has_delta());
791 }
792
793 #[test]
794 fn delta_commit_merges_to_inner() {
795 let state = State::new();
796 let tracked = state.with_delta_tracking();
797 tracked.set("key", "val");
798 assert!(!state.contains("key"));
799
800 tracked.commit();
801 assert_eq!(state.get::<String>("key"), Some("val".to_string()));
803 assert!(!tracked.has_delta());
804 }
805
806 #[test]
807 fn delta_rollback_discards_changes() {
808 let state = State::new();
809 let tracked = state.with_delta_tracking();
810 tracked.set("key", "val");
811 assert!(tracked.has_delta());
812
813 tracked.rollback();
814 assert!(!tracked.has_delta());
815 assert!(!state.contains("key"));
816 assert!(!tracked.contains("key"));
817 }
818
819 #[test]
820 fn delta_snapshot() {
821 let state = State::new();
822 let tracked = state.with_delta_tracking();
823 tracked.set("a", 1);
824 tracked.set("b", 2);
825
826 let snapshot = tracked.delta();
827 assert_eq!(snapshot.len(), 2);
828 assert!(snapshot.contains_key("a"));
829 assert!(snapshot.contains_key("b"));
830 }
831
832 #[test]
833 fn set_committed_bypasses_delta() {
834 let state = State::new();
835 let tracked = state.with_delta_tracking();
836 tracked.set_committed("direct", "value");
837
838 assert_eq!(state.get::<String>("direct"), Some("value".to_string()));
840 assert!(!tracked.has_delta());
842 assert_eq!(tracked.get::<String>("direct"), Some("value".to_string()));
844 }
845
846 #[test]
847 fn mutation_journal_records_set_and_remove() {
848 let state = State::new();
849 state.set("key", "first");
850 state.set("key", "second");
851 state.remove("key");
852
853 let mutations = state.recent_mutations();
854 assert_eq!(mutations.len(), 3);
855 assert_eq!(mutations[0].key, "key");
856 assert_eq!(mutations[0].old, None);
857 assert_eq!(mutations[0].new, Some(serde_json::json!("first")));
858 assert_eq!(mutations[0].origin, StateMutationOrigin::Set);
859
860 assert_eq!(mutations[1].old, Some(serde_json::json!("first")));
861 assert_eq!(mutations[1].new, Some(serde_json::json!("second")));
862
863 assert_eq!(mutations[2].old, Some(serde_json::json!("second")));
864 assert_eq!(mutations[2].new, None);
865 assert_eq!(mutations[2].origin, StateMutationOrigin::Remove);
866 }
867
868 #[test]
869 fn mutation_journal_is_shared_with_delta_tracking() {
870 let state = State::new();
871 state.set("committed", "yes");
872
873 let tracked = state.with_delta_tracking();
874 tracked.set("committed", "maybe");
875 tracked.commit();
876
877 let mutations = state.recent_mutations();
878 assert_eq!(mutations.len(), 3);
879 assert_eq!(mutations[1].key, "committed");
880 assert_eq!(mutations[1].old, Some(serde_json::json!("yes")));
881 assert_eq!(mutations[1].new, Some(serde_json::json!("maybe")));
882 assert_eq!(mutations[1].origin, StateMutationOrigin::Set);
883 assert!(mutations[1].delta);
884
885 assert_eq!(mutations[2].origin, StateMutationOrigin::Commit);
886 assert!(!mutations[2].delta);
887 }
888
889 #[test]
890 fn drain_mutations_clears_journal() {
891 let state = State::new();
892 state.set("a", 1);
893 state.set("b", 2);
894
895 let drained = state.drain_mutations();
896 assert_eq!(drained.len(), 2);
897 assert!(state.recent_mutations().is_empty());
898 }
899
900 #[test]
901 fn mutation_cursor_reads_only_later_changes() {
902 let state = State::new();
903 state.set("before", 1);
904 let cursor = state.mutation_cursor();
905
906 state.set("after", 2);
907 state.remove("before");
908
909 let mutations = state.mutations_since(cursor);
910 assert_eq!(mutations.len(), 2);
911 assert_eq!(mutations[0].key, "after");
912 assert_eq!(mutations[1].key, "before");
913 }
914
915 #[test]
916 fn no_delta_tracking_preserves_existing_behavior() {
917 let state = State::new();
918 assert!(!state.is_tracking_delta());
919 state.set("key", "val");
920 assert_eq!(state.get::<String>("key"), Some("val".to_string()));
921 assert!(!state.has_delta());
922 }
923
924 #[test]
927 fn prefix_app_set_and_get() {
928 let state = State::new();
929 state.app().set("flag", true);
930
931 assert_eq!(state.app().get::<bool>("flag"), Some(true));
933 assert_eq!(state.get::<bool>("app:flag"), Some(true));
935 }
936
937 #[test]
938 fn prefix_user_set_and_get() {
939 let state = State::new();
940 state.user().set("name", "Alice");
941 assert_eq!(
942 state.user().get::<String>("name"),
943 Some("Alice".to_string())
944 );
945 assert_eq!(state.get::<String>("user:name"), Some("Alice".to_string()));
946 }
947
948 #[test]
949 fn prefix_temp_set_and_get() {
950 let state = State::new();
951 state.temp().set("scratch", 42);
952 assert_eq!(state.temp().get::<i32>("scratch"), Some(42));
953 }
954
955 #[test]
956 fn prefix_contains_and_remove() {
957 let state = State::new();
958 state.app().set("x", 1);
959 assert!(state.app().contains("x"));
960 state.app().remove("x");
961 assert!(!state.app().contains("x"));
962 }
963
964 #[test]
965 fn prefix_keys() {
966 let state = State::new();
967 state.app().set("a", 1);
968 state.app().set("b", 2);
969 state.user().set("c", 3);
970
971 let app_keys = state.app().keys();
972 assert_eq!(app_keys.len(), 2);
973 assert!(app_keys.contains(&"a".to_string()));
974 assert!(app_keys.contains(&"b".to_string()));
975
976 let user_keys = state.user().keys();
977 assert_eq!(user_keys.len(), 1);
978 assert!(user_keys.contains(&"c".to_string()));
979 }
980
981 #[test]
982 fn prefix_with_delta_tracking() {
983 let state = State::new();
984 let tracked = state.with_delta_tracking();
985 tracked.app().set("flag", true);
986
987 assert_eq!(tracked.app().get::<bool>("flag"), Some(true));
989 assert!(tracked.has_delta());
991 assert!(!state.contains("app:flag"));
992
993 tracked.commit();
994 assert_eq!(state.get::<bool>("app:flag"), Some(true));
995 }
996
997 #[test]
1000 fn prefix_session_set_and_get() {
1001 let state = State::new();
1002 state.session().set("turn_count", 5);
1003 assert_eq!(state.session().get::<i32>("turn_count"), Some(5));
1004 assert_eq!(state.get::<i32>("session:turn_count"), Some(5));
1005 }
1006
1007 #[test]
1008 fn prefix_turn_set_and_get() {
1009 let state = State::new();
1010 state.turn().set("transcript", "hello");
1011 assert_eq!(
1012 state.turn().get::<String>("transcript"),
1013 Some("hello".to_string())
1014 );
1015 assert_eq!(
1016 state.get::<String>("turn:transcript"),
1017 Some("hello".to_string())
1018 );
1019 }
1020
1021 #[test]
1022 fn prefix_bg_set_and_get() {
1023 let state = State::new();
1024 state.bg().set("task_id", "abc-123");
1025 assert_eq!(
1026 state.bg().get::<String>("task_id"),
1027 Some("abc-123".to_string())
1028 );
1029 assert_eq!(
1030 state.get::<String>("bg:task_id"),
1031 Some("abc-123".to_string())
1032 );
1033 }
1034
1035 #[test]
1036 fn prefix_session_contains_and_remove() {
1037 let state = State::new();
1038 state.session().set("x", 1);
1039 assert!(state.session().contains("x"));
1040 state.session().remove("x");
1041 assert!(!state.session().contains("x"));
1042 }
1043
1044 #[test]
1045 fn prefix_turn_keys() {
1046 let state = State::new();
1047 state.turn().set("a", 1);
1048 state.turn().set("b", 2);
1049 state.session().set("c", 3);
1050
1051 let turn_keys = state.turn().keys();
1052 assert_eq!(turn_keys.len(), 2);
1053 assert!(turn_keys.contains(&"a".to_string()));
1054 assert!(turn_keys.contains(&"b".to_string()));
1055 }
1056
1057 #[test]
1060 fn derived_read_only_get() {
1061 let state = State::new();
1062 state.set("derived:sentiment", "positive");
1064 assert_eq!(
1065 state.derived().get::<String>("sentiment"),
1066 Some("positive".to_string())
1067 );
1068 }
1069
1070 #[test]
1071 fn derived_read_only_get_raw() {
1072 let state = State::new();
1073 state.set("derived:score", serde_json::json!(0.95));
1074 let raw = state.derived().get_raw("score");
1075 assert!(raw.is_some());
1076 assert_eq!(raw.unwrap(), serde_json::json!(0.95));
1077 }
1078
1079 #[test]
1080 fn derived_read_only_contains() {
1081 let state = State::new();
1082 state.set("derived:exists", true);
1083 assert!(state.derived().contains("exists"));
1084 assert!(!state.derived().contains("missing"));
1085 }
1086
1087 #[test]
1088 fn derived_read_only_keys() {
1089 let state = State::new();
1090 state.set("derived:a", 1);
1091 state.set("derived:b", 2);
1092 state.set("app:c", 3);
1093
1094 let derived_keys = state.derived().keys();
1095 assert_eq!(derived_keys.len(), 2);
1096 assert!(derived_keys.contains(&"a".to_string()));
1097 assert!(derived_keys.contains(&"b".to_string()));
1098 }
1099
1100 #[test]
1101 fn derived_missing_key_returns_none() {
1102 let state = State::new();
1103 assert_eq!(state.derived().get::<String>("nope"), None);
1104 assert_eq!(state.derived().get_raw("nope"), None);
1105 }
1106
1107 #[test]
1110 fn snapshot_values_captures_existing_keys() {
1111 let state = State::new();
1112 state.set("a", 1);
1113 state.set("b", "hello");
1114 state.set("c", true);
1115
1116 let snap = state.snapshot_values(&["a", "b", "missing"]);
1117 assert_eq!(snap.len(), 2);
1118 assert_eq!(snap.get("a"), Some(&serde_json::json!(1)));
1119 assert_eq!(snap.get("b"), Some(&serde_json::json!("hello")));
1120 assert!(!snap.contains_key("missing"));
1121 }
1122
1123 #[test]
1124 fn snapshot_values_empty_keys() {
1125 let state = State::new();
1126 state.set("a", 1);
1127 let snap = state.snapshot_values(&[]);
1128 assert!(snap.is_empty());
1129 }
1130
1131 #[test]
1134 fn diff_values_detects_changed_value() {
1135 let state = State::new();
1136 state.set("x", 1);
1137 let snap = state.snapshot_values(&["x"]);
1138
1139 state.set("x", 2);
1140 let diffs = state.diff_values(&snap, &["x"]);
1141 assert_eq!(diffs.len(), 1);
1142 assert_eq!(diffs[0].0, "x");
1143 assert_eq!(diffs[0].1, serde_json::json!(1));
1144 assert_eq!(diffs[0].2, serde_json::json!(2));
1145 }
1146
1147 #[test]
1148 fn diff_values_detects_new_key() {
1149 let state = State::new();
1150 let snap = state.snapshot_values(&["y"]);
1151
1152 state.set("y", "new");
1153 let diffs = state.diff_values(&snap, &["y"]);
1154 assert_eq!(diffs.len(), 1);
1155 assert_eq!(diffs[0].0, "y");
1156 assert_eq!(diffs[0].1, Value::Null);
1157 assert_eq!(diffs[0].2, serde_json::json!("new"));
1158 }
1159
1160 #[test]
1161 fn diff_values_detects_removed_key() {
1162 let state = State::new();
1163 state.set("z", 42);
1164 let snap = state.snapshot_values(&["z"]);
1165
1166 state.remove("z");
1167 let diffs = state.diff_values(&snap, &["z"]);
1168 assert_eq!(diffs.len(), 1);
1169 assert_eq!(diffs[0].0, "z");
1170 assert_eq!(diffs[0].1, serde_json::json!(42));
1171 assert_eq!(diffs[0].2, Value::Null);
1172 }
1173
1174 #[test]
1175 fn diff_values_no_change() {
1176 let state = State::new();
1177 state.set("stable", 10);
1178 let snap = state.snapshot_values(&["stable"]);
1179
1180 let diffs = state.diff_values(&snap, &["stable"]);
1182 assert!(diffs.is_empty());
1183 }
1184
1185 #[test]
1186 fn diff_values_multiple_keys_mixed_changes() {
1187 let state = State::new();
1188 state.set("a", 1);
1189 state.set("b", 2);
1190 let snap = state.snapshot_values(&["a", "b", "c"]);
1191
1192 state.set("a", 10); state.set("c", 3); let diffs = state.diff_values(&snap, &["a", "b", "c"]);
1197 assert_eq!(diffs.len(), 2); let diff_keys: Vec<&str> = diffs.iter().map(|(k, _, _)| k.as_str()).collect();
1199 assert!(diff_keys.contains(&"a"));
1200 assert!(diff_keys.contains(&"c"));
1201 }
1202
1203 #[test]
1206 fn clear_prefix_removes_matching_keys() {
1207 let state = State::new();
1208 state.set("turn:a", 1);
1209 state.set("turn:b", 2);
1210 state.set("app:c", 3);
1211 state.set("session:d", 4);
1212
1213 state.clear_prefix("turn:");
1214 assert!(!state.contains("turn:a"));
1215 assert!(!state.contains("turn:b"));
1216 assert!(state.contains("app:c"));
1217 assert!(state.contains("session:d"));
1218 }
1219
1220 #[test]
1221 fn clear_prefix_no_matching_keys_is_noop() {
1222 let state = State::new();
1223 state.set("app:x", 1);
1224 state.clear_prefix("turn:");
1225 assert!(state.contains("app:x"));
1226 }
1227
1228 #[test]
1229 fn clear_prefix_also_clears_delta() {
1230 let state = State::new();
1231 state.set("turn:committed", 1);
1232 let tracked = state.with_delta_tracking();
1233 tracked.set("turn:delta_val", 2);
1234
1235 assert!(tracked.contains("turn:committed"));
1237 assert!(tracked.contains("turn:delta_val"));
1238
1239 tracked.clear_prefix("turn:");
1240 assert!(!tracked.contains("turn:committed"));
1241 assert!(!tracked.contains("turn:delta_val"));
1242 }
1243
1244 #[test]
1245 fn clear_prefix_via_turn_accessor() {
1246 let state = State::new();
1247 state.turn().set("x", 1);
1248 state.turn().set("y", 2);
1249 state.app().set("z", 3);
1250
1251 state.clear_prefix("turn:");
1252 assert!(state.turn().keys().is_empty());
1253 assert!(state.app().contains("z"));
1254 }
1255
1256 #[test]
1259 fn modify_increment_existing() {
1260 let state = State::new();
1261 state.set("count", 5u32);
1262 let result = state.modify("count", 0u32, |n| n + 1);
1263 assert_eq!(result, 6);
1264 assert_eq!(state.get::<u32>("count"), Some(6));
1265 }
1266
1267 #[test]
1268 fn modify_uses_default_when_missing() {
1269 let state = State::new();
1270 let result = state.modify("new_count", 0u32, |n| n + 1);
1271 assert_eq!(result, 1);
1272 assert_eq!(state.get::<u32>("new_count"), Some(1));
1273 }
1274
1275 #[test]
1276 fn modify_with_delta_tracking() {
1277 let state = State::new();
1278 state.set("x", 10u32);
1279 let tracked = state.with_delta_tracking();
1280 let result = tracked.modify("x", 0u32, |n| n * 2);
1281 assert_eq!(result, 20);
1282 assert_eq!(tracked.get::<u32>("x"), Some(20));
1284 assert_eq!(state.get::<u32>("x"), Some(10)); }
1286
1287 #[test]
1290 fn get_falls_back_to_derived_prefix() {
1291 let state = State::new();
1292 state.set("derived:risk", 0.85);
1293 assert_eq!(state.get::<f64>("risk"), Some(0.85));
1295 }
1296
1297 #[test]
1298 fn get_prefers_direct_key_over_derived() {
1299 let state = State::new();
1300 state.set("score", 1.0);
1301 state.set("derived:score", 0.5);
1302 assert_eq!(state.get::<f64>("score"), Some(1.0));
1304 }
1305
1306 #[test]
1307 fn get_derived_fallback_skipped_for_prefixed_keys() {
1308 let state = State::new();
1309 state.set("derived:risk", 0.85);
1310 assert_eq!(state.get::<f64>("app:risk"), None);
1312 }
1313
1314 #[test]
1315 fn get_derived_fallback_with_delta_tracking() {
1316 let state = State::new();
1317 let tracked = state.with_delta_tracking();
1318 tracked.set("derived:computed_val", 42);
1319 assert_eq!(tracked.get::<i32>("computed_val"), Some(42));
1320 }
1321
1322 #[test]
1325 fn with_reads_from_inner() {
1326 let state = State::new();
1327 state.set("name", "Alice");
1328 let len = state.with("name", |v| v.as_str().unwrap().len());
1329 assert_eq!(len, Some(5));
1330 }
1331
1332 #[test]
1333 fn with_reads_from_delta_first() {
1334 let state = State::new();
1335 state.set("x", 1);
1336 let tracked = state.with_delta_tracking();
1337 tracked.set("x", 99);
1338 let val = tracked.with("x", |v| v.as_i64().unwrap());
1339 assert_eq!(val, Some(99));
1340 }
1341
1342 #[test]
1343 fn with_falls_back_to_inner_when_not_in_delta() {
1344 let state = State::new();
1345 state.set("committed", "yes");
1346 let tracked = state.with_delta_tracking();
1347 let val = tracked.with("committed", |v| v.as_str().unwrap().to_string());
1348 assert_eq!(val, Some("yes".to_string()));
1349 }
1350
1351 #[test]
1352 fn with_falls_back_to_derived() {
1353 let state = State::new();
1354 state.set("derived:risk", 0.85);
1355 let val = state.with("risk", |v| v.as_f64().unwrap());
1356 assert_eq!(val, Some(0.85));
1357 }
1358
1359 #[test]
1360 fn with_derived_fallback_skipped_for_prefixed() {
1361 let state = State::new();
1362 state.set("derived:risk", 0.85);
1363 let val = state.with("app:risk", |v| v.as_f64().unwrap());
1364 assert_eq!(val, None);
1365 }
1366
1367 #[test]
1368 fn with_returns_none_for_missing() {
1369 let state = State::new();
1370 let val = state.with("missing", |v| v.clone());
1371 assert_eq!(val, None);
1372 }
1373
1374 #[test]
1375 fn with_on_prefixed_state() {
1376 let state = State::new();
1377 state.app().set("flag", true);
1378 let val = state.app().with("flag", |v| v.as_bool().unwrap());
1379 assert_eq!(val, Some(true));
1380 }
1381
1382 #[test]
1383 fn with_on_read_only_prefixed_state() {
1384 let state = State::new();
1385 state.set("derived:score", serde_json::json!(0.95));
1386 let val = state.derived().with("score", |v| v.as_f64().unwrap());
1387 assert_eq!(val, Some(0.95));
1388 }
1389
1390 const TURN_COUNT: StateKey<u32> = StateKey::new("session:turn_count");
1393 const NAME: StateKey<String> = StateKey::new("user:name");
1394
1395 #[test]
1396 fn state_key_get_and_set() {
1397 let state = State::new();
1398 state.set_key(&TURN_COUNT, 5);
1399 assert_eq!(state.get_key(&TURN_COUNT), Some(5));
1400 }
1401
1402 #[test]
1403 fn state_key_get_missing() {
1404 let state = State::new();
1405 assert_eq!(state.get_key(&TURN_COUNT), None);
1406 }
1407
1408 #[test]
1409 fn state_key_string_type() {
1410 let state = State::new();
1411 state.set_key(&NAME, "Alice".to_string());
1412 assert_eq!(state.get_key(&NAME), Some("Alice".to_string()));
1413 }
1414
1415 #[test]
1416 fn state_key_with() {
1417 let state = State::new();
1418 state.set_key(&TURN_COUNT, 42);
1419 let val = state.with_key(&TURN_COUNT, |v| v.as_u64().unwrap());
1420 assert_eq!(val, Some(42));
1421 }
1422
1423 #[test]
1424 fn state_key_interop_with_raw() {
1425 let state = State::new();
1426 state.set_key(&TURN_COUNT, 10);
1427 assert_eq!(state.get::<u32>("session:turn_count"), Some(10));
1429 }
1430}