Skip to content

Commit

Permalink
feat(gateway)!: support zstd-stream
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Vilgot Mikael Fredenberg committed Jan 31, 2025
1 parent 7c05997 commit 2ebae88
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 96 deletions.
18 changes: 15 additions & 3 deletions book/src/chapter_1_crates/section_3_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,20 @@ This is enabled by default.
The `rustls-webpki-roots` feature enables [`tokio-websockets`]'
`rustls-webpki-roots` feature.

### Zlib
### Compression

#### Stock
`twilight-gateway` supports both Zlib and Zstandard transport compression.

#### Zlib

Zlib allows specifying two different backends.

##### Stock

The `zlib-stock` feature makes [flate2] use of the stock Zlib which is either
upstream or the one included with the operating system.

#### SIMD
##### SIMD

`zlib-simd` enables the use of [zlib-ng] which is a modern fork of zlib that in
most cases will be more effective. However, this will add an external dependency
Expand All @@ -74,6 +80,12 @@ on [cmake].
If both are enabled or if the `zlib` feature of [flate2] is enabled anywhere in
the dependency tree it will make use of that instead of [zlib-ng].

#### Zstandard

The `zstd` feature uses Facebook's zstd library to decompresses incoming messages.

This feature is mutually exclusive with the zlib features.

## Example

Starting a `Shard` and printing the contents of new messages as they come in:
Expand Down
4 changes: 3 additions & 1 deletion twilight-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ twilight-model = { default-features = false, path = "../twilight-model", version
flate2 = { default-features = false, optional = true, version = "1.0.24" }
twilight-http = { default-features = false, optional = true, path = "../twilight-http", version = "0.16.0-rc.1" }
simd-json = { default-features = false, features = ["serde_impl", "swar-number-parsing"], optional = true, version = "0.14.0-rc.3" }
zstd-safe = { default-features = false, optional = true, version = "7.2.1" }

[dev-dependencies]
anyhow = { default-features = false, features = ["std"], version = "1" }
Expand All @@ -44,7 +45,7 @@ tokio-stream = { default-features = false, version = "0.1" }
tracing-subscriber = { default-features = false, features = ["fmt", "tracing-log"], version = "0.3" }

[features]
default = ["rustls-platform-verifier", "rustls-ring", "twilight-http", "zlib-stock"]
default = ["rustls-platform-verifier", "rustls-ring", "twilight-http", "zstd"]
native-tls = ["tokio-websockets/native-tls", "tokio-websockets/openssl"]
rustls-platform-verifier = ["tokio-websockets/rustls-platform-verifier"]
rustls-native-roots = ["tokio-websockets/rustls-native-roots"]
Expand All @@ -54,6 +55,7 @@ rustls-aws_lc_rs = ["tokio-websockets/aws_lc_rs"]
rustls-aws-lc-rs = ["rustls-aws_lc_rs"] # Alias for convenience, underscores are preferred in the rustls stack
zlib-simd = ["dep:flate2", "flate2?/zlib-ng"]
zlib-stock = ["dep:flate2", "flate2?/zlib"]
zstd = ["dep:zstd-safe"]

[package.metadata.docs.rs]
rustdoc-args = ["--cfg", "docsrs"]
7 changes: 3 additions & 4 deletions twilight-gateway/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ from a `Fn(ShardId, ConfigBuilder) -> Config` closure, with the help of the
performance and widely used platforms
* none of the above: install your own via [`CryptoProvider::install_default`]
* `twilight-http` (*default*): enable the `stream::create_recommended` function
* Zlib (mutually exclusive)
* `zlib-stock` (*default*): [`flate2`]'s stock zlib implementation
* Compression (mutually exclusive)
* `zlib-stock`: [`flate2`]'s stock zlib implementation
* `zlib-simd`: use [`zlib-ng`] for zlib, may have better performance
* `zstd` (*default*): enable zstd transport compression

## Example

Expand Down Expand Up @@ -114,15 +115,13 @@ There are a few additional examples located in the

[`CryptoProvider::install_default`]: https://docs.rs/rustls/latest/rustls/crypto/struct.CryptoProvider.html#method.install_default
[`aws-lc-rs`]: https://crates.io/crates/aws-lc-rs
[`flate2`]: https://crates.io/crates/flate2
[`native-tls`]: https://crates.io/crates/native-tls
[`ring`]: https://crates.io/crates/ring
[`rustls`]: https://crates.io/crates/rustls
[`rustls-platform-verifier`]: https://crates.io/crates/rustls-platform-verifier
[`serde_json`]: https://crates.io/crates/serde_json
[`simd-json`]: https://crates.io/crates/simd-json
[`webpki-roots`]: https://crates.io/crates/webpki-roots
[`zlib-ng`]: https://github.com/zlib-ng/zlib-ng
[codecov badge]: https://img.shields.io/codecov/c/gh/twilight-rs/twilight?logo=codecov&style=for-the-badge&token=E9ERLJL0L2
[codecov link]: https://app.codecov.io/gh/twilight-rs/twilight/
[discord badge]: https://img.shields.io/discord/745809834183753828?color=%237289DA&label=discord%20server&logo=discord&style=for-the-badge
Expand Down
182 changes: 182 additions & 0 deletions twilight-gateway/src/compression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
//! Efficiently decompress Discord gateway messages.
use std::{error::Error, fmt};

/// An operation relating to compression failed.
#[derive(Debug)]
pub struct CompressionError {
/// Type of error.
kind: CompressionErrorType,
/// Source error if available.
source: Option<Box<dyn Error + Send + Sync>>,
}

impl CompressionError {
/// Immutable reference to the type of error that occurred.
#[must_use = "retrieving the type has no effect if left unused"]
pub const fn kind(&self) -> &CompressionErrorType {
&self.kind
}

/// Consume the error, returning the source error if there is any.
#[must_use = "consuming the error and retrieving the source has no effect if left unused"]
pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
self.source
}

/// Consume the error, returning the owned error type and the source error.
#[must_use = "consuming the error into its parts has no effect if left unused"]
pub fn into_parts(self) -> (CompressionErrorType, Option<Box<dyn Error + Send + Sync>>) {
(self.kind, None)
}

/// Shortcut to create a new error for a not UTF-8 message.
pub(crate) fn from_utf8_error(source: std::string::FromUtf8Error) -> Self {
Self {
kind: CompressionErrorType::NotUtf8,
source: Some(Box::new(source)),
}
}

#[cfg(feature = "zstd")]
/// Shortcut to create a new error for an erroneous status code.
pub(crate) fn from_code(code: usize) -> Self {
Self {
kind: CompressionErrorType::Decompressing,
source: Some(zstd_safe::get_error_name(code).into()),
}
}
}

