diff --git a/renetcode/src/server.rs b/renetcode/src/server.rs index 068bbcdb..7656458d 100644 --- a/renetcode/src/server.rs +++ b/renetcode/src/server.rs @@ -27,6 +27,7 @@ struct Connection { user_data: [u8; NETCODE_USER_DATA_BYTES], addr: SocketAddr, last_packet_received_time: Duration, + connection_started_time: Duration, last_packet_send_time: Duration, timeout_seconds: i32, sequence: u64, @@ -338,6 +339,7 @@ impl NetcodeServer { client_id: connect_token.client_id, last_packet_received_time: self.current_time, last_packet_send_time: self.current_time, + connection_started_time: self.current_time, addr, state: ConnectionState::PendingResponse, send_key: connect_token.server_to_client_key, @@ -406,38 +408,59 @@ impl NetcodeServer { ); client.last_packet_received_time = self.current_time; - match client.state { - ConnectionState::Connected => match packet { - Packet::Disconnect => { - client.state = ConnectionState::Disconnected; + if client.state != ConnectionState::Connected { + return Ok(ServerResult::None); + } + + match packet { + Packet::Disconnect => { + let client_id = client.client_id; + self.clients[slot] = None; + log::trace!("Client {} requested to disconnect", client_id); + return Ok(ServerResult::ClientDisconnected { + client_id, + addr, + payload: None, + }); + } + Packet::Payload(payload) => { + if !client.confirmed { + log::trace!("Confirmed connection for Client {}", client.client_id); + client.confirmed = true; + } + return Ok(ServerResult::Payload { + client_id: client.client_id, + payload, + }); + } + Packet::KeepAlive { .. } => { + if !client.confirmed { + log::trace!("Confirmed connection for Client {}", client.client_id); + client.confirmed = true; + } + return Ok(ServerResult::None); + } + Packet::ConnectionRequest { .. } => { + // If a ConnectionRequest is received while connected, + // this might be a client trying to rejoin + + // We check the time when the connection started to detect rejoins + const REJOIN_TIMEOUT: Duration = Duration::from_secs(10); + if client.connection_started_time + REJOIN_TIMEOUT < self.current_time { let client_id = client.client_id; self.clients[slot] = None; - log::trace!("Client {} requested to disconnect", client_id); + log::trace!("Client {} is trying to rejoin", client_id); + + // Disconnect client to restart the connection process return Ok(ServerResult::ClientDisconnected { client_id, addr, payload: None, }); } - Packet::Payload(payload) => { - if !client.confirmed { - log::trace!("Confirmed connection for Client {}", client.client_id); - client.confirmed = true; - } - return Ok(ServerResult::Payload { - client_id: client.client_id, - payload, - }); - } - Packet::KeepAlive { .. } => { - if !client.confirmed { - log::trace!("Confirmed connection for Client {}", client.client_id); - client.confirmed = true; - } - return Ok(ServerResult::None); - } - _ => return Ok(ServerResult::None), - }, + + return Ok(ServerResult::None); + } _ => return Ok(ServerResult::None), } } @@ -738,12 +761,44 @@ mod tests { #[test] fn server_connection() { + fn client_connect(client: &mut NetcodeClient, client_addr: SocketAddr, user_data: [u8; 256], server: &mut NetcodeServer) { + assert!(!client.is_connected()); + let (client_packet, _) = client.update(NETCODE_SEND_RATE).unwrap(); + + let result = server.process_packet(client_addr, client_packet); + assert!(matches!(result, ServerResult::PacketToSend { .. })); + match result { + ServerResult::PacketToSend { payload, .. } => client.process_packet(payload), + _ => unreachable!(), + }; + + assert!(!client.is_connected()); + let (client_packet, _) = client.update(Duration::ZERO).unwrap(); + let result = server.process_packet(client_addr, client_packet); + + match result { + ServerResult::ClientConnected { + client_id: r_id, + user_data: r_data, + payload, + .. + } => { + assert_eq!(client.client_id(), r_id); + assert_eq!(user_data, *r_data); + client.process_packet(payload) + } + _ => unreachable!(), + }; + + assert!(client.is_connected()); + } + let mut server = new_server(); let server_addresses: Vec = server.addresses(); let user_data = generate_random_bytes(); - let expire_seconds = 3; - let client_id = 4; - let timeout_seconds = 5; + let expire_seconds = 500; + let client_id: u64 = 4; + let timeout_seconds = 30; let client_addr: SocketAddr = "127.0.0.1:3000".parse().unwrap(); let connect_token = ConnectToken::generate( Duration::ZERO, @@ -757,35 +812,9 @@ mod tests { ) .unwrap(); let client_auth = ClientAuthentication::Secure { connect_token }; - let mut client = NetcodeClient::new(Duration::ZERO, client_auth).unwrap(); - let (client_packet, _) = client.update(Duration::ZERO).unwrap(); - - let result = server.process_packet(client_addr, client_packet); - assert!(matches!(result, ServerResult::PacketToSend { .. })); - match result { - ServerResult::PacketToSend { payload, .. } => client.process_packet(payload), - _ => unreachable!(), - }; - - assert!(!client.is_connected()); - let (client_packet, _) = client.update(Duration::ZERO).unwrap(); - let result = server.process_packet(client_addr, client_packet); + let mut client = NetcodeClient::new(Duration::ZERO, client_auth.clone()).unwrap(); - match result { - ServerResult::ClientConnected { - client_id: r_id, - user_data: r_data, - payload, - .. - } => { - assert_eq!(client_id, r_id); - assert_eq!(user_data, *r_data); - client.process_packet(payload) - } - _ => unreachable!(), - }; - - assert!(client.is_connected()); + client_connect(&mut client, client_addr, user_data, &mut server); for _ in 0..3 { let payload = [7u8; 300]; @@ -818,6 +847,19 @@ mod tests { } assert!(server.is_client_connected(client_id)); + + // Test rejoin + server.update(Duration::from_secs(15)); // Client should be timing out + + // Create new client with same connect token + let mut client = NetcodeClient::new(Duration::ZERO, client_auth).unwrap(); + let (client_packet, _) = client.update(Duration::ZERO).unwrap(); + let result = server.process_packet(client_addr, client_packet); + + // The client should be disconnected now and the connection workflow should restart + assert!(matches!(result, ServerResult::ClientDisconnected { payload: None, .. })); + client_connect(&mut client, client_addr, user_data, &mut server); + let result = server.disconnect(client_id); match result { ServerResult::ClientDisconnected {