gemini_genai_rs/transport/
ws.rs1use async_trait::async_trait;
9
10#[async_trait]
22pub trait Transport: Send + 'static {
23 type Error: std::error::Error + Send + Sync + 'static;
25
26 async fn connect(
28 &mut self,
29 url: &str,
30 headers: Vec<(String, String)>,
31 ) -> Result<(), Self::Error>;
32
33 async fn send(&mut self, data: Vec<u8>) -> Result<(), Self::Error>;
35
36 async fn recv(&mut self) -> Result<Option<Vec<u8>>, Self::Error>;
38
39 async fn close(&mut self) -> Result<(), Self::Error>;
41}
42
43use futures_util::{SinkExt, StreamExt};
48use tokio_tungstenite::tungstenite::client::IntoClientRequest;
49use tokio_tungstenite::tungstenite::Message;
50
51type WsStream =
52 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
53
54pub struct TungsteniteTransport {
56 ws_write: Option<futures_util::stream::SplitSink<WsStream, Message>>,
57 ws_read: Option<futures_util::stream::SplitStream<WsStream>>,
58}
59
60impl TungsteniteTransport {
61 pub fn new() -> Self {
63 Self {
64 ws_write: None,
65 ws_read: None,
66 }
67 }
68}
69
70impl Default for TungsteniteTransport {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[derive(Debug, thiserror::Error)]
78pub enum TungsteniteError {
79 #[error("Not connected")]
81 NotConnected,
82
83 #[error("WebSocket error: {0}")]
85 WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
86
87 #[error("Request error: {0}")]
89 Request(String),
90}
91
92#[async_trait]
93impl Transport for TungsteniteTransport {
94 type Error = TungsteniteError;
95
96 async fn connect(
97 &mut self,
98 url: &str,
99 headers: Vec<(String, String)>,
100 ) -> Result<(), Self::Error> {
101 let mut request = url
102 .into_client_request()
103 .map_err(|e| TungsteniteError::Request(e.to_string()))?;
104
105 for (name, value) in headers {
106 let header_name: tokio_tungstenite::tungstenite::http::HeaderName =
107 name.parse().map_err(
108 |e: tokio_tungstenite::tungstenite::http::header::InvalidHeaderName| {
109 TungsteniteError::Request(format!("invalid header name: {e}"))
110 },
111 )?;
112 let header_value: tokio_tungstenite::tungstenite::http::HeaderValue =
113 value.parse().map_err(
114 |e: tokio_tungstenite::tungstenite::http::header::InvalidHeaderValue| {
115 TungsteniteError::Request(format!("invalid header value: {e}"))
116 },
117 )?;
118 request.headers_mut().insert(header_name, header_value);
119 }
120
121 let (ws_stream, _response) = tokio_tungstenite::connect_async(request).await?;
122 let (ws_write, ws_read) = ws_stream.split();
123 self.ws_write = Some(ws_write);
124 self.ws_read = Some(ws_read);
125 Ok(())
126 }
127
128 async fn send(&mut self, data: Vec<u8>) -> Result<(), Self::Error> {
129 let ws_write = self
130 .ws_write
131 .as_mut()
132 .ok_or(TungsteniteError::NotConnected)?;
133 let text = String::from_utf8(data)
135 .map_err(|e| TungsteniteError::Request(format!("invalid UTF-8: {e}")))?;
136 ws_write.send(Message::Text(text)).await?;
137 Ok(())
138 }
139
140 async fn recv(&mut self) -> Result<Option<Vec<u8>>, Self::Error> {
141 let ws_read = self
142 .ws_read
143 .as_mut()
144 .ok_or(TungsteniteError::NotConnected)?;
145 loop {
146 match ws_read.next().await {
147 Some(Ok(Message::Text(t))) => return Ok(Some(t.into_bytes())),
148 Some(Ok(Message::Binary(b))) => return Ok(Some(b)),
150 Some(Ok(Message::Close(frame))) => {
151 if let Some(ref cf) = frame {
152 tracing::warn!(code = %cf.code, reason = %cf.reason, "WebSocket close frame received");
153 }
154 return Ok(None);
155 }
156 Some(Ok(Message::Ping(_) | Message::Pong(_))) => continue,
158 Some(Ok(Message::Frame(_))) => continue,
160 Some(Err(e)) => return Err(TungsteniteError::WebSocket(e)),
161 None => return Ok(None),
162 }
163 }
164 }
165
166 async fn close(&mut self) -> Result<(), Self::Error> {
167 if let Some(ref mut ws_write) = self.ws_write {
168 ws_write.send(Message::Close(None)).await?;
169 }
170 self.ws_write = None;
171 self.ws_read = None;
172 Ok(())
173 }
174}
175
176pub struct MockTransport {
187 sent: Vec<Vec<u8>>,
188 recv_queue: std::collections::VecDeque<Vec<u8>>,
189 connected: bool,
191}
192
193impl MockTransport {
194 pub fn new() -> Self {
196 Self {
197 sent: Vec::new(),
198 recv_queue: std::collections::VecDeque::new(),
199 connected: false,
200 }
201 }
202
203 pub fn script_recv(&mut self, data: Vec<u8>) {
205 self.recv_queue.push_back(data);
206 }
207
208 pub fn take_sent(&mut self) -> Vec<Vec<u8>> {
210 std::mem::take(&mut self.sent)
211 }
212}
213
214impl Default for MockTransport {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220#[derive(Debug, thiserror::Error)]
222pub enum MockTransportError {
223 #[error("Not connected")]
225 NotConnected,
226
227 #[error("Mock error: {0}")]
229 Custom(String),
230}
231
232#[async_trait]
233impl Transport for MockTransport {
234 type Error = MockTransportError;
235
236 async fn connect(
237 &mut self,
238 _url: &str,
239 _headers: Vec<(String, String)>,
240 ) -> Result<(), Self::Error> {
241 self.connected = true;
242 Ok(())
243 }
244
245 async fn send(&mut self, data: Vec<u8>) -> Result<(), Self::Error> {
246 if !self.connected {
247 return Err(MockTransportError::NotConnected);
248 }
249 self.sent.push(data);
250 Ok(())
251 }
252
253 async fn recv(&mut self) -> Result<Option<Vec<u8>>, Self::Error> {
254 if !self.connected {
255 return Err(MockTransportError::NotConnected);
256 }
257 tokio::task::yield_now().await;
260
261 if let Some(data) = self.recv_queue.pop_front() {
262 return Ok(Some(data));
263 }
264
265 std::future::pending().await
270 }
271
272 async fn close(&mut self) -> Result<(), Self::Error> {
273 self.connected = false;
274 Ok(())
275 }
276}
277
278#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[tokio::test]
287 async fn mock_transport_round_trip() {
288 let mut transport = MockTransport::new();
289 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
290
291 transport
292 .connect("wss://example.com", vec![])
293 .await
294 .unwrap();
295 transport.send(b"hello".to_vec()).await.unwrap();
296 let data = transport.recv().await.unwrap();
297 assert!(data.is_some());
298 let text = String::from_utf8(data.unwrap()).unwrap();
299 assert!(text.contains("setupComplete"));
300 }
301
302 #[tokio::test]
303 async fn mock_transport_records_sent() {
304 let mut transport = MockTransport::new();
305 transport
306 .connect("wss://example.com", vec![])
307 .await
308 .unwrap();
309 transport.send(b"msg1".to_vec()).await.unwrap();
310 transport.send(b"msg2".to_vec()).await.unwrap();
311 let sent = transport.take_sent();
312 assert_eq!(sent.len(), 2);
313 assert_eq!(sent[0], b"msg1");
314 }
315
316 #[tokio::test]
317 async fn mock_transport_recv_pends_when_queue_empty() {
318 let mut transport = MockTransport::new();
319 transport
320 .connect("wss://example.com", vec![])
321 .await
322 .unwrap();
323 let result =
325 tokio::time::timeout(std::time::Duration::from_millis(50), transport.recv()).await;
326 assert!(result.is_err(), "recv should pend when queue is empty");
327 }
328
329 #[tokio::test]
330 async fn mock_transport_recv_errors_when_not_connected() {
331 let mut transport = MockTransport::new();
332 let result = transport.recv().await;
334 assert!(result.is_err());
335 }
336
337 #[tokio::test]
338 async fn mock_transport_not_connected_error() {
339 let mut transport = MockTransport::new();
340 let result = transport.send(b"data".to_vec()).await;
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn transport_trait_is_object_safe_check() {
346 fn _assert_transport<T: Transport>() {}
349 _assert_transport::<MockTransport>();
350 }
351}