gemini_genai_rs/transport/
codec.rs1use base64::Engine;
4
5use crate::protocol::messages::*;
6use crate::protocol::types::*;
7use crate::session::SessionCommand;
8
9#[derive(Debug, thiserror::Error, Clone)]
11pub enum CodecError {
12 #[error("Serialization error: {0}")]
14 Serialize(String),
15 #[error("Deserialization error: {0}")]
17 Deserialize(String),
18 #[error("Invalid UTF-8")]
20 InvalidUtf8,
21}
22
23pub trait Codec: Send + Sync + 'static {
35 fn encode_setup(&self, config: &SessionConfig) -> Result<Vec<u8>, CodecError>;
37 fn encode_command(
39 &self,
40 cmd: &SessionCommand,
41 config: &SessionConfig,
42 ) -> Result<Vec<u8>, CodecError>;
43 fn decode_message(&self, data: &[u8]) -> Result<ServerMessage, CodecError>;
45}
46
47pub struct JsonCodec;
49
50impl Codec for JsonCodec {
51 fn encode_setup(&self, config: &SessionConfig) -> Result<Vec<u8>, CodecError> {
52 serde_json::to_vec(&config.to_setup_message())
53 .map_err(|e| CodecError::Serialize(e.to_string()))
54 }
55
56 fn encode_command(
57 &self,
58 cmd: &SessionCommand,
59 config: &SessionConfig,
60 ) -> Result<Vec<u8>, CodecError> {
61 match cmd {
62 SessionCommand::SendAudio(data) => {
63 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
64 let msg = RealtimeInputMessage {
65 realtime_input: RealtimeInputPayload {
66 media_chunks: Vec::new(),
67 audio: Some(Blob {
68 mime_type: config.input_audio_format.mime_type().to_string(),
69 data: encoded,
70 }),
71 video: None,
72 audio_stream_end: None,
73 text: None,
74 },
75 };
76 serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
77 }
78 SessionCommand::SendText(text) => {
79 let msg = ClientContentMessage {
80 client_content: ClientContentPayload {
81 turns: vec![Content::user(text)],
82 turn_complete: Some(true),
83 },
84 };
85 serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
86 }
87 SessionCommand::SendToolResponse(responses) => {
88 let function_responses = if config.supports_async_tools() {
89 responses.clone()
90 } else {
91 responses
92 .iter()
93 .map(|r| {
94 let mut r = r.clone();
95 r.scheduling = None;
96 r
97 })
98 .collect()
99 };
100 let msg = ToolResponseMessage {
101 tool_response: ToolResponsePayload { function_responses },
102 };
103 serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
104 }
105 SessionCommand::ActivityStart => {
106 let msg = ActivitySignalMessage {
107 realtime_input: ActivitySignalPayload {
108 activity_start: Some(ActivityStart {}),
109 activity_end: None,
110 },
111 };
112 serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
113 }
114 SessionCommand::ActivityEnd => {
115 let msg = ActivitySignalMessage {
116 realtime_input: ActivitySignalPayload {
117 activity_start: None,
118 activity_end: Some(ActivityEnd {}),
119 },
120 };
121 serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
122 }
123 SessionCommand::SendClientContent {
124 turns,
125 turn_complete,
126 } => {
127 let msg = ClientContentMessage {
128 client_content: ClientContentPayload {
129 turns: turns.clone(),
130 turn_complete: Some(*turn_complete),
131 },
132 };
133 serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
134 }
135 SessionCommand::SendVideo(data) => {
136 let encoded = base64::engine::general_purpose::STANDARD.encode(data);
137 let msg = RealtimeInputMessage {
138 realtime_input: RealtimeInputPayload {
139 media_chunks: Vec::new(),
140 audio: None,
141 video: Some(Blob {
142 mime_type: "image/jpeg".to_string(),
143 data: encoded,
144 }),
145 audio_stream_end: None,
146 text: None,
147 },
148 };
149 serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
150 }
151 SessionCommand::UpdateInstruction(instruction) => {
152 let msg = ClientContentMessage {
153 client_content: ClientContentPayload {
154 turns: vec![Content {
155 role: Some(Role::System),
156 parts: vec![Part::Text {
157 text: instruction.clone(),
158 }],
159 }],
160 turn_complete: Some(false),
161 },
162 };
163 serde_json::to_vec(&msg).map_err(|e| CodecError::Serialize(e.to_string()))
164 }
165 SessionCommand::Disconnect => Ok(Vec::new()),
166 }
167 }
168
169 fn decode_message(&self, data: &[u8]) -> Result<ServerMessage, CodecError> {
170 let text = std::str::from_utf8(data).map_err(|_| CodecError::InvalidUtf8)?;
171 ServerMessage::parse(text).map_err(|e| CodecError::Deserialize(e.to_string()))
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 fn test_config() -> SessionConfig {
180 SessionConfig::new("test-key")
181 .model(GeminiModel::Gemini2_0FlashLive)
182 .voice(Voice::Puck)
183 }
184
185 #[test]
190 fn json_codec_encode_setup() {
191 let codec = JsonCodec;
192 let config = test_config();
193 let bytes = codec.encode_setup(&config).unwrap();
194 let json = String::from_utf8(bytes).unwrap();
195 assert!(json.contains("\"setup\""), "should contain setup key");
196 assert!(
197 json.contains("gemini-2.0-flash-live-001"),
198 "should contain model name"
199 );
200 }
201
202 #[test]
203 fn json_codec_encode_send_text() {
204 let codec = JsonCodec;
205 let config = test_config();
206 let cmd = SessionCommand::SendText("Hello, world!".to_string());
207 let bytes = codec.encode_command(&cmd, &config).unwrap();
208 let json = String::from_utf8(bytes).unwrap();
209 assert!(
210 json.contains("\"clientContent\""),
211 "should contain clientContent"
212 );
213 assert!(
214 json.contains("Hello, world!"),
215 "should contain the text payload"
216 );
217 assert!(
218 json.contains("\"turnComplete\""),
219 "should contain turnComplete"
220 );
221 }
222
223 #[test]
224 fn json_codec_encode_send_audio() {
225 let codec = JsonCodec;
226 let config = test_config();
227 let audio_data = vec![1u8, 2, 3, 4];
228 let cmd = SessionCommand::SendAudio(audio_data);
229 let bytes = codec.encode_command(&cmd, &config).unwrap();
230 let json = String::from_utf8(bytes).unwrap();
231 assert!(
232 json.contains("\"realtimeInput\""),
233 "should contain realtimeInput"
234 );
235 assert!(json.contains("\"audio\""), "should contain audio field");
236 assert!(
237 json.contains("audio/pcm"),
238 "should contain the audio mime type"
239 );
240 assert!(
242 json.contains("AQIDBA=="),
243 "should contain base64-encoded data"
244 );
245 }
246
247 #[test]
248 fn json_codec_encode_tool_response() {
249 let codec = JsonCodec;
250 let config = test_config();
251 let cmd = SessionCommand::SendToolResponse(vec![FunctionResponse {
252 name: "get_weather".to_string(),
253 response: serde_json::json!({"temp": 22}),
254 id: Some("call-1".to_string()),
255 scheduling: None,
256 }]);
257 let bytes = codec.encode_command(&cmd, &config).unwrap();
258 let json = String::from_utf8(bytes).unwrap();
259 assert!(
260 json.contains("\"toolResponse\""),
261 "should contain toolResponse"
262 );
263 assert!(
264 json.contains("\"functionResponses\""),
265 "should contain functionResponses"
266 );
267 assert!(
268 json.contains("get_weather"),
269 "should contain the function name"
270 );
271 }
272
273 #[test]
274 fn json_codec_strips_scheduling_for_vertex() {
275 let codec = JsonCodec;
276 let config = SessionConfig::from_vertex("proj", "us-central1", "token")
277 .model(GeminiModel::Gemini2_0FlashLive);
278 let cmd = SessionCommand::SendToolResponse(vec![FunctionResponse {
279 name: "search".to_string(),
280 response: serde_json::json!({"ok": true}),
281 id: Some("call-1".to_string()),
282 scheduling: Some(FunctionResponseScheduling::WhenIdle),
283 }]);
284 let bytes = codec.encode_command(&cmd, &config).unwrap();
285 let json = String::from_utf8(bytes).unwrap();
286 assert!(
287 !json.contains("scheduling"),
288 "Vertex AI should strip scheduling from tool responses"
289 );
290 }
291
292 #[test]
293 fn json_codec_preserves_scheduling_for_google_ai() {
294 let codec = JsonCodec;
295 let config = test_config();
296 let cmd = SessionCommand::SendToolResponse(vec![FunctionResponse {
297 name: "search".to_string(),
298 response: serde_json::json!({"ok": true}),
299 id: Some("call-1".to_string()),
300 scheduling: Some(FunctionResponseScheduling::WhenIdle),
301 }]);
302 let bytes = codec.encode_command(&cmd, &config).unwrap();
303 let json = String::from_utf8(bytes).unwrap();
304 assert!(
305 json.contains("WHEN_IDLE"),
306 "Google AI should preserve scheduling in tool responses"
307 );
308 }
309
310 #[test]
311 fn json_codec_encode_activity_start() {
312 let codec = JsonCodec;
313 let config = test_config();
314 let cmd = SessionCommand::ActivityStart;
315 let bytes = codec.encode_command(&cmd, &config).unwrap();
316 let json = String::from_utf8(bytes).unwrap();
317 assert!(
318 json.contains("\"activityStart\""),
319 "should contain activityStart"
320 );
321 assert!(
322 !json.contains("\"activityEnd\""),
323 "should not contain activityEnd"
324 );
325 }
326
327 #[test]
328 fn json_codec_encode_activity_end() {
329 let codec = JsonCodec;
330 let config = test_config();
331 let cmd = SessionCommand::ActivityEnd;
332 let bytes = codec.encode_command(&cmd, &config).unwrap();
333 let json = String::from_utf8(bytes).unwrap();
334 assert!(
335 json.contains("\"activityEnd\""),
336 "should contain activityEnd"
337 );
338 assert!(
339 !json.contains("\"activityStart\""),
340 "should not contain activityStart"
341 );
342 }
343
344 #[test]
345 fn json_codec_encode_client_content() {
346 let codec = JsonCodec;
347 let config = test_config();
348 let cmd = SessionCommand::SendClientContent {
349 turns: vec![Content::user("context message")],
350 turn_complete: false,
351 };
352 let bytes = codec.encode_command(&cmd, &config).unwrap();
353 let json = String::from_utf8(bytes).unwrap();
354 assert!(
355 json.contains("\"clientContent\""),
356 "should contain clientContent"
357 );
358 assert!(
359 json.contains("context message"),
360 "should contain the text content"
361 );
362 assert!(
363 json.contains("\"turnComplete\":false"),
364 "should contain turnComplete set to false"
365 );
366 }
367
368 #[test]
369 fn json_codec_encode_send_video() {
370 let codec = JsonCodec;
371 let config = test_config();
372 let cmd = SessionCommand::SendVideo(vec![0xFF, 0xD8, 0xFF]); let bytes = codec.encode_command(&cmd, &config).unwrap();
374 let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
375 assert_eq!(
376 json["realtimeInput"]["video"]["mimeType"].as_str().unwrap(),
377 "image/jpeg"
378 );
379 assert!(json["realtimeInput"]["video"]["data"].is_string());
380 }
381
382 #[test]
383 fn json_codec_encode_update_instruction() {
384 let codec = JsonCodec;
385 let config = test_config();
386 let cmd = SessionCommand::UpdateInstruction("New instruction".into());
387 let bytes = codec.encode_command(&cmd, &config).unwrap();
388 let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
389 let turns = &json["clientContent"]["turns"];
390 assert_eq!(turns[0]["role"], "system");
391 assert_eq!(turns[0]["parts"][0]["text"], "New instruction");
392 }
393
394 #[test]
395 fn json_codec_encode_disconnect() {
396 let codec = JsonCodec;
397 let config = test_config();
398 let cmd = SessionCommand::Disconnect;
399 let bytes = codec.encode_command(&cmd, &config).unwrap();
400 assert!(bytes.is_empty(), "Disconnect should produce empty bytes");
401 }
402
403 #[test]
408 fn json_codec_decode_setup_complete() {
409 let codec = JsonCodec;
410 let json = r#"{"setupComplete":{"sessionResumption":{"handle":"abc123"}}}"#;
411 let msg = codec.decode_message(json.as_bytes()).unwrap();
412 match msg {
413 ServerMessage::SetupComplete(sc) => {
414 let handle = sc.setup_complete.session_resumption.unwrap().handle;
415 assert_eq!(handle, Some("abc123".to_string()));
416 }
417 _ => panic!("Expected SetupComplete"),
418 }
419 }
420
421 #[test]
422 fn json_codec_decode_server_content() {
423 let codec = JsonCodec;
424 let json = r#"{
425 "serverContent": {
426 "modelTurn": {
427 "parts": [{"text": "Hello! How can I help?"}]
428 },
429 "turnComplete": true
430 }
431 }"#;
432 let msg = codec.decode_message(json.as_bytes()).unwrap();
433 match msg {
434 ServerMessage::ServerContent(sc) => {
435 assert!(sc.server_content.turn_complete.unwrap_or(false));
436 let turn = sc.server_content.model_turn.unwrap();
437 assert_eq!(turn.parts.len(), 1);
438 match &turn.parts[0] {
439 Part::Text { text } => assert_eq!(text, "Hello! How can I help?"),
440 _ => panic!("Expected text part"),
441 }
442 }
443 _ => panic!("Expected ServerContent"),
444 }
445 }
446
447 #[test]
448 fn json_codec_decode_tool_call() {
449 let codec = JsonCodec;
450 let json = r#"{
451 "toolCall": {
452 "functionCalls": [
453 {"name": "get_weather", "args": {"city": "London"}, "id": "call-1"}
454 ]
455 }
456 }"#;
457 let msg = codec.decode_message(json.as_bytes()).unwrap();
458 match msg {
459 ServerMessage::ToolCall(tc) => {
460 assert_eq!(tc.tool_call.function_calls.len(), 1);
461 assert_eq!(tc.tool_call.function_calls[0].name, "get_weather");
462 }
463 _ => panic!("Expected ToolCall"),
464 }
465 }
466
467 #[test]
468 fn json_codec_decode_invalid_utf8() {
469 let codec = JsonCodec;
470 let bad_bytes: &[u8] = &[0xFF, 0xFE, 0xFD];
471 let result = codec.decode_message(bad_bytes);
472 match result {
473 Err(CodecError::InvalidUtf8) => {} other => panic!("Expected CodecError::InvalidUtf8, got {:?}", other),
475 }
476 }
477}