diff --git a/msg-socket/src/rep/driver.rs b/msg-socket/src/rep/driver.rs index eeea602..e547ab8 100644 --- a/msg-socket/src/rep/driver.rs +++ b/msg-socket/src/rep/driver.rs @@ -19,7 +19,11 @@ use tracing::{debug, error, info, warn}; use crate::{rep::SocketState, AuthResult, Authenticator, PubError, RepOptions, Request}; use msg_transport::{PeerAddress, Transport}; -use msg_wire::{auth, compression::try_decompress_payload, reqrep}; +use msg_wire::{ + auth, + compression::{try_decompress_payload, Compressor}, + reqrep, +}; pub(crate) struct PeerState { pending_requests: FuturesUnordered, @@ -28,6 +32,7 @@ pub(crate) struct PeerState { egress_queue: VecDeque, state: Arc, should_flush: bool, + compressor: Option>, } pub(crate) struct RepDriver { @@ -44,6 +49,9 @@ pub(crate) struct RepDriver { pub(crate) to_socket: mpsc::Sender, /// Optional connection authenticator. pub(crate) auth: Option>, + /// Optional message compressor. This is shared with the socket to keep + /// the API consistent with other socket types (e.g. `PubSocket`) + pub(crate) compressor: Option>, /// A set of pending incoming connections, represented by [`Transport::Accept`]. pub(super) conn_tasks: FuturesUnordered, /// A joinset of authentication tasks. @@ -94,6 +102,7 @@ where egress_queue: VecDeque::with_capacity(128), state: Arc::clone(&this.state), should_flush: false, + compressor: this.compressor.clone(), }), ); } @@ -216,6 +225,7 @@ where egress_queue: VecDeque::with_capacity(128), state: Arc::clone(&self.state), should_flush: false, + compressor: self.compressor.clone(), }), ); } @@ -262,11 +272,31 @@ impl Stream for PeerState { } // Then we check for completed requests, and push them onto the egress queue. - if let Poll::Ready(Some(Some((id, payload)))) = + if let Poll::Ready(Some(Some((id, mut payload)))) = this.pending_requests.poll_next_unpin(cx) { - // TODO: compress the response payload. - let compression_type = 0; + let mut compression_type = 0; + let len_before = payload.len(); + if let Some(ref compressor) = this.compressor { + match compressor.compress(&payload) { + Ok(compressed) => { + payload = compressed; + compression_type = compressor.compression_type() as u8; + } + Err(e) => { + tracing::error!("Failed to compress message: {:?}", e); + continue; + } + } + + tracing::debug!( + "Compressed message {} from {} to {} bytes", + id, + len_before, + payload.len() + ) + } + let msg = reqrep::Message::new(id, compression_type, payload); this.egress_queue.push_back(msg); diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index 3f96e3f..54b3a52 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -25,10 +25,19 @@ pub enum PubError { Transport(#[from] Box), } -#[derive(Default)] pub struct RepOptions { /// The maximum number of concurrent clients. max_clients: Option, + min_compress_size: usize, +} + +impl Default for RepOptions { + fn default() -> Self { + Self { + max_clients: None, + min_compress_size: 8192, + } + } } impl RepOptions { @@ -37,6 +46,13 @@ impl RepOptions { self.max_clients = Some(max_clients); self } + + /// Sets the minimum payload size for compression. + /// If the payload is smaller than this value, it will not be compressed. + pub fn min_compress_size(mut self, min_compress_size: usize) -> Self { + self.min_compress_size = min_compress_size; + self + } } /// The request socket state, shared between the backend task and the socket. diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index d18d5ad..09ce53b 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -1,4 +1,5 @@ use futures::{stream::FuturesUnordered, Stream}; +use msg_wire::compression::Compressor; use std::{ io, net::SocketAddr, @@ -36,6 +37,8 @@ pub struct RepSocket { auth: Option>, /// The local address this socket is bound to. local_addr: Option, + /// Optional message compressor. + compressor: Option>, } impl RepSocket @@ -56,6 +59,7 @@ where options: Arc::new(options), state: Arc::new(SocketState::default()), auth: None, + compressor: None, } } @@ -65,6 +69,12 @@ where self } + /// Sets the message compressor for this socket. + pub fn with_compressor(mut self, compressor: C) -> Self { + self.compressor = Some(Arc::new(compressor)); + self + } + /// Binds the socket to the given address. This spawns the socket driver task. pub async fn bind(&mut self, addr: A) -> Result<(), PubError> { let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE); @@ -103,6 +113,7 @@ where auth: self.auth.take(), auth_tasks: JoinSet::new(), conn_tasks: FuturesUnordered::new(), + compressor: self.compressor.take(), }; tokio::spawn(backend); diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index 82e2b3e..a4c7689 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -4,6 +4,7 @@ use msg_transport::Transport; use rustc_hash::FxHashMap; use std::{ collections::VecDeque, + io, pin::Pin, sync::Arc, task::{ready, Context, Poll}, @@ -11,10 +12,13 @@ use std::{ use tokio::sync::{mpsc, oneshot}; use tokio_util::codec::Framed; -use crate::req::SocketState; +use crate::{req::SocketState, ReqMessage}; use super::{Command, ReqError, ReqOptions}; -use msg_wire::reqrep; +use msg_wire::{ + compression::{try_decompress_payload, Compressor}, + reqrep, +}; use std::time::Instant; use tokio::time::Interval; @@ -42,6 +46,9 @@ pub(crate) struct ReqDriver { pub(crate) flush_interval: Option, /// Whether or not the connection should be flushed pub(crate) should_flush: bool, + /// Optional message compressor. This is shared with the socket to keep + /// the API consistent with other socket types (e.g. `PubSocket`) + pub(crate) compressor: Option>, } /// A pending request that is waiting for a response. @@ -57,7 +64,22 @@ impl ReqDriver { if let Some(pending) = self.pending_requests.remove(&msg.id()) { let rtt = pending.start.elapsed().as_micros() as usize; let size = msg.size(); - let _ = pending.sender.send(Ok(msg.into_payload())); + let compression_type = msg.header().compression_type(); + let mut payload = msg.into_payload(); + + // decompress the response + match try_decompress_payload(compression_type, payload) { + Ok(decompressed) => payload = decompressed, + Err(e) => { + tracing::error!("Failed to decompress response payload: {:?}", e); + let _ = pending.sender.send(Err(ReqError::Wire(reqrep::Error::Io( + io::Error::new(io::ErrorKind::Other, "Failed to decompress response"), + )))); + return; + } + } + + let _ = pending.sender.send(Ok(payload)); // Update stats self.socket_state.stats.update_rtt(rtt); @@ -178,7 +200,25 @@ where Poll::Ready(Some(Command::Send { message, response })) => { // Queue the message for sending let start = std::time::Instant::now(); - let msg = message.into_wire(this.id_counter); + + let mut msg = ReqMessage::new(message); + + let len_before = msg.payload().len(); + if len_before > this.options.min_compress_size { + if let Some(ref compressor) = this.compressor { + if let Err(e) = msg.compress(compressor.as_ref()) { + tracing::error!("Failed to compress message: {:?}", e); + } + + tracing::debug!( + "Compressed message from {} to {} bytes", + len_before, + msg.payload().len() + ); + } + } + + let msg = msg.into_wire(this.id_counter); let msg_id = msg.id(); this.id_counter = this.id_counter.wrapping_add(1); this.egress_queue.push_back(msg); diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index c4e10d6..ed0a5cc 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -36,7 +36,7 @@ pub enum ReqError { pub enum Command { Send { - message: ReqMessage, + message: Bytes, response: oneshot::Sender>, }, } diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 9b7abb4..376a91d 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -7,14 +7,12 @@ use tokio::net::{lookup_host, ToSocketAddrs}; use tokio::sync::{mpsc, oneshot}; use tokio_stream::StreamExt; use tokio_util::codec::Framed; -use tracing::debug; use msg_transport::Transport; use msg_wire::{auth, reqrep}; use super::{Command, ReqDriver, ReqError, ReqOptions, DEFAULT_BUFFER_SIZE}; use crate::backoff::ExponentialBackoff; -use crate::ReqMessage; use crate::{req::stats::SocketStats, req::SocketState}; /// The request socket. @@ -27,7 +25,7 @@ pub struct ReqSocket { options: Arc, /// Socket state. This is shared with the backend task. state: Arc, - /// Optional message compressor. + /// Optional message compressor. This is shared with the backend task. // NOTE: for now we're using dynamic dispatch, since using generics here // complicates the API a lot. We can always change this later for perf reasons. compressor: Option>, @@ -64,26 +62,11 @@ where pub async fn request(&self, message: Bytes) -> Result { let (response_tx, response_rx) = oneshot::channel(); - let mut msg = ReqMessage::new(message); - - let len_before = msg.payload().len(); - if len_before > self.options.min_compress_size { - if let Some(ref compressor) = self.compressor { - msg.compress(compressor.as_ref())?; - - debug!( - "Compressed message from {} to {} bytes", - len_before, - msg.payload().len() - ); - } - } - self.to_driver .as_ref() .ok_or(ReqError::SocketClosed)? .send(Command::Send { - message: msg, + message, response: response_tx, }) .await @@ -153,6 +136,7 @@ where )), flush_interval: self.options.flush_interval.map(tokio::time::interval), should_flush: false, + compressor: self.compressor.clone(), }; // Spawn the backend task diff --git a/msg-wire/src/compression/mod.rs b/msg-wire/src/compression/mod.rs index 7c9120c..9bfcb4f 100644 --- a/msg-wire/src/compression/mod.rs +++ b/msg-wire/src/compression/mod.rs @@ -58,10 +58,12 @@ pub trait Decompressor: Send + Sync + Unpin + 'static { /// - If the decompression fails pub fn try_decompress_payload(compression_type: u8, data: Bytes) -> Result { match CompressionType::try_from(compression_type) { - Ok(CompressionType::None) => Ok(data), - Ok(CompressionType::Gzip) => GzipDecompressor.decompress(data.as_ref()), - Ok(CompressionType::Zstd) => ZstdDecompressor.decompress(data.as_ref()), - Ok(CompressionType::Snappy) => SnappyDecompressor.decompress(data.as_ref()), + Ok(supported_compression_type) => match supported_compression_type { + CompressionType::None => Ok(data), + CompressionType::Gzip => GzipDecompressor.decompress(data.as_ref()), + CompressionType::Zstd => ZstdDecompressor.decompress(data.as_ref()), + CompressionType::Snappy => SnappyDecompressor.decompress(data.as_ref()), + }, Err(unsupported_compression_type) => Err(io::Error::new( io::ErrorKind::InvalidData, format!("unsupported compression type: {unsupported_compression_type}"), diff --git a/msg/examples/reqrep_compression.rs b/msg/examples/reqrep_compression.rs index 5ac6882..caabc8c 100644 --- a/msg/examples/reqrep_compression.rs +++ b/msg/examples/reqrep_compression.rs @@ -1,5 +1,5 @@ use bytes::Bytes; -use msg_socket::ReqOptions; +use msg_socket::{RepOptions, ReqOptions}; use msg_wire::compression::GzipCompressor; use tokio_stream::StreamExt; @@ -8,13 +8,20 @@ use msg::{tcp::Tcp, RepSocket, ReqSocket}; #[tokio::main] async fn main() { // Initialize the reply socket (server side) with a transport - let mut rep = RepSocket::new(Tcp::default()); + // and a minimum compresion size of 0 bytes for responses + let mut rep = + RepSocket::with_options(Tcp::default(), RepOptions::default().min_compress_size(0)) + // Enable Gzip compression (compression level 6) + .with_compressor(GzipCompressor::new(6)); rep.bind("0.0.0.0:4444").await.unwrap(); // Initialize the request socket (client side) with a transport + // and a minimum compresion size of 0 bytes for requests let mut req = ReqSocket::with_options(Tcp::default(), ReqOptions::default().min_compress_size(0)) - // Enable Gzip compression (compression level 6) + // Enable Gzip compression (compression level 6). + // The request and response sockets *don't have to* + // use the same compression algorithm or level. .with_compressor(GzipCompressor::new(6)); req.connect("0.0.0.0:4444").await.unwrap();