gemini_adk_rs/tools/
load_memory.rs1use std::sync::Arc;
9
10use async_trait::async_trait;
11
12use crate::error::ToolError;
13use crate::memory::MemoryService;
14use crate::tool::ToolFunction;
15
16#[derive(Debug, Clone, Default)]
18struct MemoryScope {
19 session_id: String,
20}
21
22#[derive(Clone, Default)]
30pub struct LoadMemoryTool {
31 service: Option<Arc<dyn MemoryService>>,
32 scope: MemoryScope,
33}
34
35impl std::fmt::Debug for LoadMemoryTool {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("LoadMemoryTool")
38 .field("has_service", &self.service.is_some())
39 .field("session_id", &self.scope.session_id)
40 .finish()
41 }
42}
43
44impl LoadMemoryTool {
45 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn with_memory_service(mut self, service: Arc<dyn MemoryService>) -> Self {
55 self.service = Some(service);
56 self
57 }
58
59 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
61 self.scope.session_id = session_id.into();
62 self
63 }
64}
65
66#[async_trait]
67impl ToolFunction for LoadMemoryTool {
68 fn name(&self) -> &str {
69 "load_memory"
70 }
71
72 fn description(&self) -> &str {
73 "Search and load relevant information from the agent's memory. \
74 Call this function with a query to retrieve previously stored memories."
75 }
76
77 fn parameters(&self) -> Option<serde_json::Value> {
78 Some(serde_json::json!({
79 "type": "object",
80 "properties": {
81 "query": {
82 "type": "string",
83 "description": "The search query to find relevant memories."
84 }
85 },
86 "required": ["query"]
87 }))
88 }
89
90 async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
91 let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
92
93 match &self.service {
94 Some(service) => {
97 let memories = service
98 .search(&self.scope.session_id, query)
99 .await
100 .map_err(|e| ToolError::ExecutionFailed(e.to_string()))?;
101
102 Ok(serde_json::json!({
103 "memories": memories,
104 }))
105 }
106 None => Ok(serde_json::json!({
109 "status": "memory_search_requested",
110 "query": query,
111 "results": []
112 })),
113 }
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::memory::{InMemoryMemoryService, MemoryEntry};
121 use serde_json::json;
122
123 #[test]
124 fn tool_metadata() {
125 let tool = LoadMemoryTool::new();
126 assert_eq!(tool.name(), "load_memory");
127 assert!(tool.description().contains("memory"));
128 assert!(tool.parameters().is_some());
129 }
130
131 #[tokio::test]
132 async fn call_with_query_no_service() {
133 let tool = LoadMemoryTool::new();
134 let result = tool
135 .call(json!({"query": "user preferences"}))
136 .await
137 .unwrap();
138 assert_eq!(result["query"], "user preferences");
139 assert_eq!(result["status"], "memory_search_requested");
140 }
141
142 #[tokio::test]
143 async fn call_delegates_to_memory_service() {
144 let svc = Arc::new(InMemoryMemoryService::new());
145 svc.store(
146 "s1",
147 MemoryEntry::new("rust_topic", json!("Rust programming")),
148 )
149 .await
150 .unwrap();
151
152 let tool = LoadMemoryTool::new()
153 .with_memory_service(svc)
154 .with_session_id("s1");
155
156 let result = tool.call(json!({"query": "rust"})).await.unwrap();
157 let memories = result["memories"].as_array().expect("memories array");
158 assert_eq!(memories.len(), 1);
159 assert_eq!(memories[0]["key"], "rust_topic");
160 }
161}