gemini_adk_rs/memory/
mod.rs

1//! Memory service — session-scoped memory for agents.
2//!
3//! Mirrors ADK-JS's `BaseMemoryService`. Provides a trait for storing and
4//! searching memory entries (key-value) with an in-memory default.
5
6mod in_memory;
7mod vertex_ai_memory_bank;
8#[cfg(feature = "vertex-ai-rag")]
9mod vertex_ai_rag;
10
11pub use in_memory::InMemoryMemoryService;
12pub use vertex_ai_memory_bank::{VertexAiMemoryBankConfig, VertexAiMemoryBankService};
13#[cfg(feature = "vertex-ai-rag")]
14pub use vertex_ai_rag::{VertexAiRagMemoryConfig, VertexAiRagMemoryService};
15
16use async_trait::async_trait;
17use serde::{Deserialize, Serialize};
18
19/// A memory entry — a named piece of information stored by an agent.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct MemoryEntry {
22    /// Unique key for this memory.
23    pub key: String,
24    /// The stored value.
25    pub value: serde_json::Value,
26    /// When this entry was created (Unix timestamp seconds).
27    pub created_at: u64,
28    /// When this entry was last updated (Unix timestamp seconds).
29    pub updated_at: u64,
30}
31
32impl MemoryEntry {
33    /// Create a new memory entry.
34    pub fn new(key: impl Into<String>, value: serde_json::Value) -> Self {
35        let now = now_secs();
36        Self {
37            key: key.into(),
38            value,
39            created_at: now,
40            updated_at: now,
41        }
42    }
43}
44
45/// Errors from memory service operations.
46#[derive(Debug, thiserror::Error)]
47pub enum MemoryError {
48    /// The requested memory key was not found.
49    #[error("Memory key not found: {0}")]
50    NotFound(String),
51    /// A storage backend error.
52    #[error("Storage error: {0}")]
53    Storage(String),
54}
55
56/// Trait for session-scoped memory persistence.
57///
58/// Memory is scoped to a session ID. Implementations must be `Send + Sync`.
59#[async_trait]
60pub trait MemoryService: Send + Sync {
61    /// Store a memory entry for a session.
62    async fn store(&self, session_id: &str, entry: MemoryEntry) -> Result<(), MemoryError>;
63
64    /// Retrieve a memory entry by key.
65    async fn get(&self, session_id: &str, key: &str) -> Result<Option<MemoryEntry>, MemoryError>;
66
67    /// List all memory entries for a session.
68    async fn list(&self, session_id: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
69
70    /// Search memory entries by a query string (simple substring match in default impl).
71    async fn search(&self, session_id: &str, query: &str) -> Result<Vec<MemoryEntry>, MemoryError>;
72
73    /// Delete a memory entry.
74    async fn delete(&self, session_id: &str, key: &str) -> Result<(), MemoryError>;
75
76    /// Clear all memory for a session.
77    async fn clear(&self, session_id: &str) -> Result<(), MemoryError>;
78}
79
80fn now_secs() -> u64 {
81    std::time::SystemTime::now()
82        .duration_since(std::time::UNIX_EPOCH)
83        .unwrap_or_default()
84        .as_secs()
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn memory_entry_new() {
93        let entry = MemoryEntry::new("topic", serde_json::json!("Rust"));
94        assert_eq!(entry.key, "topic");
95        assert_eq!(entry.value, serde_json::json!("Rust"));
96        assert!(entry.created_at > 0);
97    }
98
99    #[test]
100    fn memory_service_is_object_safe() {
101        fn _assert(_: &dyn MemoryService) {}
102    }
103
104    #[tokio::test]
105    async fn store_and_get() {
106        let svc = InMemoryMemoryService::new();
107        let entry = MemoryEntry::new("topic", serde_json::json!("AI"));
108        svc.store("s1", entry).await.unwrap();
109
110        let fetched = svc.get("s1", "topic").await.unwrap();
111        assert!(fetched.is_some());
112        assert_eq!(fetched.unwrap().value, serde_json::json!("AI"));
113    }
114
115    #[tokio::test]
116    async fn get_nonexistent_returns_none() {
117        let svc = InMemoryMemoryService::new();
118        let fetched = svc.get("s1", "missing").await.unwrap();
119        assert!(fetched.is_none());
120    }
121
122    #[tokio::test]
123    async fn list_entries() {
124        let svc = InMemoryMemoryService::new();
125        svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
126            .await
127            .unwrap();
128        svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
129            .await
130            .unwrap();
131        svc.store("s2", MemoryEntry::new("c", serde_json::json!(3)))
132            .await
133            .unwrap();
134
135        let entries = svc.list("s1").await.unwrap();
136        assert_eq!(entries.len(), 2);
137    }
138
139    #[tokio::test]
140    async fn search_entries() {
141        let svc = InMemoryMemoryService::new();
142        svc.store(
143            "s1",
144            MemoryEntry::new("rust_topic", serde_json::json!("Rust programming")),
145        )
146        .await
147        .unwrap();
148        svc.store(
149            "s1",
150            MemoryEntry::new("python_topic", serde_json::json!("Python scripting")),
151        )
152        .await
153        .unwrap();
154
155        let results = svc.search("s1", "rust").await.unwrap();
156        assert_eq!(results.len(), 1);
157        assert_eq!(results[0].key, "rust_topic");
158    }
159
160    #[tokio::test]
161    async fn delete_entry() {
162        let svc = InMemoryMemoryService::new();
163        svc.store("s1", MemoryEntry::new("k", serde_json::json!(1)))
164            .await
165            .unwrap();
166        svc.delete("s1", "k").await.unwrap();
167        let fetched = svc.get("s1", "k").await.unwrap();
168        assert!(fetched.is_none());
169    }
170
171    #[tokio::test]
172    async fn clear_session() {
173        let svc = InMemoryMemoryService::new();
174        svc.store("s1", MemoryEntry::new("a", serde_json::json!(1)))
175            .await
176            .unwrap();
177        svc.store("s1", MemoryEntry::new("b", serde_json::json!(2)))
178            .await
179            .unwrap();
180        svc.clear("s1").await.unwrap();
181        let entries = svc.list("s1").await.unwrap();
182        assert!(entries.is_empty());
183    }
184}