1pub mod session_manager;
4pub mod tool;
5pub mod toolset;
6
7pub use session_manager::{McpConnectionParams, McpError, McpSessionManager, McpToolInfo};
8pub use tool::McpTool;
9pub use toolset::McpToolset;
10
11#[cfg(test)]
12mod tests {
13 use super::*;
14 use crate::error::ToolError;
15 use crate::tool::ToolFunction;
16 use crate::toolset::Toolset;
17 use serde_json::json;
18 use std::collections::HashMap;
19 use std::sync::Arc;
20 use std::time::Duration;
21
22 #[test]
25 fn connection_params_stdio() {
26 let params = McpConnectionParams::Stdio {
27 command: "node".to_string(),
28 args: vec!["server.js".to_string()],
29 timeout: Some(Duration::from_secs(10)),
30 };
31 match ¶ms {
32 McpConnectionParams::Stdio {
33 command,
34 args,
35 timeout,
36 } => {
37 assert_eq!(command, "node");
38 assert_eq!(args, &["server.js"]);
39 assert_eq!(*timeout, Some(Duration::from_secs(10)));
40 }
41 _ => panic!("expected Stdio variant"),
42 }
43 }
44
45 #[test]
46 fn connection_params_sse() {
47 let mut headers = HashMap::new();
48 headers.insert("Authorization".to_string(), "Bearer token".to_string());
49 let params = McpConnectionParams::Sse {
50 url: "http://localhost:8080/sse".to_string(),
51 headers: Some(headers.clone()),
52 };
53 match ¶ms {
54 McpConnectionParams::Sse { url, headers: h } => {
55 assert_eq!(url, "http://localhost:8080/sse");
56 let h = h.as_ref().unwrap();
57 assert_eq!(h.get("Authorization").unwrap(), "Bearer token");
58 }
59 _ => panic!("expected Sse variant"),
60 }
61 }
62
63 #[test]
64 fn connection_params_stdio_no_timeout() {
65 let params = McpConnectionParams::Stdio {
66 command: "python".to_string(),
67 args: vec![],
68 timeout: None,
69 };
70 match ¶ms {
71 McpConnectionParams::Stdio { timeout, .. } => {
72 assert!(timeout.is_none());
73 }
74 _ => panic!("expected Stdio variant"),
75 }
76 }
77
78 #[test]
79 fn connection_params_sse_no_headers() {
80 let params = McpConnectionParams::Sse {
81 url: "http://localhost:3000".to_string(),
82 headers: None,
83 };
84 match ¶ms {
85 McpConnectionParams::Sse { headers, .. } => {
86 assert!(headers.is_none());
87 }
88 _ => panic!("expected Sse variant"),
89 }
90 }
91
92 #[tokio::test]
95 async fn session_manager_list_tools_unconnectable_server_errors() {
96 let manager = McpSessionManager::new(McpConnectionParams::Stdio {
100 command: "definitely_not_a_real_mcp_server_binary_xyz".to_string(),
101 args: vec![],
102 timeout: Some(Duration::from_secs(2)),
103 });
104 let result = manager.list_tools().await;
105 assert!(result.is_err());
106 match result.unwrap_err() {
107 McpError::ConnectionFailed(msg) => {
108 assert!(msg.contains("definitely_not_a_real_mcp_server_binary_xyz"));
109 }
110 other => panic!("expected McpError::ConnectionFailed, got: {other}"),
111 }
112 }
113
114 #[tokio::test]
115 async fn session_manager_call_tool_unconnectable_server_errors() {
116 let manager = McpSessionManager::new(McpConnectionParams::Stdio {
119 command: "echo".to_string(),
120 args: vec![],
121 timeout: Some(Duration::from_secs(2)),
122 });
123 let result = manager.call_tool("some_tool", json!({})).await;
124 assert!(result.is_err());
125 match result.unwrap_err() {
126 McpError::ConnectionFailed(_) => {}
127 other => panic!("expected McpError::ConnectionFailed, got: {other}"),
128 }
129 }
130
131 #[test]
132 fn session_manager_params_accessor() {
133 let params = McpConnectionParams::Sse {
134 url: "http://example.com".to_string(),
135 headers: None,
136 };
137 let manager = McpSessionManager::new(params);
138 match manager.params() {
139 McpConnectionParams::Sse { url, .. } => {
140 assert_eq!(url, "http://example.com");
141 }
142 _ => panic!("expected Sse variant"),
143 }
144 }
145
146 #[test]
149 fn mcp_tool_name_description_parameters() {
150 let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
151 command: "echo".to_string(),
152 args: vec![],
153 timeout: None,
154 }));
155 let schema = json!({"type": "object", "properties": {"query": {"type": "string"}}});
156 let tool = McpTool::new("search", "Search for things", Some(schema.clone()), manager);
157
158 assert_eq!(tool.name(), "search");
159 assert_eq!(tool.description(), "Search for things");
160 assert_eq!(tool.parameters(), Some(schema));
161 }
162
163 #[test]
164 fn mcp_tool_no_schema() {
165 let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
166 command: "echo".to_string(),
167 args: vec![],
168 timeout: None,
169 }));
170 let tool = McpTool::new("ping", "Ping the server", None, manager);
171
172 assert_eq!(tool.name(), "ping");
173 assert!(tool.parameters().is_none());
174 }
175
176 #[tokio::test]
177 async fn mcp_tool_call_delegates_to_session_manager() {
178 let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
181 command: "echo".to_string(),
182 args: vec![],
183 timeout: Some(Duration::from_secs(2)),
184 }));
185 let tool = McpTool::new("my_tool", "desc", None, manager);
186
187 let result = tool.call(json!({"key": "value"})).await;
188 assert!(result.is_err());
189 match result.unwrap_err() {
190 ToolError::ExecutionFailed(msg) => {
191 assert!(msg.contains("Connection failed") || msg.contains("connection"));
192 }
193 other => panic!("expected ToolError::ExecutionFailed, got: {other:?}"),
194 }
195 }
196
197 #[test]
200 fn mcp_toolset_get_tools_returns_empty() {
201 let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
202 command: "echo".to_string(),
203 args: vec![],
204 timeout: None,
205 }));
206 let toolset = McpToolset::new(manager);
207 assert!(toolset.get_tools().is_empty());
208 }
209
210 #[test]
211 fn mcp_toolset_with_filter_stores_filter() {
212 let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
213 command: "echo".to_string(),
214 args: vec![],
215 timeout: None,
216 }));
217 let toolset =
218 McpToolset::new(manager).with_filter(vec!["tool_a".to_string(), "tool_b".to_string()]);
219
220 let filter = toolset.filter().unwrap();
221 assert_eq!(filter.len(), 2);
222 assert_eq!(filter[0], "tool_a");
223 assert_eq!(filter[1], "tool_b");
224 }
225
226 #[test]
227 fn mcp_toolset_no_filter_by_default() {
228 let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
229 command: "echo".to_string(),
230 args: vec![],
231 timeout: None,
232 }));
233 let toolset = McpToolset::new(manager);
234 assert!(toolset.filter().is_none());
235 }
236
237 #[tokio::test]
238 async fn mcp_toolset_close_is_noop() {
239 let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Stdio {
240 command: "echo".to_string(),
241 args: vec![],
242 timeout: None,
243 }));
244 let toolset = McpToolset::new(manager);
245 toolset.close().await; }
247
248 #[test]
249 fn mcp_toolset_session_manager_accessor() {
250 let manager = Arc::new(McpSessionManager::new(McpConnectionParams::Sse {
251 url: "http://localhost:9090".to_string(),
252 headers: None,
253 }));
254 let toolset = McpToolset::new(manager.clone());
255 match toolset.session_manager().params() {
257 McpConnectionParams::Sse { url, .. } => {
258 assert_eq!(url, "http://localhost:9090");
259 }
260 _ => panic!("expected Sse variant"),
261 }
262 }
263
264 #[test]
267 fn mcp_error_display() {
268 let err = McpError::ConnectionFailed("timeout".to_string());
269 assert_eq!(err.to_string(), "Connection failed: timeout");
270
271 let err = McpError::NotConnected("no session".to_string());
272 assert_eq!(err.to_string(), "Not connected: no session");
273
274 let err = McpError::ToolCallFailed("bad args".to_string());
275 assert_eq!(err.to_string(), "Tool call failed: bad args");
276
277 let err = McpError::Other("something".to_string());
278 assert_eq!(err.to_string(), "something");
279 }
280
281 fn python3_available() -> bool {
285 std::process::Command::new("python3")
286 .arg("--version")
287 .stdout(std::process::Stdio::null())
288 .stderr(std::process::Stdio::null())
289 .status()
290 .map(|s| s.success())
291 .unwrap_or(false)
292 }
293
294 const MOCK_MCP_SERVER: &str = r#"
299import sys, json
300for line in sys.stdin:
301 line = line.strip()
302 if not line:
303 continue
304 try:
305 msg = json.loads(line)
306 except Exception:
307 continue
308 method = msg.get("method")
309 mid = msg.get("id")
310 if method == "initialize":
311 resp = {"jsonrpc": "2.0", "id": mid, "result": {
312 "protocolVersion": "2024-11-05",
313 "capabilities": {"tools": {}},
314 "serverInfo": {"name": "mock-mcp", "version": "0.1.0"}}}
315 sys.stdout.write(json.dumps(resp) + "\n")
316 sys.stdout.flush()
317 elif method == "notifications/initialized":
318 # notification, no response
319 continue
320 elif method == "tools/list":
321 resp = {"jsonrpc": "2.0", "id": mid, "result": {"tools": [
322 {"name": "echo", "description": "Echo back the input",
323 "inputSchema": {"type": "object",
324 "properties": {"text": {"type": "string"}},
325 "required": ["text"]}}]}}
326 sys.stdout.write(json.dumps(resp) + "\n")
327 sys.stdout.flush()
328 elif method == "tools/call":
329 args = msg.get("params", {}).get("arguments", {})
330 resp = {"jsonrpc": "2.0", "id": mid, "result": {
331 "content": [{"type": "text", "text": json.dumps(args)}],
332 "isError": False}}
333 sys.stdout.write(json.dumps(resp) + "\n")
334 sys.stdout.flush()
335 else:
336 resp = {"jsonrpc": "2.0", "id": mid,
337 "error": {"code": -32601, "message": "method not found"}}
338 sys.stdout.write(json.dumps(resp) + "\n")
339 sys.stdout.flush()
340"#;
341
342 #[tokio::test]
343 async fn stdio_initialize_list_and_call_against_mock_server() {
344 if !python3_available() {
345 eprintln!("skipping: python3 not found on PATH");
346 return;
347 }
348
349 let manager = McpSessionManager::new(McpConnectionParams::Stdio {
350 command: "python3".to_string(),
351 args: vec!["-c".to_string(), MOCK_MCP_SERVER.to_string()],
352 timeout: Some(Duration::from_secs(10)),
353 });
354
355 let tools = manager
357 .list_tools()
358 .await
359 .expect("tools/list should succeed");
360 assert_eq!(tools.len(), 1);
361 assert_eq!(tools[0].name, "echo");
362 assert_eq!(tools[0].description, "Echo back the input");
363 assert_eq!(tools[0].input_schema["type"], "object");
364
365 let result = manager
368 .call_tool("echo", json!({"text": "hello mcp"}))
369 .await
370 .expect("tools/call should succeed");
371 assert_eq!(result["isError"], false);
372 let content_text = result["content"][0]["text"].as_str().unwrap();
373 let echoed: serde_json::Value = serde_json::from_str(content_text).unwrap();
374 assert_eq!(echoed["text"], "hello mcp");
375 }
376}