gemini_genai_rs/transport/connection/
mod.rs1mod message_handler;
4mod reconnect;
5mod session_loop;
6
7use std::sync::Arc;
8
9use tokio::sync::{broadcast, mpsc, watch};
10
11use crate::protocol::types::*;
12use crate::session::{SessionHandle, SessionPhase, SessionState};
13use crate::transport::codec::{Codec, JsonCodec};
14use crate::transport::ws::{Transport, TungsteniteTransport};
15use crate::transport::TransportConfig;
16
17pub async fn connect(
23 config: SessionConfig,
24 transport_config: TransportConfig,
25) -> Result<SessionHandle, crate::session::SessionError> {
26 connect_with(
27 config,
28 transport_config,
29 TungsteniteTransport::new(),
30 JsonCodec,
31 )
32 .await
33}
34
35pub async fn connect_with<T, C>(
41 config: SessionConfig,
42 transport_config: TransportConfig,
43 transport: T,
44 codec: C,
45) -> Result<SessionHandle, crate::session::SessionError>
46where
47 T: Transport,
48 C: Codec,
49{
50 let (command_tx, command_rx) = mpsc::channel(transport_config.send_queue_depth);
51 let (event_tx, _) = broadcast::channel(transport_config.event_channel_capacity);
52 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
53
54 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
55
56 let mut handle = SessionHandle::new(command_tx, event_tx.clone(), state.clone(), phase_rx);
57 if let Some(pacing) = config.audio_pacing.clone() {
58 handle = handle.with_audio_pacing(pacing);
59 }
60
61 let task = if let Some(recorder) = config.wire_recorder.clone() {
65 let codec: Box<dyn Codec> = Box::new(crate::transport::recording::RecordingCodec::new(
66 codec,
67 recorder.recorder(),
68 ));
69 tokio::spawn(async move {
70 session_loop::generic_connection_loop(
71 config,
72 transport_config,
73 state,
74 command_rx,
75 event_tx,
76 transport,
77 codec,
78 )
79 .await;
80 })
81 } else {
82 tokio::spawn(async move {
83 session_loop::generic_connection_loop(
84 config,
85 transport_config,
86 state,
87 command_rx,
88 event_tx,
89 transport,
90 codec,
91 )
92 .await;
93 })
94 };
95 handle.set_task(task);
96
97 Ok(handle)
98}
99
100#[cfg(test)]
101mod tests {
102 use super::message_handler::{handle_server_msg, MessageAction};
103 use super::reconnect::reconnect_delay;
104 use super::*;
105
106 use std::time::Duration;
107
108 use crate::protocol::messages::ServerMessage;
109 use crate::session::{SessionEvent, SessionPhase, SessionState};
110 use crate::transport::codec::JsonCodec;
111 use crate::transport::ws::MockTransport;
112
113 fn no_reconnect_config() -> TransportConfig {
115 TransportConfig {
116 max_reconnect_attempts: 0,
117 connect_timeout_secs: 5,
118 setup_timeout_secs: 5,
119 ..TransportConfig::default()
120 }
121 }
122
123 #[tokio::test]
124 async fn connect_with_mock_transport() {
125 let mut transport = MockTransport::new();
126 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
128 transport.script_recv(
130 br#"{"serverContent":{"modelTurn":{"parts":[{"text":"Hello!"}]},"turnComplete":true}}"#
131 .to_vec(),
132 );
133
134 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
135
136 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
137 .await
138 .unwrap();
139
140 handle.wait_for_phase(SessionPhase::Active).await;
142 assert_eq!(handle.phase(), SessionPhase::Active);
143 }
144
145 #[tokio::test]
146 async fn connect_with_mock_receives_text_events() {
147 let mut transport = MockTransport::new();
148 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
149 transport.script_recv(
150 br#"{"serverContent":{"modelTurn":{"parts":[{"text":"Hello from mock!"}]},"turnComplete":true}}"#
151 .to_vec(),
152 );
153
154 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
155 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
156 .await
157 .unwrap();
158
159 let mut events = handle.subscribe();
160
161 handle.wait_for_phase(SessionPhase::Active).await;
163
164 let mut got_text_delta = false;
166 let mut got_text_complete = false;
167 let mut got_turn_complete = false;
168
169 for _ in 0..20 {
170 match tokio::time::timeout(Duration::from_millis(100), events.recv()).await {
171 Ok(Ok(SessionEvent::TextDelta(t))) => {
172 assert_eq!(t, "Hello from mock!");
173 got_text_delta = true;
174 }
175 Ok(Ok(SessionEvent::TextComplete(t))) => {
176 assert_eq!(t, "Hello from mock!");
177 got_text_complete = true;
178 }
179 Ok(Ok(SessionEvent::TurnComplete)) => {
180 got_turn_complete = true;
181 break;
182 }
183 Ok(Ok(_)) => continue,
184 Ok(Err(_)) => break,
185 Err(_) => break,
186 }
187 }
188
189 assert!(got_text_delta, "should have received TextDelta");
190 assert!(got_text_complete, "should have received TextComplete");
191 assert!(got_turn_complete, "should have received TurnComplete");
192 }
193
194 #[tokio::test]
195 async fn connect_with_mock_tool_call() {
196 let mut transport = MockTransport::new();
197 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
198 transport.script_recv(
199 br#"{"toolCall":{"functionCalls":[{"name":"get_weather","args":{"city":"London"},"id":"call-1"}]}}"#
200 .to_vec(),
201 );
202
203 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
204 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
205 .await
206 .unwrap();
207
208 let mut events = handle.subscribe();
209 handle.wait_for_phase(SessionPhase::Active).await;
210
211 let mut got_tool_call = false;
213 for _ in 0..20 {
214 match tokio::time::timeout(Duration::from_millis(100), events.recv()).await {
215 Ok(Ok(SessionEvent::ToolCall(calls))) => {
216 assert_eq!(calls.len(), 1);
217 assert_eq!(calls[0].name, "get_weather");
218 got_tool_call = true;
219 break;
220 }
221 Ok(Ok(_)) => continue,
222 Ok(Err(_)) => break,
223 Err(_) => break,
224 }
225 }
226
227 assert!(got_tool_call, "should have received ToolCall event");
228 }
229
230 #[tokio::test]
231 async fn connect_with_mock_graceful_disconnect() {
232 let mut transport = MockTransport::new();
233 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
234 transport.script_recv(
236 br#"{"serverContent":{"modelTurn":{"parts":[{"text":"hi"}]},"turnComplete":true}}"#
237 .to_vec(),
238 );
239
240 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
241 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
242 .await
243 .unwrap();
244
245 handle.wait_for_phase(SessionPhase::Active).await;
246 tokio::time::sleep(Duration::from_millis(50)).await;
248
249 handle.disconnect().await.unwrap();
251
252 handle.wait_for_phase(SessionPhase::Disconnected).await;
254 assert_eq!(handle.phase(), SessionPhase::Disconnected);
255 }
256
257 #[test]
258 fn handle_server_msg_preserves_interruption() {
259 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
260 let (event_tx, mut event_rx) = broadcast::channel(16);
261 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
262
263 let json = r#"{"serverContent":{"interrupted":true}}"#;
264 let msg = ServerMessage::parse(json).unwrap();
265 let action = handle_server_msg(msg, &state, &event_tx);
266
267 assert!(matches!(action, MessageAction::Continue));
268 let mut found_interrupted = false;
270 while let Ok(evt) = event_rx.try_recv() {
271 if matches!(evt, SessionEvent::Interrupted) {
272 found_interrupted = true;
273 }
274 }
275 assert!(found_interrupted, "should emit Interrupted event");
276 }
277
278 #[test]
279 fn handle_server_msg_go_away() {
280 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
281 let (event_tx, _event_rx) = broadcast::channel(16);
282 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
283
284 let json = r#"{"goAway":{"timeLeft":"30s"}}"#;
285 let msg = ServerMessage::parse(json).unwrap();
286 let action = handle_server_msg(msg, &state, &event_tx);
287
288 assert!(matches!(action, MessageAction::GoAway(Some(_))));
289 }
290
291 #[test]
292 fn handle_server_msg_unknown_is_continue() {
293 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
294 let (event_tx, _event_rx) = broadcast::channel(16);
295 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
296
297 let json = r#"{"unknownField":{"data":"test"}}"#;
298 let msg = ServerMessage::parse(json).unwrap();
299 let action = handle_server_msg(msg, &state, &event_tx);
300
301 assert!(matches!(action, MessageAction::Continue));
302 }
303
304 #[tokio::test]
305 async fn session_handle_join_after_disconnect() {
306 let mut transport = MockTransport::new();
307 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
308
309 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
310 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
311 .await
312 .unwrap();
313
314 handle.wait_for_phase(SessionPhase::Active).await;
315
316 handle.disconnect().await.unwrap();
318 handle.wait_for_phase(SessionPhase::Disconnected).await;
319
320 let result = handle.join().await;
322 assert!(result.is_ok(), "join() should succeed after disconnect");
323 }
324
325 #[tokio::test]
326 async fn session_handle_join_after_command_channel_closed() {
327 let mut transport = MockTransport::new();
328 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
329
330 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
331 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
332 .await
333 .unwrap();
334
335 handle.wait_for_phase(SessionPhase::Active).await;
336
337 let join_handle = handle.clone();
340
341 handle.disconnect().await.unwrap();
344
345 let result = join_handle.join().await;
346 assert!(result.is_ok(), "join() should succeed after channel close");
347 }
348
349 #[test]
350 fn reconnect_delay_exponential_backoff() {
351 let config = TransportConfig::default();
352 let d1 = reconnect_delay(1, &config);
353 let d2 = reconnect_delay(2, &config);
354 let d3 = reconnect_delay(3, &config);
355 assert!(d2 > d1);
357 assert!(d3 > d2);
358 let d_large = reconnect_delay(100, &config);
360 let max_with_jitter = Duration::from_millis(
361 config.reconnect_max_delay_ms as u64 + config.reconnect_max_delay_ms as u64 / 4,
362 );
363 assert!(d_large <= max_with_jitter);
364 }
365}