gemini_adk_rs/live/
extractor.rs

1//! Turn-windowed extraction — OOB LLM structured data extraction between turns.
2//!
3//! A `TurnExtractor` runs after each turn completes, taking a window of recent
4//! transcript turns and producing a structured JSON value via an out-of-band
5//! LLM call.
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use serde_json::Value;
11
12use crate::llm::{BaseLlm, LlmError, LlmRequest};
13use crate::state::State;
14
15use super::phase::Phase;
16use super::transcript::TranscriptTurn;
17
18/// Controls WHEN an extractor runs.
19///
20/// The default is `EveryTurn`, which preserves backward compatibility.
21/// Use `AfterToolCall` when tool calls are the primary state source,
22/// `Interval(n)` to reduce extraction frequency, or `OnPhaseChange`
23/// to extract only when entering a new conversation phase.
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum ExtractionTrigger {
26    /// Run on every TurnComplete event (current default).
27    EveryTurn,
28    /// Run every N TurnComplete events.
29    Interval(u32),
30    /// Run after tool calls complete.
31    AfterToolCall,
32    /// Run when a phase transition occurs.
33    OnPhaseChange,
34    /// Run on GenerationComplete — before interruption truncation.
35    ///
36    /// Use this to extract from the model's full intended output, even if
37    /// the user barged in and the audio delivery was interrupted.
38    OnGenerationComplete,
39}
40
41/// How an extracted field should be merged into authoritative session state.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum MergePolicy {
44    /// Keep an existing state value; write only when the target key is absent.
45    KeepKnown,
46    /// Always overwrite the target state key with the extracted field value.
47    Overwrite,
48}
49
50/// Predicate used to decide whether an extracted field may be promoted.
51pub type PromotionPredicate = Arc<dyn Fn(&State, &Value) -> bool + Send + Sync>;
52
53/// Rule for promoting one raw extraction field into authoritative state.
54#[derive(Clone)]
55pub struct FieldPromotion {
56    /// Field name inside the extractor's JSON object.
57    pub field: String,
58    /// State key to write when the field is accepted.
59    pub state_key: String,
60    /// Merge behavior for the target state key.
61    pub merge: MergePolicy,
62    /// Optional acceptance predicate.
63    pub accept: Option<PromotionPredicate>,
64}
65
66impl FieldPromotion {
67    /// Promote `field` into the same state key using [`MergePolicy::KeepKnown`].
68    pub fn keep_known(field: impl Into<String>) -> Self {
69        let field = field.into();
70        Self {
71            state_key: field.clone(),
72            field,
73            merge: MergePolicy::KeepKnown,
74            accept: None,
75        }
76    }
77
78    /// Promote `field` into the same state key using [`MergePolicy::Overwrite`].
79    pub fn overwrite(field: impl Into<String>) -> Self {
80        let field = field.into();
81        Self {
82            state_key: field.clone(),
83            field,
84            merge: MergePolicy::Overwrite,
85            accept: None,
86        }
87    }
88
89    /// Promote a boolean field only when its extracted value is `true`.
90    pub fn true_only(field: impl Into<String>) -> Self {
91        Self::overwrite(field).accept_when(|_, value| value.as_bool() == Some(true))
92    }
93
94    /// Promote a string field only when its extracted value is non-empty.
95    pub fn non_empty(field: impl Into<String>) -> Self {
96        Self::overwrite(field).accept_when(|_, value| {
97            value.as_str().is_some_and(|s| !s.trim().is_empty())
98        })
99    }
100
101    /// Promote into a custom target state key.
102    pub fn to(mut self, state_key: impl Into<String>) -> Self {
103        self.state_key = state_key.into();
104        self
105    }
106
107    /// Only accept this promotion when `predicate` returns true.
108    ///
109    /// This is the escape hatch for application-specific logic:
110    /// `FieldPromotion::overwrite("intent").accept_when(|state, value| ...)`.
111    pub fn accept_when(
112        mut self,
113        predicate: impl Fn(&State, &Value) -> bool + Send + Sync + 'static,
114    ) -> Self {
115        self.accept = Some(Arc::new(predicate));
116        self
117    }
118
119    /// Add an additional acceptance predicate, preserving any existing predicate.
120    pub fn and_accept_when(
121        mut self,
122        predicate: impl Fn(&State, &Value) -> bool + Send + Sync + 'static,
123    ) -> Self {
124        let previous = self.accept.take();
125        self.accept = Some(Arc::new(move |state, value| {
126            previous
127                .as_ref()
128                .map_or(true, |accept| accept(state, value))
129                && predicate(state, value)
130        }));
131        self
132    }
133
134    /// Only promote after the named concept has been presented by a phase.
135    pub fn after_presented(self, concept: impl Into<String>) -> Self {
136        let concept = concept.into();
137        self.and_accept_when(move |state, _| Phase::is_presented(state, &concept))
138    }
139}
140
141/// Strip markdown code fences from LLM output.
142///
143/// Handles `` ```json\n...\n``` ``, `` ```\n...\n``` ``, and bare JSON.
144fn strip_code_fences(text: &str) -> &str {
145    let trimmed = text.trim();
146    if let Some(rest) = trimmed.strip_prefix("```") {
147        // Skip optional language tag (e.g., "json") on the first line
148        let rest = rest.trim_start_matches(|c: char| c != '\n');
149        let rest = rest.strip_prefix('\n').unwrap_or(rest);
150        // Strip trailing ```
151        let rest = rest.trim_end();
152        rest.strip_suffix("```").unwrap_or(rest).trim()
153    } else {
154        trimmed
155    }
156}
157
158/// Trait for between-turn extraction from transcript windows.
159///
160/// Implementations receive a window of recent transcript turns and produce
161/// a structured JSON value. The processor stores the result in `State`
162/// under the extractor's name.
163#[async_trait]
164pub trait TurnExtractor: Send + Sync {
165    /// Name of this extractor (used as the State key).
166    fn name(&self) -> &str;
167
168    /// How many recent turns this extractor needs.
169    fn window_size(&self) -> usize;
170
171    /// Whether this extractor should run for the current turn.
172    ///
173    /// Override to skip extraction on trivial turns (e.g., short utterances,
174    /// turns without user speech). Default returns `true` (always extract).
175    ///
176    /// This is checked before launching the async extraction, so returning
177    /// `false` avoids an LLM round-trip entirely.
178    fn should_extract(&self, window: &[TranscriptTurn]) -> bool {
179        let _ = window;
180        true
181    }
182
183    /// The trigger mode for this extractor.
184    ///
185    /// Controls when the extractor runs. Default is `EveryTurn`.
186    fn trigger(&self) -> ExtractionTrigger {
187        ExtractionTrigger::EveryTurn
188    }
189
190    /// Field promotion rules for this extractor.
191    ///
192    /// When empty, the runtime preserves legacy behavior and auto-flattens
193    /// top-level non-null fields into state. When non-empty, only these rules
194    /// can promote raw extraction fields into authoritative state.
195    fn promotion_rules(&self) -> &[FieldPromotion] {
196        &[]
197    }
198
199    /// Extract structured data from the transcript window.
200    async fn extract(&self, window: &[TranscriptTurn]) -> Result<Value, LlmError>;
201}
202
203/// LLM-backed turn extractor that sends transcript windows to an OOB LLM
204/// with a structured extraction prompt.
205pub struct LlmExtractor {
206    name: String,
207    llm: Arc<dyn BaseLlm>,
208    prompt: String,
209    window_size: usize,
210    schema: Option<Value>,
211    /// Pre-rendered schema string (computed once at construction)
212    schema_str: Option<String>,
213    /// Minimum word count in the last user utterance to trigger extraction.
214    min_words: usize,
215    /// When this extractor should fire.
216    trigger: ExtractionTrigger,
217    /// Field promotion rules. Empty means legacy auto-flattening.
218    promotion_rules: Vec<FieldPromotion>,
219}
220
221impl LlmExtractor {
222    /// Create a new LLM-backed extractor.
223    ///
224    /// - `name`: key for storing results in State
225    /// - `llm`: the out-of-band LLM to use for extraction
226    /// - `prompt`: system instruction describing what to extract
227    /// - `window_size`: how many recent turns to include
228    pub fn new(
229        name: impl Into<String>,
230        llm: Arc<dyn BaseLlm>,
231        prompt: impl Into<String>,
232        window_size: usize,
233    ) -> Self {
234        Self {
235            name: name.into(),
236            llm,
237            prompt: prompt.into(),
238            window_size,
239            schema: None,
240            schema_str: None,
241            min_words: 0,
242            trigger: ExtractionTrigger::EveryTurn,
243            promotion_rules: Vec::new(),
244        }
245    }
246
247    /// Set the minimum word count in the last user utterance to trigger extraction.
248    ///
249    /// Turns where the user said fewer than `n` words will skip the LLM call.
250    /// Useful for filtering out "uh huh", "ok", "yes" style responses.
251    pub fn with_min_words(mut self, n: usize) -> Self {
252        self.min_words = n;
253        self
254    }
255
256    /// Set a JSON Schema for structured output.
257    ///
258    /// When set, the schema is included in the prompt to guide the LLM
259    /// toward producing valid JSON matching the schema.
260    pub fn with_schema(mut self, schema: Value) -> Self {
261        self.schema_str = serde_json::to_string_pretty(&schema).ok();
262        self.schema = Some(schema);
263        self
264    }
265
266    /// Set the trigger mode for this extractor.
267    pub fn with_trigger(mut self, trigger: ExtractionTrigger) -> Self {
268        self.trigger = trigger;
269        self
270    }
271
272    /// Set explicit field promotion rules.
273    ///
274    /// Once promotion rules are present, top-level fields are no longer
275    /// automatically flattened into state; only accepted rules promote.
276    pub fn with_promotions(mut self, rules: Vec<FieldPromotion>) -> Self {
277        self.promotion_rules = rules;
278        self
279    }
280
281    /// Format transcript turns for the LLM prompt.
282    fn format_transcript(window: &[TranscriptTurn]) -> String {
283        let mut out = String::new();
284        for turn in window {
285            if !turn.user.is_empty() {
286                out.push_str("User: ");
287                out.push_str(turn.user.trim());
288                out.push('\n');
289            }
290            if !turn.model.is_empty() {
291                out.push_str("Assistant: ");
292                out.push_str(turn.model.trim());
293                out.push('\n');
294            }
295            out.push('\n');
296        }
297        out
298    }
299}
300
301#[async_trait]
302impl TurnExtractor for LlmExtractor {
303    fn name(&self) -> &str {
304        &self.name
305    }
306
307    fn window_size(&self) -> usize {
308        self.window_size
309    }
310
311    fn should_extract(&self, window: &[TranscriptTurn]) -> bool {
312        if self.min_words == 0 {
313            return true;
314        }
315        // Check the last user utterance
316        window
317            .iter()
318            .rev()
319            .find(|t| !t.user.is_empty())
320            .is_some_and(|t| t.user.split_whitespace().count() >= self.min_words)
321    }
322
323    fn trigger(&self) -> ExtractionTrigger {
324        self.trigger.clone()
325    }
326
327    fn promotion_rules(&self) -> &[FieldPromotion] {
328        &self.promotion_rules
329    }
330
331    async fn extract(&self, window: &[TranscriptTurn]) -> Result<Value, LlmError> {
332        let transcript = Self::format_transcript(window);
333
334        let mut request = LlmRequest::from_text(format!(
335            "Transcript:\n{transcript}\nExtract the requested information."
336        ));
337        request.system_instruction = Some(self.prompt.clone());
338
339        // Use native JSON mode when a schema is available — the API constrains
340        // the model to produce valid JSON matching the schema, eliminating
341        // markdown fences and malformed output.
342        if let Some(ref schema) = self.schema {
343            request.response_mime_type = Some("application/json".to_string());
344            request.response_json_schema = Some(schema.clone());
345        } else {
346            request.response_mime_type = Some("application/json".to_string());
347        }
348
349        let response = self.llm.generate(request).await?;
350        let text = response.text();
351
352        // Fallback: strip markdown code fences if the model still wraps output
353        let cleaned = strip_code_fences(&text);
354
355        serde_json::from_str(cleaned).map_err(|e| {
356            LlmError::Other(format!(
357                "Failed to parse extraction result as JSON: {e}. Raw: {text}"
358            ))
359        })
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use crate::llm::LlmResponse;
367    use gemini_genai_rs::prelude::{Content, Part, Role};
368    use std::time::Instant;
369
370    struct MockLlm {
371        response: String,
372    }
373
374    #[async_trait]
375    impl BaseLlm for MockLlm {
376        fn model_id(&self) -> &str {
377            "mock"
378        }
379        async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
380            Ok(LlmResponse {
381                content: Content {
382                    role: Some(Role::Model),
383                    parts: vec![Part::Text {
384                        text: self.response.clone(),
385                    }],
386                },
387                finish_reason: Some("STOP".into()),
388                usage: None,
389            })
390        }
391    }
392
393    fn make_turns(pairs: &[(&str, &str)]) -> Vec<TranscriptTurn> {
394        pairs
395            .iter()
396            .enumerate()
397            .map(|(i, (user, model))| TranscriptTurn {
398                turn_number: i as u32,
399                user: user.to_string(),
400                model: model.to_string(),
401                tool_calls: Vec::new(),
402                timestamp: Instant::now(),
403            })
404            .collect()
405    }
406
407    #[tokio::test]
408    async fn llm_extractor_produces_json() {
409        let llm = Arc::new(MockLlm {
410            response: r#"{"phase": "ordering", "items": ["pizza"]}"#.to_string(),
411        });
412
413        let extractor = LlmExtractor::new("OrderState", llm, "Extract order state", 3);
414
415        let turns = make_turns(&[
416            ("I'd like a pizza", "Great! What size?"),
417            ("Large please", "Coming right up!"),
418        ]);
419
420        let result = extractor.extract(&turns).await.unwrap();
421        assert_eq!(result["phase"], "ordering");
422        assert_eq!(result["items"][0], "pizza");
423    }
424
425    #[tokio::test]
426    async fn llm_extractor_with_schema() {
427        let llm = Arc::new(MockLlm {
428            response: r#"{"sentiment": "positive", "score": 0.9}"#.to_string(),
429        });
430
431        let schema = serde_json::json!({
432            "type": "object",
433            "properties": {
434                "sentiment": {"type": "string", "enum": ["positive", "neutral", "negative"]},
435                "score": {"type": "number"}
436            }
437        });
438
439        let extractor =
440            LlmExtractor::new("Sentiment", llm, "Rate sentiment", 1).with_schema(schema);
441
442        let turns = make_turns(&[("This is great!", "Glad you think so!")]);
443        let result = extractor.extract(&turns).await.unwrap();
444        assert_eq!(result["sentiment"], "positive");
445    }
446
447    #[tokio::test]
448    async fn llm_extractor_invalid_json_returns_error() {
449        let llm = Arc::new(MockLlm {
450            response: "not json at all".to_string(),
451        });
452
453        let extractor = LlmExtractor::new("Bad", llm, "Extract", 1);
454        let turns = make_turns(&[("hi", "hello")]);
455        let result = extractor.extract(&turns).await;
456        assert!(result.is_err());
457    }
458
459    #[test]
460    fn format_transcript_readable() {
461        let turns = make_turns(&[("Hello", "Hi there!"), ("How are you?", "I'm doing well")]);
462        let formatted = LlmExtractor::format_transcript(&turns);
463        assert!(formatted.contains("User: Hello"));
464        assert!(formatted.contains("Assistant: Hi there!"));
465        assert!(formatted.contains("User: How are you?"));
466    }
467
468    #[tokio::test]
469    async fn llm_extractor_handles_markdown_fenced_json() {
470        let llm = Arc::new(MockLlm {
471            response: "```json\n{\"status\": \"ok\"}\n```".to_string(),
472        });
473
474        let extractor = LlmExtractor::new("Fenced", llm, "Extract", 1);
475        let turns = make_turns(&[("test", "reply")]);
476        let result = extractor.extract(&turns).await.unwrap();
477        assert_eq!(result["status"], "ok");
478    }
479
480    #[test]
481    fn strip_code_fences_variants() {
482        assert_eq!(super::strip_code_fences("```json\n{}\n```"), "{}");
483        assert_eq!(super::strip_code_fences("```\n{}\n```"), "{}");
484        assert_eq!(
485            super::strip_code_fences("  ```json\n{\"a\":1}\n```  "),
486            "{\"a\":1}"
487        );
488        assert_eq!(
489            super::strip_code_fences("{\"bare\":true}"),
490            "{\"bare\":true}"
491        );
492    }
493
494    #[test]
495    fn extractor_name_and_window_size() {
496        let llm = Arc::new(MockLlm {
497            response: "{}".to_string(),
498        });
499        let ext = LlmExtractor::new("TestExtractor", llm, "test", 5);
500        assert_eq!(ext.name(), "TestExtractor");
501        assert_eq!(ext.window_size(), 5);
502    }
503
504    #[test]
505    fn extractor_default_trigger_is_every_turn() {
506        let llm = Arc::new(MockLlm {
507            response: "{}".to_string(),
508        });
509        let ext = LlmExtractor::new("Test", llm, "test", 5);
510        assert_eq!(ext.trigger(), ExtractionTrigger::EveryTurn);
511    }
512
513    #[test]
514    fn extractor_with_trigger() {
515        let llm = Arc::new(MockLlm {
516            response: "{}".to_string(),
517        });
518        let ext = LlmExtractor::new("Test", llm, "test", 5)
519            .with_trigger(ExtractionTrigger::AfterToolCall);
520        assert_eq!(ext.trigger(), ExtractionTrigger::AfterToolCall);
521    }
522}