1use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12
13use serde_json::Value;
14
15use super::BoxFuture;
16use crate::state::{State, StateMutation};
17
18use super::contract::WatcherContract;
19
20pub type PredicateFn = Arc<dyn Fn(&Value, &Value) -> bool + Send + Sync>;
24
25pub enum WatchPredicate {
27 Changed,
29 ChangedTo(Value),
31 ChangedFrom(Value),
33 CrossedAbove(f64),
35 CrossedBelow(f64),
37 BecameTrue,
39 BecameFalse,
41 Custom(PredicateFn),
43}
44
45impl std::fmt::Debug for WatchPredicate {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match self {
48 Self::Changed => write!(f, "Changed"),
49 Self::ChangedTo(v) => write!(f, "ChangedTo({v})"),
50 Self::ChangedFrom(v) => write!(f, "ChangedFrom({v})"),
51 Self::CrossedAbove(t) => write!(f, "CrossedAbove({t})"),
52 Self::CrossedBelow(t) => write!(f, "CrossedBelow({t})"),
53 Self::BecameTrue => write!(f, "BecameTrue"),
54 Self::BecameFalse => write!(f, "BecameFalse"),
55 Self::Custom(_) => write!(f, "Custom(<fn>)"),
56 }
57 }
58}
59
60impl WatchPredicate {
61 fn matches(&self, old: &Value, new: &Value) -> bool {
63 match self {
64 WatchPredicate::Changed => true,
65 WatchPredicate::ChangedTo(val) => new == val,
66 WatchPredicate::ChangedFrom(val) => old == val,
67 WatchPredicate::CrossedAbove(threshold) => match (as_f64(old), as_f64(new)) {
68 (Some(o), Some(n)) => o < *threshold && n >= *threshold,
69 _ => false,
70 },
71 WatchPredicate::CrossedBelow(threshold) => match (as_f64(old), as_f64(new)) {
72 (Some(o), Some(n)) => o >= *threshold && n < *threshold,
73 _ => false,
74 },
75 WatchPredicate::BecameTrue => old != &Value::Bool(true) && new == &Value::Bool(true),
76 WatchPredicate::BecameFalse => old == &Value::Bool(true) && new != &Value::Bool(true),
77 WatchPredicate::Custom(f) => f(old, new),
78 }
79 }
80}
81
82pub struct Watcher {
87 pub key: String,
89 pub predicate: WatchPredicate,
91 pub action: Arc<dyn Fn(Value, Value, State) -> BoxFuture<()> + Send + Sync>,
93 pub blocking: bool,
96}
97
98impl std::fmt::Debug for Watcher {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("Watcher")
101 .field("key", &self.key)
102 .field("predicate", &self.predicate)
103 .field("blocking", &self.blocking)
104 .finish_non_exhaustive()
105 }
106}
107
108pub struct WatcherRegistry {
112 watchers: Vec<Watcher>,
113 observed_keys: HashSet<String>,
115}
116
117impl Default for WatcherRegistry {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl WatcherRegistry {
124 pub fn new() -> Self {
126 Self {
127 watchers: Vec::new(),
128 observed_keys: HashSet::new(),
129 }
130 }
131
132 pub fn add(&mut self, watcher: Watcher) {
134 self.observed_keys.insert(watcher.key.clone());
135 self.watchers.push(watcher);
136 }
137
138 pub fn observed_keys(&self) -> &HashSet<String> {
143 &self.observed_keys
144 }
145
146 pub fn describe(&self) -> Vec<WatcherContract> {
148 self.watchers
149 .iter()
150 .map(|watcher| WatcherContract {
151 key: watcher.key.clone(),
152 predicate: format!("{:?}", watcher.predicate),
153 blocking: watcher.blocking,
154 })
155 .collect()
156 }
157
158 pub fn evaluate(
166 &self,
167 diffs: &[(String, Value, Value)],
168 state: &State,
169 ) -> (Vec<BoxFuture<()>>, Vec<BoxFuture<()>>) {
170 let mut blocking = Vec::new();
171 let mut concurrent = Vec::new();
172
173 for (key, old, new) in diffs {
174 for watcher in &self.watchers {
175 if watcher.key == *key && watcher.predicate.matches(old, new) {
176 let fut = (watcher.action)(old.clone(), new.clone(), state.clone());
177 if watcher.blocking {
178 blocking.push(fut);
179 } else {
180 concurrent.push(fut);
181 }
182 }
183 }
184 }
185
186 (blocking, concurrent)
187 }
188
189 pub fn evaluate_mutations(
195 &self,
196 mutations: &[StateMutation],
197 state: &State,
198 ) -> (Vec<BoxFuture<()>>, Vec<BoxFuture<()>>) {
199 let mut net: HashMap<String, (Option<Value>, Option<Value>)> = HashMap::new();
200
201 for mutation in mutations {
202 if !self.observed_keys.contains(&mutation.key) {
203 continue;
204 }
205
206 net.entry(mutation.key.clone())
207 .and_modify(|(_, new)| {
208 *new = mutation.new.clone();
209 })
210 .or_insert_with(|| (mutation.old.clone(), mutation.new.clone()));
211 }
212
213 let diffs: Vec<(String, Value, Value)> = net
214 .into_iter()
215 .filter_map(|(key, (old, new))| {
216 if old == new {
217 None
218 } else {
219 Some((key, old.unwrap_or(Value::Null), new.unwrap_or(Value::Null)))
220 }
221 })
222 .collect();
223
224 self.evaluate(&diffs, state)
225 }
226}
227
228fn as_f64(v: &Value) -> Option<f64> {
232 match v {
233 Value::Number(n) => n.as_f64(),
234 _ => None,
235 }
236}
237
238#[cfg(test)]
241mod tests {
242 use super::*;
243 use serde_json::json;
244 use std::sync::atomic::{AtomicU32, Ordering};
245
246 fn counting_watcher(
248 key: &str,
249 predicate: WatchPredicate,
250 counter: Arc<AtomicU32>,
251 blocking: bool,
252 ) -> Watcher {
253 Watcher {
254 key: key.to_string(),
255 predicate,
256 action: Arc::new(move |_old, _new, _state| {
257 let c = counter.clone();
258 Box::pin(async move {
259 c.fetch_add(1, Ordering::SeqCst);
260 })
261 }),
262 blocking,
263 }
264 }
265
266 fn recording_watcher(key: &str, predicate: WatchPredicate, blocking: bool) -> Watcher {
268 Watcher {
269 key: key.to_string(),
270 predicate,
271 action: Arc::new(|old, new, state| {
272 Box::pin(async move {
273 let _ = state.set("recorded_old", old);
274 let _ = state.set("recorded_new", new);
275 })
276 }),
277 blocking,
278 }
279 }
280
281 #[tokio::test]
284 async fn changed_fires_on_any_diff() {
285 let counter = Arc::new(AtomicU32::new(0));
286 let mut registry = WatcherRegistry::new();
287 registry.add(counting_watcher(
288 "x",
289 WatchPredicate::Changed,
290 counter.clone(),
291 false,
292 ));
293
294 let state = State::new();
295 let diffs = vec![("x".to_string(), json!(1), json!(2))];
296
297 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
298 assert!(blocking.is_empty());
299 assert_eq!(concurrent.len(), 1);
300
301 for fut in concurrent {
302 fut.await;
303 }
304 assert_eq!(counter.load(Ordering::SeqCst), 1);
305 }
306
307 #[tokio::test]
310 async fn changed_to_fires_when_new_value_matches() {
311 let counter = Arc::new(AtomicU32::new(0));
312 let mut registry = WatcherRegistry::new();
313 registry.add(counting_watcher(
314 "status",
315 WatchPredicate::ChangedTo(json!("active")),
316 counter.clone(),
317 false,
318 ));
319
320 let state = State::new();
321 let diffs = vec![("status".to_string(), json!("inactive"), json!("active"))];
322
323 let (_, concurrent) = registry.evaluate(&diffs, &state);
324 assert_eq!(concurrent.len(), 1);
325
326 for fut in concurrent {
327 fut.await;
328 }
329 assert_eq!(counter.load(Ordering::SeqCst), 1);
330 }
331
332 #[tokio::test]
335 async fn changed_to_does_not_fire_when_new_value_differs() {
336 let counter = Arc::new(AtomicU32::new(0));
337 let mut registry = WatcherRegistry::new();
338 registry.add(counting_watcher(
339 "status",
340 WatchPredicate::ChangedTo(json!("active")),
341 counter.clone(),
342 false,
343 ));
344
345 let state = State::new();
346 let diffs = vec![("status".to_string(), json!("inactive"), json!("pending"))];
347
348 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
349 assert!(blocking.is_empty());
350 assert!(concurrent.is_empty());
351 assert_eq!(counter.load(Ordering::SeqCst), 0);
352 }
353
354 #[tokio::test]
357 async fn changed_from_fires_when_old_value_matches() {
358 let counter = Arc::new(AtomicU32::new(0));
359 let mut registry = WatcherRegistry::new();
360 registry.add(counting_watcher(
361 "mode",
362 WatchPredicate::ChangedFrom(json!("draft")),
363 counter.clone(),
364 false,
365 ));
366
367 let state = State::new();
368 let diffs = vec![("mode".to_string(), json!("draft"), json!("published"))];
370
371 let (_, concurrent) = registry.evaluate(&diffs, &state);
372 assert_eq!(concurrent.len(), 1);
373
374 for fut in concurrent {
375 fut.await;
376 }
377 assert_eq!(counter.load(Ordering::SeqCst), 1);
378
379 let diffs2 = vec![("mode".to_string(), json!("published"), json!("archived"))];
381 let (b, c) = registry.evaluate(&diffs2, &state);
382 assert!(b.is_empty());
383 assert!(c.is_empty());
384 assert_eq!(counter.load(Ordering::SeqCst), 1);
385 }
386
387 #[tokio::test]
390 async fn crossed_above_fires_on_upward_crossing() {
391 let counter = Arc::new(AtomicU32::new(0));
392 let mut registry = WatcherRegistry::new();
393 registry.add(counting_watcher(
394 "temp",
395 WatchPredicate::CrossedAbove(100.0),
396 counter.clone(),
397 false,
398 ));
399
400 let state = State::new();
401 let diffs = vec![("temp".to_string(), json!(95.0), json!(105.0))];
403
404 let (_, concurrent) = registry.evaluate(&diffs, &state);
405 assert_eq!(concurrent.len(), 1);
406
407 for fut in concurrent {
408 fut.await;
409 }
410 assert_eq!(counter.load(Ordering::SeqCst), 1);
411 }
412
413 #[tokio::test]
416 async fn crossed_above_does_not_fire_when_both_above() {
417 let counter = Arc::new(AtomicU32::new(0));
418 let mut registry = WatcherRegistry::new();
419 registry.add(counting_watcher(
420 "temp",
421 WatchPredicate::CrossedAbove(100.0),
422 counter.clone(),
423 false,
424 ));
425
426 let state = State::new();
427 let diffs = vec![("temp".to_string(), json!(110.0), json!(120.0))];
429
430 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
431 assert!(blocking.is_empty());
432 assert!(concurrent.is_empty());
433 assert_eq!(counter.load(Ordering::SeqCst), 0);
434 }
435
436 #[tokio::test]
439 async fn crossed_below_fires_on_downward_crossing() {
440 let counter = Arc::new(AtomicU32::new(0));
441 let mut registry = WatcherRegistry::new();
442 registry.add(counting_watcher(
443 "battery",
444 WatchPredicate::CrossedBelow(20.0),
445 counter.clone(),
446 false,
447 ));
448
449 let state = State::new();
450 let diffs = vec![("battery".to_string(), json!(25.0), json!(15.0))];
452
453 let (_, concurrent) = registry.evaluate(&diffs, &state);
454 assert_eq!(concurrent.len(), 1);
455
456 for fut in concurrent {
457 fut.await;
458 }
459 assert_eq!(counter.load(Ordering::SeqCst), 1);
460 }
461
462 #[tokio::test]
465 async fn became_true_fires_on_false_to_true() {
466 let counter = Arc::new(AtomicU32::new(0));
467 let mut registry = WatcherRegistry::new();
468 registry.add(counting_watcher(
469 "flag",
470 WatchPredicate::BecameTrue,
471 counter.clone(),
472 false,
473 ));
474
475 let state = State::new();
476 let diffs = vec![("flag".to_string(), json!(false), json!(true))];
477
478 let (_, concurrent) = registry.evaluate(&diffs, &state);
479 assert_eq!(concurrent.len(), 1);
480
481 for fut in concurrent {
482 fut.await;
483 }
484 assert_eq!(counter.load(Ordering::SeqCst), 1);
485 }
486
487 #[tokio::test]
490 async fn became_false_fires_on_true_to_false() {
491 let counter = Arc::new(AtomicU32::new(0));
492 let mut registry = WatcherRegistry::new();
493 registry.add(counting_watcher(
494 "flag",
495 WatchPredicate::BecameFalse,
496 counter.clone(),
497 false,
498 ));
499
500 let state = State::new();
501 let diffs = vec![("flag".to_string(), json!(true), json!(false))];
502
503 let (_, concurrent) = registry.evaluate(&diffs, &state);
504 assert_eq!(concurrent.len(), 1);
505
506 for fut in concurrent {
507 fut.await;
508 }
509 assert_eq!(counter.load(Ordering::SeqCst), 1);
510 }
511
512 #[tokio::test]
515 async fn custom_predicate_fires_when_fn_returns_true() {
516 let counter = Arc::new(AtomicU32::new(0));
517 let mut registry = WatcherRegistry::new();
518 registry.add(counting_watcher(
519 "score",
520 WatchPredicate::Custom(Arc::new(|old, new| {
521 match (as_f64(old), as_f64(new)) {
523 (Some(o), Some(n)) => (n - o * 2.0).abs() < f64::EPSILON,
524 _ => false,
525 }
526 })),
527 counter.clone(),
528 false,
529 ));
530
531 let state = State::new();
532 let diffs = vec![("score".to_string(), json!(5.0), json!(10.0))];
534
535 let (_, concurrent) = registry.evaluate(&diffs, &state);
536 assert_eq!(concurrent.len(), 1);
537
538 for fut in concurrent {
539 fut.await;
540 }
541 assert_eq!(counter.load(Ordering::SeqCst), 1);
542
543 let diffs2 = vec![("score".to_string(), json!(5.0), json!(11.0))];
545 let (b, c) = registry.evaluate(&diffs2, &state);
546 assert!(b.is_empty());
547 assert!(c.is_empty());
548 }
549
550 #[tokio::test]
553 async fn evaluate_separates_blocking_and_concurrent() {
554 let blocking_counter = Arc::new(AtomicU32::new(0));
555 let concurrent_counter = Arc::new(AtomicU32::new(0));
556 let mut registry = WatcherRegistry::new();
557
558 registry.add(counting_watcher(
560 "x",
561 WatchPredicate::Changed,
562 blocking_counter.clone(),
563 true,
564 ));
565
566 registry.add(counting_watcher(
568 "x",
569 WatchPredicate::Changed,
570 concurrent_counter.clone(),
571 false,
572 ));
573
574 let state = State::new();
575 let diffs = vec![("x".to_string(), json!(1), json!(2))];
576
577 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
578 assert_eq!(blocking.len(), 1);
579 assert_eq!(concurrent.len(), 1);
580
581 for fut in blocking {
583 fut.await;
584 }
585 for fut in concurrent {
586 fut.await;
587 }
588
589 assert_eq!(blocking_counter.load(Ordering::SeqCst), 1);
590 assert_eq!(concurrent_counter.load(Ordering::SeqCst), 1);
591 }
592
593 #[test]
596 fn evaluate_with_no_matching_diffs_returns_empty() {
597 let counter = Arc::new(AtomicU32::new(0));
598 let mut registry = WatcherRegistry::new();
599 registry.add(counting_watcher(
600 "x",
601 WatchPredicate::Changed,
602 counter.clone(),
603 false,
604 ));
605
606 let state = State::new();
607 let diffs = vec![("y".to_string(), json!(1), json!(2))];
609
610 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
611 assert!(blocking.is_empty());
612 assert!(concurrent.is_empty());
613 }
614
615 #[tokio::test]
616 async fn evaluate_mutations_collapses_to_net_diff() {
617 let counter = Arc::new(AtomicU32::new(0));
618 let mut registry = WatcherRegistry::new();
619 registry.add(counting_watcher(
620 "x",
621 WatchPredicate::ChangedTo(json!(3)),
622 counter.clone(),
623 false,
624 ));
625
626 let state = State::new();
627 let cursor = state.mutation_cursor();
628 let _ = state.set("x", 1);
629 let _ = state.set("x", 2);
630 let _ = state.set("x", 3);
631 let _ = state.set("ignored", 10);
632
633 let (_, concurrent) = registry.evaluate_mutations(&state.mutations_since(cursor), &state);
634 assert_eq!(concurrent.len(), 1);
635 for fut in concurrent {
636 fut.await;
637 }
638 assert_eq!(counter.load(Ordering::SeqCst), 1);
639 }
640
641 #[test]
642 fn evaluate_mutations_ignores_net_noop() {
643 let counter = Arc::new(AtomicU32::new(0));
644 let mut registry = WatcherRegistry::new();
645 registry.add(counting_watcher(
646 "x",
647 WatchPredicate::Changed,
648 counter,
649 false,
650 ));
651
652 let state = State::new();
653 let _ = state.set("x", 1);
654 let cursor = state.mutation_cursor();
655 let _ = state.set("x", 2);
656 let _ = state.set("x", 1);
657
658 let (blocking, concurrent) =
659 registry.evaluate_mutations(&state.mutations_since(cursor), &state);
660 assert!(blocking.is_empty());
661 assert!(concurrent.is_empty());
662 }
663
664 #[test]
667 fn observed_keys_tracks_added_watcher_keys() {
668 let counter = Arc::new(AtomicU32::new(0));
669 let mut registry = WatcherRegistry::new();
670
671 assert!(registry.observed_keys().is_empty());
672
673 registry.add(counting_watcher(
674 "alpha",
675 WatchPredicate::Changed,
676 counter.clone(),
677 false,
678 ));
679 registry.add(counting_watcher(
680 "beta",
681 WatchPredicate::Changed,
682 counter.clone(),
683 false,
684 ));
685 registry.add(counting_watcher(
686 "alpha",
687 WatchPredicate::BecameTrue,
688 counter.clone(),
689 true,
690 ));
691
692 let keys = registry.observed_keys();
693 assert_eq!(keys.len(), 2);
694 assert!(keys.contains("alpha"));
695 assert!(keys.contains("beta"));
696 }
697
698 #[tokio::test]
701 async fn multiple_watchers_on_same_key() {
702 let counter_a = Arc::new(AtomicU32::new(0));
703 let counter_b = Arc::new(AtomicU32::new(0));
704 let mut registry = WatcherRegistry::new();
705
706 registry.add(counting_watcher(
708 "x",
709 WatchPredicate::Changed,
710 counter_a.clone(),
711 false,
712 ));
713
714 registry.add(counting_watcher(
716 "x",
717 WatchPredicate::ChangedTo(json!(42)),
718 counter_b.clone(),
719 false,
720 ));
721
722 let state = State::new();
723 let diffs = vec![("x".to_string(), json!(1), json!(42))];
724
725 let (_, concurrent) = registry.evaluate(&diffs, &state);
726 assert_eq!(concurrent.len(), 2);
728
729 for fut in concurrent {
730 fut.await;
731 }
732 assert_eq!(counter_a.load(Ordering::SeqCst), 1);
733 assert_eq!(counter_b.load(Ordering::SeqCst), 1);
734
735 let diffs2 = vec![("x".to_string(), json!(42), json!(99))];
737 let (_, concurrent2) = registry.evaluate(&diffs2, &state);
738 assert_eq!(concurrent2.len(), 1);
739
740 for fut in concurrent2 {
741 fut.await;
742 }
743 assert_eq!(counter_a.load(Ordering::SeqCst), 2);
744 assert_eq!(counter_b.load(Ordering::SeqCst), 1); }
746
747 #[tokio::test]
750 async fn action_receives_old_new_and_state() {
751 let mut registry = WatcherRegistry::new();
752 registry.add(recording_watcher("val", WatchPredicate::Changed, false));
753
754 let state = State::new();
755 let diffs = vec![("val".to_string(), json!("before"), json!("after"))];
756
757 let (_, concurrent) = registry.evaluate(&diffs, &state);
758 assert_eq!(concurrent.len(), 1);
759
760 for fut in concurrent {
761 fut.await;
762 }
763
764 assert_eq!(state.get_raw("recorded_old"), Some(json!("before")));
765 assert_eq!(state.get_raw("recorded_new"), Some(json!("after")));
766 }
767
768 #[test]
769 fn crossed_above_with_non_numeric_values_does_not_fire() {
770 let counter = Arc::new(AtomicU32::new(0));
771 let mut registry = WatcherRegistry::new();
772 registry.add(counting_watcher(
773 "x",
774 WatchPredicate::CrossedAbove(10.0),
775 counter.clone(),
776 false,
777 ));
778
779 let state = State::new();
780 let diffs = vec![("x".to_string(), json!("low"), json!("high"))];
781
782 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
783 assert!(blocking.is_empty());
784 assert!(concurrent.is_empty());
785 }
786
787 #[test]
788 fn became_true_does_not_fire_on_non_bool() {
789 let counter = Arc::new(AtomicU32::new(0));
790 let mut registry = WatcherRegistry::new();
791 registry.add(counting_watcher(
792 "x",
793 WatchPredicate::BecameTrue,
794 counter.clone(),
795 false,
796 ));
797
798 let state = State::new();
799 let diffs = vec![("x".to_string(), json!(0), json!("true"))];
801
802 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
803 assert!(blocking.is_empty());
804 assert!(concurrent.is_empty());
805 }
806
807 #[test]
808 fn empty_diffs_produce_no_futures() {
809 let counter = Arc::new(AtomicU32::new(0));
810 let mut registry = WatcherRegistry::new();
811 registry.add(counting_watcher(
812 "x",
813 WatchPredicate::Changed,
814 counter.clone(),
815 false,
816 ));
817
818 let state = State::new();
819 let diffs: Vec<(String, Value, Value)> = vec![];
820
821 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
822 assert!(blocking.is_empty());
823 assert!(concurrent.is_empty());
824 }
825
826 #[test]
827 fn default_creates_empty_registry() {
828 let registry = WatcherRegistry::default();
829 assert!(registry.observed_keys().is_empty());
830 }
831}