gemini_genai_rs/transport/
codec.rs

1//! Message codec — encode commands, decode server messages.
2
3use base64::Engine;
4
5use crate::protocol::messages::*;
6use crate::protocol::types::*;
7use crate::session::SessionCommand;
8
9/// Error during encoding or decoding.
10#[derive(Debug, thiserror::Error, Clone)]
11pub enum CodecError {
12    /// Failed to serialize a client message to JSON.
13    #[error("Serialization error: {0}")]
14    Serialize(String),
15    /// Failed to deserialize a server message from JSON.
16    #[error("Deserialization error: {0}")]
17    Deserialize(String),
18    /// Server sent bytes that are not valid UTF-8.
19    #[error("Invalid UTF-8")]
20    InvalidUtf8,
21}
22
23/// Encodes client commands into wire bytes and decodes server bytes into messages.
24///
25/// The default implementation is [`JsonCodec`], which serializes commands as JSON
26/// and parses server responses via [`ServerMessage::parse`].
27///
28/// # Implementors
29///
30/// - [`JsonCodec`] -- Standard JSON codec. Encodes setup messages, audio (base64),
31///   text, tool responses, and activity signals. Decodes server JSON into
32///   [`ServerMessage`] variants. Handles platform-specific wire stripping
33///   (e.g., removing `scheduling` fields for Vertex AI).
34pub trait Codec: Send + Sync + 'static {
35    /// Encode the initial setup message for the given session configuration.
36    fn encode_setup(&self, config: &SessionConfig) -> Result<Vec<u8>, CodecError>;
37    /// Encode a session command into wire bytes.
38    fn encode_command(
39        &self,
40        cmd: &SessionCommand,
41        config: &SessionConfig,
42    ) -> Result<Vec<u8>, CodecError>;
43    /// Decode raw bytes from the server into a `ServerMessage`.
44    fn decode_message(&self, data: &[u8]) -> Result<ServerMessage, CodecError>;
45}
46
47/// Default JSON codec — current behavior extracted from connection.rs.
48pub struct JsonCodec;
49
50impl Codec for JsonCodec {
51    fn encode_setup(&self, config: &SessionConfig) -> Result<Vec<u8>, CodecError> {
52        serde_json::to_vec(&config.to_setup_message())
53            .map_err(|e| CodecError::Serialize(e.to_string()))
54    }
55
56    fn encode_command(
57        &self,
58        cmd: &SessionCommand,
59        config: &SessionConfig,
60    ) -> Result<Vec<u8>, CodecError> {
61        match cmd {
62            SessionCommand::SendAudio(data) => {
63                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
64                let msg = RealtimeInputMessage {
65                    realtime_input: RealtimeInputPayload {
66                        media_chunks: Vec::new(),
67                        audio: Some(Blob {
68                            mime_type: config.input_audio_format.mime_type().to_string(),
69                            data: encoded,
70                        }),
71                        video: None,
72                        audio_stream_end: None,
73                        text: None,
74                    },
75                };
76                serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
77            }
78            SessionCommand::SendText(text) => {
79                let msg = ClientContentMessage {
80                    client_content: ClientContentPayload {
81                        turns: vec![Content::user(text)],
82                        turn_complete: Some(true),
83                    },
84                };
85                serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
86            }
87            SessionCommand::SendToolResponse(responses) => {
88                let function_responses = if config.supports_async_tools() {
89                    responses.clone()
90                } else {
91                    responses
92                        .iter()
93                        .map(|r| {
94                            let mut r = r.clone();
95                            r.scheduling = None;
96                            r
97                        })
98                        .collect()
99                };
100                let msg = ToolResponseMessage {
101                    tool_response: ToolResponsePayload { function_responses },
102                };
103                serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
104            }
105            SessionCommand::ActivityStart => {
106                let msg = ActivitySignalMessage {
107                    realtime_input: ActivitySignalPayload {
108                        activity_start: Some(ActivityStart {}),
109                        activity_end: None,
110                    },
111                };
112                serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
113            }
114            SessionCommand::ActivityEnd => {
115                let msg = ActivitySignalMessage {
116                    realtime_input: ActivitySignalPayload {
117                        activity_start: None,
118                        activity_end: Some(ActivityEnd {}),
119                    },
120                };
121                serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
122            }
123            SessionCommand::SendClientContent {
124                turns,
125                turn_complete,
126            } => {
127                let msg = ClientContentMessage {
128                    client_content: ClientContentPayload {
129                        turns: turns.clone(),
130                        turn_complete: Some(*turn_complete),
131                    },
132                };
133                serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
134            }
135            SessionCommand::SendVideo(data) => {
136                let encoded = base64::engine::general_purpose::STANDARD.encode(data);
137                let msg = RealtimeInputMessage {
138                    realtime_input: RealtimeInputPayload {
139                        media_chunks: Vec::new(),
140                        audio: None,
141                        video: Some(Blob {
142                            mime_type: "image/jpeg".to_string(),
143                            data: encoded,
144                        }),
145                        audio_stream_end: None,
146                        text: None,
147                    },
148                };
149                serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
150            }
151            SessionCommand::UpdateInstruction(instruction) => {
152                let msg = ClientContentMessage {
153                    client_content: ClientContentPayload {
154                        turns: vec![Content {
155                            role: Some(Role::System),
156                            parts: vec![Part::Text {
157                                text: instruction.clone(),
158                            }],
159                        }],
160                        turn_complete: Some(false),
161                    },
162                };
163                serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
164            }
165            SessionCommand::Disconnect => Ok(Vec::new()),
166        }
167    }
168
169    fn decode_message(&self, data: &[u8]) -> Result<ServerMessage, CodecError> {
170        let text = std::str::from_utf8(data).map_err(|_| CodecError::InvalidUtf8)?;
171        ServerMessage::parse(text).map_err(|e| CodecError::Deserialize(e.to_string()))
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    fn test_config() -> SessionConfig {
180        SessionConfig::new("test-key")
181            .model(GeminiModel::Gemini2_0FlashLive)
182            .voice(Voice::Puck)
183    }
184
185    // -----------------------------------------------------------------------
186    // Encode tests
187    // -----------------------------------------------------------------------
188
189    #[test]
190    fn json_codec_encode_setup() {
191        let codec = JsonCodec;
192        let config = test_config();
193        let bytes = codec.encode_setup(&config).unwrap();
194        let json = String::from_utf8(bytes).unwrap();
195        assert!(json.contains("\"setup\""), "should contain setup key");
196        assert!(
197            json.contains("gemini-2.0-flash-live-001"),
198            "should contain model name"
199        );
200    }
201
202    #[test]
203    fn json_codec_encode_send_text() {
204        let codec = JsonCodec;
205        let config = test_config();
206        let cmd = SessionCommand::SendText("Hello, world!".to_string());
207        let bytes = codec.encode_command(&cmd, &config).unwrap();
208        let json = String::from_utf8(bytes).unwrap();
209        assert!(
210            json.contains("\"clientContent\""),
211            "should contain clientContent"
212        );
213        assert!(
214            json.contains("Hello, world!"),
215            "should contain the text payload"
216        );
217        assert!(
218            json.contains("\"turnComplete\""),
219            "should contain turnComplete"
220        );
221    }
222
223    #[test]
224    fn json_codec_encode_send_audio() {
225        let codec = JsonCodec;
226        let config = test_config();
227        let audio_data = vec![1u8, 2, 3, 4];
228        let cmd = SessionCommand::SendAudio(audio_data);
229        let bytes = codec.encode_command(&cmd, &config).unwrap();
230        let json = String::from_utf8(bytes).unwrap();
231        assert!(
232            json.contains("\"realtimeInput\""),
233            "should contain realtimeInput"
234        );
235        assert!(json.contains("\"audio\""), "should contain audio field");
236        assert!(
237            json.contains("audio/pcm"),
238            "should contain the audio mime type"
239        );
240        // base64 of [1,2,3,4] is "AQIDBA=="
241        assert!(
242            json.contains("AQIDBA=="),
243            "should contain base64-encoded data"
244        );
245    }
246
247    #[test]
248    fn json_codec_encode_tool_response() {
249        let codec = JsonCodec;
250        let config = test_config();
251        let cmd = SessionCommand::SendToolResponse(vec![FunctionResponse {
252            name: "get_weather".to_string(),
253            response: serde_json::json!({"temp": 22}),
254            id: Some("call-1".to_string()),
255            scheduling: None,
256        }]);
257        let bytes = codec.encode_command(&cmd, &config).unwrap();
258        let json = String::from_utf8(bytes).unwrap();
259        assert!(
260            json.contains("\"toolResponse\""),
261            "should contain toolResponse"
262        );
263        assert!(
264            json.contains("\"functionResponses\""),
265            "should contain functionResponses"
266        );
267        assert!(
268            json.contains("get_weather"),
269            "should contain the function name"
270        );
271    }
272
273    #[test]
274    fn json_codec_strips_scheduling_for_vertex() {
275        let codec = JsonCodec;
276        let config = SessionConfig::from_vertex("proj", "us-central1", "token")
277            .model(GeminiModel::Gemini2_0FlashLive);
278        let cmd = SessionCommand::SendToolResponse(vec![FunctionResponse {
279            name: "search".to_string(),
280            response: serde_json::json!({"ok": true}),
281            id: Some("call-1".to_string()),
282            scheduling: Some(FunctionResponseScheduling::WhenIdle),
283        }]);
284        let bytes = codec.encode_command(&cmd, &config).unwrap();
285        let json = String::from_utf8(bytes).unwrap();
286        assert!(
287            !json.contains("scheduling"),
288            "Vertex AI should strip scheduling from tool responses"
289        );
290    }
291
292    #[test]
293    fn json_codec_preserves_scheduling_for_google_ai() {
294        let codec = JsonCodec;
295        let config = test_config();
296        let cmd = SessionCommand::SendToolResponse(vec![FunctionResponse {
297            name: "search".to_string(),
298            response: serde_json::json!({"ok": true}),
299            id: Some("call-1".to_string()),
300            scheduling: Some(FunctionResponseScheduling::WhenIdle),
301        }]);
302        let bytes = codec.encode_command(&cmd, &config).unwrap();
303        let json = String::from_utf8(bytes).unwrap();
304        assert!(
305            json.contains("WHEN_IDLE"),
306            "Google AI should preserve scheduling in tool responses"
307        );
308    }
309
310    #[test]
311    fn json_codec_encode_activity_start() {
312        let codec = JsonCodec;
313        let config = test_config();
314        let cmd = SessionCommand::ActivityStart;
315        let bytes = codec.encode_command(&cmd, &config).unwrap();
316        let json = String::from_utf8(bytes).unwrap();
317        assert!(
318            json.contains("\"activityStart\""),
319            "should contain activityStart"
320        );
321        assert!(
322            !json.contains("\"activityEnd\""),
323            "should not contain activityEnd"
324        );
325    }
326
327    #[test]
328    fn json_codec_encode_activity_end() {
329        let codec = JsonCodec;
330        let config = test_config();
331        let cmd = SessionCommand::ActivityEnd;
332        let bytes = codec.encode_command(&cmd, &config).unwrap();
333        let json = String::from_utf8(bytes).unwrap();
334        assert!(
335            json.contains("\"activityEnd\""),
336            "should contain activityEnd"
337        );
338        assert!(
339            !json.contains("\"activityStart\""),
340            "should not contain activityStart"
341        );
342    }
343
344    #[test]
345    fn json_codec_encode_client_content() {
346        let codec = JsonCodec;
347        let config = test_config();
348        let cmd = SessionCommand::SendClientContent {
349            turns: vec![Content::user("context message")],
350            turn_complete: false,
351        };
352        let bytes = codec.encode_command(&cmd, &config).unwrap();
353        let json = String::from_utf8(bytes).unwrap();
354        assert!(
355            json.contains("\"clientContent\""),
356            "should contain clientContent"
357        );
358        assert!(
359            json.contains("context message"),
360            "should contain the text content"
361        );
362        assert!(
363            json.contains("\"turnComplete\":false"),
364            "should contain turnComplete set to false"
365        );
366    }
367
368    #[test]
369    fn json_codec_encode_send_video() {
370        let codec = JsonCodec;
371        let config = test_config();
372        let cmd = SessionCommand::SendVideo(vec![0xFF, 0xD8, 0xFF]); // JPEG magic bytes
373        let bytes = codec.encode_command(&cmd, &config).unwrap();
374        let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
375        assert_eq!(
376            json["realtimeInput"]["video"]["mimeType"].as_str().unwrap(),
377            "image/jpeg"
378        );
379        assert!(json["realtimeInput"]["video"]["data"].is_string());
380    }
381
382    #[test]
383    fn json_codec_encode_update_instruction() {
384        let codec = JsonCodec;
385        let config = test_config();
386        let cmd = SessionCommand::UpdateInstruction("New instruction".into());
387        let bytes = codec.encode_command(&cmd, &config).unwrap();
388        let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
389        let turns = &json["clientContent"]["turns"];
390        assert_eq!(turns[0]["role"], "system");
391        assert_eq!(turns[0]["parts"][0]["text"], "New instruction");
392    }
393
394    #[test]
395    fn json_codec_encode_disconnect() {
396        let codec = JsonCodec;
397        let config = test_config();
398        let cmd = SessionCommand::Disconnect;
399        let bytes = codec.encode_command(&cmd, &config).unwrap();
400        assert!(bytes.is_empty(), "Disconnect should produce empty bytes");
401    }
402
403    // -----------------------------------------------------------------------
404    // Decode tests
405    // -----------------------------------------------------------------------
406
407    #[test]
408    fn json_codec_decode_setup_complete() {
409        let codec = JsonCodec;
410        let json = r#"{"setupComplete":{"sessionResumption":{"handle":"abc123"}}}"#;
411        let msg = codec.decode_message(json.as_bytes()).unwrap();
412        match msg {
413            ServerMessage::SetupComplete(sc) => {
414                let handle = sc.setup_complete.session_resumption.unwrap().handle;
415                assert_eq!(handle, Some("abc123".to_string()));
416            }
417            _ => panic!("Expected SetupComplete"),
418        }
419    }
420
421    #[test]
422    fn json_codec_decode_server_content() {
423        let codec = JsonCodec;
424        let json = r#"{
425            "serverContent": {
426                "modelTurn": {
427                    "parts": [{"text": "Hello! How can I help?"}]
428                },
429                "turnComplete": true
430            }
431        }"#;
432        let msg = codec.decode_message(json.as_bytes()).unwrap();
433        match msg {
434            ServerMessage::ServerContent(sc) => {
435                assert!(sc.server_content.turn_complete.unwrap_or(false));
436                let turn = sc.server_content.model_turn.unwrap();
437                assert_eq!(turn.parts.len(), 1);
438                match &turn.parts[0] {
439                    Part::Text { text } => assert_eq!(text, "Hello! How can I help?"),
440                    _ => panic!("Expected text part"),
441                }
442            }
443            _ => panic!("Expected ServerContent"),
444        }
445    }
446
447    #[test]
448    fn json_codec_decode_tool_call() {
449        let codec = JsonCodec;
450        let json = r#"{
451            "toolCall": {
452                "functionCalls": [
453                    {"name": "get_weather", "args": {"city": "London"}, "id": "call-1"}
454                ]
455            }
456        }"#;
457        let msg = codec.decode_message(json.as_bytes()).unwrap();
458        match msg {
459            ServerMessage::ToolCall(tc) => {
460                assert_eq!(tc.tool_call.function_calls.len(), 1);
461                assert_eq!(tc.tool_call.function_calls[0].name, "get_weather");
462            }
463            _ => panic!("Expected ToolCall"),
464        }
465    }
466
467    #[test]
468    fn json_codec_decode_invalid_utf8() {
469        let codec = JsonCodec;
470        let bad_bytes: &[u8] = &[0xFF, 0xFE, 0xFD];
471        let result = codec.decode_message(bad_bytes);
472        match result {
473            Err(CodecError::InvalidUtf8) => {} // expected
474            other => panic!("Expected CodecError::InvalidUtf8, got {:?}", other),
475        }
476    }
477}