gemini_adk_rs/live/
persistence.rs1use std::collections::HashMap;
8use std::path::PathBuf;
9
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionSnapshot {
17 pub state: HashMap<String, Value>,
19 pub phase: String,
21 pub turn_count: u32,
23 pub transcript_summary: String,
25 pub resume_handle: Option<String>,
27 pub saved_at: String,
29}
30
31#[async_trait]
35pub trait SessionPersistence: Send + Sync {
36 async fn save(
38 &self,
39 session_id: &str,
40 snapshot: &SessionSnapshot,
41 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
42
43 async fn load(
45 &self,
46 session_id: &str,
47 ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>>;
48
49 async fn delete(
51 &self,
52 session_id: &str,
53 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
54}
55
56pub struct FsPersistence {
58 dir: PathBuf,
59}
60
61impl FsPersistence {
62 pub fn new(dir: impl Into<PathBuf>) -> Self {
66 Self { dir: dir.into() }
67 }
68
69 fn path(&self, session_id: &str) -> PathBuf {
70 self.dir.join(format!("{}.json", session_id))
71 }
72}
73
74#[async_trait]
75impl SessionPersistence for FsPersistence {
76 async fn save(
77 &self,
78 session_id: &str,
79 snapshot: &SessionSnapshot,
80 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
81 tokio::fs::create_dir_all(&self.dir).await?;
82 let json = serde_json::to_string_pretty(snapshot)?;
83 tokio::fs::write(self.path(session_id), json).await?;
84 Ok(())
85 }
86
87 async fn load(
88 &self,
89 session_id: &str,
90 ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
91 let path = self.path(session_id);
92 match tokio::fs::read_to_string(&path).await {
93 Ok(json) => Ok(Some(serde_json::from_str(&json)?)),
94 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
95 Err(e) => Err(e.into()),
96 }
97 }
98
99 async fn delete(
100 &self,
101 session_id: &str,
102 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
103 let path = self.path(session_id);
104 match tokio::fs::remove_file(&path).await {
105 Ok(()) => Ok(()),
106 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
107 Err(e) => Err(e.into()),
108 }
109 }
110}
111
112pub struct MemoryPersistence {
114 store: std::sync::Arc<dashmap::DashMap<String, SessionSnapshot>>,
115}
116
117impl MemoryPersistence {
118 pub fn new() -> Self {
120 Self {
121 store: std::sync::Arc::new(dashmap::DashMap::new()),
122 }
123 }
124}
125
126impl Default for MemoryPersistence {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132#[async_trait]
133impl SessionPersistence for MemoryPersistence {
134 async fn save(
135 &self,
136 session_id: &str,
137 snapshot: &SessionSnapshot,
138 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
139 self.store.insert(session_id.to_string(), snapshot.clone());
140 Ok(())
141 }
142
143 async fn load(
144 &self,
145 session_id: &str,
146 ) -> Result<Option<SessionSnapshot>, Box<dyn std::error::Error + Send + Sync>> {
147 Ok(self.store.get(session_id).map(|v| v.value().clone()))
148 }
149
150 async fn delete(
151 &self,
152 session_id: &str,
153 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
154 self.store.remove(session_id);
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[tokio::test]
164 async fn memory_persistence_round_trip() {
165 let p = MemoryPersistence::new();
166 let snapshot = SessionSnapshot {
167 state: [("name".into(), Value::String("Alice".into()))]
168 .into_iter()
169 .collect(),
170 phase: "greeting".into(),
171 turn_count: 5,
172 transcript_summary: "User: Hello\nAssistant: Hi!".into(),
173 resume_handle: Some("handle-123".into()),
174 saved_at: "2026-03-07T00:00:00Z".into(),
175 };
176
177 p.save("session-1", &snapshot).await.unwrap();
178
179 let loaded = p.load("session-1").await.unwrap().unwrap();
180 assert_eq!(loaded.phase, "greeting");
181 assert_eq!(loaded.turn_count, 5);
182 assert_eq!(loaded.resume_handle, Some("handle-123".into()));
183 }
184
185 #[tokio::test]
186 async fn memory_persistence_load_missing() {
187 let p = MemoryPersistence::new();
188 assert!(p.load("nonexistent").await.unwrap().is_none());
189 }
190
191 #[tokio::test]
192 async fn memory_persistence_delete() {
193 let p = MemoryPersistence::new();
194 let snapshot = SessionSnapshot {
195 state: HashMap::new(),
196 phase: "test".into(),
197 turn_count: 0,
198 transcript_summary: String::new(),
199 resume_handle: None,
200 saved_at: "2026-03-07T00:00:00Z".into(),
201 };
202
203 p.save("session-1", &snapshot).await.unwrap();
204 p.delete("session-1").await.unwrap();
205 assert!(p.load("session-1").await.unwrap().is_none());
206 }
207
208 #[tokio::test]
209 async fn fs_persistence_round_trip() {
210 let dir = std::env::temp_dir().join("gemini_rs_test_persistence");
211 let p = FsPersistence::new(&dir);
212 let snapshot = SessionSnapshot {
213 state: [("key".into(), Value::from(42))].into_iter().collect(),
214 phase: "main".into(),
215 turn_count: 3,
216 transcript_summary: "test".into(),
217 resume_handle: None,
218 saved_at: "2026-03-07T00:00:00Z".into(),
219 };
220
221 p.save("test-session", &snapshot).await.unwrap();
222 let loaded = p.load("test-session").await.unwrap().unwrap();
223 assert_eq!(loaded.phase, "main");
224
225 p.delete("test-session").await.unwrap();
227 let _ = tokio::fs::remove_dir_all(&dir).await;
228 }
229}