gemini_adk_fluent_rs/live/phases.rs
1//! Phase machine, instruction templating, computed state, watchers, and
2//! temporal pattern configuration methods for `Live`.
3
4use std::future::Future;
5use std::sync::Arc;
6use std::time::Duration;
7
8use serde_json::Value;
9
10use gemini_adk_rs::live::{
11 ComputedVar, Phase, RateDetector, SustainedDetector, TemporalPattern, TurnCountDetector,
12 Watcher,
13};
14use gemini_adk_rs::State;
15use gemini_genai_rs::prelude::*;
16use gemini_genai_rs::session::SessionWriter;
17
18use crate::live_builders::{PhaseBuilder, PhaseDefaults, WatchBuilder};
19
20use super::Live;
21
22impl Live {
23 /// State-reactive system instruction template.
24 ///
25 /// Called after extractors run on each turn. If it returns `Some(instruction)`,
26 /// the system instruction is updated mid-session (deduped — same instruction
27 /// is not sent twice). Returns `None` to leave the instruction unchanged.
28 ///
29 /// # Example
30 /// ```ignore
31 /// .instruction_template(|state| {
32 /// let phase: String = state.get("phase").unwrap_or_default();
33 /// match phase.as_str() {
34 /// "ordering" => Some("Focus on taking the order accurately.".into()),
35 /// "confirming" => Some("Summarize and confirm the order.".into()),
36 /// _ => None,
37 /// }
38 /// })
39 /// ```
40 pub fn instruction_template(
41 mut self,
42 f: impl Fn(&gemini_adk_rs::State) -> Option<String> + Send + Sync + 'static,
43 ) -> Self {
44 self.callbacks.instruction_template = Some(Arc::new(f));
45 self
46 }
47
48 /// State-reactive instruction amendment (additive, not replacement).
49 ///
50 /// Unlike `instruction_template` (which replaces the entire instruction),
51 /// this appends to the current phase instruction. The developer never needs
52 /// to know or repeat the base instruction.
53 ///
54 /// # Example
55 /// ```ignore
56 /// .instruction_amendment(|state| {
57 /// let risk: String = state.get("derived:risk").unwrap_or_default();
58 /// if risk == "high" {
59 /// Some("[IMPORTANT: Use empathetic language. Do not threaten.]".into())
60 /// } else {
61 /// None
62 /// }
63 /// })
64 /// ```
65 pub fn instruction_amendment(
66 mut self,
67 f: impl Fn(&gemini_adk_rs::State) -> Option<String> + Send + Sync + 'static,
68 ) -> Self {
69 self.callbacks.instruction_amendment = Some(Arc::new(f));
70 self
71 }
72
73 // -- Computed State --
74
75 /// Register a computed (derived) state variable.
76 ///
77 /// The compute function receives the full `State` and returns `Some(value)`
78 /// to write to `derived:{key}`, or `None` to skip.
79 pub fn computed(
80 mut self,
81 key: impl Into<String>,
82 deps: &[&str],
83 f: impl Fn(&State) -> Option<Value> + Send + Sync + 'static,
84 ) -> Self {
85 self.computed.register(ComputedVar {
86 key: key.into(),
87 dependencies: deps.iter().map(|s| s.to_string()).collect(),
88 compute: Arc::new(f),
89 });
90 self
91 }
92
93 // -- Phase Machine --
94
95 /// Set default modifiers and `prompt_on_enter` inherited by all phases.
96 ///
97 /// Phase-specific modifiers are applied *after* defaults, so they extend (not replace).
98 ///
99 /// ```ignore
100 /// Live::builder()
101 /// .phase_defaults(|p| {
102 /// p.with_state(&["emotional_state", "risk_level"])
103 /// .when(risk_is_elevated, "Show extra empathy.")
104 /// .prompt_on_enter(true)
105 /// })
106 /// .phase("greet").instruction("...").done()
107 /// .phase("close").instruction("...").done()
108 /// // Both phases inherit the modifiers and prompt_on_enter.
109 /// ```
110 pub fn phase_defaults(mut self, f: impl FnOnce(PhaseDefaults) -> PhaseDefaults) -> Self {
111 let defaults = f(PhaseDefaults::new());
112 self.phase_default_modifiers = defaults.modifiers;
113 self.phase_default_prompt_on_enter = defaults.prompt_on_enter;
114 self
115 }
116
117 /// Start building a conversation phase.
118 ///
119 /// Returns a [`PhaseBuilder`] that flows back to this `Live` via `.done()`.
120 pub fn phase(self, name: impl Into<String>) -> PhaseBuilder {
121 PhaseBuilder::new(self, name)
122 }
123
124 /// Set the initial phase name (must match a registered phase).
125 pub fn initial_phase(mut self, name: impl Into<String>) -> Self {
126 self.initial_phase = Some(name.into());
127 self
128 }
129
130 /// Internal method called by [`PhaseBuilder::done`].
131 pub(crate) fn add_phase(&mut self, phase: Phase) {
132 self.phases.push(phase);
133 }
134
135 // -- Watchers --
136
137 /// Start building a state watcher.
138 ///
139 /// Returns a [`WatchBuilder`] that flows back to this `Live` via `.then()`.
140 pub fn watch(self, key: impl Into<String>) -> WatchBuilder {
141 WatchBuilder::new(self, key)
142 }
143
144 /// Internal method called by [`WatchBuilder::then`].
145 pub(crate) fn add_watcher(&mut self, watcher: Watcher) {
146 self.watchers.add(watcher);
147 }
148
149 // -- Temporal Patterns --
150
151 /// Register a sustained condition pattern.
152 ///
153 /// Fires when the condition remains true for at least `duration`.
154 pub fn when_sustained<F, Fut>(
155 mut self,
156 name: impl Into<String>,
157 condition: impl Fn(&State) -> bool + Send + Sync + 'static,
158 duration: Duration,
159 action: F,
160 ) -> Self
161 where
162 F: Fn(State, Arc<dyn SessionWriter>) -> Fut + Send + Sync + 'static,
163 Fut: Future<Output = ()> + Send + 'static,
164 {
165 let detector = SustainedDetector::new(Arc::new(condition), duration);
166 self.temporal.add(TemporalPattern::new(
167 name,
168 Box::new(detector),
169 Arc::new(move |s, w| Box::pin(action(s, w))),
170 None,
171 ));
172 self
173 }
174
175 /// Register a rate detection pattern.
176 ///
177 /// Fires when at least `count` matching events occur within `window`.
178 pub fn when_rate<F, Fut>(
179 mut self,
180 name: impl Into<String>,
181 filter: impl Fn(&SessionEvent) -> bool + Send + Sync + 'static,
182 count: u32,
183 window: Duration,
184 action: F,
185 ) -> Self
186 where
187 F: Fn(State, Arc<dyn SessionWriter>) -> Fut + Send + Sync + 'static,
188 Fut: Future<Output = ()> + Send + 'static,
189 {
190 let detector = RateDetector::new(Arc::new(filter), count, window);
191 self.temporal.add(TemporalPattern::new(
192 name,
193 Box::new(detector),
194 Arc::new(move |s, w| Box::pin(action(s, w))),
195 None,
196 ));
197 self
198 }
199
200 /// Register a turn count pattern.
201 ///
202 /// Fires when the condition is true for `turn_count` consecutive turns.
203 pub fn when_turns<F, Fut>(
204 mut self,
205 name: impl Into<String>,
206 condition: impl Fn(&State) -> bool + Send + Sync + 'static,
207 turn_count: u32,
208 action: F,
209 ) -> Self
210 where
211 F: Fn(State, Arc<dyn SessionWriter>) -> Fut + Send + Sync + 'static,
212 Fut: Future<Output = ()> + Send + 'static,
213 {
214 let detector = TurnCountDetector::new(Arc::new(condition), turn_count);
215 self.temporal.add(TemporalPattern::new(
216 name,
217 Box::new(detector),
218 Arc::new(move |s, w| Box::pin(action(s, w))),
219 None,
220 ));
221 self
222 }
223}