1use std::sync::Arc;
8
9use async_trait::async_trait;
10use serde_json::Value;
11
12use crate::llm::{BaseLlm, LlmError, LlmRequest};
13use crate::state::State;
14
15use super::phase::Phase;
16use super::transcript::TranscriptTurn;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum ExtractionTrigger {
26 EveryTurn,
28 Interval(u32),
30 AfterToolCall,
32 OnPhaseChange,
34 OnGenerationComplete,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum MergePolicy {
44 KeepKnown,
46 Overwrite,
48}
49
50pub type PromotionPredicate = Arc<dyn Fn(&State, &Value) -> bool + Send + Sync>;
52
53#[derive(Clone)]
55pub struct FieldPromotion {
56 pub field: String,
58 pub state_key: String,
60 pub merge: MergePolicy,
62 pub accept: Option<PromotionPredicate>,
64}
65
66impl FieldPromotion {
67 pub fn keep_known(field: impl Into<String>) -> Self {
69 let field = field.into();
70 Self {
71 state_key: field.clone(),
72 field,
73 merge: MergePolicy::KeepKnown,
74 accept: None,
75 }
76 }
77
78 pub fn overwrite(field: impl Into<String>) -> Self {
80 let field = field.into();
81 Self {
82 state_key: field.clone(),
83 field,
84 merge: MergePolicy::Overwrite,
85 accept: None,
86 }
87 }
88
89 pub fn true_only(field: impl Into<String>) -> Self {
91 Self::overwrite(field).accept_when(|_, value| value.as_bool() == Some(true))
92 }
93
94 pub fn non_empty(field: impl Into<String>) -> Self {
96 Self::overwrite(field).accept_when(|_, value| {
97 value.as_str().is_some_and(|s| !s.trim().is_empty())
98 })
99 }
100
101 pub fn to(mut self, state_key: impl Into<String>) -> Self {
103 self.state_key = state_key.into();
104 self
105 }
106
107 pub fn accept_when(
112 mut self,
113 predicate: impl Fn(&State, &Value) -> bool + Send + Sync + 'static,
114 ) -> Self {
115 self.accept = Some(Arc::new(predicate));
116 self
117 }
118
119 pub fn and_accept_when(
121 mut self,
122 predicate: impl Fn(&State, &Value) -> bool + Send + Sync + 'static,
123 ) -> Self {
124 let previous = self.accept.take();
125 self.accept = Some(Arc::new(move |state, value| {
126 previous
127 .as_ref()
128 .map_or(true, |accept| accept(state, value))
129 && predicate(state, value)
130 }));
131 self
132 }
133
134 pub fn after_presented(self, concept: impl Into<String>) -> Self {
136 let concept = concept.into();
137 self.and_accept_when(move |state, _| Phase::is_presented(state, &concept))
138 }
139}
140
141fn strip_code_fences(text: &str) -> &str {
145 let trimmed = text.trim();
146 if let Some(rest) = trimmed.strip_prefix("```") {
147 let rest = rest.trim_start_matches(|c: char| c != '\n');
149 let rest = rest.strip_prefix('\n').unwrap_or(rest);
150 let rest = rest.trim_end();
152 rest.strip_suffix("```").unwrap_or(rest).trim()
153 } else {
154 trimmed
155 }
156}
157
158#[async_trait]
164pub trait TurnExtractor: Send + Sync {
165 fn name(&self) -> &str;
167
168 fn window_size(&self) -> usize;
170
171 fn should_extract(&self, window: &[TranscriptTurn]) -> bool {
179 let _ = window;
180 true
181 }
182
183 fn trigger(&self) -> ExtractionTrigger {
187 ExtractionTrigger::EveryTurn
188 }
189
190 fn promotion_rules(&self) -> &[FieldPromotion] {
196 &[]
197 }
198
199 async fn extract(&self, window: &[TranscriptTurn]) -> Result<Value, LlmError>;
201}
202
203pub struct LlmExtractor {
206 name: String,
207 llm: Arc<dyn BaseLlm>,
208 prompt: String,
209 window_size: usize,
210 schema: Option<Value>,
211 schema_str: Option<String>,
213 min_words: usize,
215 trigger: ExtractionTrigger,
217 promotion_rules: Vec<FieldPromotion>,
219}
220
221impl LlmExtractor {
222 pub fn new(
229 name: impl Into<String>,
230 llm: Arc<dyn BaseLlm>,
231 prompt: impl Into<String>,
232 window_size: usize,
233 ) -> Self {
234 Self {
235 name: name.into(),
236 llm,
237 prompt: prompt.into(),
238 window_size,
239 schema: None,
240 schema_str: None,
241 min_words: 0,
242 trigger: ExtractionTrigger::EveryTurn,
243 promotion_rules: Vec::new(),
244 }
245 }
246
247 pub fn with_min_words(mut self, n: usize) -> Self {
252 self.min_words = n;
253 self
254 }
255
256 pub fn with_schema(mut self, schema: Value) -> Self {
261 self.schema_str = serde_json::to_string_pretty(&schema).ok();
262 self.schema = Some(schema);
263 self
264 }
265
266 pub fn with_trigger(mut self, trigger: ExtractionTrigger) -> Self {
268 self.trigger = trigger;
269 self
270 }
271
272 pub fn with_promotions(mut self, rules: Vec<FieldPromotion>) -> Self {
277 self.promotion_rules = rules;
278 self
279 }
280
281 fn format_transcript(window: &[TranscriptTurn]) -> String {
283 let mut out = String::new();
284 for turn in window {
285 if !turn.user.is_empty() {
286 out.push_str("User: ");
287 out.push_str(turn.user.trim());
288 out.push('\n');
289 }
290 if !turn.model.is_empty() {
291 out.push_str("Assistant: ");
292 out.push_str(turn.model.trim());
293 out.push('\n');
294 }
295 out.push('\n');
296 }
297 out
298 }
299}
300
301#[async_trait]
302impl TurnExtractor for LlmExtractor {
303 fn name(&self) -> &str {
304 &self.name
305 }
306
307 fn window_size(&self) -> usize {
308 self.window_size
309 }
310
311 fn should_extract(&self, window: &[TranscriptTurn]) -> bool {
312 if self.min_words == 0 {
313 return true;
314 }
315 window
317 .iter()
318 .rev()
319 .find(|t| !t.user.is_empty())
320 .is_some_and(|t| t.user.split_whitespace().count() >= self.min_words)
321 }
322
323 fn trigger(&self) -> ExtractionTrigger {
324 self.trigger.clone()
325 }
326
327 fn promotion_rules(&self) -> &[FieldPromotion] {
328 &self.promotion_rules
329 }
330
331 async fn extract(&self, window: &[TranscriptTurn]) -> Result<Value, LlmError> {
332 let transcript = Self::format_transcript(window);
333
334 let mut request = LlmRequest::from_text(format!(
335 "Transcript:\n{transcript}\nExtract the requested information."
336 ));
337 request.system_instruction = Some(self.prompt.clone());
338
339 if let Some(ref schema) = self.schema {
343 request.response_mime_type = Some("application/json".to_string());
344 request.response_json_schema = Some(schema.clone());
345 } else {
346 request.response_mime_type = Some("application/json".to_string());
347 }
348
349 let response = self.llm.generate(request).await?;
350 let text = response.text();
351
352 let cleaned = strip_code_fences(&text);
354
355 serde_json::from_str(cleaned).map_err(|e| {
356 LlmError::Other(format!(
357 "Failed to parse extraction result as JSON: {e}. Raw: {text}"
358 ))
359 })
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::llm::LlmResponse;
367 use gemini_genai_rs::prelude::{Content, Part, Role};
368 use std::time::Instant;
369
370 struct MockLlm {
371 response: String,
372 }
373
374 #[async_trait]
375 impl BaseLlm for MockLlm {
376 fn model_id(&self) -> &str {
377 "mock"
378 }
379 async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
380 Ok(LlmResponse {
381 content: Content {
382 role: Some(Role::Model),
383 parts: vec![Part::Text {
384 text: self.response.clone(),
385 }],
386 },
387 finish_reason: Some("STOP".into()),
388 usage: None,
389 })
390 }
391 }
392
393 fn make_turns(pairs: &[(&str, &str)]) -> Vec<TranscriptTurn> {
394 pairs
395 .iter()
396 .enumerate()
397 .map(|(i, (user, model))| TranscriptTurn {
398 turn_number: i as u32,
399 user: user.to_string(),
400 model: model.to_string(),
401 tool_calls: Vec::new(),
402 timestamp: Instant::now(),
403 })
404 .collect()
405 }
406
407 #[tokio::test]
408 async fn llm_extractor_produces_json() {
409 let llm = Arc::new(MockLlm {
410 response: r#"{"phase": "ordering", "items": ["pizza"]}"#.to_string(),
411 });
412
413 let extractor = LlmExtractor::new("OrderState", llm, "Extract order state", 3);
414
415 let turns = make_turns(&[
416 ("I'd like a pizza", "Great! What size?"),
417 ("Large please", "Coming right up!"),
418 ]);
419
420 let result = extractor.extract(&turns).await.unwrap();
421 assert_eq!(result["phase"], "ordering");
422 assert_eq!(result["items"][0], "pizza");
423 }
424
425 #[tokio::test]
426 async fn llm_extractor_with_schema() {
427 let llm = Arc::new(MockLlm {
428 response: r#"{"sentiment": "positive", "score": 0.9}"#.to_string(),
429 });
430
431 let schema = serde_json::json!({
432 "type": "object",
433 "properties": {
434 "sentiment": {"type": "string", "enum": ["positive", "neutral", "negative"]},
435 "score": {"type": "number"}
436 }
437 });
438
439 let extractor =
440 LlmExtractor::new("Sentiment", llm, "Rate sentiment", 1).with_schema(schema);
441
442 let turns = make_turns(&[("This is great!", "Glad you think so!")]);
443 let result = extractor.extract(&turns).await.unwrap();
444 assert_eq!(result["sentiment"], "positive");
445 }
446
447 #[tokio::test]
448 async fn llm_extractor_invalid_json_returns_error() {
449 let llm = Arc::new(MockLlm {
450 response: "not json at all".to_string(),
451 });
452
453 let extractor = LlmExtractor::new("Bad", llm, "Extract", 1);
454 let turns = make_turns(&[("hi", "hello")]);
455 let result = extractor.extract(&turns).await;
456 assert!(result.is_err());
457 }
458
459 #[test]
460 fn format_transcript_readable() {
461 let turns = make_turns(&[("Hello", "Hi there!"), ("How are you?", "I'm doing well")]);
462 let formatted = LlmExtractor::format_transcript(&turns);
463 assert!(formatted.contains("User: Hello"));
464 assert!(formatted.contains("Assistant: Hi there!"));
465 assert!(formatted.contains("User: How are you?"));
466 }
467
468 #[tokio::test]
469 async fn llm_extractor_handles_markdown_fenced_json() {
470 let llm = Arc::new(MockLlm {
471 response: "```json\n{\"status\": \"ok\"}\n```".to_string(),
472 });
473
474 let extractor = LlmExtractor::new("Fenced", llm, "Extract", 1);
475 let turns = make_turns(&[("test", "reply")]);
476 let result = extractor.extract(&turns).await.unwrap();
477 assert_eq!(result["status"], "ok");
478 }
479
480 #[test]
481 fn strip_code_fences_variants() {
482 assert_eq!(super::strip_code_fences("```json\n{}\n```"), "{}");
483 assert_eq!(super::strip_code_fences("```\n{}\n```"), "{}");
484 assert_eq!(
485 super::strip_code_fences(" ```json\n{\"a\":1}\n``` "),
486 "{\"a\":1}"
487 );
488 assert_eq!(
489 super::strip_code_fences("{\"bare\":true}"),
490 "{\"bare\":true}"
491 );
492 }
493
494 #[test]
495 fn extractor_name_and_window_size() {
496 let llm = Arc::new(MockLlm {
497 response: "{}".to_string(),
498 });
499 let ext = LlmExtractor::new("TestExtractor", llm, "test", 5);
500 assert_eq!(ext.name(), "TestExtractor");
501 assert_eq!(ext.window_size(), 5);
502 }
503
504 #[test]
505 fn extractor_default_trigger_is_every_turn() {
506 let llm = Arc::new(MockLlm {
507 response: "{}".to_string(),
508 });
509 let ext = LlmExtractor::new("Test", llm, "test", 5);
510 assert_eq!(ext.trigger(), ExtractionTrigger::EveryTurn);
511 }
512
513 #[test]
514 fn extractor_with_trigger() {
515 let llm = Arc::new(MockLlm {
516 response: "{}".to_string(),
517 });
518 let ext = LlmExtractor::new("Test", llm, "test", 5)
519 .with_trigger(ExtractionTrigger::AfterToolCall);
520 assert_eq!(ext.trigger(), ExtractionTrigger::AfterToolCall);
521 }
522}