1use std::sync::Arc;
8
9use async_trait::async_trait;
10use serde_json::Value;
11
12use crate::llm::{BaseLlm, LlmError, LlmRequest};
13use crate::state::State;
14
15use super::phase::Phase;
16use super::transcript::TranscriptTurn;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum ExtractionTrigger {
26 EveryTurn,
28 Interval(u32),
30 AfterToolCall,
32 OnPhaseChange,
34 OnGenerationComplete,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum MergePolicy {
44 KeepKnown,
46 Overwrite,
48}
49
50pub type PromotionPredicate = Arc<dyn Fn(&State, &Value) -> bool + Send + Sync>;
52
53#[derive(Clone)]
55pub struct FieldPromotion {
56 pub field: String,
58 pub state_key: String,
60 pub merge: MergePolicy,
62 pub accept: Option<PromotionPredicate>,
64}
65
66impl FieldPromotion {
67 pub fn keep_known(field: impl Into<String>) -> Self {
69 let field = field.into();
70 Self {
71 state_key: field.clone(),
72 field,
73 merge: MergePolicy::KeepKnown,
74 accept: None,
75 }
76 }
77
78 pub fn overwrite(field: impl Into<String>) -> Self {
80 let field = field.into();
81 Self {
82 state_key: field.clone(),
83 field,
84 merge: MergePolicy::Overwrite,
85 accept: None,
86 }
87 }
88
89 pub fn true_only(field: impl Into<String>) -> Self {
91 Self::overwrite(field).accept_when(|_, value| value.as_bool() == Some(true))
92 }
93
94 pub fn non_empty(field: impl Into<String>) -> Self {
96 Self::overwrite(field)
97 .accept_when(|_, value| value.as_str().is_some_and(|s| !s.trim().is_empty()))
98 }
99
100 pub fn to(mut self, state_key: impl Into<String>) -> Self {
102 self.state_key = state_key.into();
103 self
104 }
105
106 pub fn accept_when(
111 mut self,
112 predicate: impl Fn(&State, &Value) -> bool + Send + Sync + 'static,
113 ) -> Self {
114 self.accept = Some(Arc::new(predicate));
115 self
116 }
117
118 pub fn and_accept_when(
120 mut self,
121 predicate: impl Fn(&State, &Value) -> bool + Send + Sync + 'static,
122 ) -> Self {
123 let previous = self.accept.take();
124 self.accept = Some(Arc::new(move |state, value| {
125 previous.as_ref().is_none_or(|accept| accept(state, value)) && predicate(state, value)
126 }));
127 self
128 }
129
130 pub fn after_presented(self, concept: impl Into<String>) -> Self {
132 let concept = concept.into();
133 self.and_accept_when(move |state, _| Phase::is_presented(state, &concept))
134 }
135}
136
137fn strip_code_fences(text: &str) -> &str {
141 let trimmed = text.trim();
142 if let Some(rest) = trimmed.strip_prefix("```") {
143 let rest = rest.trim_start_matches(|c: char| c != '\n');
145 let rest = rest.strip_prefix('\n').unwrap_or(rest);
146 let rest = rest.trim_end();
148 rest.strip_suffix("```").unwrap_or(rest).trim()
149 } else {
150 trimmed
151 }
152}
153
154#[async_trait]
160pub trait TurnExtractor: Send + Sync {
161 fn name(&self) -> &str;
163
164 fn window_size(&self) -> usize;
166
167 fn should_extract(&self, window: &[TranscriptTurn]) -> bool {
175 let _ = window;
176 true
177 }
178
179 fn trigger(&self) -> ExtractionTrigger {
183 ExtractionTrigger::EveryTurn
184 }
185
186 fn promotion_rules(&self) -> &[FieldPromotion] {
192 &[]
193 }
194
195 async fn extract(&self, window: &[TranscriptTurn]) -> Result<Value, LlmError>;
197
198 async fn extract_with_state(
204 &self,
205 window: &[TranscriptTurn],
206 state: &State,
207 ) -> Result<Value, LlmError> {
208 let _ = state;
209 self.extract(window).await
210 }
211
212 fn on_complete(&self) -> Option<OnComplete> {
216 None
217 }
218}
219
220#[derive(Clone)]
222pub struct OnComplete {
223 pub agent: Arc<dyn crate::text::TextAgent>,
225 pub mode: crate::orchestration::Mode,
227}
228
229pub struct LlmExtractor {
232 name: String,
233 llm: Arc<dyn BaseLlm>,
234 prompt: String,
235 window_size: usize,
236 schema: Option<Value>,
237 schema_str: Option<String>,
239 min_words: usize,
241 trigger: ExtractionTrigger,
243 promotion_rules: Vec<FieldPromotion>,
245}
246
247impl LlmExtractor {
248 pub fn new(
255 name: impl Into<String>,
256 llm: Arc<dyn BaseLlm>,
257 prompt: impl Into<String>,
258 window_size: usize,
259 ) -> Self {
260 Self {
261 name: name.into(),
262 llm,
263 prompt: prompt.into(),
264 window_size,
265 schema: None,
266 schema_str: None,
267 min_words: 0,
268 trigger: ExtractionTrigger::EveryTurn,
269 promotion_rules: Vec::new(),
270 }
271 }
272
273 pub fn with_min_words(mut self, n: usize) -> Self {
278 self.min_words = n;
279 self
280 }
281
282 pub fn with_schema(mut self, schema: Value) -> Self {
287 self.schema_str = serde_json::to_string_pretty(&schema).ok();
288 self.schema = Some(schema);
289 self
290 }
291
292 pub fn with_trigger(mut self, trigger: ExtractionTrigger) -> Self {
294 self.trigger = trigger;
295 self
296 }
297
298 pub fn with_promotions(mut self, rules: Vec<FieldPromotion>) -> Self {
303 self.promotion_rules = rules;
304 self
305 }
306
307 fn format_transcript(window: &[TranscriptTurn]) -> String {
309 let mut out = String::new();
310 for turn in window {
311 if !turn.user.is_empty() {
312 out.push_str("User: ");
313 out.push_str(turn.user.trim());
314 out.push('\n');
315 }
316 if !turn.model.is_empty() {
317 out.push_str("Assistant: ");
318 out.push_str(turn.model.trim());
319 out.push('\n');
320 }
321 out.push('\n');
322 }
323 out
324 }
325}
326
327#[async_trait]
328impl TurnExtractor for LlmExtractor {
329 fn name(&self) -> &str {
330 &self.name
331 }
332
333 fn window_size(&self) -> usize {
334 self.window_size
335 }
336
337 fn should_extract(&self, window: &[TranscriptTurn]) -> bool {
338 if self.min_words == 0 {
339 return true;
340 }
341 window
343 .iter()
344 .rev()
345 .find(|t| !t.user.is_empty())
346 .is_some_and(|t| t.user.split_whitespace().count() >= self.min_words)
347 }
348
349 fn trigger(&self) -> ExtractionTrigger {
350 self.trigger.clone()
351 }
352
353 fn promotion_rules(&self) -> &[FieldPromotion] {
354 &self.promotion_rules
355 }
356
357 async fn extract(&self, window: &[TranscriptTurn]) -> Result<Value, LlmError> {
358 let transcript = Self::format_transcript(window);
359
360 let mut request = LlmRequest::from_text(format!(
361 "Transcript:\n{transcript}\nExtract the requested information."
362 ));
363 request.system_instruction = Some(self.prompt.clone());
364
365 if let Some(ref schema) = self.schema {
369 request.response_mime_type = Some("application/json".to_string());
370 request.response_json_schema = Some(schema.clone());
371 } else {
372 request.response_mime_type = Some("application/json".to_string());
373 }
374
375 let response = self.llm.generate(request).await?;
376 let text = response.text();
377
378 let cleaned = strip_code_fences(&text);
380
381 serde_json::from_str(cleaned).map_err(|e| {
382 LlmError::Other(format!(
383 "Failed to parse extraction result as JSON: {e}. Raw: {text}"
384 ))
385 })
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::llm::LlmResponse;
393 use gemini_genai_rs::prelude::{Content, Part, Role};
394 use std::time::Instant;
395
396 struct MockLlm {
397 response: String,
398 }
399
400 #[async_trait]
401 impl BaseLlm for MockLlm {
402 fn model_id(&self) -> &str {
403 "mock"
404 }
405 async fn generate(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
406 Ok(LlmResponse {
407 content: Content {
408 role: Some(Role::Model),
409 parts: vec![Part::Text {
410 text: self.response.clone(),
411 }],
412 },
413 finish_reason: Some("STOP".into()),
414 usage: None,
415 })
416 }
417 }
418
419 fn make_turns(pairs: &[(&str, &str)]) -> Vec<TranscriptTurn> {
420 pairs
421 .iter()
422 .enumerate()
423 .map(|(i, (user, model))| TranscriptTurn {
424 turn_number: i as u32,
425 user: user.to_string(),
426 model: model.to_string(),
427 tool_calls: Vec::new(),
428 timestamp: Instant::now(),
429 })
430 .collect()
431 }
432
433 #[tokio::test]
434 async fn llm_extractor_produces_json() {
435 let llm = Arc::new(MockLlm {
436 response: r#"{"phase": "ordering", "items": ["pizza"]}"#.to_string(),
437 });
438
439 let extractor = LlmExtractor::new("OrderState", llm, "Extract order state", 3);
440
441 let turns = make_turns(&[
442 ("I'd like a pizza", "Great! What size?"),
443 ("Large please", "Coming right up!"),
444 ]);
445
446 let result = extractor.extract(&turns).await.unwrap();
447 assert_eq!(result["phase"], "ordering");
448 assert_eq!(result["items"][0], "pizza");
449 }
450
451 #[tokio::test]
452 async fn llm_extractor_with_schema() {
453 let llm = Arc::new(MockLlm {
454 response: r#"{"sentiment": "positive", "score": 0.9}"#.to_string(),
455 });
456
457 let schema = serde_json::json!({
458 "type": "object",
459 "properties": {
460 "sentiment": {"type": "string", "enum": ["positive", "neutral", "negative"]},
461 "score": {"type": "number"}
462 }
463 });
464
465 let extractor =
466 LlmExtractor::new("Sentiment", llm, "Rate sentiment", 1).with_schema(schema);
467
468 let turns = make_turns(&[("This is great!", "Glad you think so!")]);
469 let result = extractor.extract(&turns).await.unwrap();
470 assert_eq!(result["sentiment"], "positive");
471 }
472
473 #[tokio::test]
474 async fn llm_extractor_invalid_json_returns_error() {
475 let llm = Arc::new(MockLlm {
476 response: "not json at all".to_string(),
477 });
478
479 let extractor = LlmExtractor::new("Bad", llm, "Extract", 1);
480 let turns = make_turns(&[("hi", "hello")]);
481 let result = extractor.extract(&turns).await;
482 assert!(result.is_err());
483 }
484
485 #[test]
486 fn format_transcript_readable() {
487 let turns = make_turns(&[("Hello", "Hi there!"), ("How are you?", "I'm doing well")]);
488 let formatted = LlmExtractor::format_transcript(&turns);
489 assert!(formatted.contains("User: Hello"));
490 assert!(formatted.contains("Assistant: Hi there!"));
491 assert!(formatted.contains("User: How are you?"));
492 }
493
494 #[tokio::test]
495 async fn llm_extractor_handles_markdown_fenced_json() {
496 let llm = Arc::new(MockLlm {
497 response: "```json\n{\"status\": \"ok\"}\n```".to_string(),
498 });
499
500 let extractor = LlmExtractor::new("Fenced", llm, "Extract", 1);
501 let turns = make_turns(&[("test", "reply")]);
502 let result = extractor.extract(&turns).await.unwrap();
503 assert_eq!(result["status"], "ok");
504 }
505
506 #[test]
507 fn strip_code_fences_variants() {
508 assert_eq!(super::strip_code_fences("```json\n{}\n```"), "{}");
509 assert_eq!(super::strip_code_fences("```\n{}\n```"), "{}");
510 assert_eq!(
511 super::strip_code_fences(" ```json\n{\"a\":1}\n``` "),
512 "{\"a\":1}"
513 );
514 assert_eq!(
515 super::strip_code_fences("{\"bare\":true}"),
516 "{\"bare\":true}"
517 );
518 }
519
520 #[test]
521 fn extractor_name_and_window_size() {
522 let llm = Arc::new(MockLlm {
523 response: "{}".to_string(),
524 });
525 let ext = LlmExtractor::new("TestExtractor", llm, "test", 5);
526 assert_eq!(ext.name(), "TestExtractor");
527 assert_eq!(ext.window_size(), 5);
528 }
529
530 #[test]
531 fn extractor_default_trigger_is_every_turn() {
532 let llm = Arc::new(MockLlm {
533 response: "{}".to_string(),
534 });
535 let ext = LlmExtractor::new("Test", llm, "test", 5);
536 assert_eq!(ext.trigger(), ExtractionTrigger::EveryTurn);
537 }
538
539 #[test]
540 fn extractor_with_trigger() {
541 let llm = Arc::new(MockLlm {
542 response: "{}".to_string(),
543 });
544 let ext = LlmExtractor::new("Test", llm, "test", 5)
545 .with_trigger(ExtractionTrigger::AfterToolCall);
546 assert_eq!(ext.trigger(), ExtractionTrigger::AfterToolCall);
547 }
548}