gemini_adk_rs/live/
persistence.rs

1//! Session persistence — survive process restarts.
2//!
3//! The Gemini Live API supports session resumption via opaque handles.
4//! This module persists the SDK's client-side state (State, phase position,
5//! transcript summary) so it can be restored on reconnection.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14/// Serializable snapshot of the control plane state.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionSnapshot {
17    /// All state key-value pairs.
18    pub state: HashMap<String, Value>,
19    /// Current phase name.
20    pub phase: String,
21    /// Turn count at time of snapshot.
22    pub turn_count: u32,
23    /// Human-readable summary of recent transcript.
24    pub transcript_summary: String,
25    /// Resume handle from the Gemini server.
26    pub resume_handle: Option<String>,
27    /// ISO 8601 timestamp.
28    pub saved_at: String,
29}
30
31/// Trait for persisting session state across process restarts.
32///
33/// Implementations might write to the filesystem, Redis, Firestore, etc.
34#[async_trait]
35pub trait SessionPersistence: Send + Sync {
36    /// Save a session snapshot.
37    async fn save(
38        &self,
39        session_id: &str,
40        snapshot: &SessionSnapshot,
41    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
42
43    /// Load a previously saved session snapshot.
44    async fn load(
45        &self,
46        session_id: &str,
47    ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>>;
48
49    /// Delete a saved session.
50    async fn delete(
51        &self,
52        session_id: &str,
53    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
54}
55
56/// File-system persistence (good for development and single-server deployments).
57pub struct FsPersistence {
58    dir: PathBuf,
59}
60
61impl FsPersistence {
62    /// Create a new file-system persistence backend.
63    ///
64    /// The directory will be created if it doesn't exist.
65    pub fn new(dir: impl Into<PathBuf>) -> Self {
66        Self { dir: dir.into() }
67    }
68
69    fn path(&self, session_id: &str) -> PathBuf {
70        self.dir.join(format!("{}.json", session_id))
71    }
72}
73
74#[async_trait]
75impl SessionPersistence for FsPersistence {
76    async fn save(
77        &self,
78        session_id: &str,
79        snapshot: &SessionSnapshot,
80    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
81        tokio::fs::create_dir_all(&self.dir).await?;
82        let json = serde_json::to_string_pretty(snapshot)?;
83        tokio::fs::write(self.path(session_id), json).await?;
84        Ok(())
85    }
86
87    async fn load(
88        &self,
89        session_id: &str,
90    ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
91        let path = self.path(session_id);
92        match tokio::fs::read_to_string(&path).await {
93            Ok(json) => Ok(Some(serde_json::from_str(&json)?)),
94            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
95            Err(e) => Err(e.into()),
96        }
97    }
98
99    async fn delete(
100        &self,
101        session_id: &str,
102    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
103        let path = self.path(session_id);
104        match tokio::fs::remove_file(&path).await {
105            Ok(()) => Ok(()),
106            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
107            Err(e) => Err(e.into()),
108        }
109    }
110}
111
112/// In-memory persistence (good for tests).
113pub struct MemoryPersistence {
114    store: std::sync::Arc<dashmap::DashMap<String, SessionSnapshot>>,
115}
116
117impl MemoryPersistence {
118    /// Create a new in-memory persistence backend.
119    pub fn new() -> Self {
120        Self {
121            store: std::sync::Arc::new(dashmap::DashMap::new()),
122        }
123    }
124}
125
126impl Default for MemoryPersistence {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132#[async_trait]
133impl SessionPersistence for MemoryPersistence {
134    async fn save(
135        &self,
136        session_id: &str,
137        snapshot: &SessionSnapshot,
138    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
139        self.store.insert(session_id.to_string(), snapshot.clone());
140        Ok(())
141    }
142
143    async fn load(
144        &self,
145        session_id: &str,
146    ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
147        Ok(self.store.get(session_id).map(|v| v.value().clone()))
148    }
149
150    async fn delete(
151        &self,
152        session_id: &str,
153    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
154        self.store.remove(session_id);
155        Ok(())
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[tokio::test]
164    async fn memory_persistence_round_trip() {
165        let p = MemoryPersistence::new();
166        let snapshot = SessionSnapshot {
167            state: [("name".into(), Value::String("Alice".into()))]
168                .into_iter()
169                .collect(),
170            phase: "greeting".into(),
171            turn_count: 5,
172            transcript_summary: "User: Hello\nAssistant: Hi!".into(),
173            resume_handle: Some("handle-123".into()),
174            saved_at: "2026-03-07T00:00:00Z".into(),
175        };
176
177        p.save("session-1", &snapshot).await.unwrap();
178
179        let loaded = p.load("session-1").await.unwrap().unwrap();
180        assert_eq!(loaded.phase, "greeting");
181        assert_eq!(loaded.turn_count, 5);
182        assert_eq!(loaded.resume_handle, Some("handle-123".into()));
183    }
184
185    #[tokio::test]
186    async fn memory_persistence_load_missing() {
187        let p = MemoryPersistence::new();
188        assert!(p.load("nonexistent").await.unwrap().is_none());
189    }
190
191    #[tokio::test]
192    async fn memory_persistence_delete() {
193        let p = MemoryPersistence::new();
194        let snapshot = SessionSnapshot {
195            state: HashMap::new(),
196            phase: "test".into(),
197            turn_count: 0,
198            transcript_summary: String::new(),
199            resume_handle: None,
200            saved_at: "2026-03-07T00:00:00Z".into(),
201        };
202
203        p.save("session-1", &snapshot).await.unwrap();
204        p.delete("session-1").await.unwrap();
205        assert!(p.load("session-1").await.unwrap().is_none());
206    }
207
208    #[tokio::test]
209    async fn fs_persistence_round_trip() {
210        let dir = std::env::temp_dir().join("gemini_rs_test_persistence");
211        let p = FsPersistence::new(&dir);
212        let snapshot = SessionSnapshot {
213            state: [("key".into(), Value::from(42))].into_iter().collect(),
214            phase: "main".into(),
215            turn_count: 3,
216            transcript_summary: "test".into(),
217            resume_handle: None,
218            saved_at: "2026-03-07T00:00:00Z".into(),
219        };
220
221        p.save("test-session", &snapshot).await.unwrap();
222        let loaded = p.load("test-session").await.unwrap().unwrap();
223        assert_eq!(loaded.phase, "main");
224
225        // Cleanup
226        p.delete("test-session").await.unwrap();
227        let _ = tokio::fs::remove_dir_all(&dir).await;
228    }
229}