Skip to content

Commit

Permalink
feat: response compression
Browse files Browse the repository at this point in the history
  • Loading branch information
merklefruit committed Jan 23, 2024
1 parent 84e2450 commit 8511a32
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 36 deletions.
38 changes: 34 additions & 4 deletions msg-socket/src/rep/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ use tracing::{debug, error, info, warn};

use crate::{rep::SocketState, AuthResult, Authenticator, PubError, RepOptions, Request};
use msg_transport::{PeerAddress, Transport};
use msg_wire::{auth, compression::try_decompress_payload, reqrep};
use msg_wire::{
auth,
compression::{try_decompress_payload, Compressor},
reqrep,
};

pub(crate) struct PeerState<T: AsyncRead + AsyncWrite> {
pending_requests: FuturesUnordered<PendingRequest>,
Expand All @@ -28,6 +32,7 @@ pub(crate) struct PeerState<T: AsyncRead + AsyncWrite> {
egress_queue: VecDeque<reqrep::Message>,
state: Arc<SocketState>,
should_flush: bool,
compressor: Option<Arc<dyn Compressor>>,
}

pub(crate) struct RepDriver<T: Transport> {
Expand All @@ -44,6 +49,9 @@ pub(crate) struct RepDriver<T: Transport> {
pub(crate) to_socket: mpsc::Sender<Request>,
/// Optional connection authenticator.
pub(crate) auth: Option<Arc<dyn Authenticator>>,
/// Optional message compressor. This is shared with the socket to keep
/// the API consistent with other socket types (e.g. `PubSocket`)
pub(crate) compressor: Option<Arc<dyn Compressor>>,
/// A set of pending incoming connections, represented by [`Transport::Accept`].
pub(super) conn_tasks: FuturesUnordered<T::Accept>,
/// A joinset of authentication tasks.
Expand Down Expand Up @@ -94,6 +102,7 @@ where
egress_queue: VecDeque::with_capacity(128),
state: Arc::clone(&this.state),
should_flush: false,
compressor: this.compressor.clone(),
}),
);
}
Expand Down Expand Up @@ -216,6 +225,7 @@ where
egress_queue: VecDeque::with_capacity(128),
state: Arc::clone(&self.state),
should_flush: false,
compressor: self.compressor.clone(),
}),
);
}
Expand Down Expand Up @@ -262,11 +272,31 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Stream for PeerState<T> {
}

// Then we check for completed requests, and push them onto the egress queue.
if let Poll::Ready(Some(Some((id, payload)))) =
if let Poll::Ready(Some(Some((id, mut payload)))) =
this.pending_requests.poll_next_unpin(cx)
{
// TODO: compress the response payload.
let compression_type = 0;
let mut compression_type = 0;
let len_before = payload.len();
if let Some(ref compressor) = this.compressor {
match compressor.compress(&payload) {
Ok(compressed) => {
payload = compressed;
compression_type = compressor.compression_type() as u8;
}
Err(e) => {
tracing::error!("Failed to compress message: {:?}", e);
continue;
}
}

tracing::debug!(
"Compressed message {} from {} to {} bytes",
id,
len_before,
payload.len()
)
}

let msg = reqrep::Message::new(id, compression_type, payload);
this.egress_queue.push_back(msg);

Expand Down
18 changes: 17 additions & 1 deletion msg-socket/src/rep/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,19 @@ pub enum PubError {
Transport(#[from] Box<dyn std::error::Error + Send + Sync>),
}

#[derive(Default)]
pub struct RepOptions {
/// The maximum number of concurrent clients.
max_clients: Option<usize>,
min_compress_size: usize,
}

impl Default for RepOptions {
fn default() -> Self {
Self {
max_clients: None,
min_compress_size: 8192,
}
}
}

impl RepOptions {
Expand All @@ -37,6 +46,13 @@ impl RepOptions {
self.max_clients = Some(max_clients);
self
}

/// Sets the minimum payload size for compression.
/// If the payload is smaller than this value, it will not be compressed.
pub fn min_compress_size(mut self, min_compress_size: usize) -> Self {
self.min_compress_size = min_compress_size;
self
}
}

/// The request socket state, shared between the backend task and the socket.
Expand Down
11 changes: 11 additions & 0 deletions msg-socket/src/rep/socket.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use futures::{stream::FuturesUnordered, Stream};
use msg_wire::compression::Compressor;
use std::{
io,
net::SocketAddr,
Expand Down Expand Up @@ -36,6 +37,8 @@ pub struct RepSocket<T: Transport> {
auth: Option<Arc<dyn Authenticator>>,
/// The local address this socket is bound to.
local_addr: Option<SocketAddr>,
/// Optional message compressor.
compressor: Option<Arc<dyn Compressor>>,
}

impl<T> RepSocket<T>
Expand All @@ -56,6 +59,7 @@ where
options: Arc::new(options),
state: Arc::new(SocketState::default()),
auth: None,
compressor: None,
}
}

Expand All @@ -65,6 +69,12 @@ where
self
}

