gemini_genai_rs/transport/connection/
mod.rs1mod message_handler;
4mod reconnect;
5mod session_loop;
6
7use std::sync::Arc;
8
9use tokio::sync::{broadcast, mpsc, watch};
10
11use crate::protocol::types::*;
12use crate::session::{SessionHandle, SessionPhase, SessionState};
13use crate::transport::codec::{Codec, JsonCodec};
14use crate::transport::ws::{Transport, TungsteniteTransport};
15use crate::transport::TransportConfig;
16
17pub async fn connect(
23 config: SessionConfig,
24 transport_config: TransportConfig,
25) -> Result<SessionHandle, crate::session::SessionError> {
26 connect_with(
27 config,
28 transport_config,
29 TungsteniteTransport::new(),
30 JsonCodec,
31 )
32 .await
33}
34
35pub async fn connect_with<T, C>(
41 config: SessionConfig,
42 transport_config: TransportConfig,
43 transport: T,
44 codec: C,
45) -> Result<SessionHandle, crate::session::SessionError>
46where
47 T: Transport,
48 C: Codec,
49{
50 let (command_tx, command_rx) = mpsc::channel(transport_config.send_queue_depth);
51 let (event_tx, _) = broadcast::channel(transport_config.event_channel_capacity);
52 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
53
54 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
55
56 let handle = SessionHandle::new(command_tx, event_tx.clone(), state.clone(), phase_rx);
57
58 let task = tokio::spawn(async move {
59 session_loop::generic_connection_loop(
60 config,
61 transport_config,
62 state,
63 command_rx,
64 event_tx,
65 transport,
66 codec,
67 )
68 .await;
69 });
70 handle.set_task(task);
71
72 Ok(handle)
73}
74
75#[cfg(test)]
76mod tests {
77 use super::message_handler::{handle_server_msg, MessageAction};
78 use super::reconnect::reconnect_delay;
79 use super::*;
80
81 use std::time::Duration;
82
83 use crate::protocol::messages::ServerMessage;
84 use crate::session::{SessionEvent, SessionPhase, SessionState};
85 use crate::transport::codec::JsonCodec;
86 use crate::transport::ws::MockTransport;
87
88 fn no_reconnect_config() -> TransportConfig {
90 TransportConfig {
91 max_reconnect_attempts: 0,
92 connect_timeout_secs: 5,
93 setup_timeout_secs: 5,
94 ..TransportConfig::default()
95 }
96 }
97
98 #[tokio::test]
99 async fn connect_with_mock_transport() {
100 let mut transport = MockTransport::new();
101 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
103 transport.script_recv(
105 br#"{"serverContent":{"modelTurn":{"parts":[{"text":"Hello!"}]},"turnComplete":true}}"#
106 .to_vec(),
107 );
108
109 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
110
111 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
112 .await
113 .unwrap();
114
115 handle.wait_for_phase(SessionPhase::Active).await;
117 assert_eq!(handle.phase(), SessionPhase::Active);
118 }
119
120 #[tokio::test]
121 async fn connect_with_mock_receives_text_events() {
122 let mut transport = MockTransport::new();
123 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
124 transport.script_recv(
125 br#"{"serverContent":{"modelTurn":{"parts":[{"text":"Hello from mock!"}]},"turnComplete":true}}"#
126 .to_vec(),
127 );
128
129 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
130 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
131 .await
132 .unwrap();
133
134 let mut events = handle.subscribe();
135
136 handle.wait_for_phase(SessionPhase::Active).await;
138
139 let mut got_text_delta = false;
141 let mut got_text_complete = false;
142 let mut got_turn_complete = false;
143
144 for _ in 0..20 {
145 match tokio::time::timeout(Duration::from_millis(100), events.recv()).await {
146 Ok(Ok(SessionEvent::TextDelta(t))) => {
147 assert_eq!(t, "Hello from mock!");
148 got_text_delta = true;
149 }
150 Ok(Ok(SessionEvent::TextComplete(t))) => {
151 assert_eq!(t, "Hello from mock!");
152 got_text_complete = true;
153 }
154 Ok(Ok(SessionEvent::TurnComplete)) => {
155 got_turn_complete = true;
156 break;
157 }
158 Ok(Ok(_)) => continue,
159 Ok(Err(_)) => break,
160 Err(_) => break,
161 }
162 }
163
164 assert!(got_text_delta, "should have received TextDelta");
165 assert!(got_text_complete, "should have received TextComplete");
166 assert!(got_turn_complete, "should have received TurnComplete");
167 }
168
169 #[tokio::test]
170 async fn connect_with_mock_tool_call() {
171 let mut transport = MockTransport::new();
172 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
173 transport.script_recv(
174 br#"{"toolCall":{"functionCalls":[{"name":"get_weather","args":{"city":"London"},"id":"call-1"}]}}"#
175 .to_vec(),
176 );
177
178 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
179 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
180 .await
181 .unwrap();
182
183 let mut events = handle.subscribe();
184 handle.wait_for_phase(SessionPhase::Active).await;
185
186 let mut got_tool_call = false;
188 for _ in 0..20 {
189 match tokio::time::timeout(Duration::from_millis(100), events.recv()).await {
190 Ok(Ok(SessionEvent::ToolCall(calls))) => {
191 assert_eq!(calls.len(), 1);
192 assert_eq!(calls[0].name, "get_weather");
193 got_tool_call = true;
194 break;
195 }
196 Ok(Ok(_)) => continue,
197 Ok(Err(_)) => break,
198 Err(_) => break,
199 }
200 }
201
202 assert!(got_tool_call, "should have received ToolCall event");
203 }
204
205 #[tokio::test]
206 async fn connect_with_mock_graceful_disconnect() {
207 let mut transport = MockTransport::new();
208 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
209 transport.script_recv(
211 br#"{"serverContent":{"modelTurn":{"parts":[{"text":"hi"}]},"turnComplete":true}}"#
212 .to_vec(),
213 );
214
215 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
216 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
217 .await
218 .unwrap();
219
220 handle.wait_for_phase(SessionPhase::Active).await;
221 tokio::time::sleep(Duration::from_millis(50)).await;
223
224 handle.disconnect().await.unwrap();
226
227 handle.wait_for_phase(SessionPhase::Disconnected).await;
229 assert_eq!(handle.phase(), SessionPhase::Disconnected);
230 }
231
232 #[test]
233 fn handle_server_msg_preserves_interruption() {
234 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
235 let (event_tx, mut event_rx) = broadcast::channel(16);
236 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
237
238 let json = r#"{"serverContent":{"interrupted":true}}"#;
239 let msg = ServerMessage::parse(json).unwrap();
240 let action = handle_server_msg(msg, &state, &event_tx);
241
242 assert!(matches!(action, MessageAction::Continue));
243 let mut found_interrupted = false;
245 while let Ok(evt) = event_rx.try_recv() {
246 if matches!(evt, SessionEvent::Interrupted) {
247 found_interrupted = true;
248 }
249 }
250 assert!(found_interrupted, "should emit Interrupted event");
251 }
252
253 #[test]
254 fn handle_server_msg_go_away() {
255 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
256 let (event_tx, _event_rx) = broadcast::channel(16);
257 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
258
259 let json = r#"{"goAway":{"timeLeft":"30s"}}"#;
260 let msg = ServerMessage::parse(json).unwrap();
261 let action = handle_server_msg(msg, &state, &event_tx);
262
263 assert!(matches!(action, MessageAction::GoAway(Some(_))));
264 }
265
266 #[test]
267 fn handle_server_msg_unknown_is_continue() {
268 let (phase_tx, _phase_rx) = watch::channel(SessionPhase::Active);
269 let (event_tx, _event_rx) = broadcast::channel(16);
270 let state = Arc::new(SessionState::with_events(phase_tx, event_tx.clone()));
271
272 let json = r#"{"unknownField":{"data":"test"}}"#;
273 let msg = ServerMessage::parse(json).unwrap();
274 let action = handle_server_msg(msg, &state, &event_tx);
275
276 assert!(matches!(action, MessageAction::Continue));
277 }
278
279 #[tokio::test]
280 async fn session_handle_join_after_disconnect() {
281 let mut transport = MockTransport::new();
282 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
283
284 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
285 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
286 .await
287 .unwrap();
288
289 handle.wait_for_phase(SessionPhase::Active).await;
290
291 handle.disconnect().await.unwrap();
293 handle.wait_for_phase(SessionPhase::Disconnected).await;
294
295 let result = handle.join().await;
297 assert!(result.is_ok(), "join() should succeed after disconnect");
298 }
299
300 #[tokio::test]
301 async fn session_handle_join_after_command_channel_closed() {
302 let mut transport = MockTransport::new();
303 transport.script_recv(br#"{"setupComplete":{}}"#.to_vec());
304
305 let config = SessionConfig::new("test-key").model(GeminiModel::Gemini2_0FlashLive);
306 let handle = connect_with(config, no_reconnect_config(), transport, JsonCodec)
307 .await
308 .unwrap();
309
310 handle.wait_for_phase(SessionPhase::Active).await;
311
312 let join_handle = handle.clone();
315
316 handle.disconnect().await.unwrap();
319
320 let result = join_handle.join().await;
321 assert!(result.is_ok(), "join() should succeed after channel close");
322 }
323
324 #[test]
325 fn reconnect_delay_exponential_backoff() {
326 let config = TransportConfig::default();
327 let d1 = reconnect_delay(1, &config);
328 let d2 = reconnect_delay(2, &config);
329 let d3 = reconnect_delay(3, &config);
330 assert!(d2 > d1);
332 assert!(d3 > d2);
333 let d_large = reconnect_delay(100, &config);
335 let max_with_jitter = Duration::from_millis(
336 config.reconnect_max_delay_ms as u64 + config.reconnect_max_delay_ms as u64 / 4,
337 );
338 assert!(d_large <= max_with_jitter);
339 }
340}