1use std::collections::{HashMap, HashSet};
11use std::sync::Arc;
12
13use serde_json::Value;
14
15use super::BoxFuture;
16use crate::state::{State, StateMutation};
17
18pub type PredicateFn = Arc<dyn Fn(&Value, &Value) -> bool + Send + Sync>;
22
23pub enum WatchPredicate {
25 Changed,
27 ChangedTo(Value),
29 ChangedFrom(Value),
31 CrossedAbove(f64),
33 CrossedBelow(f64),
35 BecameTrue,
37 BecameFalse,
39 Custom(PredicateFn),
41}
42
43impl std::fmt::Debug for WatchPredicate {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 Self::Changed => write!(f, "Changed"),
47 Self::ChangedTo(v) => write!(f, "ChangedTo({v})"),
48 Self::ChangedFrom(v) => write!(f, "ChangedFrom({v})"),
49 Self::CrossedAbove(t) => write!(f, "CrossedAbove({t})"),
50 Self::CrossedBelow(t) => write!(f, "CrossedBelow({t})"),
51 Self::BecameTrue => write!(f, "BecameTrue"),
52 Self::BecameFalse => write!(f, "BecameFalse"),
53 Self::Custom(_) => write!(f, "Custom(<fn>)"),
54 }
55 }
56}
57
58impl WatchPredicate {
59 fn matches(&self, old: &Value, new: &Value) -> bool {
61 match self {
62 WatchPredicate::Changed => true,
63 WatchPredicate::ChangedTo(val) => new == val,
64 WatchPredicate::ChangedFrom(val) => old == val,
65 WatchPredicate::CrossedAbove(threshold) => match (as_f64(old), as_f64(new)) {
66 (Some(o), Some(n)) => o < *threshold && n >= *threshold,
67 _ => false,
68 },
69 WatchPredicate::CrossedBelow(threshold) => match (as_f64(old), as_f64(new)) {
70 (Some(o), Some(n)) => o >= *threshold && n < *threshold,
71 _ => false,
72 },
73 WatchPredicate::BecameTrue => old != &Value::Bool(true) && new == &Value::Bool(true),
74 WatchPredicate::BecameFalse => old == &Value::Bool(true) && new != &Value::Bool(true),
75 WatchPredicate::Custom(f) => f(old, new),
76 }
77 }
78}
79
80pub struct Watcher {
85 pub key: String,
87 pub predicate: WatchPredicate,
89 pub action: Arc<dyn Fn(Value, Value, State) -> BoxFuture<()> + Send + Sync>,
91 pub blocking: bool,
94}
95
96impl std::fmt::Debug for Watcher {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("Watcher")
99 .field("key", &self.key)
100 .field("predicate", &self.predicate)
101 .field("blocking", &self.blocking)
102 .finish_non_exhaustive()
103 }
104}
105
106pub struct WatcherRegistry {
110 watchers: Vec<Watcher>,
111 observed_keys: HashSet<String>,
113}
114
115impl Default for WatcherRegistry {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121impl WatcherRegistry {
122 pub fn new() -> Self {
124 Self {
125 watchers: Vec::new(),
126 observed_keys: HashSet::new(),
127 }
128 }
129
130 pub fn add(&mut self, watcher: Watcher) {
132 self.observed_keys.insert(watcher.key.clone());
133 self.watchers.push(watcher);
134 }
135
136 pub fn observed_keys(&self) -> &HashSet<String> {
141 &self.observed_keys
142 }
143
144 pub fn evaluate(
152 &self,
153 diffs: &[(String, Value, Value)],
154 state: &State,
155 ) -> (Vec<BoxFuture<()>>, Vec<BoxFuture<()>>) {
156 let mut blocking = Vec::new();
157 let mut concurrent = Vec::new();
158
159 for (key, old, new) in diffs {
160 for watcher in &self.watchers {
161 if watcher.key == *key && watcher.predicate.matches(old, new) {
162 let fut = (watcher.action)(old.clone(), new.clone(), state.clone());
163 if watcher.blocking {
164 blocking.push(fut);
165 } else {
166 concurrent.push(fut);
167 }
168 }
169 }
170 }
171
172 (blocking, concurrent)
173 }
174
175 pub fn evaluate_mutations(
181 &self,
182 mutations: &[StateMutation],
183 state: &State,
184 ) -> (Vec<BoxFuture<()>>, Vec<BoxFuture<()>>) {
185 let mut net: HashMap<String, (Option<Value>, Option<Value>)> = HashMap::new();
186
187 for mutation in mutations {
188 if !self.observed_keys.contains(&mutation.key) {
189 continue;
190 }
191
192 net.entry(mutation.key.clone())
193 .and_modify(|(_, new)| {
194 *new = mutation.new.clone();
195 })
196 .or_insert_with(|| (mutation.old.clone(), mutation.new.clone()));
197 }
198
199 let diffs: Vec<(String, Value, Value)> = net
200 .into_iter()
201 .filter_map(|(key, (old, new))| {
202 if old == new {
203 None
204 } else {
205 Some((key, old.unwrap_or(Value::Null), new.unwrap_or(Value::Null)))
206 }
207 })
208 .collect();
209
210 self.evaluate(&diffs, state)
211 }
212}
213
214fn as_f64(v: &Value) -> Option<f64> {
218 match v {
219 Value::Number(n) => n.as_f64(),
220 _ => None,
221 }
222}
223
224#[cfg(test)]
227mod tests {
228 use super::*;
229 use serde_json::json;
230 use std::sync::atomic::{AtomicU32, Ordering};
231
232 fn counting_watcher(
234 key: &str,
235 predicate: WatchPredicate,
236 counter: Arc<AtomicU32>,
237 blocking: bool,
238 ) -> Watcher {
239 Watcher {
240 key: key.to_string(),
241 predicate,
242 action: Arc::new(move |_old, _new, _state| {
243 let c = counter.clone();
244 Box::pin(async move {
245 c.fetch_add(1, Ordering::SeqCst);
246 })
247 }),
248 blocking,
249 }
250 }
251
252 fn recording_watcher(key: &str, predicate: WatchPredicate, blocking: bool) -> Watcher {
254 Watcher {
255 key: key.to_string(),
256 predicate,
257 action: Arc::new(|old, new, state| {
258 Box::pin(async move {
259 state.set("recorded_old", old);
260 state.set("recorded_new", new);
261 })
262 }),
263 blocking,
264 }
265 }
266
267 #[tokio::test]
270 async fn changed_fires_on_any_diff() {
271 let counter = Arc::new(AtomicU32::new(0));
272 let mut registry = WatcherRegistry::new();
273 registry.add(counting_watcher(
274 "x",
275 WatchPredicate::Changed,
276 counter.clone(),
277 false,
278 ));
279
280 let state = State::new();
281 let diffs = vec![("x".to_string(), json!(1), json!(2))];
282
283 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
284 assert!(blocking.is_empty());
285 assert_eq!(concurrent.len(), 1);
286
287 for fut in concurrent {
288 fut.await;
289 }
290 assert_eq!(counter.load(Ordering::SeqCst), 1);
291 }
292
293 #[tokio::test]
296 async fn changed_to_fires_when_new_value_matches() {
297 let counter = Arc::new(AtomicU32::new(0));
298 let mut registry = WatcherRegistry::new();
299 registry.add(counting_watcher(
300 "status",
301 WatchPredicate::ChangedTo(json!("active")),
302 counter.clone(),
303 false,
304 ));
305
306 let state = State::new();
307 let diffs = vec![("status".to_string(), json!("inactive"), json!("active"))];
308
309 let (_, concurrent) = registry.evaluate(&diffs, &state);
310 assert_eq!(concurrent.len(), 1);
311
312 for fut in concurrent {
313 fut.await;
314 }
315 assert_eq!(counter.load(Ordering::SeqCst), 1);
316 }
317
318 #[tokio::test]
321 async fn changed_to_does_not_fire_when_new_value_differs() {
322 let counter = Arc::new(AtomicU32::new(0));
323 let mut registry = WatcherRegistry::new();
324 registry.add(counting_watcher(
325 "status",
326 WatchPredicate::ChangedTo(json!("active")),
327 counter.clone(),
328 false,
329 ));
330
331 let state = State::new();
332 let diffs = vec![("status".to_string(), json!("inactive"), json!("pending"))];
333
334 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
335 assert!(blocking.is_empty());
336 assert!(concurrent.is_empty());
337 assert_eq!(counter.load(Ordering::SeqCst), 0);
338 }
339
340 #[tokio::test]
343 async fn changed_from_fires_when_old_value_matches() {
344 let counter = Arc::new(AtomicU32::new(0));
345 let mut registry = WatcherRegistry::new();
346 registry.add(counting_watcher(
347 "mode",
348 WatchPredicate::ChangedFrom(json!("draft")),
349 counter.clone(),
350 false,
351 ));
352
353 let state = State::new();
354 let diffs = vec![("mode".to_string(), json!("draft"), json!("published"))];
356
357 let (_, concurrent) = registry.evaluate(&diffs, &state);
358 assert_eq!(concurrent.len(), 1);
359
360 for fut in concurrent {
361 fut.await;
362 }
363 assert_eq!(counter.load(Ordering::SeqCst), 1);
364
365 let diffs2 = vec![("mode".to_string(), json!("published"), json!("archived"))];
367 let (b, c) = registry.evaluate(&diffs2, &state);
368 assert!(b.is_empty());
369 assert!(c.is_empty());
370 assert_eq!(counter.load(Ordering::SeqCst), 1);
371 }
372
373 #[tokio::test]
376 async fn crossed_above_fires_on_upward_crossing() {
377 let counter = Arc::new(AtomicU32::new(0));
378 let mut registry = WatcherRegistry::new();
379 registry.add(counting_watcher(
380 "temp",
381 WatchPredicate::CrossedAbove(100.0),
382 counter.clone(),
383 false,
384 ));
385
386 let state = State::new();
387 let diffs = vec![("temp".to_string(), json!(95.0), json!(105.0))];
389
390 let (_, concurrent) = registry.evaluate(&diffs, &state);
391 assert_eq!(concurrent.len(), 1);
392
393 for fut in concurrent {
394 fut.await;
395 }
396 assert_eq!(counter.load(Ordering::SeqCst), 1);
397 }
398
399 #[tokio::test]
402 async fn crossed_above_does_not_fire_when_both_above() {
403 let counter = Arc::new(AtomicU32::new(0));
404 let mut registry = WatcherRegistry::new();
405 registry.add(counting_watcher(
406 "temp",
407 WatchPredicate::CrossedAbove(100.0),
408 counter.clone(),
409 false,
410 ));
411
412 let state = State::new();
413 let diffs = vec![("temp".to_string(), json!(110.0), json!(120.0))];
415
416 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
417 assert!(blocking.is_empty());
418 assert!(concurrent.is_empty());
419 assert_eq!(counter.load(Ordering::SeqCst), 0);
420 }
421
422 #[tokio::test]
425 async fn crossed_below_fires_on_downward_crossing() {
426 let counter = Arc::new(AtomicU32::new(0));
427 let mut registry = WatcherRegistry::new();
428 registry.add(counting_watcher(
429 "battery",
430 WatchPredicate::CrossedBelow(20.0),
431 counter.clone(),
432 false,
433 ));
434
435 let state = State::new();
436 let diffs = vec![("battery".to_string(), json!(25.0), json!(15.0))];
438
439 let (_, concurrent) = registry.evaluate(&diffs, &state);
440 assert_eq!(concurrent.len(), 1);
441
442 for fut in concurrent {
443 fut.await;
444 }
445 assert_eq!(counter.load(Ordering::SeqCst), 1);
446 }
447
448 #[tokio::test]
451 async fn became_true_fires_on_false_to_true() {
452 let counter = Arc::new(AtomicU32::new(0));
453 let mut registry = WatcherRegistry::new();
454 registry.add(counting_watcher(
455 "flag",
456 WatchPredicate::BecameTrue,
457 counter.clone(),
458 false,
459 ));
460
461 let state = State::new();
462 let diffs = vec![("flag".to_string(), json!(false), json!(true))];
463
464 let (_, concurrent) = registry.evaluate(&diffs, &state);
465 assert_eq!(concurrent.len(), 1);
466
467 for fut in concurrent {
468 fut.await;
469 }
470 assert_eq!(counter.load(Ordering::SeqCst), 1);
471 }
472
473 #[tokio::test]
476 async fn became_false_fires_on_true_to_false() {
477 let counter = Arc::new(AtomicU32::new(0));
478 let mut registry = WatcherRegistry::new();
479 registry.add(counting_watcher(
480 "flag",
481 WatchPredicate::BecameFalse,
482 counter.clone(),
483 false,
484 ));
485
486 let state = State::new();
487 let diffs = vec![("flag".to_string(), json!(true), json!(false))];
488
489 let (_, concurrent) = registry.evaluate(&diffs, &state);
490 assert_eq!(concurrent.len(), 1);
491
492 for fut in concurrent {
493 fut.await;
494 }
495 assert_eq!(counter.load(Ordering::SeqCst), 1);
496 }
497
498 #[tokio::test]
501 async fn custom_predicate_fires_when_fn_returns_true() {
502 let counter = Arc::new(AtomicU32::new(0));
503 let mut registry = WatcherRegistry::new();
504 registry.add(counting_watcher(
505 "score",
506 WatchPredicate::Custom(Arc::new(|old, new| {
507 match (as_f64(old), as_f64(new)) {
509 (Some(o), Some(n)) => (n - o * 2.0).abs() < f64::EPSILON,
510 _ => false,
511 }
512 })),
513 counter.clone(),
514 false,
515 ));
516
517 let state = State::new();
518 let diffs = vec![("score".to_string(), json!(5.0), json!(10.0))];
520
521 let (_, concurrent) = registry.evaluate(&diffs, &state);
522 assert_eq!(concurrent.len(), 1);
523
524 for fut in concurrent {
525 fut.await;
526 }
527 assert_eq!(counter.load(Ordering::SeqCst), 1);
528
529 let diffs2 = vec![("score".to_string(), json!(5.0), json!(11.0))];
531 let (b, c) = registry.evaluate(&diffs2, &state);
532 assert!(b.is_empty());
533 assert!(c.is_empty());
534 }
535
536 #[tokio::test]
539 async fn evaluate_separates_blocking_and_concurrent() {
540 let blocking_counter = Arc::new(AtomicU32::new(0));
541 let concurrent_counter = Arc::new(AtomicU32::new(0));
542 let mut registry = WatcherRegistry::new();
543
544 registry.add(counting_watcher(
546 "x",
547 WatchPredicate::Changed,
548 blocking_counter.clone(),
549 true,
550 ));
551
552 registry.add(counting_watcher(
554 "x",
555 WatchPredicate::Changed,
556 concurrent_counter.clone(),
557 false,
558 ));
559
560 let state = State::new();
561 let diffs = vec![("x".to_string(), json!(1), json!(2))];
562
563 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
564 assert_eq!(blocking.len(), 1);
565 assert_eq!(concurrent.len(), 1);
566
567 for fut in blocking {
569 fut.await;
570 }
571 for fut in concurrent {
572 fut.await;
573 }
574
575 assert_eq!(blocking_counter.load(Ordering::SeqCst), 1);
576 assert_eq!(concurrent_counter.load(Ordering::SeqCst), 1);
577 }
578
579 #[test]
582 fn evaluate_with_no_matching_diffs_returns_empty() {
583 let counter = Arc::new(AtomicU32::new(0));
584 let mut registry = WatcherRegistry::new();
585 registry.add(counting_watcher(
586 "x",
587 WatchPredicate::Changed,
588 counter.clone(),
589 false,
590 ));
591
592 let state = State::new();
593 let diffs = vec![("y".to_string(), json!(1), json!(2))];
595
596 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
597 assert!(blocking.is_empty());
598 assert!(concurrent.is_empty());
599 }
600
601 #[tokio::test]
602 async fn evaluate_mutations_collapses_to_net_diff() {
603 let counter = Arc::new(AtomicU32::new(0));
604 let mut registry = WatcherRegistry::new();
605 registry.add(counting_watcher(
606 "x",
607 WatchPredicate::ChangedTo(json!(3)),
608 counter.clone(),
609 false,
610 ));
611
612 let state = State::new();
613 let cursor = state.mutation_cursor();
614 state.set("x", 1);
615 state.set("x", 2);
616 state.set("x", 3);
617 state.set("ignored", 10);
618
619 let (_, concurrent) = registry.evaluate_mutations(&state.mutations_since(cursor), &state);
620 assert_eq!(concurrent.len(), 1);
621 for fut in concurrent {
622 fut.await;
623 }
624 assert_eq!(counter.load(Ordering::SeqCst), 1);
625 }
626
627 #[test]
628 fn evaluate_mutations_ignores_net_noop() {
629 let counter = Arc::new(AtomicU32::new(0));
630 let mut registry = WatcherRegistry::new();
631 registry.add(counting_watcher(
632 "x",
633 WatchPredicate::Changed,
634 counter,
635 false,
636 ));
637
638 let state = State::new();
639 state.set("x", 1);
640 let cursor = state.mutation_cursor();
641 state.set("x", 2);
642 state.set("x", 1);
643
644 let (blocking, concurrent) =
645 registry.evaluate_mutations(&state.mutations_since(cursor), &state);
646 assert!(blocking.is_empty());
647 assert!(concurrent.is_empty());
648 }
649
650 #[test]
653 fn observed_keys_tracks_added_watcher_keys() {
654 let counter = Arc::new(AtomicU32::new(0));
655 let mut registry = WatcherRegistry::new();
656
657 assert!(registry.observed_keys().is_empty());
658
659 registry.add(counting_watcher(
660 "alpha",
661 WatchPredicate::Changed,
662 counter.clone(),
663 false,
664 ));
665 registry.add(counting_watcher(
666 "beta",
667 WatchPredicate::Changed,
668 counter.clone(),
669 false,
670 ));
671 registry.add(counting_watcher(
672 "alpha",
673 WatchPredicate::BecameTrue,
674 counter.clone(),
675 true,
676 ));
677
678 let keys = registry.observed_keys();
679 assert_eq!(keys.len(), 2);
680 assert!(keys.contains("alpha"));
681 assert!(keys.contains("beta"));
682 }
683
684 #[tokio::test]
687 async fn multiple_watchers_on_same_key() {
688 let counter_a = Arc::new(AtomicU32::new(0));
689 let counter_b = Arc::new(AtomicU32::new(0));
690 let mut registry = WatcherRegistry::new();
691
692 registry.add(counting_watcher(
694 "x",
695 WatchPredicate::Changed,
696 counter_a.clone(),
697 false,
698 ));
699
700 registry.add(counting_watcher(
702 "x",
703 WatchPredicate::ChangedTo(json!(42)),
704 counter_b.clone(),
705 false,
706 ));
707
708 let state = State::new();
709 let diffs = vec![("x".to_string(), json!(1), json!(42))];
710
711 let (_, concurrent) = registry.evaluate(&diffs, &state);
712 assert_eq!(concurrent.len(), 2);
714
715 for fut in concurrent {
716 fut.await;
717 }
718 assert_eq!(counter_a.load(Ordering::SeqCst), 1);
719 assert_eq!(counter_b.load(Ordering::SeqCst), 1);
720
721 let diffs2 = vec![("x".to_string(), json!(42), json!(99))];
723 let (_, concurrent2) = registry.evaluate(&diffs2, &state);
724 assert_eq!(concurrent2.len(), 1);
725
726 for fut in concurrent2 {
727 fut.await;
728 }
729 assert_eq!(counter_a.load(Ordering::SeqCst), 2);
730 assert_eq!(counter_b.load(Ordering::SeqCst), 1); }
732
733 #[tokio::test]
736 async fn action_receives_old_new_and_state() {
737 let mut registry = WatcherRegistry::new();
738 registry.add(recording_watcher("val", WatchPredicate::Changed, false));
739
740 let state = State::new();
741 let diffs = vec![("val".to_string(), json!("before"), json!("after"))];
742
743 let (_, concurrent) = registry.evaluate(&diffs, &state);
744 assert_eq!(concurrent.len(), 1);
745
746 for fut in concurrent {
747 fut.await;
748 }
749
750 assert_eq!(state.get_raw("recorded_old"), Some(json!("before")));
751 assert_eq!(state.get_raw("recorded_new"), Some(json!("after")));
752 }
753
754 #[test]
755 fn crossed_above_with_non_numeric_values_does_not_fire() {
756 let counter = Arc::new(AtomicU32::new(0));
757 let mut registry = WatcherRegistry::new();
758 registry.add(counting_watcher(
759 "x",
760 WatchPredicate::CrossedAbove(10.0),
761 counter.clone(),
762 false,
763 ));
764
765 let state = State::new();
766 let diffs = vec![("x".to_string(), json!("low"), json!("high"))];
767
768 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
769 assert!(blocking.is_empty());
770 assert!(concurrent.is_empty());
771 }
772
773 #[test]
774 fn became_true_does_not_fire_on_non_bool() {
775 let counter = Arc::new(AtomicU32::new(0));
776 let mut registry = WatcherRegistry::new();
777 registry.add(counting_watcher(
778 "x",
779 WatchPredicate::BecameTrue,
780 counter.clone(),
781 false,
782 ));
783
784 let state = State::new();
785 let diffs = vec![("x".to_string(), json!(0), json!("true"))];
787
788 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
789 assert!(blocking.is_empty());
790 assert!(concurrent.is_empty());
791 }
792
793 #[test]
794 fn empty_diffs_produce_no_futures() {
795 let counter = Arc::new(AtomicU32::new(0));
796 let mut registry = WatcherRegistry::new();
797 registry.add(counting_watcher(
798 "x",
799 WatchPredicate::Changed,
800 counter.clone(),
801 false,
802 ));
803
804 let state = State::new();
805 let diffs: Vec<(String, Value, Value)> = vec![];
806
807 let (blocking, concurrent) = registry.evaluate(&diffs, &state);
808 assert!(blocking.is_empty());
809 assert!(concurrent.is_empty());
810 }
811
812 #[test]
813 fn default_creates_empty_registry() {
814 let registry = WatcherRegistry::default();
815 assert!(registry.observed_keys().is_empty());
816 }
817}