gemini_genai_rs/transport/
recording.rs1use std::io::Write;
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::sync::Arc;
40use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
41
42use crate::protocol::messages::ServerMessage;
43use crate::protocol::types::SessionConfig;
44use crate::session::SessionCommand;
45
46use super::codec::{Codec, CodecError};
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
50pub enum WireDirection {
51 #[serde(rename = "out")]
53 Outbound,
54 #[serde(rename = "in")]
56 Inbound,
57}
58
59#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
61pub struct WireEntry {
62 pub seq: u64,
64 pub dir: WireDirection,
66 pub ts_ms: u64,
68 #[serde(rename = "payload_b64", with = "base64_bytes")]
70 pub payload: Vec<u8>,
71}
72
73mod base64_bytes {
75 use base64::Engine;
76 use serde::{Deserialize, Deserializer, Serializer};
77
78 pub fn serialize<S: Serializer>(bytes: &[u8], ser: S) -> Result<S::Ok, S::Error> {
79 ser.serialize_str(&base64::engine::general_purpose::STANDARD.encode(bytes))
80 }
81
82 pub fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result<Vec<u8>, D::Error> {
83 let s = String::deserialize(de)?;
84 base64::engine::general_purpose::STANDARD
85 .decode(s.as_bytes())
86 .map_err(serde::de::Error::custom)
87 }
88}
89
90pub trait WireRecorder: Send + Sync {
97 fn record(&self, entry: WireEntry);
99}
100
101#[derive(Clone)]
107pub struct WireRecorderHandle(Arc<dyn WireRecorder>);
108
109impl WireRecorderHandle {
110 pub fn new(recorder: Arc<dyn WireRecorder>) -> Self {
112 Self(recorder)
113 }
114
115 pub fn recorder(&self) -> Arc<dyn WireRecorder> {
117 self.0.clone()
118 }
119}
120
121impl From<Arc<dyn WireRecorder>> for WireRecorderHandle {
122 fn from(recorder: Arc<dyn WireRecorder>) -> Self {
123 Self::new(recorder)
124 }
125}
126
127impl std::fmt::Debug for WireRecorderHandle {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 f.write_str("WireRecorderHandle(..)")
130 }
131}
132
133fn epoch_millis() -> u64 {
134 SystemTime::now()
135 .duration_since(UNIX_EPOCH)
136 .map(|d| d.as_millis() as u64)
137 .unwrap_or(0)
138}
139
140pub struct RecordingCodec<C> {
156 inner: C,
157 recorder: Arc<dyn WireRecorder>,
158 seq: AtomicU64,
159}
160
161impl<C: Codec> RecordingCodec<C> {
162 pub fn new(inner: C, recorder: Arc<dyn WireRecorder>) -> Self {
164 Self {
165 inner,
166 recorder,
167 seq: AtomicU64::new(1),
168 }
169 }
170
171 fn tap(&self, dir: WireDirection, payload: &[u8]) {
172 let entry = WireEntry {
173 seq: self.seq.fetch_add(1, Ordering::Relaxed),
174 dir,
175 ts_ms: epoch_millis(),
176 payload: payload.to_vec(),
177 };
178 self.recorder.record(entry);
179 }
180}
181
182impl<C: Codec> Codec for RecordingCodec<C> {
183 fn encode_setup(&self, config: &SessionConfig) -> Result<Vec<u8>, CodecError> {
184 let bytes = self.inner.encode_setup(config)?;
185 if !bytes.is_empty() {
186 self.tap(WireDirection::Outbound, &bytes);
187 }
188 Ok(bytes)
189 }
190
191 fn encode_command(
192 &self,
193 cmd: &SessionCommand,
194 config: &SessionConfig,
195 ) -> Result<Vec<u8>, CodecError> {
196 let bytes = self.inner.encode_command(cmd, config)?;
197 if !bytes.is_empty() {
198 self.tap(WireDirection::Outbound, &bytes);
199 }
200 Ok(bytes)
201 }
202
203 fn decode_message(&self, data: &[u8]) -> Result<ServerMessage, CodecError> {
204 self.tap(WireDirection::Inbound, data);
205 self.inner.decode_message(data)
206 }
207}
208
209impl Codec for Box<dyn Codec> {
212 fn encode_setup(&self, config: &SessionConfig) -> Result<Vec<u8>, CodecError> {
213 (**self).encode_setup(config)
214 }
215
216 fn encode_command(
217 &self,
218 cmd: &SessionCommand,
219 config: &SessionConfig,
220 ) -> Result<Vec<u8>, CodecError> {
221 (**self).encode_command(cmd, config)
222 }
223
224 fn decode_message(&self, data: &[u8]) -> Result<ServerMessage, CodecError> {
225 (**self).decode_message(data)
226 }
227}
228
229const FILE_FLUSH_INTERVAL: Duration = Duration::from_secs(1);
234
235struct FileWireRecorderInner {
236 writer: std::io::BufWriter<std::fs::File>,
237 last_flush: Instant,
238}
239
240pub struct FileWireRecorder {
248 inner: parking_lot::Mutex<FileWireRecorderInner>,
249}
250
251impl FileWireRecorder {
252 pub fn create(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> {
254 let file = std::fs::File::create(path)?;
255 Ok(Self {
256 inner: parking_lot::Mutex::new(FileWireRecorderInner {
257 writer: std::io::BufWriter::new(file),
258 last_flush: Instant::now(),
259 }),
260 })
261 }
262
263 pub fn flush(&self) {
265 let mut inner = self.inner.lock();
266 if let Err(e) = inner.writer.flush() {
267 tracing::warn!(error = %e, "FileWireRecorder flush failed");
268 }
269 inner.last_flush = Instant::now();
270 }
271}
272
273impl WireRecorder for FileWireRecorder {
274 fn record(&self, entry: WireEntry) {
275 let line = match serde_json::to_string(&entry) {
276 Ok(line) => line,
277 Err(e) => {
278 tracing::warn!(error = %e, "FileWireRecorder serialize failed");
279 return;
280 }
281 };
282 let mut inner = self.inner.lock();
283 if let Err(e) = writeln!(inner.writer, "{line}") {
284 tracing::warn!(error = %e, "FileWireRecorder write failed");
285 return;
286 }
287 if inner.last_flush.elapsed() >= FILE_FLUSH_INTERVAL {
288 if let Err(e) = inner.writer.flush() {
289 tracing::warn!(error = %e, "FileWireRecorder flush failed");
290 }
291 inner.last_flush = Instant::now();
292 }
293 }
294}
295
296impl Drop for FileWireRecorder {
297 fn drop(&mut self) {
298 if let Err(e) = self.inner.lock().writer.flush() {
299 tracing::warn!(error = %e, "FileWireRecorder final flush failed");
300 }
301 }
302}
303
304#[derive(Default)]
310pub struct MemoryWireRecorder {
311 entries: parking_lot::Mutex<Vec<WireEntry>>,
312}
313
314impl MemoryWireRecorder {
315 pub fn new() -> Self {
317 Self::default()
318 }
319
320 pub fn entries(&self) -> Vec<WireEntry> {
322 self.entries.lock().clone()
323 }
324
325 pub fn len(&self) -> usize {
327 self.entries.lock().len()
328 }
329
330 pub fn is_empty(&self) -> bool {
332 self.entries.lock().is_empty()
333 }
334}
335
336impl WireRecorder for MemoryWireRecorder {
337 fn record(&self, entry: WireEntry) {
338 self.entries.lock().push(entry);
339 }
340}
341
342#[derive(Debug, thiserror::Error)]
348pub enum WireLogError {
349 #[error("failed to read wire log: {0}")]
351 Io(#[from] std::io::Error),
352 #[error("invalid wire log entry on line {line}: {source}")]
354 Parse {
355 line: usize,
357 source: serde_json::Error,
359 },
360}
361
362pub fn read_wire_log(path: impl AsRef<std::path::Path>) -> Result<Vec<WireEntry>, WireLogError> {
366 let data = std::fs::read_to_string(path)?;
367 parse_wire_log(&data)
368}
369
370pub fn parse_wire_log(data: &str) -> Result<Vec<WireEntry>, WireLogError> {
372 let mut entries = Vec::new();
373 for (idx, line) in data.lines().enumerate() {
374 let line = line.trim();
375 if line.is_empty() {
376 continue;
377 }
378 let entry: WireEntry =
379 serde_json::from_str(line).map_err(|source| WireLogError::Parse {
380 line: idx + 1,
381 source,
382 })?;
383 entries.push(entry);
384 }
385 Ok(entries)
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::protocol::types::GeminiModel;
392 use crate::transport::codec::JsonCodec;
393
394 fn test_config() -> SessionConfig {
395 SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive)
396 }
397
398 #[test]
399 fn recording_codec_taps_outbound_and_inbound() {
400 let recorder = Arc::new(MemoryWireRecorder::new());
401 let codec = RecordingCodec::new(JsonCodec, recorder.clone());
402 let config = test_config();
403
404 let setup = codec.encode_setup(&config).unwrap();
405 let cmd_bytes = codec
406 .encode_command(&SessionCommand::SendText("hi".into()), &config)
407 .unwrap();
408 let inbound = br#"{"setupComplete":{}}"#;
409 codec.decode_message(inbound).unwrap();
410
411 let entries = recorder.entries();
412 assert_eq!(entries.len(), 3);
413 assert_eq!(entries[0].seq, 1);
414 assert_eq!(entries[0].dir, WireDirection::Outbound);
415 assert_eq!(entries[0].payload, setup);
416 assert_eq!(entries[1].seq, 2);
417 assert_eq!(entries[1].dir, WireDirection::Outbound);
418 assert_eq!(entries[1].payload, cmd_bytes);
419 assert_eq!(entries[2].seq, 3);
420 assert_eq!(entries[2].dir, WireDirection::Inbound);
421 assert_eq!(entries[2].payload, inbound.to_vec());
422 assert!(entries.iter().all(|e| e.ts_ms > 0));
423 }
424
425 #[test]
426 fn recording_codec_skips_empty_encodes() {
427 let recorder = Arc::new(MemoryWireRecorder::new());
428 let codec = RecordingCodec::new(JsonCodec, recorder.clone());
429 let config = test_config();
430
431 let bytes = codec
433 .encode_command(&SessionCommand::Disconnect, &config)
434 .unwrap();
435 assert!(bytes.is_empty());
436 assert!(recorder.is_empty());
437 }
438
439 #[test]
440 fn recording_codec_records_undecodable_inbound() {
441 let recorder = Arc::new(MemoryWireRecorder::new());
442 let codec = RecordingCodec::new(JsonCodec, recorder.clone());
443
444 let bad: &[u8] = &[0xFF, 0xFE];
445 assert!(codec.decode_message(bad).is_err());
446 let entries = recorder.entries();
447 assert_eq!(entries.len(), 1);
448 assert_eq!(entries[0].dir, WireDirection::Inbound);
449 assert_eq!(entries[0].payload, bad.to_vec());
450 }
451
452 #[test]
453 fn wire_entry_jsonl_round_trip() {
454 let entry = WireEntry {
455 seq: 7,
456 dir: WireDirection::Inbound,
457 ts_ms: 1_718_000_000_123,
458 payload: br#"{"setupComplete":{}}"#.to_vec(),
459 };
460 let line = serde_json::to_string(&entry).unwrap();
461 assert!(line.contains("\"dir\":\"in\""));
462 assert!(line.contains("payload_b64"));
463 let parsed = parse_wire_log(&format!("{line}\n\n{line}")).unwrap();
464 assert_eq!(parsed.len(), 2);
465 assert_eq!(parsed[0], entry);
466 }
467
468 #[test]
469 fn file_wire_recorder_round_trip() {
470 let dir = std::env::temp_dir().join(format!(
471 "gemini-rs-wire-log-test-{}-{}",
472 std::process::id(),
473 epoch_millis()
474 ));
475 std::fs::create_dir_all(&dir).unwrap();
476 let path = dir.join("session.wire.jsonl");
477
478 {
479 let recorder = FileWireRecorder::create(&path).unwrap();
480 recorder.record(WireEntry {
481 seq: 1,
482 dir: WireDirection::Outbound,
483 ts_ms: 42,
484 payload: b"{\"setup\":{}}".to_vec(),
485 });
486 recorder.record(WireEntry {
487 seq: 2,
488 dir: WireDirection::Inbound,
489 ts_ms: 43,
490 payload: b"{\"setupComplete\":{}}".to_vec(),
491 });
492 }
494
495 let entries = read_wire_log(&path).unwrap();
496 assert_eq!(entries.len(), 2);
497 assert_eq!(entries[0].payload, b"{\"setup\":{}}".to_vec());
498 assert_eq!(entries[1].dir, WireDirection::Inbound);
499
500 let _ = std::fs::remove_dir_all(&dir);
501 }
502
503 #[test]
504 fn parse_wire_log_reports_bad_line() {
505 let err = parse_wire_log("not json").unwrap_err();
506 match err {
507 WireLogError::Parse { line, .. } => assert_eq!(line, 1),
508 other => panic!("expected Parse error, got {other:?}"),
509 }
510 }
511
512 #[test]
513 fn boxed_codec_forwards() {
514 let codec: Box<dyn Codec> = Box::new(JsonCodec);
515 let config = test_config();
516 let bytes = codec.encode_setup(&config).unwrap();
517 assert!(!bytes.is_empty());
518 assert!(codec.decode_message(br#"{"setupComplete":{}}"#).is_ok());
519 }
520}