impl fmt::Display for CompressionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
CompressionErrorType::Decompressing => f.write_str("message could not be decompressed"),
CompressionErrorType::NotUtf8 => f.write_str("decompressed message is not UTF-8"),
}
}
}

impl Error for CompressionError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_ref()
.map(|source| &**source as &(dyn Error + 'static))
}
}

/// Type of [`CompressionError`] that occurred.
#[derive(Debug)]
#[non_exhaustive]
pub enum CompressionErrorType {
/// Decompressing a frame failed.
Decompressing,
/// Decompressed message is not UTF-8.
NotUtf8,
}

#[cfg(feature = "zstd")]
pub struct Decompressor {
/// Common decompressed message buffer.
buffer: Box<[u8]>,
/// Reusable zstd decompression context.
ctx: zstd_safe::DCtx<'static>,
}

/// Gateway event decompressor.
#[cfg(feature = "zstd")]
impl Decompressor {
/// [`Self::buffer`]'s size.
const BUFFER_SIZE: usize = 32 * 1024;

/// Decompress a message.
///
/// # Errors
///
/// Returns a [`CompressionErrorType::Decompressing`] error type if the
/// message could not be decompressed.
///
/// Returns a [`CompressionErrorType::NotUtf8`] error type if the
/// decompressed message is not UTF-8.
pub(crate) fn decompress(&mut self, message: &[u8]) -> Result<String, CompressionError> {
let mut input = zstd_safe::InBuffer::around(message);

// Decompressed message. `Vec::extend_from_slice` efficiently allocates
// only what's necessary.
let mut decompressed = Vec::new();

loop {
let mut output = zstd_safe::OutBuffer::around(self.buffer.as_mut());

self.ctx
.decompress_stream(&mut output, &mut input)
.map_err(CompressionError::from_code)?;

decompressed.extend_from_slice(output.as_slice());

// Break when message has been fully decompressed.
if input.pos == input.src.len() && output.pos() != output.capacity() {
break;
}
}

String::from_utf8(decompressed).map_err(CompressionError::from_utf8_error)
}

/// Reset the decompressor's internal state.
pub(crate) fn reset(&mut self) {
self.ctx
.reset(zstd_safe::ResetDirective::SessionOnly)
.expect("resetting session is infallible");
}
}

