gemini_genai_rs/transport/
replay.rs

1//! Replay transport — feed a recorded wire log back through the session loop.
2//!
3//! [`ReplayTransport`] implements [`Transport`] over a recorded wire log (see
4//! [`crate::transport::recording`]): `recv()` yields the recorded **inbound**
5//! frames in order (as fast as the session loop consumes them), and `send()`
6//! collects outbound frames for later comparison instead of touching a network.
7//!
8//! Because the session loop broadcasts events as soon as frames arrive, a
9//! replay that starts streaming before the application has subscribed would
10//! lose events nondeterministically. The transport is therefore *gated*: the
11//! first `ungated_prefix` frames (default 1 — the `setupComplete` handshake)
12//! are delivered immediately so the connection can reach `Active`, and the
13//! rest are held until [`ReplayControl::release`] is called. Once the inbound
14//! queue is exhausted the [`ReplayControl::drained`] signal fires and `recv()`
15//! pends (like [`MockTransport`](super::ws::MockTransport)), keeping the
16//! session alive until it is disconnected.
17
18use std::collections::VecDeque;
19use std::sync::Arc;
20
21use async_trait::async_trait;
22use tokio::sync::watch;
23
24use super::recording::{WireDirection, WireEntry};
25use super::ws::Transport;
26
27/// Shared collection of frames "sent" during a replay.
28pub type OutboundFrames = Arc<parking_lot::Mutex<Vec<Vec<u8>>>>;
29
30/// Errors from the [`ReplayTransport`].
31#[derive(Debug, thiserror::Error)]
32pub enum ReplayTransportError {
33    /// Operation attempted while not connected.
34    #[error("Not connected")]
35    NotConnected,
36}
37
38/// Control handle for a [`ReplayTransport`] that has been moved into a
39/// session loop.
40#[derive(Clone)]
41pub struct ReplayControl {
42    gate_tx: Arc<watch::Sender<bool>>,
43    drained_rx: watch::Receiver<bool>,
44    outbound: OutboundFrames,
45}
46
47impl ReplayControl {
48    /// Release the gated frames: inbound replay starts flowing.
49    ///
50    /// Call this after subscribing to the session's events so none are lost.
51    pub fn release(&self) {
52        let _ = self.gate_tx.send(true);
53    }
54
55    /// Wait until every recorded inbound frame has been handed to the session
56    /// loop. Note: the *last* frame may still be in flight through the
57    /// processor when this returns — wait for its observable effects (events,
58    /// state) before asserting.
59    pub async fn drained(&self) {
60        let mut rx = self.drained_rx.clone();
61        while !*rx.borrow() {
62            if rx.changed().await.is_err() {
63                break;
64            }
65        }
66    }
67
68    /// Snapshot the outbound frames collected so far (in send order).
69    pub fn outbound_frames(&self) -> Vec<Vec<u8>> {
70        self.outbound.lock().clone()
71    }
72}
73
74/// A [`Transport`] that replays recorded inbound frames and collects outbound
75/// frames. See the [module docs](self) for gating and drain semantics.
76pub struct ReplayTransport {
77    inbound: VecDeque<Vec<u8>>,
78    ungated_prefix: usize,
79    delivered: usize,
80    gate_rx: watch::Receiver<bool>,
81    drained_tx: watch::Sender<bool>,
82    outbound: OutboundFrames,
83    connected: bool,
84}
85
86impl ReplayTransport {
87    /// Build a replay transport from raw inbound frames.
88    ///
89    /// The first frame should be the `setupComplete` handshake; it is
90    /// delivered ungated so the connection can reach `Active`.
91    pub fn from_frames(frames: Vec<Vec<u8>>) -> (Self, ReplayControl) {
92        let (gate_tx, gate_rx) = watch::channel(false);
93        let (drained_tx, drained_rx) = watch::channel(false);
94        let outbound: OutboundFrames = Arc::new(parking_lot::Mutex::new(Vec::new()));
95        let control = ReplayControl {
96            gate_tx: Arc::new(gate_tx),
97            drained_rx,
98            outbound: outbound.clone(),
99        };
100        (
101            Self {
102                inbound: frames.into(),
103                ungated_prefix: 1,
104                delivered: 0,
105                gate_rx,
106                drained_tx,
107                outbound,
108                connected: false,
109            },
110            control,
111        )
112    }
113
114    /// Build a replay transport from a recorded wire log, keeping only the
115    /// [`WireDirection::Inbound`] entries (in log order).
116    pub fn from_wire_log(entries: &[WireEntry]) -> (Self, ReplayControl) {
117        let frames = entries
118            .iter()
119            .filter(|e| e.dir == WireDirection::Inbound)
120            .map(|e| e.payload.clone())
121            .collect();
122        Self::from_frames(frames)
123    }
124
125    /// Override how many leading frames are delivered before
126    /// [`ReplayControl::release`] (default 1: the setup handshake).
127    pub fn with_ungated_prefix(mut self, n: usize) -> Self {
128        self.ungated_prefix = n;
129        self
130    }
131}
132
133#[async_trait]
134impl Transport for ReplayTransport {
135    type Error = ReplayTransportError;
136
137    async fn connect(
138        &mut self,
139        _url: &str,
140        _headers: Vec<(String, String)>,
141    ) -> Result<(), Self::Error> {
142        self.connected = true;
143        Ok(())
144    }
145
146    async fn send(&mut self, data: Vec<u8>) -> Result<(), Self::Error> {
147        if !self.connected {
148            return Err(ReplayTransportError::NotConnected);
149        }
150        self.outbound.lock().push(data);
151        Ok(())
152    }
153
154    async fn recv(&mut self) -> Result<Option<Vec<u8>>, Self::Error> {
155        if !self.connected {
156            return Err(ReplayTransportError::NotConnected);
157        }
158        // Yield so observers can see intermediate states between frames
159        // (mirrors MockTransport).
160        tokio::task::yield_now().await;
161
162        if self.inbound.is_empty() {
163            let _ = self.drained_tx.send(true);
164            // Stay connected-but-idle; the session loop's `select!` drops this
165            // future when a command (e.g. Disconnect) arrives.
166            std::future::pending::<()>().await;
167            unreachable!("pending() never resolves");
168        }
169
170        if self.delivered >= self.ungated_prefix {
171            let mut gate = self.gate_rx.clone();
172            while !*gate.borrow() {
173                // If the control handle is dropped, proceed ungated rather
174                // than deadlocking the replay.
175                if gate.changed().await.is_err() {
176                    break;
177                }
178            }
179        }
180
181        let frame = self
182            .inbound
183            .pop_front()
184            .expect("checked non-empty inbound queue");
185        self.delivered += 1;
186        if self.inbound.is_empty() {
187            let _ = self.drained_tx.send(true);
188        }
189        Ok(Some(frame))
190    }
191
192    async fn close(&mut self) -> Result<(), Self::Error> {
193        self.connected = false;
194        Ok(())
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use std::time::Duration;
202
203    #[tokio::test]
204    async fn replay_delivers_first_frame_ungated_then_waits_for_release() {
205        let (mut transport, control) = ReplayTransport::from_frames(vec![
206            br#"{"setupComplete":{}}"#.to_vec(),
207            br#"{"serverContent":{"turnComplete":true}}"#.to_vec(),
208        ]);
209        transport.connect("replay://", vec![]).await.unwrap();
210
211        // First frame (handshake) flows without release.
212        let first = transport.recv().await.unwrap().unwrap();
213        assert!(String::from_utf8(first).unwrap().contains("setupComplete"));
214
215        // Second frame is gated.
216        let gated = tokio::time::timeout(Duration::from_millis(50), transport.recv()).await;
217        assert!(gated.is_err(), "second frame should be gated");
218
219        control.release();
220        let second = transport.recv().await.unwrap().unwrap();
221        assert!(String::from_utf8(second).unwrap().contains("turnComplete"));
222
223        // Drained fires once the queue is exhausted.
224        tokio::time::timeout(Duration::from_millis(100), control.drained())
225            .await
226            .expect("drained should be signalled");
227
228        // And recv() pends from then on.
229        let idle = tokio::time::timeout(Duration::from_millis(50), transport.recv()).await;
230        assert!(idle.is_err(), "recv should pend after drain");
231    }
232
233    #[tokio::test]
234    async fn replay_collects_outbound_frames() {
235        let (mut transport, control) =
236            ReplayTransport::from_frames(vec![br#"{"setupComplete":{}}"#.to_vec()]);
237        transport.connect("replay://", vec![]).await.unwrap();
238        transport.send(b"{\"setup\":{}}".to_vec()).await.unwrap();
239        transport
240            .send(b"{\"toolResponse\":{}}".to_vec())
241            .await
242            .unwrap();
243
244        let sent = control.outbound_frames();
245        assert_eq!(sent.len(), 2);
246        assert_eq!(sent[0], b"{\"setup\":{}}".to_vec());
247    }
248
249    #[tokio::test]
250    async fn replay_from_wire_log_keeps_inbound_only() {
251        let entries = vec![
252            WireEntry {
253                seq: 1,
254                dir: WireDirection::Outbound,
255                ts_ms: 1,
256                payload: b"{\"setup\":{}}".to_vec(),
257            },
258            WireEntry {
259                seq: 2,
260                dir: WireDirection::Inbound,
261                ts_ms: 2,
262                payload: br#"{"setupComplete":{}}"#.to_vec(),
263            },
264        ];
265        let (mut transport, _control) = ReplayTransport::from_wire_log(&entries);
266        transport.connect("replay://", vec![]).await.unwrap();
267        let first = transport.recv().await.unwrap().unwrap();
268        assert!(String::from_utf8(first).unwrap().contains("setupComplete"));
269    }
270
271    #[tokio::test]
272    async fn replay_errors_when_not_connected() {
273        let (mut transport, _control) = ReplayTransport::from_frames(vec![]);
274        assert!(transport.recv().await.is_err());
275        assert!(transport.send(vec![1]).await.is_err());
276    }
277}