gemini_adk_fluent_rs/live/
phases.rs

1//! Phase machine, instruction templating, computed state, watchers, and
2//! temporal pattern configuration methods for `Live`.
3
4use std::future::Future;
5use std::sync::Arc;
6use std::time::Duration;
7
8use serde_json::Value;
9
10use gemini_adk_rs::live::{
11    ComputedVar, Phase, RateDetector, SustainedDetector, TemporalPattern, TurnCountDetector,
12    Watcher,
13};
14use gemini_adk_rs::State;
15use gemini_genai_rs::prelude::*;
16use gemini_genai_rs::session::SessionWriter;
17
18use crate::live_builders::{PhaseBuilder, PhaseDefaults, WatchBuilder};
19
20use super::Live;
21
22impl Live {
23    /// State-reactive system instruction template.
24    ///
25    /// Called after extractors run on each turn. If it returns `Some(instruction)`,
26    /// the system instruction is updated mid-session (deduped — same instruction
27    /// is not sent twice). Returns `None` to leave the instruction unchanged.
28    ///
29    /// # Example
30    /// ```ignore
31    /// .instruction_template(|state| {
32    ///     let phase: String = state.get("phase").unwrap_or_default();
33    ///     match phase.as_str() {
34    ///         "ordering" => Some("Focus on taking the order accurately.".into()),
35    ///         "confirming" => Some("Summarize and confirm the order.".into()),
36    ///         _ => None,
37    ///     }
38    /// })
39    /// ```
40    pub fn instruction_template(
41        mut self,
42        f: impl Fn(&gemini_adk_rs::State) -> Option<String> + Send + Sync + 'static,
43    ) -> Self {
44        self.callbacks.instruction_template = Some(Arc::new(f));
45        self
46    }
47
48    /// State-reactive instruction amendment (additive, not replacement).
49    ///
50    /// Unlike `instruction_template` (which replaces the entire instruction),
51    /// this appends to the current phase instruction. The developer never needs
52    /// to know or repeat the base instruction.
53    ///
54    /// # Example
55    /// ```ignore
56    /// .instruction_amendment(|state| {
57    ///     let risk: String = state.get("derived:risk").unwrap_or_default();
58    ///     if risk == "high" {
59    ///         Some("[IMPORTANT: Use empathetic language. Do not threaten.]".into())
60    ///     } else {
61    ///         None
62    ///     }
63    /// })
64    /// ```
65    pub fn instruction_amendment(
66        mut self,
67        f: impl Fn(&gemini_adk_rs::State) -> Option<String> + Send + Sync + 'static,
68    ) -> Self {
69        self.callbacks.instruction_amendment = Some(Arc::new(f));
70        self
71    }
72
73    // -- Computed State --
74
75    /// Register a computed (derived) state variable.
76    ///
77    /// The compute function receives the full `State` and returns `Some(value)`
78    /// to write to `derived:{key}`, or `None` to skip.
79    pub fn computed(
80        mut self,
81        key: impl Into<String>,
82        deps: &[&str],
83        f: impl Fn(&State) -> Option<Value> + Send + Sync + 'static,
84    ) -> Self {
85        self.computed.register(ComputedVar {
86            key: key.into(),
87            dependencies: deps.iter().map(|s| s.to_string()).collect(),
88            compute: Arc::new(f),
89        });
90        self
91    }
92
93    // -- Phase Machine --
94
95    /// Set default modifiers and `prompt_on_enter` inherited by all phases.
96    ///
97    /// Phase-specific modifiers are applied *after* defaults, so they extend (not replace).
98    ///
99    /// ```ignore
100    /// Live::builder()
101    ///     .phase_defaults(|p| {
102    ///         p.with_state(&["emotional_state", "risk_level"])
103    ///          .when(risk_is_elevated, "Show extra empathy.")
104    ///          .prompt_on_enter(true)
105    ///     })
106    ///     .phase("greet").instruction("...").done()
107    ///     .phase("close").instruction("...").done()
108    ///     // Both phases inherit the modifiers and prompt_on_enter.
109    /// ```
110    pub fn phase_defaults(mut self, f: impl FnOnce(PhaseDefaults) -> PhaseDefaults) -> Self {
111        let defaults = f(PhaseDefaults::new());
112        self.phase_default_modifiers = defaults.modifiers;
113        self.phase_default_prompt_on_enter = defaults.prompt_on_enter;
114        self
115    }
116
117    /// Start building a conversation phase.
118    ///
119    /// Returns a [`PhaseBuilder`] that flows back to this `Live` via `.done()`.
120    pub fn phase(self, name: impl Into<String>) -> PhaseBuilder {
121        PhaseBuilder::new(self, name)
122    }
123
124    /// Set the initial phase name (must match a registered phase).
125    pub fn initial_phase(mut self, name: impl Into<String>) -> Self {
126        self.initial_phase = Some(name.into());
127        self
128    }
129
130    /// Internal method called by [`PhaseBuilder::done`].
131    pub(crate) fn add_phase(&mut self, phase: Phase) {
132        self.phases.push(phase);
133    }
134
135    // -- Watchers --
136
137    /// Start building a state watcher.
138    ///
139    /// Returns a [`WatchBuilder`] that flows back to this `Live` via `.then()`.
140    pub fn watch(self, key: impl Into<String>) -> WatchBuilder {
141        WatchBuilder::new(self, key)
142    }
143
144    /// Internal method called by [`WatchBuilder::then`].
145    pub(crate) fn add_watcher(&mut self, watcher: Watcher) {
146        self.watchers.add(watcher);
147    }
148
149    // -- Temporal Patterns --
150
151    /// Register a sustained condition pattern.
152    ///
153    /// Fires when the condition remains true for at least `duration`.
154    pub fn when_sustained<F, Fut>(
155        mut self,
156        name: impl Into<String>,
157        condition: impl Fn(&State) -> bool + Send + Sync + 'static,
158        duration: Duration,
159        action: F,
160    ) -> Self
161    where
162        F: Fn(State, Arc<dyn SessionWriter>) -> Fut + Send + Sync + 'static,
163        Fut: Future<Output = ()> + Send + 'static,
164    {
165        let detector = SustainedDetector::new(Arc::new(condition), duration);
166        self.temporal.add(TemporalPattern::new(
167            name,
168            Box::new(detector),
169            Arc::new(move |s, w| Box::pin(action(s, w))),
170            None,
171        ));
172        self
173    }
174
175    /// Register a rate detection pattern.
176    ///
177    /// Fires when at least `count` matching events occur within `window`.
178    pub fn when_rate<F, Fut>(
179        mut self,
180        name: impl Into<String>,
181        filter: impl Fn(&SessionEvent) -> bool + Send + Sync + 'static,
182        count: u32,
183        window: Duration,
184        action: F,
185    ) -> Self
186    where
187        F: Fn(State, Arc<dyn SessionWriter>) -> Fut + Send + Sync + 'static,
188        Fut: Future<Output = ()> + Send + 'static,
189    {
190        let detector = RateDetector::new(Arc::new(filter), count, window);
191        self.temporal.add(TemporalPattern::new(
192            name,
193            Box::new(detector),
194            Arc::new(move |s, w| Box::pin(action(s, w))),
195            None,
196        ));
197        self
198    }
199
200    /// Register a turn count pattern.
201    ///
202    /// Fires when the condition is true for `turn_count` consecutive turns.
203    pub fn when_turns<F, Fut>(
204        mut self,
205        name: impl Into<String>,
206        condition: impl Fn(&State) -> bool + Send + Sync + 'static,
207        turn_count: u32,
208        action: F,
209    ) -> Self
210    where
211        F: Fn(State, Arc<dyn SessionWriter>) -> Fut + Send + Sync + 'static,
212        Fut: Future<Output = ()> + Send + 'static,
213    {
214        let detector = TurnCountDetector::new(Arc::new(condition), turn_count);
215        self.temporal.add(TemporalPattern::new(
216            name,
217            Box::new(detector),
218            Arc::new(move |s, w| Box::pin(action(s, w))),
219            None,
220        ));
221        self
222    }
223}