#[cfg(feature = "zstd")]
impl fmt::Debug for Decompressor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Decompressor")
.field("buffer", &self.buffer)
.field("ctx", &"<decompression context>")
.finish()
}
}

impl Default for Decompressor {
fn default() -> Self {
Self {
buffer: vec![0; Decompressor::BUFFER_SIZE].into_boxed_slice(),
ctx: zstd_safe::DCtx::create(),
}
}
}

#[cfg(all(feature = "zstd", test))]
mod tests {
use super::Decompressor;

const MESSAGE: [u8; 117] = [
40, 181, 47, 253, 0, 64, 100, 3, 0, 66, 7, 25, 28, 112, 137, 115, 116, 40, 208, 203, 85,
255, 167, 74, 75, 126, 203, 222, 231, 255, 151, 18, 211, 212, 171, 144, 151, 210, 255, 51,
4, 49, 34, 71, 98, 2, 36, 253, 122, 141, 99, 203, 225, 11, 162, 47, 133, 241, 6, 201, 82,
245, 91, 206, 247, 164, 226, 156, 92, 108, 130, 123, 11, 95, 199, 15, 61, 179, 117, 157,
28, 37, 65, 64, 25, 250, 182, 8, 199, 205, 44, 73, 47, 19, 218, 45, 27, 14, 245, 202, 81,
82, 122, 167, 121, 71, 173, 61, 140, 190, 15, 3, 1, 0, 36, 74, 18,
];
const OUTPUT: &str = r#"{"t":null,"s":null,"op":10,"d":{"heartbeat_interval":41250,"_trace":["[\"gateway-prd-us-east1-c-7s4x\",{\"micros\":0.0}]"]}}"#;

#[test]
fn decompress_single_segment() {
let mut inflator = Decompressor::default();
assert_eq!(inflator.decompress(&MESSAGE).unwrap(), OUTPUT);
}

#[test]
fn reset() {
let mut inflator = Decompressor::default();
inflator.decompress(&MESSAGE[..MESSAGE.len() - 2]).unwrap();

assert!(inflator.decompress(&MESSAGE).is_err());
inflator.reset();
assert_eq!(inflator.decompress(&MESSAGE).unwrap(), OUTPUT);
}
}
10 changes: 5 additions & 5 deletions twilight-gateway/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Errors returned by gateway operations.
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
pub use crate::inflater::{CompressionError, CompressionErrorType};
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd", feature = "zstd"))]
pub use crate::compression::{CompressionError, CompressionErrorType};

use std::{
error::Error,
Expand Down Expand Up @@ -168,7 +168,7 @@ impl ReceiveMessageError {
}

/// Shortcut to create a new error for a message compression error.
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd", feature = "zstd"))]
pub(crate) fn from_compression(source: CompressionError) -> Self {
Self {
kind: ReceiveMessageErrorType::Compression,
Expand All @@ -180,7 +180,7 @@ impl ReceiveMessageError {
impl Display for ReceiveMessageError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match &self.kind {
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd", feature = "zstd"))]
ReceiveMessageErrorType::Compression => {
f.write_str("binary message could not be decompressed")
}
Expand Down Expand Up @@ -208,7 +208,7 @@ pub enum ReceiveMessageErrorType {
/// Binary message could not be decompressed.
///
/// The associated error downcasts to [`CompressionError`].
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd"))]
#[cfg(any(feature = "zlib-stock", feature = "zlib-simd", feature = "zstd"))]
Compression,
/// Gateway event could not be deserialized.
Deserializing {
Expand Down
Loading

0 comments on commit 2ebae88

Please sign in to comment.