gemini_adk_fluent_rs/live/
mod.rs1mod callbacks;
28mod config;
29mod connect;
30mod extraction;
31mod phases;
32
33use std::collections::HashMap;
34use std::sync::Arc;
35use std::time::Duration;
36
37use gemini_adk_rs::live::extractor::TurnExtractor;
38use gemini_adk_rs::live::needs::RepairConfig;
39use gemini_adk_rs::live::persistence::SessionPersistence;
40use gemini_adk_rs::live::steering::{ContextDelivery, SteeringMode};
41use gemini_adk_rs::live::{
42 ComputedRegistry, EventCallbacks, InstructionModifier, Phase, TemporalRegistry,
43 ToolExecutionMode, WatcherRegistry,
44};
45use gemini_adk_rs::llm::BaseLlm;
46use gemini_adk_rs::tool::ToolDispatcher;
47use gemini_genai_rs::prelude::*;
48
49pub(crate) struct DeferredAgentTool {
51 pub(crate) name: String,
52 pub(crate) description: String,
53 pub(crate) agent: Arc<dyn gemini_adk_rs::text::TextAgent>,
54}
55
56pub struct Live {
99 pub(crate) config: SessionConfig,
100 pub(crate) callbacks: EventCallbacks,
101 pub(crate) dispatcher: Option<ToolDispatcher>,
102 pub(crate) extractors: Vec<Arc<dyn TurnExtractor>>,
103 pub(crate) computed: ComputedRegistry,
105 pub(crate) phases: Vec<Phase>,
106 pub(crate) initial_phase: Option<String>,
107 pub(crate) watchers: WatcherRegistry,
108 pub(crate) temporal: TemporalRegistry,
109 pub(crate) greeting: Option<String>,
110 pub(crate) phase_default_modifiers: Vec<InstructionModifier>,
112 pub(crate) phase_default_prompt_on_enter: bool,
113 pub(crate) tool_execution_modes: HashMap<String, ToolExecutionMode>,
115 pub(crate) deferred_agent_tools: Vec<DeferredAgentTool>,
117 pub(crate) warm_up_llms: Vec<Arc<dyn BaseLlm>>,
119 pub(crate) soft_turn_timeout: Option<Duration>,
121 pub(crate) steering_mode: SteeringMode,
122 pub(crate) context_delivery: ContextDelivery,
123 pub(crate) repair_config: Option<RepairConfig>,
124 pub(crate) persistence: Option<Arc<dyn SessionPersistence>>,
125 pub(crate) session_id: Option<String>,
126 pub(crate) tool_advisory: bool,
127 pub(crate) telemetry_interval: Option<Duration>,
128}
129
130impl Live {
131 pub fn builder() -> Self {
172 Self {
173 config: SessionConfig::from_endpoint(ApiEndpoint::google_ai("")),
174 callbacks: EventCallbacks::default(),
175 dispatcher: None,
176 extractors: Vec::new(),
177 computed: ComputedRegistry::new(),
178 phases: Vec::new(),
179 initial_phase: None,
180 watchers: WatcherRegistry::new(),
181 temporal: TemporalRegistry::new(),
182 greeting: None,
183 phase_default_modifiers: Vec::new(),
184 phase_default_prompt_on_enter: false,
185 tool_execution_modes: HashMap::new(),
186 deferred_agent_tools: Vec::new(),
187 warm_up_llms: Vec::new(),
188 soft_turn_timeout: None,
189 steering_mode: SteeringMode::default(),
190 context_delivery: ContextDelivery::default(),
191 repair_config: None,
192 persistence: None,
193 session_id: None,
194 tool_advisory: true,
195 telemetry_interval: None,
196 }
197 }
198
199 pub fn telemetry_interval(mut self, interval: Duration) -> Self {
204 self.telemetry_interval = Some(interval);
205 self
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use std::sync::Arc;
213 use std::time::Duration;
214
215 #[test]
216 fn builder_chain_compiles() {
217 let _live = Live::builder()
218 .model(GeminiModel::Gemini2_0FlashLive)
219 .voice(Voice::Kore)
220 .instruction("Test")
221 .temperature(0.7)
222 .google_search()
223 .transcription(true, true)
224 .affective_dialog(true)
225 .session_resume(true)
226 .context_compression(4000, 2000)
227 .on_audio(|_data| {})
228 .on_text(|_t| {})
229 .on_vad_start(|| {})
230 .on_interrupted(|| async {})
231 .on_turn_complete(|| async {})
232 .on_go_away(|_d| async {})
233 .on_connected(|_writer| async {})
234 .on_disconnected(|_r| async {})
235 .on_error(|_e| async {});
236 }
238
239 #[test]
240 fn builder_with_extraction_compiles() {
241 use gemini_adk_rs::llm::{BaseLlm, LlmError, LlmRequest, LlmResponse};
242 use schemars::JsonSchema;
243
244 #[derive(serde::Deserialize, serde::Serialize, JsonSchema)]
245 struct OrderState {
246 phase: String,
247 items: Vec<String>,
248 }
249
250 struct FakeLlm;
251
252 #[async_trait::async_trait]
253 impl BaseLlm for FakeLlm {
254 fn model_id(&self) -> &str {
255 "fake"
256 }
257 async fn generate(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
258 unimplemented!()
259 }
260 }
261
262 let _live = Live::builder()
263 .model(GeminiModel::Gemini2_0FlashLive)
264 .instruction("Restaurant order assistant")
265 .extract_turns::<OrderState>(
266 Arc::new(FakeLlm),
267 "Extract order state: items, quantities, phase",
268 )
269 .on_extracted(|name, value| async move {
270 let _ = (name, value);
271 })
272 .before_tool_response(|responses, _state| async move {
274 responses })
276 .on_turn_boundary(|_state, _writer| async move {
277 })
279 .instruction_template(|state| {
280 let phase: String = state.get("phase").unwrap_or_default();
281 match phase.as_str() {
282 "ordering" => Some("Take orders accurately.".into()),
283 _ => None,
284 }
285 });
286 }
288
289 #[test]
290 fn builder_with_computed_state_compiles() {
291 let _live = Live::builder()
292 .model(GeminiModel::Gemini2_0FlashLive)
293 .instruction("Test computed state")
294 .computed("doubled", &["app:count"], |state| {
295 let count: i64 = state.get("app:count")?;
296 Some(serde_json::json!(count * 2))
297 })
298 .computed("level", &["app:score"], |state| {
299 let score: f64 = state.get("app:score")?;
300 if score > 0.5 {
301 Some(serde_json::json!("high"))
302 } else {
303 Some(serde_json::json!("low"))
304 }
305 });
306 }
307
308 #[test]
309 fn builder_with_phases_compiles() {
310 let _live = Live::builder()
311 .model(GeminiModel::Gemini2_0FlashLive)
312 .phase("greeting")
313 .instruction("Welcome the user warmly")
314 .transition("main", |s| s.get::<bool>("greeted").unwrap_or(false))
315 .on_enter(|state, _writer| async move {
316 state.set("entered_greeting", true);
317 })
318 .done()
319 .phase("main")
320 .dynamic_instruction(|s| {
321 let topic: String = s.get("topic").unwrap_or_default();
322 format!("Discuss {topic}")
323 })
324 .tools(vec!["search".into(), "lookup".into()])
325 .transition("farewell", |s| s.get::<bool>("done").unwrap_or(false))
326 .done()
327 .phase("farewell")
328 .instruction("Say goodbye")
329 .terminal()
330 .done()
331 .initial_phase("greeting");
332 }
333
334 #[test]
335 fn builder_with_phase_guard_compiles() {
336 let _live = Live::builder()
337 .model(GeminiModel::Gemini2_0FlashLive)
338 .phase("start")
339 .instruction("Begin")
340 .transition("secure", |_| true)
341 .done()
342 .phase("secure")
343 .instruction("Secure area")
344 .guard(|s| s.get::<bool>("verified").unwrap_or(false))
345 .on_exit(|state, _writer| async move {
346 state.set("left_secure", true);
347 })
348 .terminal()
349 .done()
350 .initial_phase("start");
351 }
352
353 #[test]
354 fn builder_with_watchers_compiles() {
355 let _live = Live::builder()
356 .model(GeminiModel::Gemini2_0FlashLive)
357 .watch("app:score")
358 .crossed_above(0.9)
359 .then(|_old, _new, state| async move {
360 state.set("high_score_alert", true);
361 })
362 .watch("app:status")
363 .changed_to(serde_json::json!("complete"))
364 .blocking()
365 .then(|_old, _new, _state| async move {
366 })
368 .watch("app:flag")
369 .became_true()
370 .then(|_old, _new, _state| async move {
371 });
373 }
374
375 #[test]
376 fn builder_with_temporal_patterns_compiles() {
377 let _live = Live::builder()
378 .model(GeminiModel::Gemini2_0FlashLive)
379 .when_sustained(
380 "user_confused",
381 |s| s.get::<bool>("confused").unwrap_or(false),
382 Duration::from_secs(30),
383 |_state, _writer| async move {
384 },
386 )
387 .when_rate(
388 "rapid_errors",
389 |evt| matches!(evt, SessionEvent::TextDelta(_)),
390 5,
391 Duration::from_secs(10),
392 |_state, _writer| async move {
393 },
395 )
396 .when_turns(
397 "stuck_in_loop",
398 |s| s.get::<bool>("repeating").unwrap_or(false),
399 3,
400 |_state, _writer| async move {
401 },
403 );
404 }
405
406 #[test]
407 fn builder_full_l1_chain_compiles() {
408 let _live = Live::builder()
410 .model(GeminiModel::Gemini2_0FlashLive)
411 .voice(Voice::Kore)
412 .instruction("Full featured agent")
413 .computed("sentiment_level", &["app:sentiment_score"], |state| {
415 let score: f64 = state.get("app:sentiment_score")?;
416 if score > 0.7 {
417 Some(serde_json::json!("positive"))
418 } else if score < 0.3 {
419 Some(serde_json::json!("negative"))
420 } else {
421 Some(serde_json::json!("neutral"))
422 }
423 })
424 .phase("greeting")
426 .instruction("Greet the user")
427 .transition("help", |s| s.get::<bool>("needs_help").unwrap_or(false))
428 .done()
429 .phase("help")
430 .instruction("Help the user")
431 .terminal()
432 .done()
433 .initial_phase("greeting")
434 .watch("app:sentiment_score")
436 .crossed_below(0.2)
437 .then(|_old, _new, state| async move {
438 state.set("alert:low_sentiment", true);
439 })
440 .when_turns(
442 "repeated_confusion",
443 |s| s.get::<bool>("confused").unwrap_or(false),
444 3,
445 |_state, _writer| async move {},
446 )
447 .on_audio(|_data| {})
449 .on_text(|_t| {})
450 .on_turn_complete(|| async {});
451 }
452
453 #[test]
454 fn builder_with_callback_modes_compiles() {
455 let _live = Live::builder()
456 .model(GeminiModel::Gemini2_0FlashLive)
457 .on_turn_complete_concurrent(|| async {})
458 .on_error_concurrent(|_e| async {})
459 .on_extracted_concurrent(|_name, _val| async {})
460 .on_extraction_error_concurrent(|_name, _err| async {})
461 .on_connected_concurrent(|_w| async {})
462 .on_disconnected_concurrent(|_r| async {})
463 .on_go_away_concurrent(|_d| async {});
464 }
465
466 #[test]
467 fn builder_with_background_tools_compiles() {
468 use gemini_adk_rs::live::DefaultResultFormatter;
469
470 let _live = Live::builder()
471 .model(GeminiModel::Gemini2_0FlashLive)
472 .tool_background("search_kb")
473 .tool_background_with_formatter("analyze_document", Arc::new(DefaultResultFormatter));
474 }
475
476 #[test]
477 fn builder_mixed_callback_modes_and_bg_tools() {
478 use gemini_adk_rs::live::DefaultResultFormatter;
479
480 let _live = Live::builder()
481 .model(GeminiModel::Gemini2_0FlashLive)
482 .voice(Voice::Kore)
483 .instruction("Full featured agent")
484 .tool_background("slow_tool")
485 .tool_background_with_formatter("kb_search", Arc::new(DefaultResultFormatter))
486 .on_turn_complete_concurrent(|| async {})
487 .on_extracted_concurrent(|_name, _val| async {})
488 .on_audio(|_data| {})
489 .on_text(|_t| {})
490 .on_interrupted(|| async {});
491 }
492}