1use std::collections::HashMap;
8use std::path::PathBuf;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionSnapshot {
17 pub state: HashMap<String, Value>,
19 pub phase: String,
21 pub turn_count: u32,
23 pub transcript_summary: String,
25 pub resume_handle: Option<String>,
27 pub saved_at: String,
29}
30
31#[async_trait]
35pub trait SessionPersistence: Send + Sync {
36 async fn save(
38 &self,
39 session_id: &str,
40 snapshot: &SessionSnapshot,
41 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
42
43 async fn load(
45 &self,
46 session_id: &str,
47 ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>>;
48
49 async fn delete(
51 &self,
52 session_id: &str,
53 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
54}
55
56pub struct FsPersistence {
58 dir: PathBuf,
59}
60
61impl FsPersistence {
62 pub fn new(dir: impl Into<PathBuf>) -> Self {
66 Self { dir: dir.into() }
67 }
68
69 fn path(&self, session_id: &str) -> PathBuf {
70 self.dir.join(format!("{}.json", session_id))
71 }
72
73 fn tmp_path(&self, session_id: &str) -> PathBuf {
74 self.dir.join(format!("{}.json.tmp", session_id))
75 }
76}
77
78#[async_trait]
79impl SessionPersistence for FsPersistence {
80 async fn save(
81 &self,
82 session_id: &str,
83 snapshot: &SessionSnapshot,
84 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
85 tokio::fs::create_dir_all(&self.dir).await?;
86 let json = serde_json::to_string_pretty(snapshot)?;
87 let tmp = self.tmp_path(session_id);
93 tokio::fs::write(&tmp, json).await?;
94 tokio::fs::rename(&tmp, self.path(session_id)).await?;
95 Ok(())
96 }
97
98 async fn load(
99 &self,
100 session_id: &str,
101 ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
102 let path = self.path(session_id);
103 match tokio::fs::read_to_string(&path).await {
104 Ok(json) => Ok(Some(serde_json::from_str(&json)?)),
105 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
106 Err(e) => Err(e.into()),
107 }
108 }
109
110 async fn delete(
111 &self,
112 session_id: &str,
113 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
114 let path = self.path(session_id);
115 match tokio::fs::remove_file(&path).await {
116 Ok(()) => Ok(()),
117 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
118 Err(e) => Err(e.into()),
119 }
120 }
121}
122
123pub struct MemoryPersistence {
125 store: std::sync::Arc<dashmap::DashMap<String, SessionSnapshot>>,
126}
127
128impl MemoryPersistence {
129 pub fn new() -> Self {
131 Self {
132 store: std::sync::Arc::new(dashmap::DashMap::new()),
133 }
134 }
135}
136
137impl Default for MemoryPersistence {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143#[async_trait]
144impl SessionPersistence for MemoryPersistence {
145 async fn save(
146 &self,
147 session_id: &str,
148 snapshot: &SessionSnapshot,
149 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
150 self.store.insert(session_id.to_string(), snapshot.clone());
151 Ok(())
152 }
153
154 async fn load(
155 &self,
156 session_id: &str,
157 ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
158 Ok(self.store.get(session_id).map(|v| v.value().clone()))
159 }
160
161 async fn delete(
162 &self,
163 session_id: &str,
164 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
165 self.store.remove(session_id);
166 Ok(())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[tokio::test]
175 async fn memory_persistence_round_trip() {
176 let p = MemoryPersistence::new();
177 let snapshot = SessionSnapshot {
178 state: [("name".into(), Value::String("Alice".into()))]
179 .into_iter()
180 .collect(),
181 phase: "greeting".into(),
182 turn_count: 5,
183 transcript_summary: "User: Hello\nAssistant: Hi!".into(),
184 resume_handle: Some("handle-123".into()),
185 saved_at: "2026-03-07T00:00:00Z".into(),
186 };
187
188 p.save("session-1", &snapshot).await.unwrap();
189
190 let loaded = p.load("session-1").await.unwrap().unwrap();
191 assert_eq!(loaded.phase, "greeting");
192 assert_eq!(loaded.turn_count, 5);
193 assert_eq!(loaded.resume_handle, Some("handle-123".into()));
194 }
195
196 #[tokio::test]
197 async fn memory_persistence_load_missing() {
198 let p = MemoryPersistence::new();
199 assert!(p.load("nonexistent").await.unwrap().is_none());
200 }
201
202 #[tokio::test]
203 async fn memory_persistence_delete() {
204 let p = MemoryPersistence::new();
205 let snapshot = SessionSnapshot {
206 state: HashMap::new(),
207 phase: "test".into(),
208 turn_count: 0,
209 transcript_summary: String::new(),
210 resume_handle: None,
211 saved_at: "2026-03-07T00:00:00Z".into(),
212 };
213
214 p.save("session-1", &snapshot).await.unwrap();
215 p.delete("session-1").await.unwrap();
216 assert!(p.load("session-1").await.unwrap().is_none());
217 }
218
219 #[tokio::test]
220 async fn fs_persistence_round_trip() {
221 let dir = std::env::temp_dir().join("gemini_rs_test_persistence");
222 let p = FsPersistence::new(&dir);
223 let snapshot = SessionSnapshot {
224 state: [("key".into(), Value::from(42))].into_iter().collect(),
225 phase: "main".into(),
226 turn_count: 3,
227 transcript_summary: "test".into(),
228 resume_handle: None,
229 saved_at: "2026-03-07T00:00:00Z".into(),
230 };
231
232 p.save("test-session", &snapshot).await.unwrap();
233 let loaded = p.load("test-session").await.unwrap().unwrap();
234 assert_eq!(loaded.phase, "main");
235
236 p.delete("test-session").await.unwrap();
238 let _ = tokio::fs::remove_dir_all(&dir).await;
239 }
240
241 #[tokio::test]
242 async fn fs_persistence_save_is_atomic_and_leaves_no_tmp_file() {
243 let dir = std::env::temp_dir().join(format!(
244 "gemini_rs_test_persistence_atomic_{}",
245 uuid::Uuid::new_v4()
246 ));
247 let p = FsPersistence::new(&dir);
248 let snapshot = SessionSnapshot {
249 state: HashMap::new(),
250 phase: "main".into(),
251 turn_count: 1,
252 transcript_summary: "x".into(),
253 resume_handle: None,
254 saved_at: "now".into(),
255 };
256
257 p.save("atomic-session", &snapshot).await.unwrap();
258
259 assert!(
260 !p.tmp_path("atomic-session").exists(),
261 "tmp file must be renamed away after save"
262 );
263 assert!(p.path("atomic-session").exists());
264
265 let _ = tokio::fs::remove_dir_all(&dir).await;
266 }
267
268 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
269 async fn fs_persistence_concurrent_loads_never_observe_torn_snapshot() {
270 let dir = std::env::temp_dir().join(format!(
274 "gemini_rs_test_persistence_torn_{}",
275 uuid::Uuid::new_v4()
276 ));
277 let p = std::sync::Arc::new(FsPersistence::new(&dir));
278
279 let big = "x".repeat(256 * 1024);
281 let snapshot = SessionSnapshot {
282 state: [("blob".to_string(), Value::String(big))]
283 .into_iter()
284 .collect(),
285 phase: "main".into(),
286 turn_count: 0,
287 transcript_summary: String::new(),
288 resume_handle: None,
289 saved_at: "now".into(),
290 };
291 p.save("torn", &snapshot).await.unwrap();
292
293 let writer = {
294 let p = p.clone();
295 let snapshot = snapshot.clone();
296 tokio::spawn(async move {
297 for _ in 0..50 {
298 p.save("torn", &snapshot).await.unwrap();
299 }
300 })
301 };
302
303 for _ in 0..200 {
304 let loaded = p
305 .load("torn")
306 .await
307 .expect("load must never observe a torn snapshot");
308 assert!(loaded.is_some(), "snapshot must always be present");
309 }
310
311 writer.await.unwrap();
312 let _ = tokio::fs::remove_dir_all(&dir).await;
313 }
314}