gemini_adk_rs/tools/mcp/
mod.rs

1//! MCP (Model Context Protocol) toolset — connect to MCP servers and use their tools.
2
3pub mod session_manager;
4pub mod tool;
5pub mod toolset;
6
7pub use session_manager::{McpConnectionParams, McpError, McpSessionManager, McpToolInfo};
8pub use tool::McpTool;
9pub use toolset::McpToolset;
10
11#[cfg(test)]
12mod tests {
13    use super::*;
14    use crate::error::ToolError;
15    use crate::tool::ToolFunction;
16    use crate::toolset::Toolset;
17    use serde_json::json;
18    use std::collections::HashMap;
19    use std::sync::Arc;
20    use std::time::Duration;
21
22    // --- McpConnectionParams tests ---
23
24    #[test]
25    fn connection_params_stdio() {
26        let params = McpConnectionParams::Stdio {
27            command: "node".to_string(),
28            args: vec!["server.js".to_string()],
29            timeout: Some(Duration::from_secs(10)),
30        };
31        match &params {
32            McpConnectionParams::Stdio {
33                command,
34                args,
35                timeout,
36            } => {
37                assert_eq!(command, "node");
38                assert_eq!(args, &["server.js"]);
39                assert_eq!(*timeout, Some(Duration::from_secs(10)));
40            }
41            _ => panic!("expected Stdio variant"),
42        }
43    }
44
45    #[test]
46    fn connection_params_sse() {
47        let mut headers = HashMap::new();
48        headers.insert("Authorization".to_string(), "Bearer token".to_string());
49        let params = McpConnectionParams::Sse {
50            url: "http://localhost:8080/sse".to_string(),
51            headers: Some(headers.clone()),
52        };
53        match &params {
54            McpConnectionParams::Sse { url, headers: h } => {
55                assert_eq!(url, "http://localhost:8080/sse");
56                let h = h.as_ref().unwrap();
57                assert_eq!(h.get("Authorization").unwrap(), "Bearer token");
58            }
59            _ => panic!("expected Sse variant"),
60        }
61    }
62
63    #[test]
64    fn connection_params_stdio_no_timeout() {
65        let params = McpConnectionParams::Stdio {
66            command: "python".to_string(),
67            args: vec![],
68            timeout: None,
69        };
70        match &params {
71            McpConnectionParams::Stdio { timeout, .. } => {
72                assert!(timeout.is_none());
73            }
74            _ => panic!("expected Stdio variant"),
75        }
76    }
77
78    #[test]
79    fn connection_params_sse_no_headers() {
80        let params = McpConnectionParams::Sse {
81            url: "http://localhost:3000".to_string(),
82            headers: None,
83        };
84        match &params {
85            McpConnectionParams::Sse { headers, .. } => {
86                assert!(headers.is_none());
87            }
88            _ => panic!("expected Sse variant"),
89        }
90    }
91
92    // --- McpSessionManager tests ---
93
94    #[tokio::test]
95    async fn session_manager_list_tools_unconnectable_server_errors() {
96        // A bogus command cannot be spawned, so the lazy connect fails. The real
97        // client surfaces this as a ConnectionFailed error rather than returning
98        // an empty tool list.
99        let manager = McpSessionManager::new(McpConnectionParams::Stdio {
100            command: "definitely_not_a_real_mcp_server_binary_xyz".to_string(),
101            args: vec![],
102            timeout: Some(Duration::from_secs(2)),
103        });
104        let result = manager.list_tools().await;
105        assert!(result.is_err());
106        match result.unwrap_err() {
107            McpError::ConnectionFailed(msg) => {
108                assert!(msg.contains("definitely_not_a_real_mcp_server_binary_xyz"));
109            }
110            other => panic!("expected McpError::ConnectionFailed, got: {other}"),
111        }
112    }
113
114    #[tokio::test]
115    async fn session_manager_call_tool_unconnectable_server_errors() {
116        // `echo` ignores stdin and exits, closing its stdout before answering the
117        // initialize handshake. The real client reports a connection failure.
118        let manager = McpSessionManager::new(McpConnectionParams::Stdio {
119            command: "echo".to_string(),
120            args: vec![],
121            timeout: Some(Duration::from_secs(2)),
122        });
123        let result = manager.call_tool("some_tool", json!({})).await;
124        assert!(result.is_err());
125        match result.unwrap_err() {
126            McpError::ConnectionFailed(_) => {}
127            other => panic!("expected McpError::ConnectionFailed, got: {other}"),
128        }
129    }
130
131    #[test]
132    fn session_manager_params_accessor() {
133        let params = McpConnectionParams::Sse {
134            url: "http://example.com".to_string(),
135            headers: None,
136        };
137        let manager = McpSessionManager::new(params);
138        match manager.params() {
139            McpConnectionParams::Sse { url, .. } => {
140                assert_eq!(url, "http://example.com");
141            }
142            _ => panic!("expected Sse variant"),
143        }
144    }
145
146    // --- McpTool tests ---
147
148    #[test]
149    fn mcp_tool_name_description_parameters() {
150        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
151            command: "echo".to_string(),
152            args: vec![],
153            timeout: None,
154        }));
155        let schema = json!({"type": "object", "properties": {"query": {"type": "string"}}});
156        let tool = McpTool::new("search", "Search for things", Some(schema.clone()), manager);
157
158        assert_eq!(tool.name(), "search");
159        assert_eq!(tool.description(), "Search for things");
160        assert_eq!(tool.parameters(), Some(schema));
161    }
162
163    #[test]
164    fn mcp_tool_no_schema() {
165        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
166            command: "echo".to_string(),
167            args: vec![],
168            timeout: None,
169        }));
170        let tool = McpTool::new("ping", "Ping the server", None, manager);
171
172        assert_eq!(tool.name(), "ping");
173        assert!(tool.parameters().is_none());
174    }
175
176    #[tokio::test]
177    async fn mcp_tool_call_delegates_to_session_manager() {
178        // `echo` is not a real MCP server, so the handshake fails. The McpTool
179        // wraps the session manager's McpError into a ToolError::ExecutionFailed.
180        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
181            command: "echo".to_string(),
182            args: vec![],
183            timeout: Some(Duration::from_secs(2)),
184        }));
185        let tool = McpTool::new("my_tool", "desc", None, manager);
186
187        let result = tool.call(json!({"key": "value"})).await;
188        assert!(result.is_err());
189        match result.unwrap_err() {
190            ToolError::ExecutionFailed(msg) => {
191                assert!(msg.contains("Connection failed") || msg.contains("connection"));
192            }
193            other => panic!("expected ToolError::ExecutionFailed, got: {other:?}"),
194        }
195    }
196
197    // --- McpToolset tests ---
198
199    #[test]
200    fn mcp_toolset_get_tools_returns_empty() {
201        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
202            command: "echo".to_string(),
203            args: vec![],
204            timeout: None,
205        }));
206        let toolset = McpToolset::new(manager);
207        assert!(toolset.get_tools().is_empty());
208    }
209
210    #[test]
211    fn mcp_toolset_with_filter_stores_filter() {
212        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
213            command: "echo".to_string(),
214            args: vec![],
215            timeout: None,
216        }));
217        let toolset =
218            McpToolset::new(manager).with_filter(vec!["tool_a".to_string(), "tool_b".to_string()]);
219
220        let filter = toolset.filter().unwrap();
221        assert_eq!(filter.len(), 2);
222        assert_eq!(filter[0], "tool_a");
223        assert_eq!(filter[1], "tool_b");
224    }
225
226    #[test]
227    fn mcp_toolset_no_filter_by_default() {
228        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
229            command: "echo".to_string(),
230            args: vec![],
231            timeout: None,
232        }));
233        let toolset = McpToolset::new(manager);
234        assert!(toolset.filter().is_none());
235    }
236
237    #[tokio::test]
238    async fn mcp_toolset_close_is_noop() {
239        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
240            command: "echo".to_string(),
241            args: vec![],
242            timeout: None,
243        }));
244        let toolset = McpToolset::new(manager);
245        toolset.close().await; // Should not panic
246    }
247
248    #[test]
249    fn mcp_toolset_session_manager_accessor() {
250        let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Sse {
251            url: "http://localhost:9090".to_string(),
252            headers: None,
253        }));
254        let toolset = McpToolset::new(manager.clone());
255        // Verify the session manager is accessible
256        match toolset.session_manager().params() {
257            McpConnectionParams::Sse { url, .. } => {
258                assert_eq!(url, "http://localhost:9090");
259            }
260            _ => panic!("expected Sse variant"),
261        }
262    }
263
264    // --- McpError display tests ---
265
266    #[test]
267    fn mcp_error_display() {
268        let err = McpError::ConnectionFailed("timeout".to_string());
269        assert_eq!(err.to_string(), "Connection failed: timeout");
270
271        let err = McpError::NotConnected("no session".to_string());
272        assert_eq!(err.to_string(), "Not connected: no session");
273
274        let err = McpError::ToolCallFailed("bad args".to_string());
275        assert_eq!(err.to_string(), "Tool call failed: bad args");
276
277        let err = McpError::Other("something".to_string());
278        assert_eq!(err.to_string(), "something");
279    }
280
281    // --- Real stdio MCP client integration test ---
282
283    /// Return true if `python3` is available on PATH.
284    fn python3_available() -> bool {
285        std::process::Command::new("python3")
286            .arg("--version")
287            .stdout(std::process::Stdio::null())
288            .stderr(std::process::Stdio::null())
289            .status()
290            .map(|s| s.success())
291            .unwrap_or(false)
292    }
293
294    /// A tiny MCP server over stdio: reads newline-delimited JSON-RPC requests,
295    /// handles `initialize`, ignores the `notifications/initialized` notification,
296    /// answers `tools/list` with one tool, and answers `tools/call` by echoing the
297    /// arguments back as text content.
298    const MOCK_MCP_SERVER: &str = r#"
299import sys, json
300for line in sys.stdin:
301    line = line.strip()
302    if not line:
303        continue
304    try:
305        msg = json.loads(line)
306    except Exception:
307        continue
308    method = msg.get("method")
309    mid = msg.get("id")
310    if method == "initialize":
311        resp = {"jsonrpc": "2.0", "id": mid, "result": {
312            "protocolVersion": "2024-11-05",
313            "capabilities": {"tools": {}},
314            "serverInfo": {"name": "mock-mcp", "version": "0.1.0"}}}
315        sys.stdout.write(json.dumps(resp) + "\n")
316        sys.stdout.flush()
317    elif method == "notifications/initialized":
318        # notification, no response
319        continue
320    elif method == "tools/list":
321        resp = {"jsonrpc": "2.0", "id": mid, "result": {"tools": [
322            {"name": "echo", "description": "Echo back the input",
323             "inputSchema": {"type": "object",
324                             "properties": {"text": {"type": "string"}},
325                             "required": ["text"]}}]}}
326        sys.stdout.write(json.dumps(resp) + "\n")
327        sys.stdout.flush()
328    elif method == "tools/call":
329        args = msg.get("params", {}).get("arguments", {})
330        resp = {"jsonrpc": "2.0", "id": mid, "result": {
331            "content": [{"type": "text", "text": json.dumps(args)}],
332            "isError": False}}
333        sys.stdout.write(json.dumps(resp) + "\n")
334        sys.stdout.flush()
335    else:
336        resp = {"jsonrpc": "2.0", "id": mid,
337                "error": {"code": -32601, "message": "method not found"}}
338        sys.stdout.write(json.dumps(resp) + "\n")
339        sys.stdout.flush()
340"#;
341
342    #[tokio::test]
343    async fn stdio_initialize_list_and_call_against_mock_server() {
344        if !python3_available() {
345            eprintln!("skipping: python3 not found on PATH");
346            return;
347        }
348
349        let manager = McpSessionManager::new(McpConnectionParams::Stdio {
350            command: "python3".to_string(),
351            args: vec!["-c".to_string(), MOCK_MCP_SERVER.to_string()],
352            timeout: Some(Duration::from_secs(10)),
353        });
354
355        // tools/list (triggers lazy initialize handshake first).
356        let tools = manager
357            .list_tools()
358            .await
359            .expect("tools/list should succeed");
360        assert_eq!(tools.len(), 1);
361        assert_eq!(tools[0].name, "echo");
362        assert_eq!(tools[0].description, "Echo back the input");
363        assert_eq!(tools[0].input_schema["type"], "object");
364
365        // tools/call echoes the arguments back as text content. The connection is
366        // reused (no second handshake).
367        let result = manager
368            .call_tool("echo", json!({"text": "hello mcp"}))
369            .await
370            .expect("tools/call should succeed");
371        assert_eq!(result["isError"], false);
372        let content_text = result["content"][0]["text"].as_str().unwrap();
373        let echoed: serde_json::Value = serde_json::from_str(content_text).unwrap();
374        assert_eq!(echoed["text"], "hello mcp");
375    }
376}