gemini_adk_rs/tools/mcp/
mod.rs

1//! MCP (Model Context Protocol) toolset — connect to MCP servers and use their tools.
2
3pub mod session_manager;
4pub mod tool;
5pub mod toolset;
6
7pub use session_manager::{McpConnectionParams, McpError, McpSessionManager, McpToolInfo};
8pub use tool::McpTool;
9pub use toolset::McpToolset;
10
11#[cfg(test)]
12mod tests {
13    use super::*;
14    use crate::error::ToolError;
15    use crate::tool::ToolFunction;
16    use crate::toolset::Toolset;
17    use serde_json::json;
18    use std::collections::HashMap;
19    use std::sync::Arc;
20    use std::time::Duration;
21
22    // --- McpConnectionParams tests ---
23
24    #[test]
25    fn connection_params_stdio() {
26        let params = McpConnectionParams::Stdio {
27            command: "node".to_string(),
28            args: vec!["server.js".to_string()],
29            timeout: Some(Duration::from_secs(10)),
30        };
31        match &params {
32            McpConnectionParams::Stdio {
33                command,
34                args,
35                timeout,
36            } => {
37                assert_eq!(command, "node");
38                assert_eq!(args, &["server.js"]);
39                assert_eq!(*timeout, Some(Duration::from_secs(10)));
40            }
41            _ => panic!("expected Stdio variant"),
42        }
43    }
44
45    #[test]
46    fn connection_params_sse() {
47        let mut headers = HashMap::new();
48        headers.insert("Authorization".to_string(), "Bearer token".to_string());
49        let params = McpConnectionParams::Sse {
50            url: "http://localhost:8080/sse".to_string(),
51            headers: Some(headers.clone()),
52        };
53        match &params {
54            McpConnectionParams::Sse { url, headers: h } => {
55                assert_eq!(url, "http://localhost:8080/sse");
56                let h = h.as_ref().unwrap();
57                assert_eq!(h.get("Authorization").unwrap(), "Bearer token");
58            }
59            _ => panic!("expected Sse variant"),
60        }
61    }
62
63    #[test]
64    fn connection_params_stdio_no_timeout() {
65        let params = McpConnectionParams::Stdio {
66            command: "python".to_string(),
67            args: vec![],
68            timeout: None,
69        };
70        match &params {
71            McpConnectionParams::Stdio { timeout, .. } => {
72                assert!(timeout.is_none());
73            }
74            _ => panic!("expected Stdio variant"),
75        }
76    }
77
78    #[test]
79    fn connection_params_sse_no_headers() {
80        let params = McpConnectionParams::Sse {
81            url: "http://localhost:3000".to_string(),
82            headers: None,
83        };
84        match &params {
85            McpConnectionParams::Sse { headers, .. } => {
86                assert!(headers.is_none());
87            }
88            _ => panic!("expected Sse variant"),
89        }
90    }
91
92    // --- McpSessionManager tests ---
93
94    #[tokio::test]
95    async fn session_manager_list_tools_returns_empty() {
96        let manager = McpSessionManager::new(McpConnectionParams::Stdio {
97            command: "echo".to_string(),
98            args: vec![],
99            timeout: None,
100        });
101        let tools = manager.list_tools().await.unwrap();
102        assert!(tools.is_empty());
103    }
104
105    #[tokio::test]
106    async fn session_manager_call_tool_returns_not_connected() {
107        let manager = McpSessionManager::new(McpConnectionParams::Stdio {
108            command: "echo".to_string(),
109            args: vec![],
110            timeout: None,
111        });
112        let result = manager.call_tool("some_tool", json!({})).await;
113        assert!(result.is_err());
114        let err = result.unwrap_err();
115        match &err {
116            McpError::NotConnected(msg) => {
117                assert!(msg.contains("some_tool"));
118            }
119            other => panic!("expected McpError::NotConnected, got: {other}"),
120        }
121    }
122
123    #[test]
124    fn session_manager_params_accessor() {
125        let params = McpConnectionParams::Sse {
126            url: "http://example.com".to_string(),
127            headers: None,
128        };
129        let manager = McpSessionManager::new(params);
130        match manager.params() {
131            McpConnectionParams::Sse { url, .. } => {
132                assert_eq!(url, "http://example.com");
133            }
134            _ => panic!("expected Sse variant"),
135        }
136    }
137
138    // --- McpTool tests ---
139
140    #[test]
141    fn mcp_tool_name_description_parameters() {
142        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
143            command: "echo".to_string(),
144            args: vec![],
145            timeout: None,
146        }));
147        let schema = json!({"type": "object", "properties": {"query": {"type": "string"}}});
148        let tool = McpTool::new("search", "Search for things", Some(schema.clone()), manager);
149
150        assert_eq!(tool.name(), "search");
151        assert_eq!(tool.description(), "Search for things");
152        assert_eq!(tool.parameters(), Some(schema));
153    }
154
155    #[test]
156    fn mcp_tool_no_schema() {
157        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
158            command: "echo".to_string(),
159            args: vec![],
160            timeout: None,
161        }));
162        let tool = McpTool::new("ping", "Ping the server", None, manager);
163
164        assert_eq!(tool.name(), "ping");
165        assert!(tool.parameters().is_none());
166    }
167
168    #[tokio::test]
169    async fn mcp_tool_call_delegates_to_session_manager() {
170        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
171            command: "echo".to_string(),
172            args: vec![],
173            timeout: None,
174        }));
175        let tool = McpTool::new("my_tool", "desc", None, manager);
176
177        let result = tool.call(json!({"key": "value"})).await;
178        assert!(result.is_err());
179        match result.unwrap_err() {
180            ToolError::ExecutionFailed(msg) => {
181                assert!(msg.contains("my_tool"));
182                assert!(msg.contains("not connected") || msg.contains("Not connected"));
183            }
184            other => panic!("expected ToolError::ExecutionFailed, got: {other:?}"),
185        }
186    }
187
188    // --- McpToolset tests ---
189
190    #[test]
191    fn mcp_toolset_get_tools_returns_empty() {
192        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
193            command: "echo".to_string(),
194            args: vec![],
195            timeout: None,
196        }));
197        let toolset = McpToolset::new(manager);
198        assert!(toolset.get_tools().is_empty());
199    }
200
201    #[test]
202    fn mcp_toolset_with_filter_stores_filter() {
203        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
204            command: "echo".to_string(),
205            args: vec![],
206            timeout: None,
207        }));
208        let toolset =
209            McpToolset::new(manager).with_filter(vec!["tool_a".to_string(), "tool_b".to_string()]);
210
211        let filter = toolset.filter().unwrap();
212        assert_eq!(filter.len(), 2);
213        assert_eq!(filter[0], "tool_a");
214        assert_eq!(filter[1], "tool_b");
215    }
216
217    #[test]
218    fn mcp_toolset_no_filter_by_default() {
219        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
220            command: "echo".to_string(),
221            args: vec![],
222            timeout: None,
223        }));
224        let toolset = McpToolset::new(manager);
225        assert!(toolset.filter().is_none());
226    }
227
228    #[tokio::test]
229    async fn mcp_toolset_close_is_noop() {
230        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
231            command: "echo".to_string(),
232            args: vec![],
233            timeout: None,
234        }));
235        let toolset = McpToolset::new(manager);
236        toolset.close().await; // Should not panic
237    }
238
239    #[test]
240    fn mcp_toolset_session_manager_accessor() {
241        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Sse {
242            url: "http://localhost:9090".to_string(),
243            headers: None,
244        }));
245        let toolset = McpToolset::new(manager.clone());
246        // Verify the session manager is accessible
247        match toolset.session_manager().params() {
248            McpConnectionParams::Sse { url, .. } => {
249                assert_eq!(url, "http://localhost:9090");
250            }
251            _ => panic!("expected Sse variant"),
252        }
253    }
254
255    // --- McpError display tests ---
256
257    #[test]
258    fn mcp_error_display() {
259        let err = McpError::ConnectionFailed("timeout".to_string());
260        assert_eq!(err.to_string(), "Connection failed: timeout");
261
262        let err = McpError::NotConnected("no session".to_string());
263        assert_eq!(err.to_string(), "Not connected: no session");
264
265        let err = McpError::ToolCallFailed("bad args".to_string());
266        assert_eq!(err.to_string(), "Tool call failed: bad args");
267
268        let err = McpError::Other("something".to_string());
269        assert_eq!(err.to_string(), "something");
270    }
271}