gemini_adk_fluent_rs/live/
extraction.rs

1//! Extraction pipeline configuration methods for `Live`.
2
3use std::future::Future;
4use std::sync::Arc;
5
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8
9use gemini_adk_rs::live::extractor::{ExtractionTrigger, LlmExtractor, TurnExtractor};
10use gemini_adk_rs::llm::BaseLlm;
11
12use super::Live;
13
14impl Live {
15    // -- Turn Extraction Pipeline --
16
17    /// Add a turn extractor that runs an OOB LLM after each turn to extract
18    /// structured data from the transcript window.
19    ///
20    /// Automatically enables both input and output transcription.
21    /// The extraction result is stored in `State` under the type name
22    /// (e.g., `"OrderState"`) and can be read via `handle.extracted::<T>(name)`.
23    ///
24    /// The type `T` must implement `JsonSchema` for schema-guided extraction.
25    /// The window size defaults to 3 turns.
26    pub fn extract_turns<T>(self, llm: Arc<dyn BaseLlm>, prompt: impl Into<String>) -> Self
27    where
28        T: DeserializeOwned + Serialize + schemars::JsonSchema + Send + Sync + 'static,
29    {
30        self.extract_turns_windowed::<T>(llm, prompt, 3)
31    }
32
33    /// Like `extract_turns` but with a custom window size.
34    pub fn extract_turns_windowed<T>(
35        mut self,
36        llm: Arc<dyn BaseLlm>,
37        prompt: impl Into<String>,
38        window_size: usize,
39    ) -> Self
40    where
41        T: DeserializeOwned + Serialize + schemars::JsonSchema + Send + Sync + 'static,
42    {
43        // Auto-enable transcription
44        self.config = self
45            .config
46            .enable_input_transcription()
47            .enable_output_transcription();
48
49        // Derive name from type
50        let name = std::any::type_name::<T>()
51            .rsplit("::")
52            .next()
53            .unwrap_or("Extraction")
54            .to_string();
55
56        // Generate JSON schema from the type
57        let root_schema = schemars::schema_for!(T);
58        let schema = serde_json::to_value(root_schema).unwrap_or(serde_json::Value::Null);
59
60        // Auto-register LLM for connection warming
61        self.warm_up_llms.push(llm.clone());
62
63        let extractor = LlmExtractor::new(name, llm, prompt, window_size)
64            .with_schema(schema)
65            .with_min_words(3);
66        self.extractors.push(Arc::new(extractor));
67        self
68    }
69
70    /// Like `extract_turns_windowed` but with a custom extraction trigger.
71    ///
72    /// Use `ExtractionTrigger::AfterToolCall` when tool calls are the primary
73    /// state source, `ExtractionTrigger::Interval(n)` to reduce extraction
74    /// frequency, or `ExtractionTrigger::OnPhaseChange` for phase-entry extraction.
75    pub fn extract_turns_triggered<T>(
76        mut self,
77        llm: Arc<dyn BaseLlm>,
78        prompt: impl Into<String>,
79        window_size: usize,
80        trigger: ExtractionTrigger,
81    ) -> Self
82    where
83        T: DeserializeOwned + Serialize + schemars::JsonSchema + Send + Sync + 'static,
84    {
85        // Auto-enable transcription
86        self.config = self
87            .config
88            .enable_input_transcription()
89            .enable_output_transcription();
90
91        let name = std::any::type_name::<T>()
92            .rsplit("::")
93            .next()
94            .unwrap_or("Extraction")
95            .to_string();
96
97        let root_schema = schemars::schema_for!(T);
98        let schema = serde_json::to_value(root_schema).unwrap_or(serde_json::Value::Null);
99
100        self.warm_up_llms.push(llm.clone());
101
102        let extractor = LlmExtractor::new(name, llm, prompt, window_size)
103            .with_schema(schema)
104            .with_min_words(3)
105            .with_trigger(trigger);
106        self.extractors.push(Arc::new(extractor));
107        self
108    }
109
110    /// Like [`extract_turns_triggered`](Self::extract_turns_triggered), but lets
111    /// callers configure the underlying [`LlmExtractor`] before registration.
112    ///
113    /// Use this for field promotion rules, custom minimum word counts, or other
114    /// extraction policies that should live at the SDK layer instead of app
115    /// callback glue.
116    pub fn extract_turns_configured<T>(
117        mut self,
118        llm: Arc<dyn BaseLlm>,
119        prompt: impl Into<String>,
120        window_size: usize,
121        trigger: ExtractionTrigger,
122        configure: impl FnOnce(LlmExtractor) -> LlmExtractor,
123    ) -> Self
124    where
125        T: DeserializeOwned + Serialize + schemars::JsonSchema + Send + Sync + 'static,
126    {
127        self.config = self
128            .config
129            .enable_input_transcription()
130            .enable_output_transcription();
131
132        let name = std::any::type_name::<T>()
133            .rsplit("::")
134            .next()
135            .unwrap_or("Extraction")
136            .to_string();
137
138        let root_schema = schemars::schema_for!(T);
139        let schema = serde_json::to_value(root_schema).unwrap_or(serde_json::Value::Null);
140
141        self.warm_up_llms.push(llm.clone());
142
143        let extractor = LlmExtractor::new(name, llm, prompt, window_size)
144            .with_schema(schema)
145            .with_min_words(3)
146            .with_trigger(trigger);
147        self.extractors.push(Arc::new(configure(extractor)));
148        self
149    }
150
151    /// Add a custom `TurnExtractor` implementation.
152    pub fn extractor(mut self, extractor: Arc<dyn TurnExtractor>) -> Self {
153        // Auto-enable transcription
154        self.config = self
155            .config
156            .enable_input_transcription()
157            .enable_output_transcription();
158        self.extractors.push(extractor);
159        self
160    }
161
162    /// Called when a TurnExtractor produces a result.
163    ///
164    /// The callback receives the extractor name and the extracted JSON value.
165    pub fn on_extracted<F, Fut>(mut self, f: F) -> Self
166    where
167        F: Fn(String, serde_json::Value) -> Fut + Send + Sync + 'static,
168        Fut: Future<Output = ()> + Send + 'static,
169    {
170        self.callbacks.on_extracted = Some(Arc::new(move |name, value| Box::pin(f(name, value))));
171        self
172    }
173
174    /// Called when a TurnExtractor fails.
175    ///
176    /// The callback receives the extractor name and error message.
177    /// Use this for custom error handling (alerting, retry logic, etc.).
178    pub fn on_extraction_error<F, Fut>(mut self, f: F) -> Self
179    where
180        F: Fn(String, String) -> Fut + Send + Sync + 'static,
181        Fut: Future<Output = ()> + Send + 'static,
182    {
183        self.callbacks.on_extraction_error =
184            Some(Arc::new(move |name, error| Box::pin(f(name, error))));
185        self
186    }
187}