gemini_genai_rs/transport/
ws.rs

1//! Transport abstraction — bidirectional message transport.
2//!
3//! The [`Transport`] trait defines a pluggable transport layer for sending and
4//! receiving raw bytes. The default implementation [`TungsteniteTransport`] wraps
5//! `tokio-tungstenite` for WebSocket connectivity. [`MockTransport`] enables
6//! deterministic unit testing without a network.
7
8use async_trait::async_trait;
9
10/// A bidirectional message transport.
11///
12/// The default is WebSocket ([`TungsteniteTransport`]); [`MockTransport`] enables
13/// unit testing without a real server.
14///
15/// # Implementors
16///
17/// - [`TungsteniteTransport`] -- Production WebSocket transport using `tokio-tungstenite`.
18///   Handles both Text and Binary frames (Vertex AI sends Binary).
19/// - [`MockTransport`] -- Deterministic test transport. Records sent data and replays
20///   scripted responses. When the queue is empty, `recv()` pends indefinitely.
21#[async_trait]
22pub trait Transport: Send + 'static {
23    /// The error type produced by this transport.
24    type Error: std::error::Error + Send + Sync + 'static;
25
26    /// Connect to the given URL with optional headers.
27    async fn connect(
28        &mut self,
29        url: &str,
30        headers: Vec<(String, String)>,
31    ) -> Result<(), Self::Error>;
32
33    /// Send raw bytes.
34    async fn send(&mut self, data: Vec<u8>) -> Result<(), Self::Error>;
35
36    /// Receive raw bytes. Returns `None` when the connection is closed.
37    async fn recv(&mut self) -> Result<Option<Vec<u8>>, Self::Error>;
38
39    /// Close the transport.
40    async fn close(&mut self) -> Result<(), Self::Error>;
41}
42
43// ---------------------------------------------------------------------------
44// TungsteniteTransport — WebSocket transport using tokio-tungstenite
45// ---------------------------------------------------------------------------
46
47use futures_util::{SinkExt, StreamExt};
48use tokio_tungstenite::tungstenite::client::IntoClientRequest;
49use tokio_tungstenite::tungstenite::Message;
50
51type WsStream =
52    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
53
54/// WebSocket transport using `tokio-tungstenite`.
55pub struct TungsteniteTransport {
56    ws_write: Option<futures_util::stream::SplitSink<WsStream, Message>>,
57    ws_read: Option<futures_util::stream::SplitStream<WsStream>>,
58}
59
60impl TungsteniteTransport {
61    /// Create a new, disconnected transport.
62    pub fn new() -> Self {
63        Self {
64            ws_write: None,
65            ws_read: None,
66        }
67    }
68}
69
70impl Default for TungsteniteTransport {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76/// Errors from the [`TungsteniteTransport`].
77#[derive(Debug, thiserror::Error)]
78pub enum TungsteniteError {
79    /// The transport is not connected.
80    #[error("Not connected")]
81    NotConnected,
82
83    /// WebSocket protocol error from tungstenite.
84    #[error("WebSocket error: {0}")]
85    WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
86
87    /// Failed to construct the HTTP request (e.g. bad URL or header).
88    #[error("Request error: {0}")]
89    Request(String),
90}
91
92#[async_trait]
93impl Transport for TungsteniteTransport {
94    type Error = TungsteniteError;
95
96    async fn connect(
97        &mut self,
98        url: &str,
99        headers: Vec<(String, String)>,
100    ) -> Result<(), Self::Error> {
101        let mut request = url
102            .into_client_request()
103            .map_err(|e| TungsteniteError::Request(e.to_string()))?;
104
105        for (name, value) in headers {
106            let header_name: tokio_tungstenite::tungstenite::http::HeaderName =
107                name.parse().map_err(
108                    |e: tokio_tungstenite::tungstenite::http::header::InvalidHeaderName| {
109                        TungsteniteError::Request(format!("invalid header name: {e}"))
110                    },
111                )?;
112            let header_value: tokio_tungstenite::tungstenite::http::HeaderValue =
113                value.parse().map_err(
114                    |e: tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue| {
115                        TungsteniteError::Request(format!("invalid header value: {e}"))
116                    },
117                )?;
118            request.headers_mut().insert(header_name, header_value);
119        }
120
121        let (ws_stream, _response) = tokio_tungstenite::connect_async(request).await?;
122        let (ws_write, ws_read) = ws_stream.split();
123        self.ws_write = Some(ws_write);
124        self.ws_read = Some(ws_read);
125        Ok(())
126    }
127
128    async fn send(&mut self, data: Vec<u8>) -> Result<(), Self::Error> {
129        let ws_write = self
130            .ws_write
131            .as_mut()
132            .ok_or(TungsteniteError::NotConnected)?;
133        // Convert bytes to a UTF-8 text frame. The wire protocol sends JSON as text.
134        let text = String::from_utf8(data)
135            .map_err(|e| TungsteniteError::Request(format!("invalid UTF-8: {e}")))?;
136        ws_write.send(Message::Text(text)).await?;
137        Ok(())
138    }
139
140    async fn recv(&mut self) -> Result<Option<Vec<u8>>, Self::Error> {
141        let ws_read = self
142            .ws_read
143            .as_mut()
144            .ok_or(TungsteniteError::NotConnected)?;
145        loop {
146            match ws_read.next().await {
147                Some(Ok(Message::Text(t))) => return Ok(Some(t.into_bytes())),
148                // IMPORTANT: Vertex AI sends JSON in Binary frames.
149                Some(Ok(Message::Binary(b))) => return Ok(Some(b)),
150                Some(Ok(Message::Close(frame))) => {
151                    if let Some(ref cf) = frame {
152                        tracing::warn!(code = %cf.code, reason = %cf.reason, "WebSocket close frame received");
153                    }
154                    return Ok(None);
155                }
156                // Ping/Pong are handled internally by tungstenite; skip them.
157                Some(Ok(Message::Ping(_) | Message::Pong(_))) => continue,
158                // Frame is a low-level variant; skip.
159                Some(Ok(Message::Frame(_))) => continue,
160                Some(Err(e)) => return Err(TungsteniteError::WebSocket(e)),
161                None => return Ok(None),
162            }
163        }
164    }
165
166    async fn close(&mut self) -> Result<(), Self::Error> {
167        if let Some(ref mut ws_write) = self.ws_write {
168            ws_write.send(Message::Close(None)).await?;
169        }
170        self.ws_write = None;
171        self.ws_read = None;
172        Ok(())
173    }
174}
175
176// ---------------------------------------------------------------------------
177// MockTransport — for unit testing
178// ---------------------------------------------------------------------------
179
180/// Mock transport for unit testing.
181///
182/// Records sent data and replays scripted responses from a queue. When the
183/// queue is empty and the transport is connected, [`recv`](Transport::recv)
184/// will pend indefinitely — simulating a connected-but-idle transport.
185/// Call [`close`](Transport::close) to signal connection closure (returns `None`).
186pub struct MockTransport {
187    sent: Vec<Vec<u8>>,
188    recv_queue: std::collections::VecDeque<Vec<u8>>,
189    /// Whether connect() has been called (and close() has not).
190    connected: bool,
191}
192
193impl MockTransport {
194    /// Create a new, disconnected mock transport.
195    pub fn new() -> Self {
196        Self {
197            sent: Vec::new(),
198            recv_queue: std::collections::VecDeque::new(),
199            connected: false,
200        }
201    }
202
203    /// Queue a message to be returned by [`Transport::recv`].
204    pub fn script_recv(&mut self, data: Vec<u8>) {
205        self.recv_queue.push_back(data);
206    }
207
208    /// Take all sent data (for assertions). Drains the internal buffer.
209    pub fn take_sent(&mut self) -> Vec<Vec<u8>> {
210        std::mem::take(&mut self.sent)
211    }
212}
213
214impl Default for MockTransport {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220/// Errors from the [`MockTransport`].
221#[derive(Debug, thiserror::Error)]
222pub enum MockTransportError {
223    /// Operation attempted while not connected.
224    #[error("Not connected")]
225    NotConnected,
226
227    /// A custom error injected for testing.
228    #[error("Mock error: {0}")]
229    Custom(String),
230}
231
232#[async_trait]
233impl Transport for MockTransport {
234    type Error = MockTransportError;
235
236    async fn connect(
237        &mut self,
238        _url: &str,
239        _headers: Vec<(String, String)>,
240    ) -> Result<(), Self::Error> {
241        self.connected = true;
242        Ok(())
243    }
244
245    async fn send(&mut self, data: Vec<u8>) -> Result<(), Self::Error> {
246        if !self.connected {
247            return Err(MockTransportError::NotConnected);
248        }
249        self.sent.push(data);
250        Ok(())
251    }
252
253    async fn recv(&mut self) -> Result<Option<Vec<u8>>, Self::Error> {
254        if !self.connected {
255            return Err(MockTransportError::NotConnected);
256        }
257        // Yield to the scheduler so tests can observe intermediate states
258        // (phase transitions, events) before the next message is processed.
259        tokio::task::yield_now().await;
260
261        if let Some(data) = self.recv_queue.pop_front() {
262            return Ok(Some(data));
263        }
264
265        // Queue is empty: pend indefinitely, simulating a connected-but-idle
266        // transport waiting for the next message from the server.
267        // The connection loop uses `tokio::select!` so this future is dropped
268        // when a command (e.g., Disconnect) arrives on the command channel.
269        std::future::pending().await
270    }
271
272    async fn close(&mut self) -> Result<(), Self::Error> {
273        self.connected = false;
274        Ok(())
275    }
276}
277
278// ---------------------------------------------------------------------------
279// Tests
280// ---------------------------------------------------------------------------
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[tokio::test]
287    async fn mock_transport_round_trip() {
288        let mut transport = MockTransport::new();
289        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
290
291        transport
292            .connect("wss://example.com", vec![])
293            .await
294            .unwrap();
295        transport.send(b"hello".to_vec()).await.unwrap();
296        let data = transport.recv().await.unwrap();
297        assert!(data.is_some());
298        let text = String::from_utf8(data.unwrap()).unwrap();
299        assert!(text.contains("setupComplete"));
300    }
301
302    #[tokio::test]
303    async fn mock_transport_records_sent() {
304        let mut transport = MockTransport::new();
305        transport
306            .connect("wss://example.com", vec![])
307            .await
308            .unwrap();
309        transport.send(b"msg1".to_vec()).await.unwrap();
310        transport.send(b"msg2".to_vec()).await.unwrap();
311        let sent = transport.take_sent();
312        assert_eq!(sent.len(), 2);
313        assert_eq!(sent[0], b"msg1");
314    }
315
316    #[tokio::test]
317    async fn mock_transport_recv_pends_when_queue_empty() {
318        let mut transport = MockTransport::new();
319        transport
320            .connect("wss://example.com", vec![])
321            .await
322            .unwrap();
323        // recv() should pend when queue is empty (simulating idle transport)
324        let result =
325            tokio::time::timeout(std::time::Duration::from_millis(50), transport.recv()).await;
326        assert!(result.is_err(), "recv should pend when queue is empty");
327    }
328
329    #[tokio::test]
330    async fn mock_transport_recv_errors_when_not_connected() {
331        let mut transport = MockTransport::new();
332        // Not connected yet — recv should error
333        let result = transport.recv().await;
334        assert!(result.is_err());
335    }
336
337    #[tokio::test]
338    async fn mock_transport_not_connected_error() {
339        let mut transport = MockTransport::new();
340        let result = transport.send(b"data".to_vec()).await;
341        assert!(result.is_err());
342    }
343
344    #[test]
345    fn transport_trait_is_object_safe_check() {
346        // Transport has an associated type, so it's not directly object-safe
347        // but can be used as generic bounds. This test just verifies compilation.
348        fn _assert_transport<T: Transport>() {}
349        _assert_transport::<MockTransport>();
350    }
351}