gemini_adk_fluent_rs/live/
extraction.rs

1//! Extraction pipeline configuration methods for `Live`.
2
3use std::future::Future;
4use std::sync::Arc;
5
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8
9use gemini_adk_rs::live::extractor::{ExtractionTrigger, LlmExtractor, TurnExtractor};
10use gemini_adk_rs::llm::BaseLlm;
11
12use super::Live;
13
14impl Live {
15    // -- Turn Extraction Pipeline --
16
17    /// Add a turn extractor that runs an OOB LLM after each turn to extract
18    /// structured data from the transcript window.
19    ///
20    /// Automatically enables both input and output transcription.
21    /// The extraction result is stored in `State` under the type name
22    /// (e.g., `"OrderState"`) and can be read via `handle.extracted::<T>(name)`.
23    ///
24    /// The type `T` must implement `JsonSchema` for schema-guided extraction.
25    /// The window size defaults to 3 turns.
26    pub fn extract_turns<T>(self, llm: Arc<dyn BaseLlm>, prompt: impl Into<String>) -> Self
27    where
28        T: DeserializeOwned + Serialize + schemars::JsonSchema + Send + Sync + 'static,
29    {
30        self.extract_turns_windowed::<T>(llm, prompt, 3)
31    }
32
33    /// Register a deterministic [`Extract`](gemini_adk_rs::extract::Extract)
34    /// record — CPU recognizers over the transcript, no model, no network. The
35    /// recognized fields are promoted into `State`, where `Flow` guards
36    /// (`done(captured([...]))`) and repair read them. Composes with
37    /// `extract_turns` (LLM) on the same session for a cheap-first cascade.
38    pub fn extract_record(mut self, spec: gemini_adk_rs::extract::Extract) -> Self {
39        self.config = self.config.enable_input_transcription();
40        self.extractors.push(spec.into_extractor());
41        self
42    }
43
44    /// Like `extract_turns` but with a custom window size.
45    pub fn extract_turns_windowed<T>(
46        mut self,
47        llm: Arc<dyn BaseLlm>,
48        prompt: impl Into<String>,
49        window_size: usize,
50    ) -> Self
51    where
52        T: DeserializeOwned + Serialize + schemars::JsonSchema + Send + Sync + 'static,
53    {
54        // Auto-enable transcription
55        self.config = self
56            .config
57            .enable_input_transcription()
58            .enable_output_transcription();
59
60        // Derive name from type
61        let name = std::any::type_name::<T>()
62            .rsplit("::")
63            .next()
64            .unwrap_or("Extraction")
65            .to_string();
66
67        // Generate JSON schema from the type
68        let root_schema = schemars::schema_for!(T);
69        let schema = serde_json::to_value(root_schema).unwrap_or(serde_json::Value::Null);
70
71        // Auto-register LLM for connection warming
72        self.warm_up_llms.push(llm.clone());
73
74        let extractor = LlmExtractor::new(name, llm, prompt, window_size)
75            .with_schema(schema)
76            .with_min_words(3);
77        self.extractors.push(Arc::new(extractor));
78        self
79    }
80
81    /// Like `extract_turns_windowed` but with a custom extraction trigger.
82    ///
83    /// Use `ExtractionTrigger::AfterToolCall` when tool calls are the primary
84    /// state source, `ExtractionTrigger::Interval(n)` to reduce extraction
85    /// frequency, or `ExtractionTrigger::OnPhaseChange` for phase-entry extraction.
86    pub fn extract_turns_triggered<T>(
87        mut self,
88        llm: Arc<dyn BaseLlm>,
89        prompt: impl Into<String>,
90        window_size: usize,
91        trigger: ExtractionTrigger,
92    ) -> Self
93    where
94        T: DeserializeOwned + Serialize + schemars::JsonSchema + Send + Sync + 'static,
95    {
96        // Auto-enable transcription
97        self.config = self
98            .config
99            .enable_input_transcription()
100            .enable_output_transcription();
101
102        let name = std::any::type_name::<T>()
103            .rsplit("::")
104            .next()
105            .unwrap_or("Extraction")
106            .to_string();
107
108        let root_schema = schemars::schema_for!(T);
109        let schema = serde_json::to_value(root_schema).unwrap_or(serde_json::Value::Null);
110
111        self.warm_up_llms.push(llm.clone());
112
113        let extractor = LlmExtractor::new(name, llm, prompt, window_size)
114            .with_schema(schema)
115            .with_min_words(3)
116            .with_trigger(trigger);
117        self.extractors.push(Arc::new(extractor));
118        self
119    }
120
121    /// Like [`extract_turns_triggered`](Self::extract_turns_triggered), but lets
122    /// callers configure the underlying [`LlmExtractor`] before registration.
123    ///
124    /// Use this for field promotion rules, custom minimum word counts, or other
125    /// extraction policies that should live at the SDK layer instead of app
126    /// callback glue.
127    pub fn extract_turns_configured<T>(
128        mut self,
129        llm: Arc<dyn BaseLlm>,
130        prompt: impl Into<String>,
131        window_size: usize,
132        trigger: ExtractionTrigger,
133        configure: impl FnOnce(LlmExtractor) -> LlmExtractor,
134    ) -> Self
135    where
136        T: DeserializeOwned + Serialize + schemars::JsonSchema + Send + Sync + 'static,
137    {
138        self.config = self
139            .config
140            .enable_input_transcription()
141            .enable_output_transcription();
142
143        let name = std::any::type_name::<T>()
144            .rsplit("::")
145            .next()
146            .unwrap_or("Extraction")
147            .to_string();
148
149        let root_schema = schemars::schema_for!(T);
150        let schema = serde_json::to_value(root_schema).unwrap_or(serde_json::Value::Null);
151
152        self.warm_up_llms.push(llm.clone());
153
154        let extractor = LlmExtractor::new(name, llm, prompt, window_size)
155            .with_schema(schema)
156            .with_min_words(3)
157            .with_trigger(trigger);
158        self.extractors.push(Arc::new(configure(extractor)));
159        self
160    }
161
162    /// Add a custom `TurnExtractor` implementation.
163    pub fn extractor(mut self, extractor: Arc<dyn TurnExtractor>) -> Self {
164        // Auto-enable transcription
165        self.config = self
166            .config
167            .enable_input_transcription()
168            .enable_output_transcription();
169        self.extractors.push(extractor);
170        self
171    }
172
173    /// Called when a TurnExtractor produces a result.
174    ///
175    /// The callback receives the extractor name and the extracted JSON value.
176    pub fn on_extracted<F, Fut>(mut self, f: F) -> Self
177    where
178        F: Fn(String, serde_json::Value) -> Fut + Send + Sync + 'static,
179        Fut: Future<Output = ()> + Send + 'static,
180    {
181        self.callbacks.on_extracted = Some(Arc::new(move |name, value| Box::pin(f(name, value))));
182        self
183    }
184
185    /// Called when a TurnExtractor fails.
186    ///
187    /// The callback receives the extractor name and error message.
188    /// Use this for custom error handling (alerting, retry logic, etc.).
189    pub fn on_extraction_error<F, Fut>(mut self, f: F) -> Self
190    where
191        F: Fn(String, String) -> Fut + Send + Sync + 'static,
192        Fut: Future<Output = ()> + Send + 'static,
193    {
194        self.callbacks.on_extraction_error =
195            Some(Arc::new(move |name, error| Box::pin(f(name, error))));
196        self
197    }
198}