gemini_adk_rs/memory/
mod.rs1mod in_memory;
7mod vertex_ai_memory_bank;
8#[cfg(feature = "vertex-ai-rag")]
9mod vertex_ai_rag;
10
11pub use in_memory::InMemoryMemoryService;
12pub use vertex_ai_memory_bank::{VertexAiMemoryBankConfig, VertexAiMemoryBankService};
13#[cfg(feature = "vertex-ai-rag")]
14pub use vertex_ai_rag::{VertexAiRagMemoryConfig, VertexAiRagMemoryService};
15
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MemoryEntry {
22 pub key: String,
24 pub value: serde_json::Value,
26 pub created_at: u64,
28 pub updated_at: u64,
30}
31
32impl MemoryEntry {
33 pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
35 let now = now_secs();
36 Self {
37 key: key.into(),
38 value,
39 created_at: now,
40 updated_at: now,
41 }
42 }
43}
44
45#[derive(Debug, thiserror::Error)]
47pub enum MemoryError {
48 #[error("Memory key not found: {0}")]
50 NotFound(String),
51 #[error("Storage error: {0}")]
53 Storage(String),
54}
55
56#[async_trait]
60pub trait MemoryService: Send + Sync {
61 async fn store(&self, session_id: &str, entry: MemoryEntry) -> Result<(), MemoryError>;
63
64 async fn get(&self, session_id: &str, key: &str) -> Result<Option<MemoryEntry>, MemoryError>;
66
67 async fn list(&self, session_id: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
69
70 async fn search(&self, session_id: &str, query: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
72
73 async fn delete(&self, session_id: &str, key: &str) -> Result<(), MemoryError>;
75
76 async fn clear(&self, session_id: &str) -> Result<(), MemoryError>;
78}
79
80fn now_secs() -> u64 {
81 std::time::SystemTime::now()
82 .duration_since(std::time::UNIX_EPOCH)
83 .unwrap_or_default()
84 .as_secs()
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn memory_entry_new() {
93 let entry = MemoryEntry::new("topic", serde_json::json!("Rust"));
94 assert_eq!(entry.key, "topic");
95 assert_eq!(entry.value, serde_json::json!("Rust"));
96 assert!(entry.created_at > 0);
97 }
98
99 #[test]
100 fn memory_service_is_object_safe() {
101 fn _assert(_: &dyn MemoryService) {}
102 }
103
104 #[tokio::test]
105 async fn store_and_get() {
106 let svc = InMemoryMemoryService::new();
107 let entry = MemoryEntry::new("topic", serde_json::json!("AI"));
108 svc.store("s1", entry).await.unwrap();
109
110 let fetched = svc.get("s1", "topic").await.unwrap();
111 assert!(fetched.is_some());
112 assert_eq!(fetched.unwrap().value, serde_json::json!("AI"));
113 }
114
115 #[tokio::test]
116 async fn get_nonexistent_returns_none() {
117 let svc = InMemoryMemoryService::new();
118 let fetched = svc.get("s1", "missing").await.unwrap();
119 assert!(fetched.is_none());
120 }
121
122 #[tokio::test]
123 async fn list_entries() {
124 let svc = InMemoryMemoryService::new();
125 svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
126 .await
127 .unwrap();
128 svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
129 .await
130 .unwrap();
131 svc.store("s2", MemoryEntry::new("c", serde_json::json!(3)))
132 .await
133 .unwrap();
134
135 let entries = svc.list("s1").await.unwrap();
136 assert_eq!(entries.len(), 2);
137 }
138
139 #[tokio::test]
140 async fn search_entries() {
141 let svc = InMemoryMemoryService::new();
142 svc.store(
143 "s1",
144 MemoryEntry::new("rust_topic", serde_json::json!("Rust programming")),
145 )
146 .await
147 .unwrap();
148 svc.store(
149 "s1",
150 MemoryEntry::new("python_topic", serde_json::json!("Python scripting")),
151 )
152 .await
153 .unwrap();
154
155 let results = svc.search("s1", "rust").await.unwrap();
156 assert_eq!(results.len(), 1);
157 assert_eq!(results[0].key, "rust_topic");
158 }
159
160 #[tokio::test]
161 async fn delete_entry() {
162 let svc = InMemoryMemoryService::new();
163 svc.store("s1", MemoryEntry::new("k", serde_json::json!(1)))
164 .await
165 .unwrap();
166 svc.delete("s1", "k").await.unwrap();
167 let fetched = svc.get("s1", "k").await.unwrap();
168 assert!(fetched.is_none());
169 }
170
171 #[tokio::test]
172 async fn clear_session() {
173 let svc = InMemoryMemoryService::new();
174 svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
175 .await
176 .unwrap();
177 svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
178 .await
179 .unwrap();
180 svc.clear("s1").await.unwrap();
181 let entries = svc.list("s1").await.unwrap();
182 assert!(entries.is_empty());
183 }
184}