gemini_adk_rs/tools/retrieval/
files.rs

1//! Files retrieval tool — retrieve context from local files.
2//!
3//! Mirrors ADK-Python's `files_retrieval` tool. Provides simple
4//! substring-based retrieval from a collection of text files.
5
6use std::path::PathBuf;
7
8use async_trait::async_trait;
9
10use super::base::{BaseRetrievalTool, RetrievalResult};
11use crate::error::ToolError;
12
13/// Retrieval tool that searches through local text files.
14///
15/// Performs simple substring matching over file contents and returns
16/// relevant file chunks as retrieval results.
17#[derive(Debug, Clone)]
18pub struct FilesRetrievalTool {
19    /// Paths to the files to search.
20    files: Vec<PathBuf>,
21    /// Chunk size in characters for splitting files.
22    chunk_size: usize,
23    /// Overlap between chunks in characters.
24    chunk_overlap: usize,
25}
26
27impl FilesRetrievalTool {
28    /// Create a new files retrieval tool.
29    pub fn new(files: Vec<PathBuf>) -> Self {
30        Self {
31            files,
32            chunk_size: 1000,
33            chunk_overlap: 200,
34        }
35    }
36
37    /// Set the chunk size for splitting files.
38    pub fn with_chunk_size(mut self, size: usize) -> Self {
39        self.chunk_size = size;
40        self
41    }
42
43    /// Set the overlap between chunks.
44    pub fn with_chunk_overlap(mut self, overlap: usize) -> Self {
45        self.chunk_overlap = overlap;
46        self
47    }
48
49    /// Split text into overlapping chunks.
50    fn chunk_text(&self, text: &str) -> Vec<String> {
51        if text.len() <= self.chunk_size {
52            return vec![text.to_string()];
53        }
54
55        let mut chunks = Vec::new();
56        let mut start = 0;
57        while start < text.len() {
58            let end = (start + self.chunk_size).min(text.len());
59            chunks.push(text[start..end].to_string());
60            if end >= text.len() {
61                break;
62            }
63            start += self.chunk_size - self.chunk_overlap;
64        }
65        chunks
66    }
67}
68
69#[async_trait]
70impl BaseRetrievalTool for FilesRetrievalTool {
71    fn name(&self) -> &str {
72        "files_retrieval"
73    }
74
75    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<RetrievalResult>, ToolError> {
76        let query_lower = query.to_lowercase();
77        let mut all_results = Vec::new();
78
79        for path in &self.files {
80            let content = tokio::fs::read_to_string(path).await.map_err(|e| {
81                ToolError::ExecutionFailed(format!("Failed to read {}: {e}", path.display()))
82            })?;
83
84            let chunks = self.chunk_text(&content);
85            let source = path.display().to_string();
86
87            for chunk in &chunks {
88                let chunk_lower = chunk.to_lowercase();
89                // Simple relevance scoring: count query term occurrences
90                let words: Vec<&str> = query_lower.split_whitespace().collect();
91                let matches = words.iter().filter(|w| chunk_lower.contains(*w)).count();
92
93                if matches > 0 {
94                    let score = matches as f64 / words.len().max(1) as f64;
95                    all_results.push(RetrievalResult {
96                        content: chunk.clone(),
97                        source: source.clone(),
98                        score,
99                        metadata: serde_json::Value::Null,
100                    });
101                }
102            }
103        }
104
105        // Sort by score descending and take top_k
106        all_results.sort_by(|a, b| {
107            b.score
108                .partial_cmp(&a.score)
109                .unwrap_or(std::cmp::Ordering::Equal)
110        });
111        all_results.truncate(top_k);
112
113        Ok(all_results)
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn chunk_short_text() {
123        let tool = FilesRetrievalTool::new(vec![]);
124        let chunks = tool.chunk_text("short text");
125        assert_eq!(chunks.len(), 1);
126        assert_eq!(chunks[0], "short text");
127    }
128
129    #[test]
130    fn chunk_long_text() {
131        let tool = FilesRetrievalTool::new(vec![])
132            .with_chunk_size(10)
133            .with_chunk_overlap(3);
134        let text = "abcdefghijklmnopqrstuvwxyz";
135        let chunks = tool.chunk_text(text);
136        assert!(chunks.len() > 1);
137        // First chunk should be 10 chars
138        assert_eq!(chunks[0].len(), 10);
139    }
140
141    #[tokio::test]
142    async fn retrieve_from_nonexistent_file() {
143        let tool = FilesRetrievalTool::new(vec![PathBuf::from("/nonexistent/file.txt")]);
144        let result = tool.retrieve("test", 5).await;
145        assert!(result.is_err());
146    }
147}