gemini_adk_rs/live/
persistence.rs

1//! Session persistence — survive process restarts.
2//!
3//! The Gemini Live API supports session resumption via opaque handles.
4//! This module persists the SDK's client-side state (State, phase position,
5//! transcript summary) so it can be restored on reconnection.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14/// Serializable snapshot of the control plane state.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionSnapshot {
17    /// All state key-value pairs.
18    pub state: HashMap<String, Value>,
19    /// Current phase name.
20    pub phase: String,
21    /// Turn count at time of snapshot.
22    pub turn_count: u32,
23    /// Human-readable summary of recent transcript.
24    pub transcript_summary: String,
25    /// Resume handle from the Gemini server.
26    pub resume_handle: Option<String>,
27    /// ISO 8601 timestamp.
28    pub saved_at: String,
29}
30
31/// Trait for persisting session state across process restarts.
32///
33/// Implementations might write to the filesystem, Redis, Firestore, etc.
34#[async_trait]
35pub trait SessionPersistence: Send + Sync {
36    /// Save a session snapshot.
37    async fn save(
38        &self,
39        session_id: &str,
40        snapshot: &SessionSnapshot,
41    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
42
43    /// Load a previously saved session snapshot.
44    async fn load(
45        &self,
46        session_id: &str,
47    ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>>;
48
49    /// Delete a saved session.
50    async fn delete(
51        &self,
52        session_id: &str,
53    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
54}
55
56/// File-system persistence (good for development and single-server deployments).
57pub struct FsPersistence {
58    dir: PathBuf,
59}
60
61impl FsPersistence {
62    /// Create a new file-system persistence backend.
63    ///
64    /// The directory will be created if it doesn't exist.
65    pub fn new(dir: impl Into<PathBuf>) -> Self {
66        Self { dir: dir.into() }
67    }
68
69    fn path(&self, session_id: &str) -> PathBuf {
70        self.dir.join(format!("{}.json", session_id))
71    }
72
73    fn tmp_path(&self, session_id: &str) -> PathBuf {
74        self.dir.join(format!("{}.json.tmp", session_id))
75    }
76}
77
78#[async_trait]
79impl SessionPersistence for FsPersistence {
80    async fn save(
81        &self,
82        session_id: &str,
83        snapshot: &SessionSnapshot,
84    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
85        tokio::fs::create_dir_all(&self.dir).await?;
86        let json = serde_json::to_string_pretty(snapshot)?;
87        // Write to a sibling temp file, then atomically rename over the
88        // destination. `rename(2)` is atomic on the same filesystem, so a
89        // crash mid-write (or a concurrent `load`) only ever observes the
90        // previous complete snapshot or the new complete snapshot — never a
91        // torn half-write.
92        let tmp = self.tmp_path(session_id);
93        tokio::fs::write(&tmp, json).await?;
94        tokio::fs::rename(&tmp, self.path(session_id)).await?;
95        Ok(())
96    }
97
98    async fn load(
99        &self,
100        session_id: &str,
101    ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
102        let path = self.path(session_id);
103        match tokio::fs::read_to_string(&path).await {
104            Ok(json) => Ok(Some(serde_json::from_str(&json)?)),
105            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
106            Err(e) => Err(e.into()),
107        }
108    }
109
110    async fn delete(
111        &self,
112        session_id: &str,
113    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
114        let path = self.path(session_id);
115        match tokio::fs::remove_file(&path).await {
116            Ok(()) => Ok(()),
117            Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
118            Err(e) => Err(e.into()),
119        }
120    }
121}
122
123/// In-memory persistence (good for tests).
124pub struct MemoryPersistence {
125    store: std::sync::Arc<dashmap::DashMap<String, SessionSnapshot>>,
126}
127
128impl MemoryPersistence {
129    /// Create a new in-memory persistence backend.
130    pub fn new() -> Self {
131        Self {
132            store: std::sync::Arc::new(dashmap::DashMap::new()),
133        }
134    }
135}
136
137impl Default for MemoryPersistence {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143#[async_trait]
144impl SessionPersistence for MemoryPersistence {
145    async fn save(
146        &self,
147        session_id: &str,
148        snapshot: &SessionSnapshot,
149    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
150        self.store.insert(session_id.to_string(), snapshot.clone());
151        Ok(())
152    }
153
154    async fn load(
155        &self,
156        session_id: &str,
157    ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
158        Ok(self.store.get(session_id).map(|v| v.value().clone()))
159    }
160
161    async fn delete(
162        &self,
163        session_id: &str,
164    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
165        self.store.remove(session_id);
166        Ok(())
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[tokio::test]
175    async fn memory_persistence_round_trip() {
176        let p = MemoryPersistence::new();
177        let snapshot = SessionSnapshot {
178            state: [("name".into(), Value::String("Alice".into()))]
179                .into_iter()
180                .collect(),
181            phase: "greeting".into(),
182            turn_count: 5,
183            transcript_summary: "User: Hello\nAssistant: Hi!".into(),
184            resume_handle: Some("handle-123".into()),
185            saved_at: "2026-03-07T00:00:00Z".into(),
186        };
187
188        p.save("session-1", &snapshot).await.unwrap();
189
190        let loaded = p.load("session-1").await.unwrap().unwrap();
191        assert_eq!(loaded.phase, "greeting");
192        assert_eq!(loaded.turn_count, 5);
193        assert_eq!(loaded.resume_handle, Some("handle-123".into()));
194    }
195
196    #[tokio::test]
197    async fn memory_persistence_load_missing() {
198        let p = MemoryPersistence::new();
199        assert!(p.load("nonexistent").await.unwrap().is_none());
200    }
201
202    #[tokio::test]
203    async fn memory_persistence_delete() {
204        let p = MemoryPersistence::new();
205        let snapshot = SessionSnapshot {
206            state: HashMap::new(),
207            phase: "test".into(),
208            turn_count: 0,
209            transcript_summary: String::new(),
210            resume_handle: None,
211            saved_at: "2026-03-07T00:00:00Z".into(),
212        };
213
214        p.save("session-1", &snapshot).await.unwrap();
215        p.delete("session-1").await.unwrap();
216        assert!(p.load("session-1").await.unwrap().is_none());
217    }
218
219    #[tokio::test]
220    async fn fs_persistence_round_trip() {
221        let dir = std::env::temp_dir().join("gemini_rs_test_persistence");
222        let p = FsPersistence::new(&dir);
223        let snapshot = SessionSnapshot {
224            state: [("key".into(), Value::from(42))].into_iter().collect(),
225            phase: "main".into(),
226            turn_count: 3,
227            transcript_summary: "test".into(),
228            resume_handle: None,
229            saved_at: "2026-03-07T00:00:00Z".into(),
230        };
231
232        p.save("test-session", &snapshot).await.unwrap();
233        let loaded = p.load("test-session").await.unwrap().unwrap();
234        assert_eq!(loaded.phase, "main");
235
236        // Cleanup
237        p.delete("test-session").await.unwrap();
238        let _ = tokio::fs::remove_dir_all(&dir).await;
239    }
240
241    #[tokio::test]
242    async fn fs_persistence_save_is_atomic_and_leaves_no_tmp_file() {
243        let dir = std::env::temp_dir().join(format!(
244            "gemini_rs_test_persistence_atomic_{}",
245            uuid::Uuid::new_v4()
246        ));
247        let p = FsPersistence::new(&dir);
248        let snapshot = SessionSnapshot {
249            state: HashMap::new(),
250            phase: "main".into(),
251            turn_count: 1,
252            transcript_summary: "x".into(),
253            resume_handle: None,
254            saved_at: "now".into(),
255        };
256
257        p.save("atomic-session", &snapshot).await.unwrap();
258
259        assert!(
260            !p.tmp_path("atomic-session").exists(),
261            "tmp file must be renamed away after save"
262        );
263        assert!(p.path("atomic-session").exists());
264
265        let _ = tokio::fs::remove_dir_all(&dir).await;
266    }
267
268    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
269    async fn fs_persistence_concurrent_loads_never_observe_torn_snapshot() {
270        // Hammer save() while load()ing concurrently: with the tmp+rename
271        // scheme every load parses; with the old direct `fs::write` a reader
272        // could observe a truncated/partial file mid-write.
273        let dir = std::env::temp_dir().join(format!(
274            "gemini_rs_test_persistence_torn_{}",
275            uuid::Uuid::new_v4()
276        ));
277        let p = std::sync::Arc::new(FsPersistence::new(&dir));
278
279        // A snapshot large enough that a write is not a single tiny syscall.
280        let big = "x".repeat(256 * 1024);
281        let snapshot = SessionSnapshot {
282            state: [("blob".to_string(), Value::String(big))]
283                .into_iter()
284                .collect(),
285            phase: "main".into(),
286            turn_count: 0,
287            transcript_summary: String::new(),
288            resume_handle: None,
289            saved_at: "now".into(),
290        };
291        p.save("torn", &snapshot).await.unwrap();
292
293        let writer = {
294            let p = p.clone();
295            let snapshot = snapshot.clone();
296            tokio::spawn(async move {
297                for _ in 0..50 {
298                    p.save("torn", &snapshot).await.unwrap();
299                }
300            })
301        };
302
303        for _ in 0..200 {
304            let loaded = p
305                .load("torn")
306                .await
307                .expect("load must never observe a torn snapshot");
308            assert!(loaded.is_some(), "snapshot must always be present");
309        }
310
311        writer.await.unwrap();
312        let _ = tokio::fs::remove_dir_all(&dir).await;
313    }
314}