gemini_genai_rs/transport/connection/
mod.rs

1//! WebSocket connection lifecycle — connect, setup, full-duplex split, reconnection.
2
3mod message_handler;
4mod reconnect;
5mod session_loop;
6
7use std::sync::Arc;
8
9use tokio::sync::{broadcast, mpsc, watch};
10
11use crate::protocol::types::*;
12use crate::session::{SessionHandle, SessionPhase, SessionState};
13use crate::transport::codec::{Codec, JsonCodec};
14use crate::transport::ws::{Transport, TungsteniteTransport};
15use crate::transport::TransportConfig;
16
17/// Connect to the Gemini Multimodal Live API and return a session handle.
18///
19/// This is the main entry point. It uses the default [`TungsteniteTransport`]
20/// and [`JsonCodec`]. For custom transports or codecs (e.g. testing with
21/// [`MockTransport`](crate::transport::ws::MockTransport)), use [`connect_with`].
22pub async fn connect(
23    config: SessionConfig,
24    transport_config: TransportConfig,
25) -> Result<SessionHandle, crate::session::SessionError> {
26    connect_with(
27        config,
28        transport_config,
29        TungsteniteTransport::new(),
30        JsonCodec,
31    )
32    .await
33}
34
35/// Connect with a custom transport and codec.
36///
37/// This is the generic entry point that accepts any [`Transport`] + [`Codec`]
38/// implementation. The default [`connect`] delegates here with
39/// [`TungsteniteTransport`] and [`JsonCodec`].
40pub async fn connect_with<T, C>(
41    config: SessionConfig,
42    transport_config: TransportConfig,
43    transport: T,
44    codec: C,
45) -> Result<SessionHandle, crate::session::SessionError>
46where
47    T: Transport,
48    C: Codec,
49{
50    let (command_tx, command_rx) = mpsc::channel(transport_config.send_queue_depth);
51    let (event_tx, _) = broadcast::channel(transport_config.event_channel_capacity);
52    let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
53
54    let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
55
56    let handle = SessionHandle::new(command_tx, event_tx.clone(), state.clone(), phase_rx);
57
58    let task = tokio::spawn(async move {
59        session_loop::generic_connection_loop(
60            config,
61            transport_config,
62            state,
63            command_rx,
64            event_tx,
65            transport,
66            codec,
67        )
68        .await;
69    });
70    handle.set_task(task);
71
72    Ok(handle)
73}
74
75#[cfg(test)]
76mod tests {
77    use super::message_handler::{handle_server_msg, MessageAction};
78    use super::reconnect::reconnect_delay;
79    use super::*;
80
81    use std::time::Duration;
82
83    use crate::protocol::messages::ServerMessage;
84    use crate::session::{SessionEvent, SessionPhase, SessionState};
85    use crate::transport::codec::JsonCodec;
86    use crate::transport::ws::MockTransport;
87
88    /// TransportConfig that disables reconnection for mock tests.
89    fn no_reconnect_config() -> TransportConfig {
90        TransportConfig {
91            max_reconnect_attempts: 0,
92            connect_timeout_secs: 5,
93            setup_timeout_secs: 5,
94            ..TransportConfig::default()
95        }
96    }
97
98    #[tokio::test]
99    async fn connect_with_mock_transport() {
100        let mut transport = MockTransport::new();
101        // Script setupComplete response
102        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
103        // Script a text response then turn complete
104        transport.script_recv(
105            br#"{"serverContent":{"modelTurn":{"parts":[{"text":"Hello!"}]},"turnComplete":true}}"#
106                .to_vec(),
107        );
108
109        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
110
111        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
112            .await
113            .unwrap();
114
115        // Should reach Active phase after setup completes
116        handle.wait_for_phase(SessionPhase::Active).await;
117        assert_eq!(handle.phase(), SessionPhase::Active);
118    }
119
120    #[tokio::test]
121    async fn connect_with_mock_receives_text_events() {
122        let mut transport = MockTransport::new();
123        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
124        transport.script_recv(
125            br#"{"serverContent":{"modelTurn":{"parts":[{"text":"Hello from mock!"}]},"turnComplete":true}}"#
126                .to_vec(),
127        );
128
129        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
130        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
131            .await
132            .unwrap();
133
134        let mut events = handle.subscribe();
135
136        // Wait for the session to become active
137        handle.wait_for_phase(SessionPhase::Active).await;
138
139        // Collect events until TurnComplete
140        let mut got_text_delta = false;
141        let mut got_text_complete = false;
142        let mut got_turn_complete = false;
143
144        for _ in 0..20 {
145            match tokio::time::timeout(Duration::from_millis(100), events.recv()).await {
146                Ok(Ok(SessionEvent::TextDelta(t))) => {
147                    assert_eq!(t, "Hello from mock!");
148                    got_text_delta = true;
149                }
150                Ok(Ok(SessionEvent::TextComplete(t))) => {
151                    assert_eq!(t, "Hello from mock!");
152                    got_text_complete = true;
153                }
154                Ok(Ok(SessionEvent::TurnComplete)) => {
155                    got_turn_complete = true;
156                    break;
157                }
158                Ok(Ok(_)) => continue,
159                Ok(Err(_)) => break,
160                Err(_) => break,
161            }
162        }
163
164        assert!(got_text_delta, "should have received TextDelta");
165        assert!(got_text_complete, "should have received TextComplete");
166        assert!(got_turn_complete, "should have received TurnComplete");
167    }
168
169    #[tokio::test]
170    async fn connect_with_mock_tool_call() {
171        let mut transport = MockTransport::new();
172        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
173        transport.script_recv(
174            br#"{"toolCall":{"functionCalls":[{"name":"get_weather","args":{"city":"London"},"id":"call-1"}]}}"#
175                .to_vec(),
176        );
177
178        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
179        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
180            .await
181            .unwrap();
182
183        let mut events = handle.subscribe();
184        handle.wait_for_phase(SessionPhase::Active).await;
185
186        // Look for the ToolCall event
187        let mut got_tool_call = false;
188        for _ in 0..20 {
189            match tokio::time::timeout(Duration::from_millis(100), events.recv()).await {
190                Ok(Ok(SessionEvent::ToolCall(calls))) => {
191                    assert_eq!(calls.len(), 1);
192                    assert_eq!(calls[0].name, "get_weather");
193                    got_tool_call = true;
194                    break;
195                }
196                Ok(Ok(_)) => continue,
197                Ok(Err(_)) => break,
198                Err(_) => break,
199            }
200        }
201
202        assert!(got_tool_call, "should have received ToolCall event");
203    }
204
205    #[tokio::test]
206    async fn connect_with_mock_graceful_disconnect() {
207        let mut transport = MockTransport::new();
208        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
209        // Keep the connection alive with a message that arrives before disconnect
210        transport.script_recv(
211            br#"{"serverContent":{"modelTurn":{"parts":[{"text":"hi"}]},"turnComplete":true}}"#
212                .to_vec(),
213        );
214
215        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
216        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
217            .await
218            .unwrap();
219
220        handle.wait_for_phase(SessionPhase::Active).await;
221        // Small delay to let the background task process
222        tokio::time::sleep(Duration::from_millis(50)).await;
223
224        // Disconnect gracefully
225        handle.disconnect().await.unwrap();
226
227        // Wait for disconnected phase
228        handle.wait_for_phase(SessionPhase::Disconnected).await;
229        assert_eq!(handle.phase(), SessionPhase::Disconnected);
230    }
231
232    #[test]
233    fn handle_server_msg_preserves_interruption() {
234        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
235        let (event_tx, mut event_rx) = broadcast::channel(16);
236        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
237
238        let json = r#"{"serverContent":{"interrupted":true}}"#;
239        let msg = ServerMessage::parse(json).unwrap();
240        let action = handle_server_msg(msg, &state, &event_tx);
241
242        assert!(matches!(action, MessageAction::Continue));
243        // Should have emitted Interrupted event
244        let mut found_interrupted = false;
245        while let Ok(evt) = event_rx.try_recv() {
246            if matches!(evt, SessionEvent::Interrupted) {
247                found_interrupted = true;
248            }
249        }
250        assert!(found_interrupted, "should emit Interrupted event");
251    }
252
253    #[test]
254    fn handle_server_msg_go_away() {
255        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
256        let (event_tx, _event_rx) = broadcast::channel(16);
257        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
258
259        let json = r#"{"goAway":{"timeLeft":"30s"}}"#;
260        let msg = ServerMessage::parse(json).unwrap();
261        let action = handle_server_msg(msg, &state, &event_tx);
262
263        assert!(matches!(action, MessageAction::GoAway(Some(_))));
264    }
265
266    #[test]
267    fn handle_server_msg_unknown_is_continue() {
268        let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
269        let (event_tx, _event_rx) = broadcast::channel(16);
270        let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
271
272        let json = r#"{"unknownField":{"data":"test"}}"#;
273        let msg = ServerMessage::parse(json).unwrap();
274        let action = handle_server_msg(msg, &state, &event_tx);
275
276        assert!(matches!(action, MessageAction::Continue));
277    }
278
279    #[tokio::test]
280    async fn session_handle_join_after_disconnect() {
281        let mut transport = MockTransport::new();
282        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
283
284        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
285        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
286            .await
287            .unwrap();
288
289        handle.wait_for_phase(SessionPhase::Active).await;
290
291        // Disconnect to end the connection loop task
292        handle.disconnect().await.unwrap();
293        handle.wait_for_phase(SessionPhase::Disconnected).await;
294
295        // join() should return Ok after the task completes
296        let result = handle.join().await;
297        assert!(result.is_ok(), "join() should succeed after disconnect");
298    }
299
300    #[tokio::test]
301    async fn session_handle_join_after_command_channel_closed() {
302        let mut transport = MockTransport::new();
303        transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
304
305        let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
306        let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
307            .await
308            .unwrap();
309
310        handle.wait_for_phase(SessionPhase::Active).await;
311
312        // Drop all senders to close the command channel, which triggers disconnect
313        // We need to get the handle before dropping the original
314        let join_handle = handle.clone();
315
316        // Drop command_tx by dropping the handle — but we cloned it first.
317        // Instead, disconnect and then join.
318        handle.disconnect().await.unwrap();
319
320        let result = join_handle.join().await;
321        assert!(result.is_ok(), "join() should succeed after channel close");
322    }
323
324    #[test]
325    fn reconnect_delay_exponential_backoff() {
326        let config = TransportConfig::default();
327        let d1 = reconnect_delay(1, &config);
328        let d2 = reconnect_delay(2, &config);
329        let d3 = reconnect_delay(3, &config);
330        // Each step should roughly double (plus jitter)
331        assert!(d2 > d1);
332        assert!(d3 > d2);
333        // Should not exceed max
334        let d_large = reconnect_delay(100, &config);
335        let max_with_jitter = Duration::from_millis(
336            config.reconnect_max_delay_ms as u64 + config.reconnect_max_delay_ms as u64 / 4,
337        );
338        assert!(d_large <= max_with_jitter);
339    }
340}