gemini_genai_rs/transport/
replay.rs1use std::collections::VecDeque;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use tokio::sync::watch;
23
24use super::recording::{WireDirection, WireEntry};
25use super::ws::Transport;
26
27pub type OutboundFrames = Arc<parking_lot::Mutex<Vec<Vec<u8>>>>;
29
30#[derive(Debug, thiserror::Error)]
32pub enum ReplayTransportError {
33 #[error("Not connected")]
35 NotConnected,
36}
37
38#[derive(Clone)]
41pub struct ReplayControl {
42 gate_tx: Arc<watch::Sender<bool>>,
43 drained_rx: watch::Receiver<bool>,
44 outbound: OutboundFrames,
45}
46
47impl ReplayControl {
48 pub fn release(&self) {
52 let _ = self.gate_tx.send(true);
53 }
54
55 pub async fn drained(&self) {
60 let mut rx = self.drained_rx.clone();
61 while !*rx.borrow() {
62 if rx.changed().await.is_err() {
63 break;
64 }
65 }
66 }
67
68 pub fn outbound_frames(&self) -> Vec<Vec<u8>> {
70 self.outbound.lock().clone()
71 }
72}
73
74pub struct ReplayTransport {
77 inbound: VecDeque<Vec<u8>>,
78 ungated_prefix: usize,
79 delivered: usize,
80 gate_rx: watch::Receiver<bool>,
81 drained_tx: watch::Sender<bool>,
82 outbound: OutboundFrames,
83 connected: bool,
84}
85
86impl ReplayTransport {
87 pub fn from_frames(frames: Vec<Vec<u8>>) -> (Self, ReplayControl) {
92 let (gate_tx, gate_rx) = watch::channel(false);
93 let (drained_tx, drained_rx) = watch::channel(false);
94 let outbound: OutboundFrames = Arc::new(parking_lot::Mutex::new(Vec::new()));
95 let control = ReplayControl {
96 gate_tx: Arc::new(gate_tx),
97 drained_rx,
98 outbound: outbound.clone(),
99 };
100 (
101 Self {
102 inbound: frames.into(),
103 ungated_prefix: 1,
104 delivered: 0,
105 gate_rx,
106 drained_tx,
107 outbound,
108 connected: false,
109 },
110 control,
111 )
112 }
113
114 pub fn from_wire_log(entries: &[WireEntry]) -> (Self, ReplayControl) {
117 let frames = entries
118 .iter()
119 .filter(|e| e.dir == WireDirection::Inbound)
120 .map(|e| e.payload.clone())
121 .collect();
122 Self::from_frames(frames)
123 }
124
125 pub fn with_ungated_prefix(mut self, n: usize) -> Self {
128 self.ungated_prefix = n;
129 self
130 }
131}
132
133#[async_trait]
134impl Transport for ReplayTransport {
135 type Error = ReplayTransportError;
136
137 async fn connect(
138 &mut self,
139 _url: &str,
140 _headers: Vec<(String, String)>,
141 ) -> Result<(), Self::Error> {
142 self.connected = true;
143 Ok(())
144 }
145
146 async fn send(&mut self, data: Vec<u8>) -> Result<(), Self::Error> {
147 if !self.connected {
148 return Err(ReplayTransportError::NotConnected);
149 }
150 self.outbound.lock().push(data);
151 Ok(())
152 }
153
154 async fn recv(&mut self) -> Result<Option<Vec<u8>>, Self::Error> {
155 if !self.connected {
156 return Err(ReplayTransportError::NotConnected);
157 }
158 tokio::task::yield_now().await;
161
162 if self.inbound.is_empty() {
163 let _ = self.drained_tx.send(true);
164 std::future::pending::<()>().await;
167 unreachable!("pending() never resolves");
168 }
169
170 if self.delivered >= self.ungated_prefix {
171 let mut gate = self.gate_rx.clone();
172 while !*gate.borrow() {
173 if gate.changed().await.is_err() {
176 break;
177 }
178 }
179 }
180
181 let frame = self
182 .inbound
183 .pop_front()
184 .expect("checked non-empty inbound queue");
185 self.delivered += 1;
186 if self.inbound.is_empty() {
187 let _ = self.drained_tx.send(true);
188 }
189 Ok(Some(frame))
190 }
191
192 async fn close(&mut self) -> Result<(), Self::Error> {
193 self.connected = false;
194 Ok(())
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use std::time::Duration;
202
203 #[tokio::test]
204 async fn replay_delivers_first_frame_ungated_then_waits_for_release() {
205 let (mut transport, control) = ReplayTransport::from_frames(vec![
206 br#"{"setupComplete":{}}"#.to_vec(),
207 br#"{"serverContent":{"turnComplete":true}}"#.to_vec(),
208 ]);
209 transport.connect("replay://", vec![]).await.unwrap();
210
211 let first = transport.recv().await.unwrap().unwrap();
213 assert!(String::from_utf8(first).unwrap().contains("setupComplete"));
214
215 let gated = tokio::time::timeout(Duration::from_millis(50), transport.recv()).await;
217 assert!(gated.is_err(), "second frame should be gated");
218
219 control.release();
220 let second = transport.recv().await.unwrap().unwrap();
221 assert!(String::from_utf8(second).unwrap().contains("turnComplete"));
222
223 tokio::time::timeout(Duration::from_millis(100), control.drained())
225 .await
226 .expect("drained should be signalled");
227
228 let idle = tokio::time::timeout(Duration::from_millis(50), transport.recv()).await;
230 assert!(idle.is_err(), "recv should pend after drain");
231 }
232
233 #[tokio::test]
234 async fn replay_collects_outbound_frames() {
235 let (mut transport, control) =
236 ReplayTransport::from_frames(vec![br#"{"setupComplete":{}}"#.to_vec()]);
237 transport.connect("replay://", vec![]).await.unwrap();
238 transport.send(b"{\"setup\":{}}".to_vec()).await.unwrap();
239 transport
240 .send(b"{\"toolResponse\":{}}".to_vec())
241 .await
242 .unwrap();
243
244 let sent = control.outbound_frames();
245 assert_eq!(sent.len(), 2);
246 assert_eq!(sent[0], b"{\"setup\":{}}".to_vec());
247 }
248
249 #[tokio::test]
250 async fn replay_from_wire_log_keeps_inbound_only() {
251 let entries = vec![
252 WireEntry {
253 seq: 1,
254 dir: WireDirection::Outbound,
255 ts_ms: 1,
256 payload: b"{\"setup\":{}}".to_vec(),
257 },
258 WireEntry {
259 seq: 2,
260 dir: WireDirection::Inbound,
261 ts_ms: 2,
262 payload: br#"{"setupComplete":{}}"#.to_vec(),
263 },
264 ];
265 let (mut transport, _control) = ReplayTransport::from_wire_log(&entries);
266 transport.connect("replay://", vec![]).await.unwrap();
267 let first = transport.recv().await.unwrap().unwrap();
268 assert!(String::from_utf8(first).unwrap().contains("setupComplete"));
269 }
270
271 #[tokio::test]
272 async fn replay_errors_when_not_connected() {
273 let (mut transport, _control) = ReplayTransport::from_frames(vec![]);
274 assert!(transport.recv().await.is_err());
275 assert!(transport.send(vec![1]).await.is_err());
276 }
277}