/// Sets the message compressor for this socket.
pub fn with_compressor<C: Compressor + 'static>(mut self, compressor: C) -> Self {
self.compressor = Some(Arc::new(compressor));
self
}

/// Binds the socket to the given address. This spawns the socket driver task.
pub async fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> Result<(), PubError> {
let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE);
Expand Down Expand Up @@ -103,6 +113,7 @@ where
auth: self.auth.take(),
auth_tasks: JoinSet::new(),
conn_tasks: FuturesUnordered::new(),
compressor: self.compressor.take(),
};

tokio::spawn(backend);
Expand Down
48 changes: 44 additions & 4 deletions msg-socket/src/req/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@ use msg_transport::Transport;
use rustc_hash::FxHashMap;
use std::{
collections::VecDeque,
io,
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tokio::sync::{mpsc, oneshot};
use tokio_util::codec::Framed;

use crate::req::SocketState;
use crate::{req::SocketState, ReqMessage};

use super::{Command, ReqError, ReqOptions};
use msg_wire::reqrep;
use msg_wire::{
compression::{try_decompress_payload, Compressor},
reqrep,
};
use std::time::Instant;
use tokio::time::Interval;

Expand Down Expand Up @@ -42,6 +46,9 @@ pub(crate) struct ReqDriver<T: Transport> {
pub(crate) flush_interval: Option<tokio::time::Interval>,
/// Whether or not the connection should be flushed
pub(crate) should_flush: bool,
/// Optional message compressor. This is shared with the socket to keep
/// the API consistent with other socket types (e.g. `PubSocket`)
pub(crate) compressor: Option<Arc<dyn Compressor>>,
}

/// A pending request that is waiting for a response.
Expand All @@ -57,7 +64,22 @@ impl<T: Transport> ReqDriver<T> {
if let Some(pending) = self.pending_requests.remove(&msg.id()) {
let rtt = pending.start.elapsed().as_micros() as usize;
let size = msg.size();
let _ = pending.sender.send(Ok(msg.into_payload()));
let compression_type = msg.header().compression_type();
let mut payload = msg.into_payload();

// decompress the response
match try_decompress_payload(compression_type, payload) {
Ok(decompressed) => payload = decompressed,
Err(e) => {
tracing::error!("Failed to decompress response payload: {:?}", e);
let _ = pending.sender.send(Err(ReqError::Wire(reqrep::Error::Io(
io::Error::new(io::ErrorKind::Other, "Failed to decompress response"),
))));
return;
}
}

let _ = pending.sender.send(Ok(payload));

// Update stats
self.socket_state.stats.update_rtt(rtt);
Expand Down Expand Up @@ -178,7 +200,25 @@ where
Poll::Ready(Some(Command::Send { message, response })) => {
// Queue the message for sending
let start = std::time::Instant::now();
let msg = message.into_wire(this.id_counter);

let mut msg = ReqMessage::new(message);

let len_before = msg.payload().len();
if len_before > this.options.min_compress_size {
if let Some(ref compressor) = this.compressor {
if let Err(e) = msg.compress(compressor.as_ref()) {
tracing::error!("Failed to compress message: {:?}", e);
}

tracing::debug!(
"Compressed message from {} to {} bytes",
len_before,
msg.payload().len()
);
}
}

let msg = msg.into_wire(this.id_counter);
let msg_id = msg.id();
this.id_counter = this.id_counter.wrapping_add(1);
this.egress_queue.push_back(msg);
Expand Down
2 changes: 1 addition & 1 deletion msg-socket/src/req/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub enum ReqError {

pub enum Command {
Send {
message: ReqMessage,
message: Bytes,
response: oneshot::Sender<Result<Bytes, ReqError>>,
},
}
Expand Down
22 changes: 3 additions & 19 deletions msg-socket/src/req/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ use tokio::net::{lookup_host, ToSocketAddrs};
use tokio::sync::{mpsc, oneshot};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
use tracing::debug;

use msg_transport::Transport;
use msg_wire::{auth, reqrep};

use super::{Command, ReqDriver, ReqError, ReqOptions, DEFAULT_BUFFER_SIZE};
use crate::backoff::ExponentialBackoff;
use crate::ReqMessage;
use crate::{req::stats::SocketStats, req::SocketState};

/// The request socket.
Expand All @@ -27,7 +25,7 @@ pub struct ReqSocket<T: Transport> {
options: Arc<ReqOptions>,
/// Socket state. This is shared with the backend task.
state: Arc<SocketState>,
/// Optional message compressor.
/// Optional message compressor. This is shared with the backend task.
// NOTE: for now we're using dynamic dispatch, since using generics here
// complicates the API a lot. We can always change this later for perf reasons.
compressor: Option<Arc<dyn Compressor>>,
Expand Down Expand Up @@ -64,26 +62,11 @@ where
pub async fn request(&self, message: Bytes) -> Result<Bytes, ReqError> {
let (response_tx, response_rx) = oneshot::channel();

let mut msg = ReqMessage::new(message);

let len_before = msg.payload().len();
if len_before > self.options.min_compress_size {
if let Some(ref compressor) = self.compressor {
msg.compress(compressor.as_ref())?;

debug!(
"Compressed message from {} to {} bytes",
len_before,
msg.payload().len()
);
}
}

self.to_driver
.as_ref()
.ok_or(ReqError::SocketClosed)?
.send(Command::Send {
message: msg,
message,
response: response_tx,
})
.await
Expand Down Expand Up @@ -153,6 +136,7 @@ where
)),
flush_interval: self.options.flush_interval.map(tokio::time::interval),
should_flush: false,
compressor: self.compressor.clone(),
};

// Spawn the backend task
Expand Down
10 changes: 6 additions & 4 deletions msg-wire/src/compression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ pub trait Decompressor: Send + Sync + Unpin + 'static {
/// - If the decompression fails
pub fn try_decompress_payload(compression_type: u8, data: Bytes) -> Result<Bytes, io::Error> {
match CompressionType::try_from(compression_type) {
Ok(CompressionType::None) => Ok(data),
Ok(CompressionType::Gzip) => GzipDecompressor.decompress(data.as_ref()),
Ok(CompressionType::Zstd) => ZstdDecompressor.decompress(data.as_ref()),
Ok(CompressionType::Snappy) => SnappyDecompressor.decompress(data.as_ref()),
Ok(supported_compression_type) => match supported_compression_type {
CompressionType::None => Ok(data),
CompressionType::Gzip => GzipDecompressor.decompress(data.as_ref()),
CompressionType::Zstd => ZstdDecompressor.decompress(data.as_ref()),
CompressionType::Snappy => SnappyDecompressor.decompress(data.as_ref()),
},
Err(unsupported_compression_type) => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unsupported compression type: {unsupported_compression_type}"),
Expand Down
13 changes: 10 additions & 3 deletions msg/examples/reqrep_compression.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use bytes::Bytes;
use msg_socket::ReqOptions;
use msg_socket::{RepOptions, ReqOptions};
use msg_wire::compression::GzipCompressor;
use tokio_stream::StreamExt;

Expand All @@ -8,13 +8,20 @@ use msg::{tcp::Tcp, RepSocket, ReqSocket};
#[tokio::main]
async fn main() {
// Initialize the reply socket (server side) with a transport
let mut rep = RepSocket::new(Tcp::default());
// and a minimum compresion size of 0 bytes for responses
let mut rep =
RepSocket::with_options(Tcp::default(), RepOptions::default().min_compress_size(0))
// Enable Gzip compression (compression level 6)
.with_compressor(GzipCompressor::new(6));
rep.bind("0.0.0.0:4444").await.unwrap();

// Initialize the request socket (client side) with a transport
// and a minimum compresion size of 0 bytes for requests
let mut req =
ReqSocket::with_options(Tcp::default(), ReqOptions::default().min_compress_size(0))
// Enable Gzip compression (compression level 6)
// Enable Gzip compression (compression level 6).
// The request and response sockets *don't have to*
// use the same compression algorithm or level.
.with_compressor(GzipCompressor::new(6));

req.connect("0.0.0.0:4444").await.unwrap();
Expand Down

0 comments on commit 8511a32

Please sign in to comment.