gemini_adk_rs/tools/retrieval/
files.rs1use std::path::PathBuf;
7
8use async_trait::async_trait;
9
10use super::base::{BaseRetrievalTool, RetrievalResult};
11use crate::error::ToolError;
12
13#[derive(Debug, Clone)]
18pub struct FilesRetrievalTool {
19 files: Vec<PathBuf>,
21 chunk_size: usize,
23 chunk_overlap: usize,
25}
26
27impl FilesRetrievalTool {
28 pub fn new(files: Vec<PathBuf>) -> Self {
30 Self {
31 files,
32 chunk_size: 1000,
33 chunk_overlap: 200,
34 }
35 }
36
37 pub fn with_chunk_size(mut self, size: usize) -> Self {
39 self.chunk_size = size;
40 self
41 }
42
43 pub fn with_chunk_overlap(mut self, overlap: usize) -> Self {
45 self.chunk_overlap = overlap;
46 self
47 }
48
49 fn chunk_text(&self, text: &str) -> Vec<String> {
51 if text.len() <= self.chunk_size {
52 return vec![text.to_string()];
53 }
54
55 let mut chunks = Vec::new();
56 let mut start = 0;
57 while start < text.len() {
58 let end = (start + self.chunk_size).min(text.len());
59 chunks.push(text[start..end].to_string());
60 if end >= text.len() {
61 break;
62 }
63 start += self.chunk_size - self.chunk_overlap;
64 }
65 chunks
66 }
67}
68
69#[async_trait]
70impl BaseRetrievalTool for FilesRetrievalTool {
71 fn name(&self) -> &str {
72 "files_retrieval"
73 }
74
75 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<RetrievalResult>, ToolError> {
76 let query_lower = query.to_lowercase();
77 let mut all_results = Vec::new();
78
79 for path in &self.files {
80 let content = tokio::fs::read_to_string(path).await.map_err(|e| {
81 ToolError::ExecutionFailed(format!("Failed to read {}: {e}", path.display()))
82 })?;
83
84 let chunks = self.chunk_text(&content);
85 let source = path.display().to_string();
86
87 for chunk in &chunks {
88 let chunk_lower = chunk.to_lowercase();
89 let words: Vec<&str> = query_lower.split_whitespace().collect();
91 let matches = words.iter().filter(|w| chunk_lower.contains(*w)).count();
92
93 if matches > 0 {
94 let score = matches as f64 / words.len().max(1) as f64;
95 all_results.push(RetrievalResult {
96 content: chunk.clone(),
97 source: source.clone(),
98 score,
99 metadata: serde_json::Value::Null,
100 });
101 }
102 }
103 }
104
105 all_results.sort_by(|a, b| {
107 b.score
108 .partial_cmp(&a.score)
109 .unwrap_or(std::cmp::Ordering::Equal)
110 });
111 all_results.truncate(top_k);
112
113 Ok(all_results)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn chunk_short_text() {
123 let tool = FilesRetrievalTool::new(vec![]);
124 let chunks = tool.chunk_text("short text");
125 assert_eq!(chunks.len(), 1);
126 assert_eq!(chunks[0], "short text");
127 }
128
129 #[test]
130 fn chunk_long_text() {
131 let tool = FilesRetrievalTool::new(vec![])
132 .with_chunk_size(10)
133 .with_chunk_overlap(3);
134 let text = "abcdefghijklmnopqrstuvwxyz";
135 let chunks = tool.chunk_text(text);
136 assert!(chunks.len() > 1);
137 assert_eq!(chunks[0].len(), 10);
139 }
140
141 #[tokio::test]
142 async fn retrieve_from_nonexistent_file() {
143 let tool = FilesRetrievalTool::new(vec![PathBuf::from("/nonexistent/file.txt")]);
144 let result = tool.retrieve("test", 5).await;
145 assert!(result.is_err());
146 }
147}