gemini_adk_rs/tools/
load_memory.rs

1//! Load memory tool — allows agents to search their memory store.
2//!
3//! Mirrors ADK-Python's `load_memory_tool`. Provides the model with a tool to
4//! search session memory using a query string. This tool is *local*: it simply
5//! delegates to whatever [`MemoryService`] is wired into the session, mirroring
6//! ADK's `load_memory` which calls `tool_context.search_memory(query)`.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::error::ToolError;
13use crate::memory::MemoryService;
14use crate::tool::ToolFunction;
15
16/// Scope used when delegating to the [`MemoryService`] for a search.
17#[derive(Debug, Clone, Default)]
18struct MemoryScope {
19    session_id: String,
20}
21
22/// Tool that searches the agent's memory store.
23///
24/// When the model needs to recall previously stored information, it can call
25/// this tool with a search query. If a [`MemoryService`] is wired via
26/// [`with_memory_service`](LoadMemoryTool::with_memory_service), the call is
27/// delegated to it; otherwise a placeholder response is returned (matching the
28/// "runtime intercepts the call" model).
29#[derive(Clone, Default)]
30pub struct LoadMemoryTool {
31    service: Option<Arc<dyn MemoryService>>,
32    scope: MemoryScope,
33}
34
35impl std::fmt::Debug for LoadMemoryTool {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("LoadMemoryTool")
38            .field("has_service", &self.service.is_some())
39            .field("session_id", &self.scope.session_id)
40            .finish()
41    }
42}
43
44impl LoadMemoryTool {
45    /// Create a new load memory tool with no memory service wired.
46    ///
47    /// Without a service, [`call`](ToolFunction::call) returns a placeholder
48    /// indicating the query was received (the runtime is expected to intercept).
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Wire a [`MemoryService`] that this tool delegates searches to.
54    pub fn with_memory_service(mut self, service: Arc<dyn MemoryService>) -> Self {
55        self.service = Some(service);
56        self
57    }
58
59    /// Set the session ID used to scope memory searches.
60    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
61        self.scope.session_id = session_id.into();
62        self
63    }
64}
65
66#[async_trait]
67impl ToolFunction for LoadMemoryTool {
68    fn name(&self) -> &str {
69        "load_memory"
70    }
71
72    fn description(&self) -> &str {
73        "Search and load relevant information from the agent's memory. \
74         Call this function with a query to retrieve previously stored memories."
75    }
76
77    fn parameters(&self) -> Option<serde_json::Value> {
78        Some(serde_json::json!({
79            "type": "object",
80            "properties": {
81                "query": {
82                    "type": "string",
83                    "description": "The search query to find relevant memories."
84                }
85            },
86            "required": ["query"]
87        }))
88    }
89
90    async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
91        let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
92
93        match &self.service {
94            // Delegate to the wired MemoryService, mirroring ADK's
95            // `tool_context.search_memory(query)`.
96            Some(service) => {
97                let memories = service
98                    .search(&self.scope.session_id, query)
99                    .await
100                    .map_err(|e| ToolError::ExecutionFailed(e.to_string()))?;
101
102                Ok(serde_json::json!({
103                    "memories": memories,
104                }))
105            }
106            // No service wired — the runtime is expected to intercept this call
107            // and route it to the MemoryService.
108            None => Ok(serde_json::json!({
109                "status": "memory_search_requested",
110                "query": query,
111                "results": []
112            })),
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::memory::{InMemoryMemoryService, MemoryEntry};
121    use serde_json::json;
122
123    #[test]
124    fn tool_metadata() {
125        let tool = LoadMemoryTool::new();
126        assert_eq!(tool.name(), "load_memory");
127        assert!(tool.description().contains("memory"));
128        assert!(tool.parameters().is_some());
129    }
130
131    #[tokio::test]
132    async fn call_with_query_no_service() {
133        let tool = LoadMemoryTool::new();
134        let result = tool
135            .call(json!({"query": "user preferences"}))
136            .await
137            .unwrap();
138        assert_eq!(result["query"], "user preferences");
139        assert_eq!(result["status"], "memory_search_requested");
140    }
141
142    #[tokio::test]
143    async fn call_delegates_to_memory_service() {
144        let svc = Arc::new(InMemoryMemoryService::new());
145        svc.store(
146            "s1",
147            MemoryEntry::new("rust_topic", json!("Rust programming")),
148        )
149        .await
150        .unwrap();
151
152        let tool = LoadMemoryTool::new()
153            .with_memory_service(svc)
154            .with_session_id("s1");
155
156        let result = tool.call(json!({"query": "rust"})).await.unwrap();
157        let memories = result["memories"].as_array().expect("memories array");
158        assert_eq!(memories.len(), 1);
159        assert_eq!(memories[0]["key"], "rust_topic");
160    }
161}