gemini_adk_rs/tools/mcp/
session_manager.rs

1//! MCP session management — connection params, tool discovery, and tool invocation.
2//!
3//! Implements a real MCP (Model Context Protocol) client speaking JSON-RPC 2.0.
4//! The primary transport is **stdio** (newline-delimited JSON over a subprocess's
5//! stdin/stdout), which works on default features. An optional **HTTP** transport
6//! (single-shot JSON-RPC POST) is available behind the `mcp-http` feature.
7
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11
12use serde_json::{json, Value};
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout, Command};
15use tokio::sync::Mutex;
16
17/// MCP protocol version advertised during the handshake.
18const MCP_PROTOCOL_VERSION: &str = "2024-11-05";
19
20/// Connection parameters for an MCP server.
21#[derive(Debug, Clone)]
22pub enum McpConnectionParams {
23    /// Connect via stdio (subprocess).
24    Stdio {
25        /// The command to execute.
26        command: String,
27        /// Arguments passed to the command.
28        args: Vec<String>,
29        /// Connection timeout.
30        timeout: Option<Duration>,
31    },
32    /// Connect via SSE/StreamableHTTP.
33    Sse {
34        /// The URL of the MCP server.
35        url: String,
36        /// Optional HTTP headers for authentication.
37        headers: Option<HashMap<String, String>>,
38    },
39}
40
41/// Live stdio connection state: the child process plus framed I/O handles.
42struct StdioConnection {
43    /// The child process. Kept alive so the pipes stay open; killed on drop.
44    #[allow(dead_code)]
45    child: Child,
46    /// Subprocess stdin (we write requests here).
47    stdin: ChildStdin,
48    /// Buffered subprocess stdout (we read newline-delimited responses here).
49    stdout: BufReader<ChildStdout>,
50}
51
52/// Manages the MCP client session lifecycle.
53pub struct McpSessionManager {
54    params: McpConnectionParams,
55    /// Lazily-established stdio connection (None until first use, then reused).
56    stdio: Mutex<Option<StdioConnection>>,
57    /// Monotonic JSON-RPC request id counter.
58    next_id: AtomicU64,
59}
60
61impl McpSessionManager {
62    /// Create a new MCP session manager with the given connection params.
63    pub fn new(params: McpConnectionParams) -> Self {
64        Self {
65            params,
66            stdio: Mutex::new(None),
67            next_id: AtomicU64::new(1),
68        }
69    }
70
71    /// Get the connection parameters.
72    pub fn params(&self) -> &McpConnectionParams {
73        &self.params
74    }
75
76    fn next_id(&self) -> u64 {
77        self.next_id.fetch_add(1, Ordering::Relaxed)
78    }
79
80    /// List available tools from the MCP server.
81    ///
82    /// Connects (and performs the MCP handshake) lazily on first use, then issues
83    /// a `tools/list` JSON-RPC request and maps the result into [`McpToolInfo`]s.
84    pub async fn list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
85        match &self.params {
86            McpConnectionParams::Stdio { .. } => self.stdio_list_tools().await,
87            McpConnectionParams::Sse { .. } => self.http_list_tools().await,
88        }
89    }
90
91    /// Call a tool on the MCP server via `tools/call`.
92    ///
93    /// Returns the JSON-RPC `result` object on success. Returns
94    /// [`McpError::ToolCallFailed`] on a JSON-RPC error or when the result has
95    /// `isError: true`.
96    pub async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
97        match &self.params {
98            McpConnectionParams::Stdio { .. } => self.stdio_call_tool(name, args).await,
99            McpConnectionParams::Sse { .. } => self.http_call_tool(name, args).await,
100        }
101    }
102
103    // ------------------------------------------------------------------
104    // stdio transport
105    // ------------------------------------------------------------------
106
107    async fn stdio_list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
108        let timeout = self.stdio_timeout();
109        let mut guard = self.stdio.lock().await;
110        self.ensure_connected(&mut guard, timeout).await?;
111        let conn = guard.as_mut().expect("connection established above");
112
113        let id = self.next_id();
114        let req = json!({
115            "jsonrpc": "2.0",
116            "id": id,
117            "method": "tools/list",
118            "params": {},
119        });
120        let result = stdio_request(conn, id, &req, timeout).await?;
121        parse_tools_list(&result)
122    }
123
124    async fn stdio_call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
125        let timeout = self.stdio_timeout();
126        let mut guard = self.stdio.lock().await;
127        self.ensure_connected(&mut guard, timeout).await?;
128        let conn = guard.as_mut().expect("connection established above");
129
130        let id = self.next_id();
131        let arguments = if args.is_null() { json!({}) } else { args };
132        let req = json!({
133            "jsonrpc": "2.0",
134            "id": id,
135            "method": "tools/call",
136            "params": { "name": name, "arguments": arguments },
137        });
138        let result = stdio_request(conn, id, &req, timeout).await?;
139        check_tool_result(&result, name)
140    }
141
142    fn stdio_timeout(&self) -> Option<Duration> {
143        match &self.params {
144            McpConnectionParams::Stdio { timeout, .. } => *timeout,
145            _ => None,
146        }
147    }
148
149    /// Ensure a stdio connection exists and has completed the MCP handshake.
150    async fn ensure_connected(
151        &self,
152        guard: &mut Option<StdioConnection>,
153        timeout: Option<Duration>,
154    ) -> Result<(), McpError> {
155        if guard.is_some() {
156            return Ok(());
157        }
158        let (command, args) = match &self.params {
159            McpConnectionParams::Stdio { command, args, .. } => (command.clone(), args.clone()),
160            _ => {
161                return Err(McpError::ConnectionFailed(
162                    "stdio transport requested for non-stdio params".to_string(),
163                ))
164            }
165        };
166
167        let mut child = Command::new(&command)
168            .args(&args)
169            .stdin(std::process::Stdio::piped())
170            .stdout(std::process::Stdio::piped())
171            .stderr(std::process::Stdio::null())
172            .spawn()
173            .map_err(|e| {
174                McpError::ConnectionFailed(format!("failed to spawn MCP server '{command}': {e}"))
175            })?;
176
177        let stdin = child.stdin.take().ok_or_else(|| {
178            McpError::ConnectionFailed("MCP server stdin not available".to_string())
179        })?;
180        let stdout = child.stdout.take().ok_or_else(|| {
181            McpError::ConnectionFailed("MCP server stdout not available".to_string())
182        })?;
183
184        let mut conn = StdioConnection {
185            child,
186            stdin,
187            stdout: BufReader::new(stdout),
188        };
189
190        // --- Handshake: initialize -> read result -> notifications/initialized ---
191        let id = self.next_id();
192        let init = json!({
193            "jsonrpc": "2.0",
194            "id": id,
195            "method": "initialize",
196            "params": {
197                "protocolVersion": MCP_PROTOCOL_VERSION,
198                "capabilities": {},
199                "clientInfo": { "name": "gemini-adk-rs", "version": "0.6.0" },
200            },
201        });
202        stdio_request(&mut conn, id, &init, timeout)
203            .await
204            .map_err(|e| McpError::ConnectionFailed(format!("MCP initialize failed: {e}")))?;
205
206        let initialized = json!({
207            "jsonrpc": "2.0",
208            "method": "notifications/initialized",
209        });
210        stdio_write(&mut conn, &initialized).await.map_err(|e| {
211            McpError::ConnectionFailed(format!("MCP initialized notify failed: {e}"))
212        })?;
213
214        *guard = Some(conn);
215        Ok(())
216    }
217
218    // ------------------------------------------------------------------
219    // HTTP transport (feature-gated)
220    // ------------------------------------------------------------------
221
222    #[cfg(feature = "mcp-http")]
223    async fn http_list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
224        let id = self.next_id();
225        let req = json!({
226            "jsonrpc": "2.0",
227            "id": id,
228            "method": "tools/list",
229            "params": {},
230        });
231        let result = self.http_request(id, &req).await?;
232        parse_tools_list(&result)
233    }
234
235    #[cfg(feature = "mcp-http")]
236    async fn http_call_tool(&self, name: &str, args: Value) -> Result<Value, McpError> {
237        let id = self.next_id();
238        let arguments = if args.is_null() { json!({}) } else { args };
239        let req = json!({
240            "jsonrpc": "2.0",
241            "id": id,
242            "method": "tools/call",
243            "params": { "name": name, "arguments": arguments },
244        });
245        let result = self.http_request(id, &req).await?;
246        check_tool_result(&result, name)
247    }
248
249    #[cfg(feature = "mcp-http")]
250    #[cfg(not(feature = "mcp-http"))]
251    async fn http_request(&self, _id: u64, _req: &Value) -> Result<Value, McpError> {
252        Err(McpError::ConnectionFailed(
253            "SSE/HTTP MCP transport requires the `mcp-http` feature".to_string(),
254        ))
255    }
256
257    #[cfg(feature = "mcp-http")]
258    async fn http_request(&self, id: u64, req: &Value) -> Result<Value, McpError> {
259        let (url, headers) = match &self.params {
260            McpConnectionParams::Sse { url, headers } => (url.clone(), headers.clone()),
261            _ => {
262                return Err(McpError::ConnectionFailed(
263                    "HTTP transport requested for non-SSE params".to_string(),
264                ))
265            }
266        };
267
268        let client = reqwest::Client::new();
269        let mut builder = client
270            .post(&url)
271            .header("content-type", "application/json")
272            .header("accept", "application/json")
273            .json(req);
274        if let Some(hdrs) = headers {
275            for (k, v) in hdrs {
276                builder = builder.header(k, v);
277            }
278        }
279
280        let resp = builder
281            .send()
282            .await
283            .map_err(|e| McpError::ConnectionFailed(format!("MCP HTTP request failed: {e}")))?;
284        if !resp.status().is_success() {
285            return Err(McpError::ConnectionFailed(format!(
286                "MCP HTTP request returned status {}",
287                resp.status()
288            )));
289        }
290        let body: Value = resp
291            .json()
292            .await
293            .map_err(|e| McpError::Other(format!("invalid MCP HTTP response body: {e}")))?;
294        extract_result(&body, id)
295    }
296
297    #[cfg(not(feature = "mcp-http"))]
298    async fn http_list_tools(&self) -> Result<Vec<McpToolInfo>, McpError> {
299        Err(McpError::ConnectionFailed(
300            "mcp-http feature not enabled".to_string(),
301        ))
302    }
303
304    #[cfg(not(feature = "mcp-http"))]
305    async fn http_call_tool(&self, _name: &str, _args: Value) -> Result<Value, McpError> {
306        Err(McpError::ConnectionFailed(
307            "mcp-http feature not enabled".to_string(),
308        ))
309    }
310}
311
312// ----------------------------------------------------------------------
313// stdio framing helpers
314// ----------------------------------------------------------------------
315
316/// Write a single JSON-RPC message as one compact, newline-terminated line.
317async fn stdio_write(conn: &mut StdioConnection, msg: &Value) -> Result<(), McpError> {
318    let mut line = serde_json::to_string(msg)
319        .map_err(|e| McpError::Other(format!("failed to serialize JSON-RPC: {e}")))?;
320    line.push('\n');
321    conn.stdin
322        .write_all(line.as_bytes())
323        .await
324        .map_err(|e| McpError::ConnectionFailed(format!("failed to write to MCP server: {e}")))?;
325    conn.stdin.flush().await.map_err(|e| {
326        McpError::ConnectionFailed(format!("failed to flush MCP server stdin: {e}"))
327    })?;
328    Ok(())
329}
330
331/// Send a JSON-RPC request and read the matching response by `id`, skipping
332/// notifications and unrelated messages. Honors the optional timeout.
333async fn stdio_request(
334    conn: &mut StdioConnection,
335    id: u64,
336    req: &Value,
337    timeout: Option<Duration>,
338) -> Result<Value, McpError> {
339    let fut = async {
340        stdio_write(conn, req).await?;
341        loop {
342            let mut line = String::new();
343            let n = conn.stdout.read_line(&mut line).await.map_err(|e| {
344                McpError::ConnectionFailed(format!("failed to read from MCP server: {e}"))
345            })?;
346            if n == 0 {
347                // EOF: child closed stdout / exited before responding.
348                return Err(McpError::ConnectionFailed(
349                    "MCP server closed connection before responding".to_string(),
350                ));
351            }
352            let trimmed = line.trim();
353            if trimmed.is_empty() {
354                continue;
355            }
356            let msg: Value = match serde_json::from_str(trimmed) {
357                Ok(v) => v,
358                // Non-JSON line (stray log output) — skip it.
359                Err(_) => continue,
360            };
361            // Skip anything that isn't the response to our request id.
362            match msg.get("id").and_then(value_id_as_u64) {
363                Some(resp_id) if resp_id == id => return extract_result(&msg, id),
364                _ => continue,
365            }
366        }
367    };
368
369    match timeout {
370        Some(dur) => match tokio::time::timeout(dur, fut).await {
371            Ok(res) => res,
372            Err(_) => Err(McpError::ConnectionFailed(format!(
373                "MCP request timed out after {dur:?}"
374            ))),
375        },
376        None => fut.await,
377    }
378}
379
380// ----------------------------------------------------------------------
381// JSON-RPC / MCP result parsing (shared by both transports)
382// ----------------------------------------------------------------------
383
384/// Interpret a JSON-RPC `id` value (number or numeric string) as u64.
385fn value_id_as_u64(v: &Value) -> Option<u64> {
386    if let Some(n) = v.as_u64() {
387        return Some(n);
388    }
389    v.as_str().and_then(|s| s.parse::<u64>().ok())
390}
391
392/// Extract the `result` from a JSON-RPC response, mapping `error` to [`McpError`].
393fn extract_result(msg: &Value, id: u64) -> Result<Value, McpError> {
394    if let Some(err) = msg.get("error") {
395        let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(0);
396        let message = err
397            .get("message")
398            .and_then(|m| m.as_str())
399            .unwrap_or("unknown error");
400        return Err(McpError::ToolCallFailed(format!(
401            "JSON-RPC error {code}: {message}"
402        )));
403    }
404    match msg.get("result") {
405        Some(result) => Ok(result.clone()),
406        None => Err(McpError::Other(format!(
407            "JSON-RPC response (id {id}) has neither result nor error"
408        ))),
409    }
410}
411
412/// Map a `tools/list` result into [`McpToolInfo`]s.
413fn parse_tools_list(result: &Value) -> Result<Vec<McpToolInfo>, McpError> {
414    let tools = result
415        .get("tools")
416        .and_then(|t| t.as_array())
417        .ok_or_else(|| McpError::Other("tools/list result missing 'tools' array".to_string()))?;
418
419    let mut out = Vec::with_capacity(tools.len());
420    for t in tools {
421        let name = t
422            .get("name")
423            .and_then(|n| n.as_str())
424            .ok_or_else(|| McpError::Other("tool entry missing 'name'".to_string()))?
425            .to_string();
426        let description = t
427            .get("description")
428            .and_then(|d| d.as_str())
429            .unwrap_or("")
430            .to_string();
431        let input_schema = t
432            .get("inputSchema")
433            .cloned()
434            .unwrap_or_else(|| json!({"type": "object"}));
435        out.push(McpToolInfo {
436            name,
437            description,
438            input_schema,
439        });
440    }
441    Ok(out)
442}
443
444/// Validate a `tools/call` result, surfacing `isError: true` as a failure.
445fn check_tool_result(result: &Value, name: &str) -> Result<Value, McpError> {
446    if result
447        .get("isError")
448        .and_then(|e| e.as_bool())
449        .unwrap_or(false)
450    {
451        return Err(McpError::ToolCallFailed(format!(
452            "tool '{name}' reported isError: {result}"
453        )));
454    }
455    Ok(result.clone())
456}
457
458/// Information about an MCP tool.
459#[derive(Debug, Clone)]
460pub struct McpToolInfo {
461    /// Tool name.
462    pub name: String,
463    /// Human-readable tool description.
464    pub description: String,
465    /// JSON Schema for the tool's input parameters.
466    pub input_schema: serde_json::Value,
467}
468
469/// MCP-related errors.
470#[derive(Debug, thiserror::Error)]
471pub enum McpError {
472    /// Failed to connect to the MCP server.
473    #[error("Connection failed: {0}")]
474    ConnectionFailed(String),
475    /// The MCP session is not connected.
476    #[error("Not connected: {0}")]
477    NotConnected(String),
478    /// A tool call to the MCP server failed.
479    #[error("Tool call failed: {0}")]
480    ToolCallFailed(String),
481    /// A catch-all for other MCP errors.
482    #[error("{0}")]
483    Other(String),
484}