gemini_genai_rs/transport/connection/
mod.rs

1//! WebSocket connection lifecycle — connect, setup, full-duplex split, reconnection.
2
3mod message_handler;
4mod reconnect;
5mod session_loop;
6
7use std::sync::Arc;
8
9use tokio::sync::{broadcast, mpsc, watch};
10
11use crate::protocol::types::*;
12use crate::session::{SessionHandle, SessionPhase, SessionState};
13use crate::transport::codec::{Codec, JsonCodec};
14use crate::transport::ws::{Transport, TungsteniteTransport};
15use crate::transport::TransportConfig;
16
17/// Connect to the Gemini Multimodal Live API and return a session handle.
18///
19/// This is the main entry point. It uses the default [`TungsteniteTransport`]
20/// and [`JsonCodec`]. For custom transports or codecs (e.g. testing with
21/// [`MockTransport`](crate::transport::ws::MockTransport)), use [`connect_with`].
22pub async fn connect(
23    config: SessionConfig,
24    transport_config: TransportConfig,
25) -> Result<SessionHandle, crate::session::SessionError> {
26    connect_with(
27        config,
28        transport_config,
29        TungsteniteTransport::new(),
30        JsonCodec,
31    )
32    .await
33}
34
35/// Connect with a custom transport and codec.
36///
37/// This is the generic entry point that accepts any [`Transport`] + [`Codec`]
38/// implementation. The default [`connect`] delegates here with
39/// [`TungsteniteTransport`] and [`JsonCodec`].
40pub async fn connect_with<T, C>(
41    config: SessionConfig,
42    transport_config: TransportConfig,
43    transport: T,
44    codec: C,
45) -> Result<SessionHandle, crate::session::SessionError>
46where
47    T: Transport,
48    C: Codec,
49{
50    let (command_tx, command_rx) = mpsc::channel(transport_config.send_queue_depth);
51    let (event_tx, _) = broadcast::channel(transport_config.event_channel_capacity);
52    let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
53
54    let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
55
56    let mut handle = SessionHandle::new(command_tx, event_tx.clone(), state.clone(), phase_rx);
57    if let Some(pacing) = config.audio_pacing.clone() {
58        handle = handle.with_audio_pacing(pacing);
59    }
60
61    // Honor a config-installed wire recorder by wrapping the codec. The
62    // boxed indirection keeps `connect_with` generic while letting the
63    // recorder be a runtime decision.
64    let task = if let Some(recorder) = config.wire_recorder.clone() {
65        let codec: Box<dyn Codec> = Box::new(crate::transport::recording::RecordingCodec::new(
66            codec,
67            recorder.recorder(),
68        ));
69        tokio::spawn(async move {
70            session_loop::generic_connection_loop(
71                config,
72                transport_config,
73                state,
74                command_rx,
75                event_tx,
76                transport,
77                codec,
78            )
79            .await;
80        })
81    } else {
82        tokio::spawn(async move {
83            session_loop::generic_connection_loop(
84                config,
85                transport_config,
86                state,
87                command_rx,
88                event_tx,
89                transport,
90                codec,
91            )
92            .await;
93        })
94    };
95    handle.set_task(task);
96
97    Ok(handle)
98}
99
100#[cfg(test)]
101mod tests {
102    use super::message_handler::{handle_server_msg, MessageAction};
103    use super::reconnect::reconnect_delay;
104    use super::*;
105
106    use std::time::Duration;
107
108    use crate::protocol::messages::ServerMessage;
109    use crate::session::{SessionEvent, SessionPhase, SessionState};
110    use crate::transport::codec::JsonCodec;
111    use crate::transport::ws::MockTransport;
112
113    /// TransportConfig that disables reconnection for mock tests.
114    fn no_reconnect_config() -> TransportConfig {
115        TransportConfig {
116            max_reconnect_attempts: 0,
117            connect_timeout_secs: 5,
118            setup_timeout_secs: 5,
119            ..TransportConfig::default()
120        }
121    }
122
123    #[tokio::test]
124    async fn connect_with_mock_transport() {
125        let mut transport = MockTransport::new();
126        // Script setupComplete response
127        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
128        // Script a text response then turn complete
129        transport.script_recv(
130            br#"{"serverContent":{"modelTurn":{"parts":[{"text":"Hello!"}]},"turnComplete":true}}"#
131                .to_vec(),
132        );
133
134        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
135
136        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
137            .await
138            .unwrap();
139
140        // Should reach Active phase after setup completes
141        handle.wait_for_phase(SessionPhase::Active).await;
142        assert_eq!(handle.phase(), SessionPhase::Active);
143    }
144
145    #[tokio::test]
146    async fn connect_with_mock_receives_text_events() {
147        let mut transport = MockTransport::new();
148        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
149        transport.script_recv(
150            br#"{"serverContent":{"modelTurn":{"parts":[{"text":"Hello from mock!"}]},"turnComplete":true}}"#
151                .to_vec(),
152        );
153
154        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
155        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
156            .await
157            .unwrap();
158
159        let mut events = handle.subscribe();
160
161        // Wait for the session to become active
162        handle.wait_for_phase(SessionPhase::Active).await;
163
164        // Collect events until TurnComplete
165        let mut got_text_delta = false;
166        let mut got_text_complete = false;
167        let mut got_turn_complete = false;
168
169        for _ in 0..20 {
170            match tokio::time::timeout(Duration::from_millis(100), events.recv()).await {
171                Ok(Ok(SessionEvent::TextDelta(t))) => {
172                    assert_eq!(t, "Hello from mock!");
173                    got_text_delta = true;
174                }
175                Ok(Ok(SessionEvent::TextComplete(t))) => {
176                    assert_eq!(t, "Hello from mock!");
177                    got_text_complete = true;
178                }
179                Ok(Ok(SessionEvent::TurnComplete)) => {
180                    got_turn_complete = true;
181                    break;
182                }
183                Ok(Ok(_)) => continue,
184                Ok(Err(_)) => break,
185                Err(_) => break,
186            }
187        }
188
189        assert!(got_text_delta, "should have received TextDelta");
190        assert!(got_text_complete, "should have received TextComplete");
191        assert!(got_turn_complete, "should have received TurnComplete");
192    }
193
194    #[tokio::test]
195    async fn connect_with_mock_tool_call() {
196        let mut transport = MockTransport::new();
197        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
198        transport.script_recv(
199            br#"{"toolCall":{"functionCalls":[{"name":"get_weather","args":{"city":"London"},"id":"call-1"}]}}"#
200                .to_vec(),
201        );
202
203        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
204        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
205            .await
206            .unwrap();
207
208        let mut events = handle.subscribe();
209        handle.wait_for_phase(SessionPhase::Active).await;
210
211        // Look for the ToolCall event
212        let mut got_tool_call = false;
213        for _ in 0..20 {
214            match tokio::time::timeout(Duration::from_millis(100), events.recv()).await {
215                Ok(Ok(SessionEvent::ToolCall(calls))) => {
216                    assert_eq!(calls.len(), 1);
217                    assert_eq!(calls[0].name, "get_weather");
218                    got_tool_call = true;
219                    break;
220                }
221                Ok(Ok(_)) => continue,
222                Ok(Err(_)) => break,
223                Err(_) => break,
224            }
225        }
226
227        assert!(got_tool_call, "should have received ToolCall event");
228    }
229
230    #[tokio::test]
231    async fn connect_with_mock_graceful_disconnect() {
232        let mut transport = MockTransport::new();
233        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
234        // Keep the connection alive with a message that arrives before disconnect
235        transport.script_recv(
236            br#"{"serverContent":{"modelTurn":{"parts":[{"text":"hi"}]},"turnComplete":true}}"#
237                .to_vec(),
238        );
239
240        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
241        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
242            .await
243            .unwrap();
244
245        handle.wait_for_phase(SessionPhase::Active).await;
246        // Small delay to let the background task process
247        tokio::time::sleep(Duration::from_millis(50)).await;
248
249        // Disconnect gracefully
250        handle.disconnect().await.unwrap();
251
252        // Wait for disconnected phase
253        handle.wait_for_phase(SessionPhase::Disconnected).await;
254        assert_eq!(handle.phase(), SessionPhase::Disconnected);
255    }
256
257    #[test]
258    fn handle_server_msg_preserves_interruption() {
259        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
260        let (event_tx, mut event_rx) = broadcast::channel(16);
261        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
262
263        let json = r#"{"serverContent":{"interrupted":true}}"#;
264        let msg = ServerMessage::parse(json).unwrap();
265        let action = handle_server_msg(msg, &state, &event_tx);
266
267        assert!(matches!(action, MessageAction::Continue));
268        // Should have emitted Interrupted event
269        let mut found_interrupted = false;
270        while let Ok(evt) = event_rx.try_recv() {
271            if matches!(evt, SessionEvent::Interrupted) {
272                found_interrupted = true;
273            }
274        }
275        assert!(found_interrupted, "should emit Interrupted event");
276    }
277
278    #[test]
279    fn handle_server_msg_go_away() {
280        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
281        let (event_tx, _event_rx) = broadcast::channel(16);
282        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
283
284        let json = r#"{"goAway":{"timeLeft":"30s"}}"#;
285        let msg = ServerMessage::parse(json).unwrap();
286        let action = handle_server_msg(msg, &state, &event_tx);
287
288        assert!(matches!(action, MessageAction::GoAway(Some(_))));
289    }
290
291    #[test]
292    fn handle_server_msg_unknown_is_continue() {
293        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
294        let (event_tx, _event_rx) = broadcast::channel(16);
295        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
296
297        let json = r#"{"unknownField":{"data":"test"}}"#;
298        let msg = ServerMessage::parse(json).unwrap();
299        let action = handle_server_msg(msg, &state, &event_tx);
300
301        assert!(matches!(action, MessageAction::Continue));
302    }
303
304    #[tokio::test]
305    async fn session_handle_join_after_disconnect() {
306        let mut transport = MockTransport::new();
307        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
308
309        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
310        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
311            .await
312            .unwrap();
313
314        handle.wait_for_phase(SessionPhase::Active).await;
315
316        // Disconnect to end the connection loop task
317        handle.disconnect().await.unwrap();
318        handle.wait_for_phase(SessionPhase::Disconnected).await;
319
320        // join() should return Ok after the task completes
321        let result = handle.join().await;
322        assert!(result.is_ok(), "join() should succeed after disconnect");
323    }
324
325    #[tokio::test]
326    async fn session_handle_join_after_command_channel_closed() {
327        let mut transport = MockTransport::new();
328        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
329
330        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
331        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
332            .await
333            .unwrap();
334
335        handle.wait_for_phase(SessionPhase::Active).await;
336
337        // Drop all senders to close the command channel, which triggers disconnect
338        // We need to get the handle before dropping the original
339        let join_handle = handle.clone();
340
341        // Drop command_tx by dropping the handle — but we cloned it first.
342        // Instead, disconnect and then join.
343        handle.disconnect().await.unwrap();
344
345        let result = join_handle.join().await;
346        assert!(result.is_ok(), "join() should succeed after channel close");
347    }
348
349    #[test]
350    fn reconnect_delay_exponential_backoff() {
351        let config = TransportConfig::default();
352        let d1 = reconnect_delay(1, &config);
353        let d2 = reconnect_delay(2, &config);
354        let d3 = reconnect_delay(3, &config);
355        // Each step should roughly double (plus jitter)
356        assert!(d2 > d1);
357        assert!(d3 > d2);
358        // Should not exceed max
359        let d_large = reconnect_delay(100, &config);
360        let max_with_jitter = Duration::from_millis(
361            config.reconnect_max_delay_ms as u64 + config.reconnect_max_delay_ms as u64 / 4,
362        );
363        assert!(d_large <= max_with_jitter);
364    }
365}