gemini_genai_rs/protocol/messages/
mod.rs

1//! Client→Server and Server→Client message envelopes for the Gemini Live wire protocol.
2
3pub mod client;
4pub mod server;
5
6pub use client::*;
7pub use server::*;
8
9#[cfg(test)]
10mod tests {
11    use super::*;
12    use crate::protocol::types::*;
13
14    #[test]
15    fn setup_message_serialization() {
16        let config = SessionConfig::new("test-key")
17            .model(GeminiModel::Gemini2_0FlashLive)
18            .voice(Voice::Kore)
19            .system_instruction("You are a helpful assistant.");
20
21        let json = config.to_setup_json();
22        assert!(json.contains("\"setup\""));
23        assert!(json.contains("\"generationConfig\""));
24        assert!(json.contains("\"Kore\""));
25        assert!(json.contains("\"systemInstruction\""));
26    }
27
28    #[test]
29    fn parse_setup_complete() {
30        let json = r#"{"setupComplete":{"sessionResumption":{"handle":"abc123"}}}"#;
31        let msg = ServerMessage::parse(json).unwrap();
32        match msg {
33            ServerMessage::SetupComplete(sc) => {
34                let handle = sc.setup_complete.session_resumption.unwrap().handle;
35                assert_eq!(handle, Some("abc123".to_string()));
36            }
37            _ => panic!("Expected SetupComplete"),
38        }
39    }
40
41    #[test]
42    fn parse_server_content_text() {
43        let json = r#"{
44            "serverContent": {
45                "modelTurn": {
46                    "parts": [{"text": "Hello! How can I help?"}]
47                },
48                "turnComplete": true
49            }
50        }"#;
51        let msg = ServerMessage::parse(json).unwrap();
52        match msg {
53            ServerMessage::ServerContent(sc) => {
54                assert!(sc.server_content.turn_complete.unwrap_or(false));
55                let turn = sc.server_content.model_turn.unwrap();
56                assert_eq!(turn.parts.len(), 1);
57                match &turn.parts[0] {
58                    Part::Text { text } => assert_eq!(text, "Hello! How can I help?"),
59                    _ => panic!("Expected text part"),
60                }
61            }
62            _ => panic!("Expected ServerContent"),
63        }
64    }
65
66    #[test]
67    fn parse_server_content_audio() {
68        let json = r#"{
69            "serverContent": {
70                "modelTurn": {
71                    "parts": [{"inlineData": {"mimeType": "audio/pcm", "data": "AAAA"}}]
72                }
73            }
74        }"#;
75        let msg = ServerMessage::parse(json).unwrap();
76        match msg {
77            ServerMessage::ServerContent(sc) => {
78                let turn = sc.server_content.model_turn.unwrap();
79                match &turn.parts[0] {
80                    Part::InlineData { inline_data } => {
81                        assert_eq!(inline_data.mime_type, "audio/pcm");
82                    }
83                    _ => panic!("Expected inline data part"),
84                }
85            }
86            _ => panic!("Expected ServerContent"),
87        }
88    }
89
90    #[test]
91    fn parse_tool_call() {
92        let json = r#"{
93            "toolCall": {
94                "functionCalls": [
95                    {"name": "get_weather", "args": {"city": "London"}, "id": "call-1"}
96                ]
97            }
98        }"#;
99        let msg = ServerMessage::parse(json).unwrap();
100        match msg {
101            ServerMessage::ToolCall(tc) => {
102                assert_eq!(tc.tool_call.function_calls.len(), 1);
103                assert_eq!(tc.tool_call.function_calls[0].name, "get_weather");
104            }
105            _ => panic!("Expected ToolCall"),
106        }
107    }
108
109    #[test]
110    fn parse_tool_call_cancellation() {
111        let json = r#"{"toolCallCancellation": {"ids": ["call-1", "call-2"]}}"#;
112        let msg = ServerMessage::parse(json).unwrap();
113        match msg {
114            ServerMessage::ToolCallCancellation(tc) => {
115                assert_eq!(tc.tool_call_cancellation.ids, vec!["call-1", "call-2"]);
116            }
117            _ => panic!("Expected ToolCallCancellation"),
118        }
119    }
120
121    #[test]
122    fn parse_go_away() {
123        let json = r#"{"goAway": {"timeLeft": "30s"}}"#;
124        let msg = ServerMessage::parse(json).unwrap();
125        match msg {
126            ServerMessage::GoAway(ga) => {
127                assert_eq!(ga.go_away.time_left, Some("30s".to_string()));
128            }
129            _ => panic!("Expected GoAway"),
130        }
131    }
132
133    #[test]
134    fn parse_interrupted() {
135        let json = r#"{"serverContent": {"interrupted": true}}"#;
136        let msg = ServerMessage::parse(json).unwrap();
137        match msg {
138            ServerMessage::ServerContent(sc) => {
139                assert!(sc.server_content.interrupted.unwrap_or(false));
140            }
141            _ => panic!("Expected ServerContent"),
142        }
143    }
144
145    #[test]
146    fn parse_unknown_message() {
147        let json = r#"{"newFeature": {"value": 42}}"#;
148        let msg = ServerMessage::parse(json).unwrap();
149        assert!(matches!(msg, ServerMessage::Unknown(_)));
150    }
151
152    #[test]
153    fn realtime_input_serialization_audio() {
154        let msg = RealtimeInputMessage {
155            realtime_input: RealtimeInputPayload {
156                media_chunks: Vec::new(),
157                audio: Some(Blob {
158                    mime_type: "audio/pcm".to_string(),
159                    data: "AQIDBA==".to_string(),
160                }),
161                video: None,
162                audio_stream_end: None,
163                text: None,
164            },
165        };
166        let json = serde_json::to_string(&msg).unwrap();
167        assert!(json.contains("\"realtimeInput\""));
168        assert!(json.contains("\"audio\""));
169        assert!(json.contains("\"mimeType\""));
170        // Deprecated field should not appear when empty
171        assert!(!json.contains("\"mediaChunks\""));
172    }
173
174    #[test]
175    fn realtime_input_serialization_legacy() {
176        let msg = RealtimeInputMessage {
177            realtime_input: RealtimeInputPayload {
178                media_chunks: vec![MediaChunk {
179                    mime_type: "audio/pcm".to_string(),
180                    data: "AQIDBA==".to_string(),
181                }],
182                audio: None,
183                video: None,
184                audio_stream_end: None,
185                text: None,
186            },
187        };
188        let json = serde_json::to_string(&msg).unwrap();
189        assert!(json.contains("\"mediaChunks\""));
190    }
191
192    #[test]
193    fn parse_session_resumption_update() {
194        let json = r#"{"sessionResumptionUpdate": {"newHandle": "handle-xyz", "resumable": true}}"#;
195        let msg = ServerMessage::parse(json).unwrap();
196        match msg {
197            ServerMessage::SessionResumptionUpdate(sru) => {
198                assert_eq!(
199                    sru.session_resumption_update.new_handle,
200                    Some("handle-xyz".to_string())
201                );
202                assert_eq!(sru.session_resumption_update.resumable, Some(true));
203            }
204            _ => panic!("Expected SessionResumptionUpdate"),
205        }
206    }
207
208    #[test]
209    fn tool_response_serialization() {
210        let msg = ToolResponseMessage {
211            tool_response: ToolResponsePayload {
212                function_responses: vec![FunctionResponse {
213                    name: "get_weather".to_string(),
214                    response: serde_json::json!({"temp": 22}),
215                    id: Some("call-1".to_string()),
216                    scheduling: None,
217                }],
218            },
219        };
220        let json = serde_json::to_string(&msg).unwrap();
221        assert!(json.contains("\"toolResponse\""));
222        assert!(json.contains("\"functionResponses\""));
223    }
224
225    #[test]
226    fn client_content_serialization() {
227        let msg = ClientContentMessage {
228            client_content: ClientContentPayload {
229                turns: vec![Content::user("Hello")],
230                turn_complete: Some(true),
231            },
232        };
233        let json = serde_json::to_string(&msg).unwrap();
234        assert!(json.contains("\"clientContent\""));
235        assert!(json.contains("\"turnComplete\""));
236    }
237
238    #[test]
239    fn activity_signal_serialization() {
240        let msg = ActivitySignalMessage {
241            realtime_input: ActivitySignalPayload {
242                activity_start: Some(ActivityStart {}),
243                activity_end: None,
244            },
245        };
246        let json = serde_json::to_string(&msg).unwrap();
247        assert!(json.contains("\"activityStart\""));
248    }
249
250    #[test]
251    fn voice_activity_type_serialization() {
252        let start = VoiceActivityType::VoiceActivityStart;
253        let json = serde_json::to_string(&start).unwrap();
254        assert_eq!(json, "\"VOICE_ACTIVITY_START\"");
255        let parsed: VoiceActivityType = serde_json::from_str(&json).unwrap();
256        assert_eq!(parsed, start);
257
258        let end = VoiceActivityType::VoiceActivityEnd;
259        let json = serde_json::to_string(&end).unwrap();
260        assert_eq!(json, "\"VOICE_ACTIVITY_END\"");
261        let parsed: VoiceActivityType = serde_json::from_str(&json).unwrap();
262        assert_eq!(parsed, end);
263    }
264
265    #[test]
266    fn parse_voice_activity_message() {
267        let json = r#"{"voiceActivity":{"voiceActivityType":"VOICE_ACTIVITY_START"}}"#;
268        let msg = ServerMessage::parse(json).unwrap();
269        match msg {
270            ServerMessage::VoiceActivity(va) => {
271                assert_eq!(
272                    va.voice_activity.voice_activity_type,
273                    Some(VoiceActivityType::VoiceActivityStart)
274                );
275            }
276            _ => panic!("Expected VoiceActivity"),
277        }
278
279        let json = r#"{"voiceActivity":{"voiceActivityType":"VOICE_ACTIVITY_END"}}"#;
280        let msg = ServerMessage::parse(json).unwrap();
281        match msg {
282            ServerMessage::VoiceActivity(va) => {
283                assert_eq!(
284                    va.voice_activity.voice_activity_type,
285                    Some(VoiceActivityType::VoiceActivityEnd)
286                );
287            }
288            _ => panic!("Expected VoiceActivity"),
289        }
290    }
291
292    #[test]
293    fn parse_input_transcription() {
294        let json = r#"{
295            "serverContent": {
296                "inputTranscription": {"text": "Hello world"}
297            }
298        }"#;
299        let msg = ServerMessage::parse(json).unwrap();
300        match msg {
301            ServerMessage::ServerContent(sc) => {
302                let text = sc.server_content.input_transcription.unwrap().text.unwrap();
303                assert_eq!(text, "Hello world");
304            }
305            _ => panic!("Expected ServerContent"),
306        }
307    }
308}