1use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11
12use serde_json::{json, Value};
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout, Command};
15use tokio::sync::Mutex;
16
17const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
19
20#[derive(Debug, Clone)]
22pub enum McpConnectionParams {
23 Stdio {
25 command: String,
27 args: Vec<String>,
29 timeout: Option<Duration>,
31 },
32 Sse {
34 url: String,
36 headers: Option<HashMap<String, String>>,
38 },
39}
40
41struct StdioConnection {
43 #[allow(dead_code)]
45 child: Child,
46 stdin: ChildStdin,
48 stdout: BufReader<ChildStdout>,
50}
51
52pub struct McpSessionManager {
54 params: McpConnectionParams,
55 stdio: Mutex<Option<StdioConnection>>,
57 next_id: AtomicU64,
59}
60
61impl McpSessionManager {
62 pub fn new(params: McpConnectionParams) -> Self {
64 Self {
65 params,
66 stdio: Mutex::new(None),
67 next_id: AtomicU64::new(1),
68 }
69 }
70
71 pub fn params(&self) -> &McpConnectionParams {
73 &self.params
74 }
75
76 fn next_id(&self) -> u64 {
77 self.next_id.fetch_add(1, Ordering::Relaxed)
78 }
79
80 pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
85 match &self.params {
86 McpConnectionParams::Stdio { .. } => self.stdio_list_tools().await,
87 McpConnectionParams::Sse { .. } => self.http_list_tools().await,
88 }
89 }
90
91 pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
97 match &self.params {
98 McpConnectionParams::Stdio { .. } => self.stdio_call_tool(name, args).await,
99 McpConnectionParams::Sse { .. } => self.http_call_tool(name, args).await,
100 }
101 }
102
103 async fn stdio_list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
108 let timeout = self.stdio_timeout();
109 let mut guard = self.stdio.lock().await;
110 self.ensure_connected(&mut guard, timeout).await?;
111 let conn = guard.as_mut().expect("connection established above");
112
113 let id = self.next_id();
114 let req = json!({
115 "jsonrpc": "2.0",
116 "id": id,
117 "method": "tools/list",
118 "params": {},
119 });
120 let result = stdio_request(conn, id, &req, timeout).await?;
121 parse_tools_list(&result)
122 }
123
124 async fn stdio_call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
125 let timeout = self.stdio_timeout();
126 let mut guard = self.stdio.lock().await;
127 self.ensure_connected(&mut guard, timeout).await?;
128 let conn = guard.as_mut().expect("connection established above");
129
130 let id = self.next_id();
131 let arguments = if args.is_null() { json!({}) } else { args };
132 let req = json!({
133 "jsonrpc": "2.0",
134 "id": id,
135 "method": "tools/call",
136 "params": { "name": name, "arguments": arguments },
137 });
138 let result = stdio_request(conn, id, &req, timeout).await?;
139 check_tool_result(&result, name)
140 }
141
142 fn stdio_timeout(&self) -> Option<Duration> {
143 match &self.params {
144 McpConnectionParams::Stdio { timeout, .. } => *timeout,
145 _ => None,
146 }
147 }
148
149 async fn ensure_connected(
151 &self,
152 guard: &mut Option<StdioConnection>,
153 timeout: Option<Duration>,
154 ) -> Result<(), McpError> {
155 if guard.is_some() {
156 return Ok(());
157 }
158 let (command, args) = match &self.params {
159 McpConnectionParams::Stdio { command, args, .. } => (command.clone(), args.clone()),
160 _ => {
161 return Err(McpError::ConnectionFailed(
162 "stdio transport requested for non-stdio params".to_string(),
163 ))
164 }
165 };
166
167 let mut child = Command::new(&command)
168 .args(&args)
169 .stdin(std::process::Stdio::piped())
170 .stdout(std::process::Stdio::piped())
171 .stderr(std::process::Stdio::null())
172 .spawn()
173 .map_err(|e| {
174 McpError::ConnectionFailed(format!("failed to spawn MCP server '{command}': {e}"))
175 })?;
176
177 let stdin = child.stdin.take().ok_or_else(|| {
178 McpError::ConnectionFailed("MCP server stdin not available".to_string())
179 })?;
180 let stdout = child.stdout.take().ok_or_else(|| {
181 McpError::ConnectionFailed("MCP server stdout not available".to_string())
182 })?;
183
184 let mut conn = StdioConnection {
185 child,
186 stdin,
187 stdout: BufReader::new(stdout),
188 };
189
190 let id = self.next_id();
192 let init = json!({
193 "jsonrpc": "2.0",
194 "id": id,
195 "method": "initialize",
196 "params": {
197 "protocolVersion": MCP_PROTOCOL_VERSION,
198 "capabilities": {},
199 "clientInfo": { "name": "gemini-adk-rs", "version": "0.6.0" },
200 },
201 });
202 stdio_request(&mut conn, id, &init, timeout)
203 .await
204 .map_err(|e| McpError::ConnectionFailed(format!("MCP initialize failed: {e}")))?;
205
206 let initialized = json!({
207 "jsonrpc": "2.0",
208 "method": "notifications/initialized",
209 });
210 stdio_write(&mut conn, &initialized).await.map_err(|e| {
211 McpError::ConnectionFailed(format!("MCP initialized notify failed: {e}"))
212 })?;
213
214 *guard = Some(conn);
215 Ok(())
216 }
217
218 #[cfg(feature = "mcp-http")]
223 async fn http_list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
224 let id = self.next_id();
225 let req = json!({
226 "jsonrpc": "2.0",
227 "id": id,
228 "method": "tools/list",
229 "params": {},
230 });
231 let result = self.http_request(id, &req).await?;
232 parse_tools_list(&result)
233 }
234
235 #[cfg(feature = "mcp-http")]
236 async fn http_call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
237 let id = self.next_id();
238 let arguments = if args.is_null() { json!({}) } else { args };
239 let req = json!({
240 "jsonrpc": "2.0",
241 "id": id,
242 "method": "tools/call",
243 "params": { "name": name, "arguments": arguments },
244 });
245 let result = self.http_request(id, &req).await?;
246 check_tool_result(&result, name)
247 }
248
249 #[cfg(feature = "mcp-http")]
250 #[cfg(not(feature = "mcp-http"))]
251 async fn http_request(&self, _id: u64, _req: &Value) -> Result<Value, McpError> {
252 Err(McpError::ConnectionFailed(
253 "SSE/HTTP MCP transport requires the `mcp-http` feature".to_string(),
254 ))
255 }
256
257 #[cfg(feature = "mcp-http")]
258 async fn http_request(&self, id: u64, req: &Value) -> Result<Value, McpError> {
259 let (url, headers) = match &self.params {
260 McpConnectionParams::Sse { url, headers } => (url.clone(), headers.clone()),
261 _ => {
262 return Err(McpError::ConnectionFailed(
263 "HTTP transport requested for non-SSE params".to_string(),
264 ))
265 }
266 };
267
268 let client = reqwest::Client::new();
269 let mut builder = client
270 .post(&url)
271 .header("content-type", "application/json")
272 .header("accept", "application/json")
273 .json(req);
274 if let Some(hdrs) = headers {
275 for (k, v) in hdrs {
276 builder = builder.header(k, v);
277 }
278 }
279
280 let resp = builder
281 .send()
282 .await
283 .map_err(|e| McpError::ConnectionFailed(format!("MCP HTTP request failed: {e}")))?;
284 if !resp.status().is_success() {
285 return Err(McpError::ConnectionFailed(format!(
286 "MCP HTTP request returned status {}",
287 resp.status()
288 )));
289 }
290 let body: Value = resp
291 .json()
292 .await
293 .map_err(|e| McpError::Other(format!("invalid MCP HTTP response body: {e}")))?;
294 extract_result(&body, id)
295 }
296
297 #[cfg(not(feature = "mcp-http"))]
298 async fn http_list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
299 Err(McpError::ConnectionFailed(
300 "mcp-http feature not enabled".to_string(),
301 ))
302 }
303
304 #[cfg(not(feature = "mcp-http"))]
305 async fn http_call_tool(&self, _name: &str, _args: Value) -> Result<Value, McpError> {
306 Err(McpError::ConnectionFailed(
307 "mcp-http feature not enabled".to_string(),
308 ))
309 }
310}
311
312async fn stdio_write(conn: &mut StdioConnection, msg: &Value) -> Result<(), McpError> {
318 let mut line = serde_json::to_string(msg)
319 .map_err(|e| McpError::Other(format!("failed to serialize JSON-RPC: {e}")))?;
320 line.push('\n');
321 conn.stdin
322 .write_all(line.as_bytes())
323 .await
324 .map_err(|e| McpError::ConnectionFailed(format!("failed to write to MCP server: {e}")))?;
325 conn.stdin.flush().await.map_err(|e| {
326 McpError::ConnectionFailed(format!("failed to flush MCP server stdin: {e}"))
327 })?;
328 Ok(())
329}
330
331async fn stdio_request(
334 conn: &mut StdioConnection,
335 id: u64,
336 req: &Value,
337 timeout: Option<Duration>,
338) -> Result<Value, McpError> {
339 let fut = async {
340 stdio_write(conn, req).await?;
341 loop {
342 let mut line = String::new();
343 let n = conn.stdout.read_line(&mut line).await.map_err(|e| {
344 McpError::ConnectionFailed(format!("failed to read from MCP server: {e}"))
345 })?;
346 if n == 0 {
347 return Err(McpError::ConnectionFailed(
349 "MCP server closed connection before responding".to_string(),
350 ));
351 }
352 let trimmed = line.trim();
353 if trimmed.is_empty() {
354 continue;
355 }
356 let msg: Value = match serde_json::from_str(trimmed) {
357 Ok(v) => v,
358 Err(_) => continue,
360 };
361 match msg.get("id").and_then(value_id_as_u64) {
363 Some(resp_id) if resp_id == id => return extract_result(&msg, id),
364 _ => continue,
365 }
366 }
367 };
368
369 match timeout {
370 Some(dur) => match tokio::time::timeout(dur, fut).await {
371 Ok(res) => res,
372 Err(_) => Err(McpError::ConnectionFailed(format!(
373 "MCP request timed out after {dur:?}"
374 ))),
375 },
376 None => fut.await,
377 }
378}
379
380fn value_id_as_u64(v: &Value) -> Option<u64> {
386 if let Some(n) = v.as_u64() {
387 return Some(n);
388 }
389 v.as_str().and_then(|s| s.parse::<u64>().ok())
390}
391
392fn extract_result(msg: &Value, id: u64) -> Result<Value, McpError> {
394 if let Some(err) = msg.get("error") {
395 let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(0);
396 let message = err
397 .get("message")
398 .and_then(|m| m.as_str())
399 .unwrap_or("unknown error");
400 return Err(McpError::ToolCallFailed(format!(
401 "JSON-RPC error {code}: {message}"
402 )));
403 }
404 match msg.get("result") {
405 Some(result) => Ok(result.clone()),
406 None => Err(McpError::Other(format!(
407 "JSON-RPC response (id {id}) has neither result nor error"
408 ))),
409 }
410}
411
412fn parse_tools_list(result: &Value) -> Result<Vec<McpToolInfo>, McpError> {
414 let tools = result
415 .get("tools")
416 .and_then(|t| t.as_array())
417 .ok_or_else(|| McpError::Other("tools/list result missing 'tools' array".to_string()))?;
418
419 let mut out = Vec::with_capacity(tools.len());
420 for t in tools {
421 let name = t
422 .get("name")
423 .and_then(|n| n.as_str())
424 .ok_or_else(|| McpError::Other("tool entry missing 'name'".to_string()))?
425 .to_string();
426 let description = t
427 .get("description")
428 .and_then(|d| d.as_str())
429 .unwrap_or("")
430 .to_string();
431 let input_schema = t
432 .get("inputSchema")
433 .cloned()
434 .unwrap_or_else(|| json!({"type": "object"}));
435 out.push(McpToolInfo {
436 name,
437 description,
438 input_schema,
439 });
440 }
441 Ok(out)
442}
443
444fn check_tool_result(result: &Value, name: &str) -> Result<Value, McpError> {
446 if result
447 .get("isError")
448 .and_then(|e| e.as_bool())
449 .unwrap_or(false)
450 {
451 return Err(McpError::ToolCallFailed(format!(
452 "tool '{name}' reported isError: {result}"
453 )));
454 }
455 Ok(result.clone())
456}
457
458#[derive(Debug, Clone)]
460pub struct McpToolInfo {
461 pub name: String,
463 pub description: String,
465 pub input_schema: serde_json::Value,
467}
468
469#[derive(Debug, thiserror::Error)]
471pub enum McpError {
472 #[error("Connection failed: {0}")]
474 ConnectionFailed(String),
475 #[error("Not connected: {0}")]
477 NotConnected(String),
478 #[error("Tool call failed: {0}")]
480 ToolCallFailed(String),
481 #[error("{0}")]
483 Other(String),
484}