1use std::sync::Arc;
6
7#[derive(Clone)]
9pub struct StateTransform {
10 name: &'static str,
11 transform: Arc<dyn Fn(&mut serde_json::Value) + Send + Sync>,
12}
13
14impl StateTransform {
15 fn new(name: &'static str, f: impl Fn(&mut serde_json::Value) + Send + Sync + 'static) -> Self {
16 Self {
17 name,
18 transform: Arc::new(f),
19 }
20 }
21
22 pub fn apply(&self, state: &mut serde_json::Value) {
24 (self.transform)(state);
25 }
26
27 pub fn name(&self) -> &str {
29 self.name
30 }
31}
32
33impl std::fmt::Debug for StateTransform {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("StateTransform")
36 .field("name", &self.name)
37 .finish()
38 }
39}
40
41impl std::ops::Shr for StateTransform {
43 type Output = StateTransformChain;
44
45 fn shr(self, rhs: StateTransform) -> Self::Output {
46 StateTransformChain {
47 steps: vec![self, rhs],
48 }
49 }
50}
51
52#[derive(Clone)]
54pub struct StateTransformChain {
55 pub steps: Vec<StateTransform>,
57}
58
59impl StateTransformChain {
60 pub fn apply(&self, state: &mut serde_json::Value) {
62 for step in &self.steps {
63 step.apply(state);
64 }
65 }
66}
67
68impl std::ops::Shr<StateTransform> for StateTransformChain {
70 type Output = StateTransformChain;
71
72 fn shr(mut self, rhs: StateTransform) -> Self::Output {
73 self.steps.push(rhs);
74 self
75 }
76}
77
78pub struct S;
80
81impl S {
82 pub fn pick(keys: &[&str]) -> StateTransform {
84 let keys: Vec<String> = keys.iter().map(|k| k.to_string()).collect();
85 StateTransform::new("pick", move |state| {
86 if let Some(obj) = state.as_object_mut() {
87 obj.retain(|k, _| keys.contains(k));
88 }
89 })
90 }
91
92 pub fn rename(mappings: &[(&str, &str)]) -> StateTransform {
94 let mappings: Vec<(String, String)> = mappings
95 .iter()
96 .map(|(a, b)| (a.to_string(), b.to_string()))
97 .collect();
98 StateTransform::new("rename", move |state| {
99 if let Some(obj) = state.as_object_mut() {
100 for (from, to) in &mappings {
101 if let Some(val) = obj.remove(from) {
102 obj.insert(to.clone(), val);
103 }
104 }
105 }
106 })
107 }
108
109 pub fn merge(keys: &[&str], into: &str) -> StateTransform {
111 let keys: Vec<String> = keys.iter().map(|k| k.to_string()).collect();
112 let into = into.to_string();
113 StateTransform::new("merge", move |state| {
114 if let Some(obj) = state.as_object_mut() {
115 let mut merged = serde_json::Map::new();
116 for key in &keys {
117 if let Some(val) = obj.remove(key) {
118 merged.insert(key.clone(), val);
119 }
120 }
121 obj.insert(into.clone(), serde_json::Value::Object(merged));
122 }
123 })
124 }
125
126 pub fn defaults(defaults: serde_json::Value) -> StateTransform {
128 StateTransform::new("defaults", move |state| {
129 if let (Some(obj), Some(defaults_obj)) = (state.as_object_mut(), defaults.as_object()) {
130 for (k, v) in defaults_obj {
131 obj.entry(k.clone()).or_insert_with(|| v.clone());
132 }
133 }
134 })
135 }
136
137 pub fn map(f: impl Fn(&mut serde_json::Value) + Send + Sync + 'static) -> StateTransform {
139 StateTransform::new("map", f)
140 }
141
142 pub fn flatten(key: &str) -> StateTransform {
144 let key = key.to_string();
145 StateTransform::new("flatten", move |state| {
146 if let Some(obj) = state.as_object_mut() {
147 if let Some(serde_json::Value::Object(nested)) = obj.remove(&key) {
148 for (k, v) in nested {
149 obj.insert(k, v);
150 }
151 }
152 }
153 })
154 }
155
156 pub fn set(key: &str, value: serde_json::Value) -> StateTransform {
158 let key = key.to_string();
159 StateTransform::new("set", move |state| {
160 if let Some(obj) = state.as_object_mut() {
161 obj.insert(key.clone(), value.clone());
162 }
163 })
164 }
165
166 pub fn is_set(key: &str) -> impl Fn(&gemini_adk_rs::State) -> bool + Send + Sync + 'static {
177 let key = key.to_string();
178 move |s: &gemini_adk_rs::State| s.contains(&key)
179 }
180
181 pub fn is_true(key: &str) -> impl Fn(&gemini_adk_rs::State) -> bool + Send + Sync + 'static {
187 let key = key.to_string();
188 move |s: &gemini_adk_rs::State| s.get::<bool>(&key).unwrap_or(false)
189 }
190
191 pub fn eq(
197 key: &str,
198 expected: &str,
199 ) -> impl Fn(&gemini_adk_rs::State) -> bool + Send + Sync + 'static {
200 let key = key.to_string();
201 let expected = expected.to_string();
202 move |s: &gemini_adk_rs::State| {
203 s.get::<String>(&key)
204 .map(|v| v == expected)
205 .unwrap_or(false)
206 }
207 }
208
209 pub fn one_of(
215 key: &str,
216 values: &[&str],
217 ) -> impl Fn(&gemini_adk_rs::State) -> bool + Send + Sync + 'static {
218 let key = key.to_string();
219 let values: Vec<String> = values.iter().map(|v| v.to_string()).collect();
220 move |s: &gemini_adk_rs::State| s.get::<String>(&key).is_some_and(|v| values.contains(&v))
221 }
222
223 pub fn transform(
225 key: &str,
226 f: impl Fn(serde_json::Value) -> serde_json::Value + Send + Sync + 'static,
227 ) -> StateTransform {
228 let key = key.to_string();
229 StateTransform::new("transform", move |state| {
230 if let Some(obj) = state.as_object_mut() {
231 if let Some(val) = obj.remove(&key) {
232 obj.insert(key.clone(), f(val));
233 }
234 }
235 })
236 }
237
238 pub fn guard(
240 predicate: impl Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
241 msg: &str,
242 ) -> StateTransform {
243 let msg = msg.to_string();
244 StateTransform::new("guard", move |state| {
245 assert!(predicate(state), "{}", msg);
246 })
247 }
248
249 pub fn compute(
251 key: &str,
252 f: impl Fn(&serde_json::Value) -> serde_json::Value + Send + Sync + 'static,
253 ) -> StateTransform {
254 let key = key.to_string();
255 StateTransform::new("compute", move |state| {
256 let val = f(state);
257 if let Some(obj) = state.as_object_mut() {
258 obj.insert(key.clone(), val);
259 }
260 })
261 }
262
263 pub fn accumulate(source_key: &str, into: &str) -> StateTransform {
265 let source = source_key.to_string();
266 let into = into.to_string();
267 StateTransform::new("accumulate", move |state| {
268 if let Some(obj) = state.as_object_mut() {
269 if let Some(val) = obj.get(&source).cloned() {
270 let arr = obj
271 .entry(into.clone())
272 .or_insert_with(|| serde_json::Value::Array(Vec::new()));
273 if let Some(arr) = arr.as_array_mut() {
274 arr.push(val);
275 }
276 }
277 }
278 })
279 }
280
281 pub fn counter(key: &str, step: i64) -> StateTransform {
283 let key = key.to_string();
284 StateTransform::new("counter", move |state| {
285 if let Some(obj) = state.as_object_mut() {
286 let current = obj.get(&key).and_then(|v| v.as_i64()).unwrap_or(0);
287 obj.insert(key.clone(), serde_json::json!(current + step));
288 }
289 })
290 }
291
292 pub fn require(keys: &[&str]) -> StateTransform {
294 let keys: Vec<String> = keys.iter().map(|k| k.to_string()).collect();
295 StateTransform::new("require", move |state| {
296 if let Some(obj) = state.as_object() {
297 for key in &keys {
298 assert!(
299 obj.contains_key(key),
300 "Required key '{}' missing from state",
301 key
302 );
303 }
304 }
305 })
306 }
307
308 pub fn identity() -> StateTransform {
310 StateTransform::new("identity", |_| {})
311 }
312
313 pub fn when(
315 predicate: impl Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
316 inner: StateTransform,
317 ) -> StateTransform {
318 StateTransform::new("when", move |state| {
319 if predicate(state) {
320 inner.apply(state);
321 }
322 })
323 }
324
325 pub fn drop(keys: &[&str]) -> StateTransform {
327 let keys: Vec<String> = keys.iter().map(|k| k.to_string()).collect();
328 StateTransform::new("drop", move |state| {
329 if let Some(obj) = state.as_object_mut() {
330 for key in &keys {
331 obj.remove(key);
332 }
333 }
334 })
335 }
336
337 pub fn log(message: &str) -> StateTransform {
346 let message = message.to_string();
347 StateTransform::new("log", move |_state| {
348 eprintln!("[S::log] {}", message);
349 })
350 }
351
352 pub fn unflatten(key: &str) -> StateTransform {
362 let key = key.to_string();
363 StateTransform::new("unflatten", move |state| {
364 if let Some(obj) = state.as_object_mut() {
365 let prefix = format!("{}.", key);
366 let dotted: Vec<(String, serde_json::Value)> = obj
367 .keys()
368 .filter(|k| k.starts_with(&prefix))
369 .cloned()
370 .collect::<Vec<_>>()
371 .into_iter()
372 .filter_map(|k| obj.remove(&k).map(|v| (k, v)))
373 .collect();
374
375 if !dotted.is_empty() {
376 let nested = obj
377 .entry(key.clone())
378 .or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
379 if let Some(nested_obj) = nested.as_object_mut() {
380 for (k, v) in dotted {
381 let sub_key = k[prefix.len()..].to_string();
382 nested_obj.insert(sub_key, v);
383 }
384 }
385 }
386 }
387 })
388 }
389
390 pub fn zip(keys: &[&str], into: &str) -> StateTransform {
401 let keys: Vec<String> = keys.iter().map(|k| k.to_string()).collect();
402 let into = into.to_string();
403 StateTransform::new("zip", move |state| {
404 if let Some(obj) = state.as_object_mut() {
405 let arrays: Vec<&Vec<serde_json::Value>> = keys
406 .iter()
407 .filter_map(|k| obj.get(k).and_then(|v| v.as_array()))
408 .collect();
409
410 if arrays.len() == keys.len() {
411 let min_len = arrays.iter().map(|a| a.len()).min().unwrap_or(0);
412 let mut zipped = Vec::with_capacity(min_len);
413 for i in 0..min_len {
414 let tuple: Vec<serde_json::Value> =
415 arrays.iter().map(|a| a[i].clone()).collect();
416 zipped.push(serde_json::Value::Array(tuple));
417 }
418 obj.insert(into.clone(), serde_json::Value::Array(zipped));
419 }
420 }
421 })
422 }
423
424 pub fn group_by(source: &str, key: &str, into: &str) -> StateTransform {
435 let source = source.to_string();
436 let key = key.to_string();
437 let into = into.to_string();
438 StateTransform::new("group_by", move |state| {
439 if let Some(obj) = state.as_object_mut() {
440 if let Some(arr) = obj.get(&source).and_then(|v| v.as_array()) {
441 let mut groups: serde_json::Map<String, serde_json::Value> =
442 serde_json::Map::new();
443 for item in arr {
444 let group_key = item
445 .get(&key)
446 .and_then(|v| v.as_str())
447 .unwrap_or("_unknown")
448 .to_string();
449 let group = groups
450 .entry(group_key)
451 .or_insert_with(|| serde_json::Value::Array(Vec::new()));
452 if let Some(arr) = group.as_array_mut() {
453 arr.push(item.clone());
454 }
455 }
456 obj.insert(into.clone(), serde_json::Value::Object(groups));
457 }
458 }
459 })
460 }
461
462 pub fn history(key: &str, max: usize) -> StateTransform {
471 let key = key.to_string();
472 StateTransform::new("history", move |state| {
473 if let Some(obj) = state.as_object_mut() {
474 let history_key = format!("{}_history", key);
475 if let Some(val) = obj.get(&key).cloned() {
476 let arr = obj
477 .entry(history_key)
478 .or_insert_with(|| serde_json::Value::Array(Vec::new()));
479 if let Some(arr) = arr.as_array_mut() {
480 arr.push(val);
481 while arr.len() > max {
482 arr.remove(0);
483 }
484 }
485 }
486 }
487 })
488 }
489
490 pub fn validate(schema: serde_json::Value) -> StateTransform {
506 StateTransform::new("validate", move |state| {
507 if let Some(obj) = state.as_object() {
508 if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
510 for req in required {
511 if let Some(key) = req.as_str() {
512 assert!(
513 obj.contains_key(key),
514 "Validation failed: required key '{}' missing from state",
515 key
516 );
517 }
518 }
519 }
520 if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
522 for (key, prop_schema) in properties {
523 if let Some(val) = obj.get(key) {
524 if let Some(expected_type) =
525 prop_schema.get("type").and_then(|v| v.as_str())
526 {
527 let actual_ok = match expected_type {
528 "string" => val.is_string(),
529 "number" | "integer" => val.is_number(),
530 "boolean" => val.is_boolean(),
531 "array" => val.is_array(),
532 "object" => val.is_object(),
533 "null" => val.is_null(),
534 _ => true,
535 };
536 assert!(
537 actual_ok,
538 "Validation failed: key '{}' expected type '{}', got {:?}",
539 key, expected_type, val
540 );
541 }
542 }
543 }
544 }
545 }
546 })
547 }
548
549 pub fn branch(
561 predicate: impl Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
562 if_true: StateTransform,
563 if_false: StateTransform,
564 ) -> StateTransform {
565 StateTransform::new("branch", move |state| {
566 if predicate(state) {
567 if_true.apply(state);
568 } else {
569 if_false.apply(state);
570 }
571 })
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use serde_json::json;
579
580 #[test]
581 fn pick_keeps_only_specified_keys() {
582 let mut state = json!({"a": 1, "b": 2, "c": 3});
583 S::pick(&["a", "c"]).apply(&mut state);
584 assert_eq!(state, json!({"a": 1, "c": 3}));
585 }
586
587 #[test]
588 fn rename_renames_keys() {
589 let mut state = json!({"old_name": 42});
590 S::rename(&[("old_name", "new_name")]).apply(&mut state);
591 assert_eq!(state, json!({"new_name": 42}));
592 }
593
594 #[test]
595 fn merge_combines_keys() {
596 let mut state = json!({"x": 1, "y": 2, "z": 3});
597 S::merge(&["x", "y"], "combined").apply(&mut state);
598 assert_eq!(state, json!({"z": 3, "combined": {"x": 1, "y": 2}}));
599 }
600
601 #[test]
602 fn defaults_sets_missing() {
603 let mut state = json!({"existing": "yes"});
604 S::defaults(json!({"existing": "no", "missing": "added"})).apply(&mut state);
605 assert_eq!(state["existing"], "yes");
606 assert_eq!(state["missing"], "added");
607 }
608
609 #[test]
610 fn drop_removes_keys() {
611 let mut state = json!({"keep": 1, "remove": 2});
612 S::drop(&["remove"]).apply(&mut state);
613 assert_eq!(state, json!({"keep": 1}));
614 }
615
616 #[test]
617 fn map_custom_transform() {
618 let mut state = json!({"count": 5});
619 S::map(|s| {
620 if let Some(n) = s.get("count").and_then(|v| v.as_i64()) {
621 s["count"] = json!(n * 2);
622 }
623 })
624 .apply(&mut state);
625 assert_eq!(state["count"], 10);
626 }
627
628 #[test]
629 fn chain_with_shr() {
630 let chain = S::pick(&["a", "b"]) >> S::rename(&[("a", "x")]);
631 let mut state = json!({"a": 1, "b": 2, "c": 3});
632 chain.apply(&mut state);
633 assert_eq!(state, json!({"x": 1, "b": 2}));
634 }
635
636 #[test]
637 fn flatten_nested_object() {
638 let mut state = json!({"nested": {"x": 1, "y": 2}, "z": 3});
639 S::flatten("nested").apply(&mut state);
640 assert_eq!(state, json!({"x": 1, "y": 2, "z": 3}));
641 }
642
643 #[test]
644 fn flatten_missing_key_is_noop() {
645 let mut state = json!({"a": 1});
646 S::flatten("nonexistent").apply(&mut state);
647 assert_eq!(state, json!({"a": 1}));
648 }
649
650 #[test]
651 fn set_inserts_value() {
652 let mut state = json!({"a": 1});
653 S::set("b", json!(42)).apply(&mut state);
654 assert_eq!(state, json!({"a": 1, "b": 42}));
655 }
656
657 #[test]
658 fn set_overwrites_existing() {
659 let mut state = json!({"a": 1});
660 S::set("a", json!("replaced")).apply(&mut state);
661 assert_eq!(state, json!({"a": "replaced"}));
662 }
663
664 #[test]
665 fn chain_extends() {
666 let chain = S::pick(&["a"]) >> S::rename(&[("a", "b")]) >> S::defaults(json!({"c": 99}));
667 let mut state = json!({"a": 1, "x": 2});
668 chain.apply(&mut state);
669 assert_eq!(state, json!({"b": 1, "c": 99}));
670 }
671
672 #[test]
673 fn log_is_noop_on_state() {
674 let mut state = json!({"a": 1});
675 S::log("debug message").apply(&mut state);
676 assert_eq!(state, json!({"a": 1}));
677 }
678
679 #[test]
680 fn unflatten_groups_dotted_keys() {
681 let mut state = json!({"addr.city": "NYC", "addr.zip": "10001", "name": "Alice"});
682 S::unflatten("addr").apply(&mut state);
683 assert_eq!(
684 state,
685 json!({"name": "Alice", "addr": {"city": "NYC", "zip": "10001"}})
686 );
687 }
688
689 #[test]
690 fn unflatten_missing_prefix_is_noop() {
691 let mut state = json!({"a": 1});
692 S::unflatten("addr").apply(&mut state);
693 assert_eq!(state, json!({"a": 1}));
694 }
695
696 #[test]
697 fn zip_combines_arrays() {
698 let mut state = json!({"names": ["a", "b", "c"], "scores": [10, 20, 30]});
699 S::zip(&["names", "scores"], "zipped").apply(&mut state);
700 assert_eq!(state["zipped"], json!([["a", 10], ["b", 20], ["c", 30]]));
701 }
702
703 #[test]
704 fn zip_truncates_to_shortest() {
705 let mut state = json!({"a": [1, 2, 3], "b": [10, 20]});
706 S::zip(&["a", "b"], "z").apply(&mut state);
707 assert_eq!(state["z"], json!([[1, 10], [2, 20]]));
708 }
709
710 #[test]
711 fn group_by_groups_elements() {
712 let mut state = json!({
713 "items": [
714 {"type": "fruit", "name": "apple"},
715 {"type": "veg", "name": "carrot"},
716 {"type": "fruit", "name": "banana"}
717 ]
718 });
719 S::group_by("items", "type", "grouped").apply(&mut state);
720 let grouped = &state["grouped"];
721 assert_eq!(grouped["fruit"].as_array().unwrap().len(), 2);
722 assert_eq!(grouped["veg"].as_array().unwrap().len(), 1);
723 }
724
725 #[test]
726 fn history_tracks_values() {
727 let mut state = json!({"score": 10});
728 let t = S::history("score", 3);
729 t.apply(&mut state);
730 state["score"] = json!(20);
731 t.apply(&mut state);
732 state["score"] = json!(30);
733 t.apply(&mut state);
734 state["score"] = json!(40);
735 t.apply(&mut state);
736 assert_eq!(state["score_history"], json!([20, 30, 40]));
738 }
739
740 #[test]
741 fn validate_passes_valid_state() {
742 let mut state = json!({"name": "Alice", "age": 30});
743 S::validate(json!({
744 "required": ["name", "age"],
745 "properties": {
746 "name": {"type": "string"},
747 "age": {"type": "number"}
748 }
749 }))
750 .apply(&mut state);
751 }
753
754 #[test]
755 #[should_panic(expected = "required key 'missing' missing from state")]
756 fn validate_fails_missing_required() {
757 let mut state = json!({"name": "Alice"});
758 S::validate(json!({"required": ["name", "missing"]})).apply(&mut state);
759 }
760
761 #[test]
762 #[should_panic(expected = "expected type 'string'")]
763 fn validate_fails_wrong_type() {
764 let mut state = json!({"name": 42});
765 S::validate(json!({
766 "properties": {"name": {"type": "string"}}
767 }))
768 .apply(&mut state);
769 }
770
771 #[test]
772 fn branch_takes_true_path() {
773 let mut state = json!({"premium": true});
774 S::branch(
775 |s| s.get("premium").and_then(|v| v.as_bool()).unwrap_or(false),
776 S::set("tier", json!("gold")),
777 S::set("tier", json!("basic")),
778 )
779 .apply(&mut state);
780 assert_eq!(state["tier"], "gold");
781 }
782
783 #[test]
784 fn branch_takes_false_path() {
785 let mut state = json!({"premium": false});
786 S::branch(
787 |s| s.get("premium").and_then(|v| v.as_bool()).unwrap_or(false),
788 S::set("tier", json!("gold")),
789 S::set("tier", json!("basic")),
790 )
791 .apply(&mut state);
792 assert_eq!(state["tier"], "basic");
793 }
794}