Skip to content

Commit

Permalink
refactor(gateway, http, lavalink): Switch to fastrand and tokio-webso…
Browse files Browse the repository at this point in the history
…ckets (#2239)

This switches our websocket library from tokio-tungstenite (https://github.com/snapview/tokio-tungstenite/) to tokio-websockets (https://github.com/Gelbpunkt/tokio-websockets).

tokio-tungstenite is very slow with releases and has a bloated dependency stack, while tokio-websockets is much more lightweight, more strictly RFC compliant, performs substantially better under heavy load and has support for SIMD. This also allows us to drop the custom TlsConnector logic.

Additionally, this switches out rand for fastrand since tokio-websockets supports both and fastrand is "random enough" for our usage. The public API remains completely untouched by these changes.

Signed-off-by: Jens Reidel <[email protected]>
  • Loading branch information
Gelbpunkt authored Sep 20, 2023
1 parent cbe1d10 commit 0914338
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 477 deletions.
19 changes: 6 additions & 13 deletions twilight-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ version = "0.15.4"

[dependencies]
bitflags = { default-features = false, version = "2" }
futures-util = { default-features = false, features = ["std"], version = "0.3" }
rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" }
fastrand = { default-features = false, features = ["std"], version = "2" }
futures-util = { default-features = false, features = ["sink", "std"], version = "0.3" }
serde = { default-features = false, features = ["derive"], version = "1" }
serde_json = { default-features = false, features = ["std"], version = "1" }
tokio = { default-features = false, features = ["net", "rt", "sync", "time"], version = "1.19" }
tokio-tungstenite = { default-features = false, features = ["connect"], version = "0.19" }
tokio-websockets = { default-features = false, features = ["client", "fastrand", "sha1_smol", "simd"], version = "0.4" }
tracing = { default-features = false, features = ["std", "attributes"], version = "0.1" }
twilight-gateway-queue = { default-features = false, path = "../twilight-gateway-queue", version = "0.15.4" }
twilight-model = { default-features = false, path = "../twilight-model", version = "0.15.4" }
Expand All @@ -34,13 +34,6 @@ flate2 = { default-features = false, optional = true, version = "1.0.24" }
twilight-http = { default-features = false, optional = true, path = "../twilight-http", version = "0.15.4" }
simd-json = { default-features = false, features = ["serde_impl", "swar-number-parsing"], optional = true, version = ">=0.4, <0.11" }

# TLS libraries
# They are needed to track what is used in tokio-tungstenite
native-tls = { default-features = false, optional = true, version = "0.2.8" }
rustls-native-certs = { default-features = false, optional = true, version = "0.6" }
rustls-tls = { default-features = false, optional = true, package = "rustls", version = "0.21" }
webpki-roots = { default-features = false, optional = true, version = "0.23" }

[dev-dependencies]
anyhow = { default-features = false, features = ["std"], version = "1" }
futures = { default-features = false, version = "0.3" }
Expand All @@ -51,9 +44,9 @@ tracing-subscriber = { default-features = false, features = ["fmt", "tracing-log

[features]
default = ["rustls-native-roots", "twilight-http", "zlib-stock"]
native = ["dep:native-tls", "tokio-tungstenite/native-tls"]
rustls-native-roots = ["dep:rustls-tls", "dep:rustls-native-certs", "tokio-tungstenite/rustls-tls-native-roots"]
rustls-webpki-roots = ["dep:rustls-tls", "dep:webpki-roots", "tokio-tungstenite/rustls-tls-webpki-roots"]
native = ["tokio-websockets/native-tls", "tokio-websockets/openssl"]
rustls-native-roots = ["tokio-websockets/rustls-native-roots"]
rustls-webpki-roots = ["tokio-websockets/rustls-webpki-roots"]
zlib-simd = ["dep:flate2", "flate2?/zlib-ng"]
zlib-stock = ["dep:flate2", "flate2?/zlib"]

Expand Down
8 changes: 4 additions & 4 deletions twilight-gateway/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
use crate::{
queue::{InMemoryQueue, Queue},
tls::TlsContainer,
EventTypeFlags, Session,
};
use std::{
fmt::{Debug, Formatter, Result as FmtResult},
sync::Arc,
};
use tokio_websockets::Connector;
use twilight_model::gateway::{
payload::outgoing::{identify::IdentifyProperties, update_presence::UpdatePresencePayload},
Intents,
Expand Down Expand Up @@ -69,7 +69,7 @@ pub struct Config {
/// TLS connector for Websocket connections.
// We need this to be public so [`stream`] can re-use TLS on multiple shards
// if unconfigured.
tls: TlsContainer,
tls: Arc<Connector>,
/// Token used to authenticate when identifying with the gateway.
///
/// The token is prefixed with "Bot ", which is required by Discord for
Expand Down Expand Up @@ -147,7 +147,7 @@ impl Config {
}

/// Immutable reference to the TLS connector in use by the shard.
pub(crate) const fn tls(&self) -> &TlsContainer {
pub(crate) fn tls(&self) -> &Connector {
&self.tls
}

Expand Down Expand Up @@ -195,7 +195,7 @@ impl ConfigBuilder {
queue: Arc::new(InMemoryQueue::default()),
ratelimit_messages: true,
session: None,
tls: TlsContainer::new().unwrap(),
tls: Arc::new(Connector::new().unwrap()),
token: Token::new(token.into_boxed_str()),
},
}
Expand Down
50 changes: 27 additions & 23 deletions twilight-gateway/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
//! Utilities for creating Websocket connections.
use crate::{error::ReceiveMessageError, tls::TlsContainer, API_VERSION};
use crate::{
error::{ReceiveMessageError, ReceiveMessageErrorType},
API_VERSION,
};
use std::fmt::{Display, Formatter, Result as FmtResult};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, MaybeTlsStream, WebSocketStream};
use tokio_websockets::{ClientBuilder, Connector, Limits, MaybeTlsStream, WebsocketStream};

/// Query argument with zlib-stream enabled.
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
Expand All @@ -16,29 +19,12 @@ const COMPRESSION_FEATURES: &str = "";
/// URL of the Discord gateway.
const GATEWAY_URL: &str = "wss://gateway.discord.gg";

/// Configuration used for Websocket connections.
///
/// `max_frame_size` and `max_message_queue` limits are disabled because
/// Discord is not a malicious actor and having a limit has caused problems on
/// large [`GuildCreate`] payloads.
///
/// `accept_unmasked_frames` and `max_send_queue` are set to their
/// defaults.
///
/// [`GuildCreate`]: twilight_model::gateway::payload::incoming::GuildCreate
const WEBSOCKET_CONFIG: WebSocketConfig = WebSocketConfig {
accept_unmasked_frames: false,
max_frame_size: None,
max_message_size: None,
max_send_queue: None,
};

/// [`tokio_tungstenite`] library Websocket connection.
/// [`tokio_websockets`] library Websocket connection.
///
/// Connections are used by [`Shard`]s when reconnecting.
///
/// [`Shard`]: crate::Shard
pub type Connection = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub type Connection = WebsocketStream<MaybeTlsStream<TcpStream>>;

/// Formatter for a gateway URL, with the API version and compression features
/// specified.
Expand Down Expand Up @@ -93,12 +79,30 @@ impl Display for ConnectionUrl<'_> {
#[tracing::instrument(skip_all)]
pub async fn connect(
maybe_gateway_url: Option<&str>,
tls: &TlsContainer,
tls: &Connector,
) -> Result<Connection, ReceiveMessageError> {
let url = ConnectionUrl::new(maybe_gateway_url).to_string();

// Limits to impose on Websocket connections.
//
// `max_payload_len` limit is disabled because Discord is not a malicious
// actor and having a limit has caused problems on large `GuildCreate`
// payloads.
let limits = Limits::default().max_payload_len(None);

tracing::debug!(?url, "shaking hands with gateway");
let stream = tls.connect(&url, WEBSOCKET_CONFIG).await?;

let (stream, _) = ClientBuilder::new()
.uri(&url)
.expect("Gateway URL must be valid")
.limits(limits)
.connector(tls)
.connect()
.await
.map_err(|source| ReceiveMessageError {
kind: ReceiveMessageErrorType::Reconnect,
source: Some(Box::new(source)),
})?;

Ok(stream)
}
Expand Down
1 change: 0 additions & 1 deletion twilight-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ mod message;
mod ratelimiter;
mod session;
mod shard;
mod tls;

#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
pub use self::inflater::Inflater;
Expand Down
52 changes: 27 additions & 25 deletions twilight-gateway/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
//! input will not be checked and will be passed directly to the underlying
//! websocket library.
use tokio_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame as TungsteniteCloseFrame},
Message as TungsteniteMessage,
};
use std::borrow::Cow;

use tokio_websockets::{CloseCode, Message as WebsocketMessage};
use twilight_model::gateway::CloseFrame;

/// Message to send over the connection to the remote.
Expand All @@ -25,33 +24,36 @@ pub enum Message {
}

impl Message {
/// Convert a `tungstenite` websocket message into a `twilight` websocket
/// Convert a `tokio-websockets` websocket message into a `twilight` websocket
/// message.
pub(crate) fn from_tungstenite(tungstenite: TungsteniteMessage) -> Option<Self> {
match tungstenite {
TungsteniteMessage::Close(frame) => Some(Self::Close(frame.map(|frame| CloseFrame {
code: frame.code.into(),
reason: frame.reason,
}))),
TungsteniteMessage::Text(string) => Some(Self::Text(string)),
TungsteniteMessage::Binary(_)
| TungsteniteMessage::Frame(_)
| TungsteniteMessage::Ping(_)
| TungsteniteMessage::Pong(_) => None,
pub(crate) fn from_websocket_msg(msg: &WebsocketMessage) -> Option<Self> {
if msg.is_close() {
let (code, reason) = msg.as_close().unwrap();

let frame = (code == CloseCode::NO_STATUS_RECEIVED).then(|| CloseFrame {
code: code.into(),
reason: Cow::Owned(reason.to_string()),
});

Some(Self::Close(frame))
} else if msg.is_text() {
Some(Self::Text(msg.as_text().unwrap().to_owned()))
} else {
None
}
}

/// Convert a `twilight` websocket message into a `tungstenite` websocket
/// Convert a `twilight` websocket message into a `tokio-websockets` websocket
/// message.
pub(crate) fn into_tungstenite(self) -> TungsteniteMessage {
pub(crate) fn into_websocket_msg(self) -> WebsocketMessage {
match self {
Self::Close(frame) => {
TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
code: CloseCode::from(frame.code),
reason: frame.reason,
}))
}
Self::Text(string) => TungsteniteMessage::Text(string),
Self::Close(frame) => WebsocketMessage::close(
frame
.as_ref()
.and_then(|f| CloseCode::try_from(f.code).ok()),
frame.map(|f| f.reason).as_deref().unwrap_or_default(),
),
Self::Text(string) => WebsocketMessage::text(string),
}
}
}
Expand Down
38 changes: 17 additions & 21 deletions twilight-gateway/src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ use tokio::{
sync::oneshot,
time::{self, Duration, Instant, Interval, MissedTickBehavior},
};
use tokio_tungstenite::tungstenite::{Error as TungsteniteError, Message as TungsteniteMessage};
use tokio_websockets::{Error as WebsocketError, Message as WebsocketMessage};
use twilight_model::gateway::{
event::{Event, GatewayEventDeserializer},
payload::{
Expand Down Expand Up @@ -568,18 +568,16 @@ impl Shard {
/// Identify with the gateway.
Identify,
/// Handle this incoming gateway message.
Message(Option<Result<TungsteniteMessage, TungsteniteError>>),
Message(Option<Result<WebsocketMessage, WebsocketError>>),
}

match self.status {
ConnectionStatus::Disconnected {
close_code,
reconnect_attempts,
} => {
// The shard is considered disconnected after having received a
// close frame or encountering a websocket error, but it should
// only reconnect after the underlying TCP connection is closed
// by the server (having returned `Ok(None)`).
// The shard should should only reconnect after the gateway
// closes the underlying TCP connection.
if self.connection.is_none() {
self.reconnect(close_code, reconnect_attempts).await?;
}
Expand Down Expand Up @@ -672,17 +670,17 @@ impl Shard {
match poll_fn(next_action).await {
Action::Message(Some(Ok(message))) => {
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
if let TungsteniteMessage::Binary(bytes) = &message {
if message.is_binary() {
if let Some(decompressed) = self
.inflater
.inflate(bytes)
.inflate(message.as_payload())
.map_err(ReceiveMessageError::from_compression)?
{
tracing::trace!(%decompressed);
break Message::Text(decompressed);
};
}
if let Some(message) = Message::from_tungstenite(message) {
if let Some(message) = Message::from_websocket_msg(&message) {
break message;
}
}
Expand All @@ -696,7 +694,7 @@ impl Shard {
feature = "rustls-native-roots",
feature = "rustls-webpki-roots"
))]
Action::Message(Some(Err(TungsteniteError::Io(e))))
Action::Message(Some(Err(WebsocketError::Io(e))))
if e.kind() == IoErrorKind::UnexpectedEof
// Assert we're directly connected to Discord's gateway.
&& self.config.proxy_url().is_none()
Expand Down Expand Up @@ -726,11 +724,11 @@ impl Shard {
ConnectionStatus::FatallyClosed { close_code } => {
return Err(ReceiveMessageError::from_fatally_closed(close_code))
}
_ => unreachable!(
"stream ended because websocket is closed (received close frame sets \
status to disconnected or fatally closed) or because it errored (which \
also sets status to disconnected)"
),
_ => {
// Abnormal closure without close frame exchange.
self.disconnect(CloseInitiator::None);
self.reconnect(None, 0).await?;
}
};

continue;
Expand Down Expand Up @@ -804,13 +802,11 @@ impl Shard {

match &message {
Message::Close(frame) => {
// Tungstenite automatically replies to the close message.
// tokio-websockets automatically replies to the close message.
tracing::debug!(?frame, "received websocket close message");
// Don't run `disconnect` if we initiated the close.
if !self.status.is_disconnected() {
self.disconnect(CloseInitiator::Gateway(
frame.as_ref().map(|frame| frame.code),
));
self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
}
}
Message::Text(event) => {
Expand Down Expand Up @@ -913,7 +909,7 @@ impl Shard {
kind: SendErrorType::Sending,
source: None,
})?
.send(message.into_tungstenite())
.send(message.into_websocket_msg())
.await
.map_err(|source| SendError {
kind: SendErrorType::Sending,
Expand Down Expand Up @@ -1116,7 +1112,7 @@ impl Shard {
let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
// First heartbeat should have some jitter, see
// https://discord.com/developers/docs/topics/gateway#heartbeat-interval
let jitter = heartbeat_interval.mul_f64(rand::random());
let jitter = heartbeat_interval.mul_f64(fastrand::f64());
tracing::debug!(?heartbeat_interval, ?jitter, "received hello");

if self.config().ratelimit_messages() {
Expand Down
Loading

0 comments on commit 0914338

Please sign in to comment.