diff --git a/msg-common/Cargo.toml b/msg-common/Cargo.toml index b138826..e3458b8 100644 --- a/msg-common/Cargo.toml +++ b/msg-common/Cargo.toml @@ -10,8 +10,6 @@ license.workspace = true homepage.workspace = true repository.workspace = true -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] futures.workspace = true tokio.workspace = true diff --git a/msg-common/src/channel.rs b/msg-common/src/channel.rs new file mode 100644 index 0000000..03ccc74 --- /dev/null +++ b/msg-common/src/channel.rs @@ -0,0 +1,155 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{Sink, SinkExt, Stream}; +use tokio::sync::mpsc::{ + self, + error::{TryRecvError, TrySendError}, + Receiver, +}; +use tokio_util::sync::{PollSendError, PollSender}; + +/// A bounded, bi-directional channel for sending and receiving messages. +/// Relies on Tokio's [`mpsc`] channel. +/// +/// Channel also implements the [`Stream`] and [`Sink`] traits for convenience. +pub struct Channel { + tx: PollSender, + rx: Receiver, +} + +/// Creates a new channel with the given buffer size. This will return a tuple of +/// 2 [`Channel`]s, both of which can be used to send and receive messages. +/// +/// It works with 2 generic types, `S` and `R`, which represent the types of +/// messages that can be sent and received, respectively. The first channel in +/// the tuple can be used to send messages of type `S` and receive messages of +/// type `R`. The second channel can be used to send messages of type `R` and +/// receive messages of type `S`. +pub fn channel(tx_buffer: usize, rx_buffer: usize) -> (Channel, Channel) +where + S: Send, + R: Send, +{ + let (tx1, rx1) = mpsc::channel(tx_buffer); + let (tx2, rx2) = mpsc::channel(rx_buffer); + + let tx1 = PollSender::new(tx1); + let tx2 = PollSender::new(tx2); + + (Channel { tx: tx1, rx: rx2 }, Channel { tx: tx2, rx: rx1 }) +} + +impl Channel { + /// Sends a value, waiting until there is capacity. + /// + /// A successful send occurs when it is determined that the other end of the + /// channel has not hung up already. An unsuccessful send would be one where + /// the corresponding receiver has already been closed. Note that a return + /// value of `Err` means that the data will never be received, but a return + /// value of `Ok` does not mean that the data will be received. It is + /// possible for the corresponding receiver to hang up immediately after + /// this function returns `Ok`. + pub async fn send(&mut self, msg: S) -> Result<(), PollSendError> { + self.tx.send(msg).await + } + + /// Attempts to immediately send a message on this [`Sender`] + /// + /// This method differs from [`send`] by returning immediately if the channel's + /// buffer is full or no receiver is waiting to acquire some data. Compared + /// with [`send`], this function has two failure cases instead of one (one for + /// disconnection, one for a full buffer). + pub fn try_send(&mut self, msg: S) -> Result<(), TrySendError> { + if let Some(tx) = self.tx.get_ref() { + tx.try_send(msg) + } else { + Err(TrySendError::Closed(msg)) + } + } + + /// Receives the next value for this receiver. + /// + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will sleep until a message is sent or + /// the channel is closed. Note that if [`close`] is called, but there are + /// still outstanding [`Permits`] from before it was closed, the channel is + /// not considered closed by `recv` until the permits are released. + pub async fn recv(&mut self) -> Option { + self.rx.recv().await + } + + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`Empty`](TryRecvError::Empty) error if the channel is currently + /// empty, but there are still outstanding [senders] or [permits]. + /// + /// This method returns the [`Disconnected`](TryRecvError::Disconnected) error if the channel is + /// currently empty, and there are no outstanding [senders] or [permits]. + /// + /// Unlike the [`poll_recv`](Self::poll_recv) method, this method will never return an + /// [`Empty`](TryRecvError::Empty) error spuriously. + pub fn try_recv(&mut self) -> Result { + self.rx.try_recv() + } + + /// Polls to receive the next message on this channel. + /// + /// This method returns: + /// + /// * `Poll::Pending` if no messages are available but the channel is not closed, or if a + /// spurious failure happens. + /// * `Poll::Ready(Some(message))` if a message is available. + /// * `Poll::Ready(None)` if the channel has been closed and all messages sent before it was + /// closed have been received. + /// + /// When the method returns `Poll::Pending`, the `Waker` in the provided + /// `Context` is scheduled to receive a wakeup when a message is sent on any + /// receiver, or when the channel is closed. Note that on multiple calls to + /// `poll_recv`, only the `Waker` from the `Context` passed to the most + /// recent call is scheduled to receive a wakeup. + /// + /// If this method returns `Poll::Pending` due to a spurious failure, then + /// the `Waker` will be notified when the situation causing the spurious + /// failure has been resolved. Note that receiving such a wakeup does not + /// guarantee that the next call will succeed — it could fail with another + /// spurious failure. + pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } +} + +impl Stream for Channel { + type Item = R; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } +} + +impl Sink for Channel { + type Error = PollSendError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.tx.poll_ready_unpin(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: S) -> Result<(), Self::Error> { + self.tx.start_send_unpin(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.tx.poll_flush_unpin(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.tx.poll_close_unpin(cx) + } +} diff --git a/msg-common/src/lib.rs b/msg-common/src/lib.rs index f7e4b94..b9b0db1 100644 --- a/msg-common/src/lib.rs +++ b/msg-common/src/lib.rs @@ -1,30 +1,23 @@ +//! Common utilities and types for msg-rs. + #![doc(issue_tracker_base_url = "https://github.com/chainbound/msg-rs/issues/")] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] #![cfg_attr(not(test), warn(unused_crate_dependencies))] -use std::{ - pin::Pin, - task::{Context, Poll}, - time::SystemTime, -}; +use std::time::SystemTime; + +use futures::future::BoxFuture; -use futures::{future::BoxFuture, Sink, SinkExt, Stream}; -use tokio::sync::mpsc::{ - self, - error::{TryRecvError, TrySendError}, - Receiver, -}; -use tokio_util::sync::{PollSendError, PollSender}; +mod channel; +pub use channel::{channel, Channel}; -pub mod task; +mod task; +pub use task::JoinMap; /// Returns the current UNIX timestamp in microseconds. #[inline] pub fn unix_micros() -> u64 { - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_micros() as u64 + SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_micros() as u64 } /// Wraps the given error in a boxed future. @@ -40,146 +33,3 @@ pub mod constants { pub const MiB: u32 = 1024 * KiB; pub const GiB: u32 = 1024 * MiB; } - -/// A bounded, bi-directional channel for sending and receiving messages. -/// Relies on Tokio's [`mpsc`] channel. -/// -/// Channel also implements the [`Stream`] and [`Sink`] traits for convenience. -pub struct Channel { - tx: PollSender, - rx: Receiver, -} - -/// Creates a new channel with the given buffer size. This will return a tuple of -/// 2 [`Channel`]s, both of which can be used to send and receive messages. -/// -/// It works with 2 generic types, `S` and `R`, which represent the types of -/// messages that can be sent and received, respectively. The first channel in -/// the tuple can be used to send messages of type `S` and receive messages of -/// type `R`. The second channel can be used to send messages of type `R` and -/// receive messages of type `S`. -pub fn channel(tx_buffer: usize, rx_buffer: usize) -> (Channel, Channel) -where - S: Send, - R: Send, -{ - let (tx1, rx1) = mpsc::channel(tx_buffer); - let (tx2, rx2) = mpsc::channel(rx_buffer); - - let tx1 = PollSender::new(tx1); - let tx2 = PollSender::new(tx2); - - (Channel { tx: tx1, rx: rx2 }, Channel { tx: tx2, rx: rx1 }) -} - -impl Channel { - /// Sends a value, waiting until there is capacity. - /// - /// A successful send occurs when it is determined that the other end of the - /// channel has not hung up already. An unsuccessful send would be one where - /// the corresponding receiver has already been closed. Note that a return - /// value of `Err` means that the data will never be received, but a return - /// value of `Ok` does not mean that the data will be received. It is - /// possible for the corresponding receiver to hang up immediately after - /// this function returns `Ok`. - pub async fn send(&mut self, msg: S) -> Result<(), PollSendError> { - self.tx.send(msg).await - } - - /// Attempts to immediately send a message on this [`Sender`] - /// - /// This method differs from [`send`] by returning immediately if the channel's - /// buffer is full or no receiver is waiting to acquire some data. Compared - /// with [`send`], this function has two failure cases instead of one (one for - /// disconnection, one for a full buffer). - pub fn try_send(&mut self, msg: S) -> Result<(), TrySendError> { - if let Some(tx) = self.tx.get_ref() { - tx.try_send(msg) - } else { - Err(TrySendError::Closed(msg)) - } - } - - /// Receives the next value for this receiver. - /// - /// This method returns `None` if the channel has been closed and there are - /// no remaining messages in the channel's buffer. This indicates that no - /// further values can ever be received from this `Receiver`. The channel is - /// closed when all senders have been dropped, or when [`close`] is called. - /// - /// If there are no messages in the channel's buffer, but the channel has - /// not yet been closed, this method will sleep until a message is sent or - /// the channel is closed. Note that if [`close`] is called, but there are - /// still outstanding [`Permits`] from before it was closed, the channel is - /// not considered closed by `recv` until the permits are released. - pub async fn recv(&mut self) -> Option { - self.rx.recv().await - } - - /// Tries to receive the next value for this receiver. - /// - /// This method returns the [`Empty`](TryRecvError::Empty) error if the channel is currently - /// empty, but there are still outstanding [senders] or [permits]. - /// - /// This method returns the [`Disconnected`](TryRecvError::Disconnected) error if the channel is - /// currently empty, and there are no outstanding [senders] or [permits]. - /// - /// Unlike the [`poll_recv`](Self::poll_recv) method, this method will never return an - /// [`Empty`](TryRecvError::Empty) error spuriously. - pub fn try_recv(&mut self) -> Result { - self.rx.try_recv() - } - - /// Polls to receive the next message on this channel. - /// - /// This method returns: - /// - /// * `Poll::Pending` if no messages are available but the channel is not - /// closed, or if a spurious failure happens. - /// * `Poll::Ready(Some(message))` if a message is available. - /// * `Poll::Ready(None)` if the channel has been closed and all messages - /// sent before it was closed have been received. - /// - /// When the method returns `Poll::Pending`, the `Waker` in the provided - /// `Context` is scheduled to receive a wakeup when a message is sent on any - /// receiver, or when the channel is closed. Note that on multiple calls to - /// `poll_recv`, only the `Waker` from the `Context` passed to the most - /// recent call is scheduled to receive a wakeup. - /// - /// If this method returns `Poll::Pending` due to a spurious failure, then - /// the `Waker` will be notified when the situation causing the spurious - /// failure has been resolved. Note that receiving such a wakeup does not - /// guarantee that the next call will succeed — it could fail with another - /// spurious failure. - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll> { - self.rx.poll_recv(cx) - } -} - -impl Stream for Channel { - type Item = R; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx.poll_recv(cx) - } -} - -impl Sink for Channel { - type Error = PollSendError; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_ready_unpin(cx) - } - - fn start_send(mut self: Pin<&mut Self>, item: S) -> Result<(), Self::Error> { - self.tx.start_send_unpin(item) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_flush_unpin(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_close_unpin(cx) - } -} diff --git a/msg-common/src/task.rs b/msg-common/src/task.rs index b5d62fe..223c554 100644 --- a/msg-common/src/task.rs +++ b/msg-common/src/task.rs @@ -6,8 +6,8 @@ use std::{ use tokio::task::{JoinError, JoinSet}; /// A collection of keyed tasks spawned on a Tokio runtime. -/// Hacky implementation of a join set that allows for a key to be associated with each task by having -/// the task return a tuple of (key, value). +/// Hacky implementation of a join set that allows for a key to be associated with each task by +/// having the task return a tuple of (key, value). #[derive(Debug, Default)] pub struct JoinMap { keys: HashSet, @@ -17,10 +17,7 @@ pub struct JoinMap { impl JoinMap { /// Create a new `JoinSet`. pub fn new() -> Self { - Self { - keys: HashSet::new(), - joinset: JoinSet::new(), - } + Self { keys: HashSet::new(), joinset: JoinSet::new() } } /// Returns the number of tasks currently in the `JoinSet`. @@ -71,7 +68,8 @@ where /// Polls for one of the tasks in the set to complete. /// - /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the set. + /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the + /// set. /// /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to @@ -83,11 +81,11 @@ where /// This function returns: /// /// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is - /// available right now. - /// * `Poll::Ready(Some(Ok(value)))` if one of the tasks in this `JoinSet` has completed. - /// The `value` is the return value of one of the tasks that completed. + /// available right now. + /// * `Poll::Ready(Some(Ok(value)))` if one of the tasks in this `JoinSet` has completed. The + /// `value` is the return value of one of the tasks that completed. /// * `Poll::Ready(Some(Err(err)))` if one of the tasks in this `JoinSet` has panicked or been - /// aborted. The `err` is the `JoinError` from the panicked/aborted task. + /// aborted. The `err` is the `JoinError` from the panicked/aborted task. /// * `Poll::Ready(None)` if the `JoinSet` is empty. /// /// Note that this method may return `Poll::Pending` even if one of the tasks has completed. diff --git a/msg-sim/src/dummynet.rs b/msg-sim/src/dummynet.rs index f5ea4b4..b9a3e8c 100644 --- a/msg-sim/src/dummynet.rs +++ b/msg-sim/src/dummynet.rs @@ -21,12 +21,7 @@ pub struct Pipe { impl Pipe { /// Creates a new pipe with the given ID. The ID must be unique. pub fn new(id: usize) -> Self { - Self { - id, - bandwidth: None, - delay: None, - plr: None, - } + Self { id, bandwidth: None, delay: None, plr: None } } /// Set the bandwidth cap of the pipe in Kbps. @@ -54,10 +49,7 @@ impl Pipe { fn build_cmd(&self) -> Command { let mut cmd = Command::new("sudo"); - cmd.arg("dnctl") - .arg("pipe") - .arg(self.id.to_string()) - .arg("config"); + cmd.arg("dnctl").arg("pipe").arg(self.id.to_string()).arg("config"); if let Some(bandwidth) = self.bandwidth { let bw = format!("{}Kbit/s", bandwidth); @@ -87,10 +79,7 @@ impl Pipe { fn destroy_cmd(&self) -> Command { let mut cmd = Command::new("sudo"); - cmd.arg("dnctl") - .arg("pipe") - .arg("delete") - .arg(self.id.to_string()); + cmd.arg("dnctl").arg("pipe").arg("delete").arg(self.id.to_string()); cmd } @@ -169,9 +158,7 @@ impl PacketFilter { /// Destroys the packet filter by executing the correct shell commands. pub fn destroy(self) -> io::Result<()> { - let status = Command::new("sudo") - .args(["pfctl", "-f", "/etc/pf.conf"]) - .status()?; + let status = Command::new("sudo").args(["pfctl", "-f", "/etc/pf.conf"]).status()?; assert_status(status, "Failed to flush packet filter")?; @@ -181,21 +168,15 @@ impl PacketFilter { // Remove the loopback alias let status = Command::new("sudo") - .args([ - "ifconfig", - &self.loopback, - "-alias", - &self.endpoint.unwrap().to_string(), - ]) + .args(["ifconfig", &self.loopback, "-alias", &self.endpoint.unwrap().to_string()]) .status()?; assert_status(status, "Failed to remove the loopback alias")?; // Reset the MTU of the loopback interface - let status = Command::new("sudo") - .args(["ifconfig", &self.loopback, "mtu", "16384"]) - .status()?; + let status = + Command::new("sudo").args(["ifconfig", &self.loopback, "mtu", "16384"]).status()?; assert_status(status, "Failed to reset loopback MTU back to 16384")?; @@ -207,19 +188,13 @@ impl PacketFilter { fn create_loopback_alias(&self) -> io::Result<()> { let status = Command::new("sudo") - .args([ - "ifconfig", - &self.loopback, - "alias", - &self.endpoint.unwrap().to_string(), - ]) + .args(["ifconfig", &self.loopback, "alias", &self.endpoint.unwrap().to_string()]) .status()?; assert_status(status, "Failed to create loopback alias")?; - let status = Command::new("sudo") - .args(["ifconfig", &self.loopback, "mtu", "1500"]) - .status()?; + let status = + Command::new("sudo").args(["ifconfig", &self.loopback, "mtu", "1500"]).status()?; assert_status(status, "Failed to set loopback MTU to 1500")?; @@ -230,31 +205,18 @@ impl PacketFilter { /// `(cat /etc/pf.conf && echo "dummynet-anchor \"msg-sim\"" && /// echo "anchor \"msg-sim\"") | sudo pfctl -f -` fn load_pf_config(&self) -> io::Result<()> { - let echo_cmd = format!( - "dummynet-anchor \"{}\"\nanchor \"{}\"", - self.anchor, self.anchor - ); + let echo_cmd = format!("dummynet-anchor \"{}\"\nanchor \"{}\"", self.anchor, self.anchor); - let mut cat = Command::new("cat") - .arg("/etc/pf.conf") - .stdout(Stdio::piped()) - .spawn()?; + let mut cat = Command::new("cat").arg("/etc/pf.conf").stdout(Stdio::piped()).spawn()?; let cat_stdout = cat.stdout.take().unwrap(); - let mut echo = Command::new("echo") - .arg(echo_cmd) - .stdout(Stdio::piped()) - .spawn()?; + let mut echo = Command::new("echo").arg(echo_cmd).stdout(Stdio::piped()).spawn()?; let echo_stdout = echo.stdout.take().unwrap(); - let mut pfctl = Command::new("sudo") - .arg("pfctl") - .arg("-f") - .arg("-") - .stdin(Stdio::piped()) - .spawn()?; + let mut pfctl = + Command::new("sudo").arg("pfctl").arg("-f").arg("-").stdin(Stdio::piped()).spawn()?; let pfctl_stdin = pfctl.stdin.as_mut().unwrap(); io::copy(&mut cat_stdout.chain(echo_stdout), pfctl_stdin)?; @@ -277,10 +239,7 @@ impl PacketFilter { let echo_command = format!("dummynet in from any to {} pipe {}", endpoint, pipe_id); // Set up the echo command - let mut echo = Command::new("echo") - .arg(echo_command) - .stdout(Stdio::piped()) - .spawn()?; + let mut echo = Command::new("echo").arg(echo_command).stdout(Stdio::piped()).spawn()?; if let Some(echo_stdout) = echo.stdout.take() { // Set up the pfctl command @@ -346,10 +305,7 @@ mod tests { let cmd = Pipe::new(1).bandwidth(10).delay(100).plr(0.1).build_cmd(); let cmd_str = cmd_to_string(&cmd); - assert_eq!( - cmd_str, - "sudo dnctl pipe 1 config bw 10Kbit/s delay 100 plr 0.1" - ); + assert_eq!(cmd_str, "sudo dnctl pipe 1 config bw 10Kbit/s delay 100 plr 0.1"); let cmd = Pipe::new(2).delay(1000).plr(10.0).build_cmd(); let cmd_str = cmd_to_string(&cmd); @@ -378,9 +334,7 @@ mod tests { let pipe = Pipe::new(3).bandwidth(100).delay(300); let endpoint = "127.0.0.2".parse().unwrap(); - let pf = PacketFilter::new(pipe) - .endpoint(endpoint) - .anchor("msg-sim-test"); + let pf = PacketFilter::new(pipe).endpoint(endpoint).anchor("msg-sim-test"); pf.enable().unwrap(); diff --git a/msg-sim/src/lib.rs b/msg-sim/src/lib.rs index 0e9a6e8..0e1d3d1 100644 --- a/msg-sim/src/lib.rs +++ b/msg-sim/src/lib.rs @@ -34,10 +34,7 @@ pub struct Simulator { impl Simulator { pub fn new() -> Self { - Self { - active_sims: HashMap::new(), - sim_id: 1, - } + Self { active_sims: HashMap::new(), sim_id: 1 } } /// Starts a new simulation on the given endpoint according to the config. @@ -108,9 +105,8 @@ impl Simulation { pipe = pipe.plr(plr); } - let mut pf = PacketFilter::new(pipe) - .anchor(format!("msg-sim-{}", self.id)) - .endpoint(self.endpoint); + let mut pf = + PacketFilter::new(pipe).anchor(format!("msg-sim-{}", self.id)).endpoint(self.endpoint); if !self.config.protocols.is_empty() { pf = pf.protocols(self.config.protocols.clone()); diff --git a/msg-socket/src/connection/backoff.rs b/msg-socket/src/connection/backoff.rs index 4dde9a7..77d9e3a 100644 --- a/msg-socket/src/connection/backoff.rs +++ b/msg-socket/src/connection/backoff.rs @@ -24,12 +24,7 @@ pub struct ExponentialBackoff { impl ExponentialBackoff { pub fn new(initial: Duration, max_retries: usize) -> Self { - Self { - retry_count: 0, - max_retries, - backoff: initial, - timeout: None, - } + Self { retry_count: 0, max_retries, backoff: initial, timeout: None } } /// (Re)-set the timeout to the current backoff duration. diff --git a/msg-socket/src/pub/driver.rs b/msg-socket/src/pub/driver.rs index e6a5c10..8f8c8ea 100644 --- a/msg-socket/src/pub/driver.rs +++ b/msg-socket/src/pub/driver.rs @@ -33,7 +33,8 @@ pub(crate) struct PubDriver, A: Address> { pub(super) conn_tasks: FuturesUnordered, /// A joinset of authentication tasks. pub(super) auth_tasks: JoinSet, PubError>>, - /// The receiver end of the message broadcast channel. The sender half is stored by [`PubSocket`](super::PubSocket). + /// The receiver end of the message broadcast channel. The sender half is stored by + /// [`PubSocket`](super::PubSocket). pub(super) from_socket_bcast: broadcast::Receiver, } @@ -48,7 +49,8 @@ where let this = self.get_mut(); loop { - // First, poll the joinset of authentication tasks. If a new connection has been handled we spawn a new session for it. + // First, poll the joinset of authentication tasks. If a new connection has been handled + // we spawn a new session for it. if let Poll::Ready(Some(Ok(auth))) = this.auth_tasks.poll_join_next(cx) { match auth { Ok(auth) => { @@ -75,7 +77,7 @@ where this.id_counter = this.id_counter.wrapping_add(1); } Err(e) => { - error!("Error authenticating client: {:?}", e); + error!(err = ?e, "Error authenticating client"); this.state.stats.decrement_active_clients(); } } @@ -83,20 +85,21 @@ where continue; } - // Then poll the incoming connection tasks. If a new connection has been accepted, spawn a new authentication task for it. + // Then poll the incoming connection tasks. If a new connection has been accepted, spawn + // a new authentication task for it. if let Poll::Ready(Some(incoming)) = this.conn_tasks.poll_next_unpin(cx) { match incoming { Ok(io) => { if let Err(e) = this.on_incoming(io) { - error!("Error accepting incoming connection: {:?}", e); + error!(err = ?e, "Error accepting incoming connection"); this.state.stats.decrement_active_clients(); } } Err(e) => { - error!("Error accepting incoming connection: {:?}", e); + error!(err = ?e, "Error accepting incoming connection"); - // Active clients have already been incremented in the initial call to `poll_accept`, - // so we need to decrement them here. + // Active clients have already been incremented in the initial call to + // `poll_accept`, so we need to decrement them here. this.state.stats.decrement_active_clients(); } } @@ -104,21 +107,18 @@ where continue; } - // Finally, poll the transport for new incoming connection futures and push them to the incoming connection tasks. + // Finally, poll the transport for new incoming connection futures and push them to the + // incoming connection tasks. if let Poll::Ready(accept) = Pin::new(&mut this.transport).poll_accept(cx) { if let Some(max) = this.options.max_clients { if this.state.stats.active_clients() >= max { - warn!( - "Max connections reached ({}), rejecting new incoming connection", - max - ); - + warn!("Max connections reached ({}), rejecting incoming connection", max); continue; } } - // Increment the active clients counter. If the authentication fails, this counter - // will be decremented. + // Increment the active clients counter. If the authentication fails, + // this counter will be decremented. this.state.stats.increment_active_clients(); this.conn_tasks.push(accept); @@ -179,11 +179,7 @@ where conn.send(auth::Message::Ack).await?; conn.flush().await?; - Ok(AuthResult { - id, - addr, - stream: conn.into_inner(), - }) + Ok(AuthResult { id, addr, stream: conn.into_inner() }) }); } else { let mut framed = Framed::new(io, pubsub::Codec::new()); @@ -204,10 +200,7 @@ where tokio::spawn(session); self.id_counter = self.id_counter.wrapping_add(1); - debug!( - "New connection from {:?}, session ID {}", - addr, self.id_counter - ); + debug!("New connection from {:?}, session ID {}", addr, self.id_counter); } Ok(()) diff --git a/msg-socket/src/pub/mod.rs b/msg-socket/src/pub/mod.rs index 3360bf9..0c16925 100644 --- a/msg-socket/src/pub/mod.rs +++ b/msg-socket/src/pub/mod.rs @@ -91,8 +91,8 @@ impl PubOptions { self } - /// Sets the minimum payload size in bytes for compression to be used. If the payload is smaller than - /// this threshold, it will not be compressed. + /// Sets the minimum payload size in bytes for compression to be used. If the payload is smaller + /// than this threshold, it will not be compressed. pub fn min_compress_size(mut self, min_compress_size: usize) -> Self { self.min_compress_size = min_compress_size; self @@ -170,6 +170,7 @@ mod tests { use futures::StreamExt; use msg_transport::{quic::Quic, tcp::Tcp}; use msg_wire::compression::GzipCompressor; + use tracing::info; use crate::{Authenticator, SubOptions, SubSocket}; @@ -179,7 +180,7 @@ mod tests { impl Authenticator for Auth { fn authenticate(&self, id: &Bytes) -> bool { - tracing::info!("Auth request from: {:?}", id); + info!("Auth request from: {:?}", id); true } } @@ -199,13 +200,10 @@ mod tests { sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; - pub_socket - .publish("HELLO".to_string(), "WORLD".into()) - .await - .unwrap(); + pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap(); let msg = sub_socket.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!("WORLD", msg.payload()); } @@ -228,13 +226,10 @@ mod tests { sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; - pub_socket - .publish("HELLO".to_string(), "WORLD".into()) - .await - .unwrap(); + pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap(); let msg = sub_socket.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!("WORLD", msg.payload()); } @@ -257,13 +252,10 @@ mod tests { sub_socket.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; - pub_socket - .publish("HELLO".to_string(), "WORLD".into()) - .await - .unwrap(); + pub_socket.publish("HELLO".to_string(), "WORLD".into()).await.unwrap(); let msg = sub_socket.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!("WORLD", msg.payload()); } @@ -287,18 +279,15 @@ mod tests { sub2.subscribe("HELLO".to_string()).await.unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; - pub_socket - .publish("HELLO".to_string(), Bytes::from("WORLD")) - .await - .unwrap(); + pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap(); let msg = sub1.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!("WORLD", msg.payload()); let msg = sub2.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!("WORLD", msg.payload()); } @@ -324,18 +313,15 @@ mod tests { let original_msg = Bytes::from("WOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOORLD"); - pub_socket - .publish("HELLO".to_string(), original_msg.clone()) - .await - .unwrap(); + pub_socket.publish("HELLO".to_string(), original_msg.clone()).await.unwrap(); let msg = sub1.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!(original_msg, msg.payload()); let msg = sub2.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!(original_msg, msg.payload()); } @@ -356,13 +342,10 @@ mod tests { pub_socket.bind("0.0.0.0:6662").await.unwrap(); tokio::time::sleep(Duration::from_millis(2000)).await; - pub_socket - .publish("HELLO".to_string(), Bytes::from("WORLD")) - .await - .unwrap(); + pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap(); let msg = sub_socket.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!("WORLD", msg.payload()); } @@ -383,13 +366,10 @@ mod tests { pub_socket.bind("0.0.0.0:6662").await.unwrap(); tokio::time::sleep(Duration::from_millis(2000)).await; - pub_socket - .publish("HELLO".to_string(), Bytes::from("WORLD")) - .await - .unwrap(); + pub_socket.publish("HELLO".to_string(), Bytes::from("WORLD")).await.unwrap(); let msg = sub_socket.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!("HELLO", msg.topic()); assert_eq!("WORLD", msg.payload()); } diff --git a/msg-socket/src/pub/session.rs b/msg-socket/src/pub/session.rs index 8b9d260..f000276 100644 --- a/msg-socket/src/pub/session.rs +++ b/msg-socket/src/pub/session.rs @@ -1,10 +1,11 @@ -use futures::{Future, SinkExt, StreamExt}; use std::{ borrow::Cow, pin::Pin, sync::Arc, task::{Context, Poll}, }; + +use futures::{Future, SinkExt, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_stream::wrappers::BroadcastStream; use tokio_util::codec::Framed; @@ -22,8 +23,7 @@ pub(super) struct SubscriberSession { pub(super) from_socket_bcast: BroadcastStream, /// Messages queued to be sent on the connection pub(super) pending_egress: Option, - /// Session request sender. - // to_driver: mpsc::Sender, + /// The socket state, shared between the backend task and the socket. pub(super) state: Arc, /// The framed connection. pub(super) conn: Framed, @@ -40,19 +40,13 @@ impl SubscriberSession { fn on_outgoing(&mut self, msg: PubMessage) { // Check if the message matches the topic filter if self.topic_filter.contains(msg.topic()) { - trace!( - topic = msg.topic(), - "Message matches topic filter, adding to egress queue" - ); + trace!(topic = msg.topic(), "Message matches topic filter, adding to egress queue"); // Generate the wire message and increment the sequence number self.pending_egress = Some(msg.into_wire(self.seq)); self.seq = self.seq.wrapping_add(1); } else { - trace!( - topic = msg.topic(), - "Message does not match topic filter, discarding" - ); + trace!(topic = msg.topic(), "Message does not match topic filter, discarding"); } } @@ -69,10 +63,7 @@ impl SubscriberSession { self.topic_filter.remove(&topic) } ControlMsg::Close => { - debug!( - "Closing session after receiving close message {}", - self.session_id - ); + debug!("Closing session after receiving close message {}", self.session_id); } } } @@ -127,10 +118,7 @@ fn msg_to_control(msg: &pubsub::Message) -> ControlMsg { ControlMsg::Close } } else { - tracing::warn!( - "Unkown control message topic, closing session: {:?}", - msg.topic() - ); + warn!("Unkown control message topic, closing session: {:?}", msg.topic()); ControlMsg::Close } } @@ -144,8 +132,8 @@ impl Future for SubscriberSession { loop { // First check if we should flush the connection. We only do this if we have written - // some data and the flush interval has elapsed. Only when we have succesfully flushed the data - // will we reset the `should_flush` flag. + // some data and the flush interval has elapsed. Only when we have succesfully flushed + // the data will we reset the `should_flush` flag. if this.should_flush(cx) { if let Poll::Ready(Ok(_)) = this.conn.poll_flush_unpin(cx) { this.should_flush = false; @@ -155,7 +143,7 @@ impl Future for SubscriberSession { // Then, try to drain the egress queue. if this.conn.poll_ready_unpin(cx).is_ready() { if let Some(msg) = this.pending_egress.take() { - tracing::debug!("Sending message: {:?}", msg); + debug!(?msg, "Sending message"); let msg_len = msg.size(); match this.conn.start_send_unpin(msg) { @@ -167,7 +155,7 @@ impl Future for SubscriberSession { continue; } Err(e) => { - tracing::error!("Failed to send message to socket: {:?}", e); + error!(err = ?e, "Failed to send message to socket"); let _ = this.conn.poll_close_unpin(cx); // End this stream as we can't send any more messages return Poll::Ready(()); @@ -186,10 +174,7 @@ impl Future for SubscriberSession { continue; } Some(Err(e)) => { - warn!( - session_id = this.session_id, - "Receiver lagging behind: {:?}", e - ); + warn!(err = ?e, session_id = this.session_id, "Receiver lagging behind"); continue; } None => { @@ -204,23 +189,17 @@ impl Future for SubscriberSession { if let Poll::Ready(item) = this.conn.poll_next_unpin(cx) { match item { Some(Ok(msg)) => { - debug!("Incoming message: {:?}", msg); + debug!(?msg, "Incoming message"); this.on_incoming(msg); continue; } Some(Err(e)) => { - error!( - session_id = this.session_id, - "Error reading from socket: {:?}", e - ); + error!(err = ?e, session_id = this.session_id, "Error reading from socket"); let _ = this.conn.poll_close_unpin(cx); return Poll::Ready(()); } None => { - warn!( - "Connection closed, shutting down session {}", - this.session_id - ); + warn!("Connection closed, shutting down session {}", this.session_id); return Poll::Ready(()); } } diff --git a/msg-socket/src/pub/socket.rs b/msg-socket/src/pub/socket.rs index 8671c96..3b53439 100644 --- a/msg-socket/src/pub/socket.rs +++ b/msg-socket/src/pub/socket.rs @@ -1,6 +1,7 @@ +use std::{io, net::SocketAddr, path::PathBuf, sync::Arc}; + use bytes::Bytes; use futures::stream::FuturesUnordered; -use std::{io, net::SocketAddr, path::PathBuf, sync::Arc}; use tokio::{ net::{lookup_host, ToSocketAddrs}, sync::broadcast, @@ -10,6 +11,7 @@ use tracing::{debug, trace, warn}; use super::{driver::PubDriver, stats::SocketStats, PubError, PubMessage, PubOptions, SocketState}; use crate::Authenticator; + use msg_transport::{Address, Transport}; use msg_wire::compression::Compressor; @@ -23,7 +25,8 @@ pub struct PubSocket, A: Address> { /// The transport used by this socket. This value is temporary and will be moved /// to the driver task once the socket is bound. transport: Option, - /// The broadcast channel to all active [`SubscriberSession`](super::session::SubscriberSession)s. + /// The broadcast channel to all active + /// [`SubscriberSession`](super::session::SubscriberSession)s. to_sessions_bcast: Option>, /// Optional connection authenticator. auth: Option>, @@ -104,16 +107,13 @@ where let (to_sessions_bcast, from_socket_bcast) = broadcast::channel(self.options.session_buffer_size); - let mut transport = self - .transport - .take() - .expect("Transport has been moved already"); + let mut transport = self.transport.take().expect("Transport has been moved already"); for addr in addresses { match transport.bind(addr.clone()).await { Ok(_) => break, Err(e) => { - warn!("Failed to bind to {:?}, trying next address: {}", addr, e); + warn!(err = ?e, "Failed to bind to {:?}, trying next address", addr); continue; } } @@ -160,22 +160,12 @@ where if let Some(ref compressor) = self.compressor { msg.compress(compressor.as_ref())?; - trace!( - "Compressed message from {} to {} bytes", - len_before, - msg.payload().len(), - ); + trace!("Compressed message from {} to {} bytes", len_before, msg.payload().len(),); } } // Broadcast the message directly to all active sessions. - if self - .to_sessions_bcast - .as_ref() - .ok_or(PubError::SocketClosed)? - .send(msg) - .is_err() - { + if self.to_sessions_bcast.as_ref().ok_or(PubError::SocketClosed)?.send(msg).is_err() { debug!("No active subscriber sessions"); } @@ -194,21 +184,11 @@ where // For relatively small messages, this takes <100us msg.compress(compressor.as_ref())?; - debug!( - "Compressed message from {} to {} bytes", - len_before, - msg.payload().len(), - ); + debug!("Compressed message from {} to {} bytes", len_before, msg.payload().len(),); } // Broadcast the message directly to all active sessions. - if self - .to_sessions_bcast - .as_ref() - .ok_or(PubError::SocketClosed)? - .send(msg) - .is_err() - { + if self.to_sessions_bcast.as_ref().ok_or(PubError::SocketClosed)?.send(msg).is_err() { debug!("No active subscriber sessions"); } diff --git a/msg-socket/src/pub/trie.rs b/msg-socket/src/pub/trie.rs index 04a02bc..13bb830 100644 --- a/msg-socket/src/pub/trie.rs +++ b/msg-socket/src/pub/trie.rs @@ -10,11 +10,7 @@ struct Node { impl Node { fn new() -> Self { - Self { - children: FxHashMap::default(), - catch_all: false, - topic_end: false, - } + Self { children: FxHashMap::default(), catch_all: false, topic_end: false } } } @@ -36,10 +32,7 @@ impl PrefixTrie { pub fn insert(&mut self, topic: &str) { let mut node = &mut self.root; for token in topic.split('.') { - node = node - .children - .entry(token.to_string()) - .or_insert(Node::new()); + node = node.children.entry(token.to_string()).or_insert(Node::new()); // Check if this is a catch-all wildcard. If so, we mark it as such and break. if token == ">" { node.catch_all = true; diff --git a/msg-socket/src/rep/driver.rs b/msg-socket/src/rep/driver.rs index 90a9cce..a1596f4 100644 --- a/msg-socket/src/rep/driver.rs +++ b/msg-socket/src/rep/driver.rs @@ -82,7 +82,7 @@ where match try_decompress_payload(request.compression_type, request.msg) { Ok(decompressed) => request.msg = decompressed, Err(e) => { - error!("Failed to decompress message: {:?}", e); + error!(err = ?e, "Failed to decompress message"); continue; } } @@ -91,7 +91,7 @@ where let _ = this.to_socket.try_send(request); } Some(Err(e)) => { - error!("Error receiving message from peer {:?}: {:?}", peer, e); + error!(err = ?e, "Error receiving message from peer {:?}", peer); } None => { warn!("Peer {:?} disconnected", peer); @@ -122,7 +122,7 @@ where ); } Err(e) => { - tracing::error!("Error authenticating client: {:?}", e); + error!(err = ?e, "Error authenticating client"); this.state.stats.decrement_active_clients(); } } @@ -134,15 +134,15 @@ where match incoming { Ok(io) => { if let Err(e) = this.on_incoming(io) { - error!("Error accepting incoming connection: {:?}", e); + error!(err = ?e, "Error accepting incoming connection"); this.state.stats.decrement_active_clients(); } } Err(e) => { - error!("Error accepting incoming connection: {:?}", e); + error!(err = ?e, "Error accepting incoming connection"); - // Active clients have already been incremented in the initial call to `poll_accept`, - // so we need to decrement them here. + // Active clients have already been incremented in the initial call to + // `poll_accept`, so we need to decrement them here. this.state.stats.decrement_active_clients(); } } @@ -150,7 +150,8 @@ where continue; } - // Finally, poll the transport for new incoming connection futures and push them to the incoming connection tasks. + // Finally, poll the transport for new incoming connection futures and push them to the + // incoming connection tasks. if let Poll::Ready(accept) = Pin::new(&mut this.transport).poll_accept(cx) { if let Some(max) = this.options.max_clients { if this.state.stats.active_clients() >= max { @@ -225,11 +226,7 @@ where conn.send(auth::Message::Ack).await?; conn.flush().await?; - Ok(AuthResult { - id, - addr, - stream: conn.into_inner(), - }) + Ok(AuthResult { id, addr, stream: conn.into_inner() }) }); } else { self.peer_states.insert( @@ -279,7 +276,7 @@ impl Stream for PeerState } Err(e) => { this.state.stats.increment_failed_requests(); - tracing::error!("Failed to send message to socket: {:?}", e); + error!(err = ?e, "Failed to send message to socket"); // End this stream as we can't send any more messages return Poll::Ready(None); } @@ -300,12 +297,12 @@ impl Stream for PeerState compression_type = compressor.compression_type() as u8; } Err(e) => { - tracing::error!("Failed to compress message: {:?}", e); + error!(err = ?e, "Failed to compress message"); continue; } } - tracing::debug!( + debug!( "Compressed message {} from {} to {} bytes", id, len_before, @@ -328,10 +325,7 @@ impl Stream for PeerState let (tx, rx) = oneshot::channel(); // Add the pending request to the list - this.pending_requests.push(PendingRequest { - msg_id: msg.id(), - response: rx, - }); + this.pending_requests.push(PendingRequest { msg_id: msg.id(), response: rx }); let request = Request { source: this.addr.clone(), diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index 16f2161..9067311 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -33,10 +33,7 @@ pub struct RepOptions { impl Default for RepOptions { fn default() -> Self { - Self { - max_clients: None, - min_compress_size: 8192, - } + Self { max_clients: None, min_compress_size: 8192 } } } @@ -86,9 +83,7 @@ impl Request { /// Responds to the request. pub fn respond(self, response: Bytes) -> Result<(), PubError> { - self.response - .send(response) - .map_err(|_| PubError::SocketClosed) + self.response.send(response).map_err(|_| PubError::SocketClosed) } } @@ -100,6 +95,7 @@ mod tests { use msg_transport::tcp::Tcp; use msg_wire::compression::{GzipCompressor, SnappyCompressor}; use rand::Rng; + use tracing::{debug, info}; use crate::{req::ReqSocket, Authenticator, ReqOptions}; @@ -142,7 +138,7 @@ mod tests { // println!("Response: {:?} {:?}", _res, req_start.elapsed()); } let elapsed = start.elapsed(); - tracing::info!("{} reqs in {:?}", n_reqs, elapsed); + info!("{} reqs in {:?}", n_reqs, elapsed); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -186,7 +182,7 @@ mod tests { impl Authenticator for Auth { fn authenticate(&self, _id: &Bytes) -> bool { - tracing::info!("{:?}", _id); + info!("{:?}", _id); true } } @@ -203,12 +199,12 @@ mod tests { req.connect(rep.local_addr().unwrap()).await.unwrap(); - tracing::info!("Connected to rep"); + info!("Connected to rep"); tokio::spawn(async move { loop { let req = rep.next().await.unwrap(); - tracing::debug!("Received request"); + debug!("Received request"); req.respond(Bytes::from("hello")).unwrap(); } @@ -229,7 +225,7 @@ mod tests { let _res = req.request(msg).await.unwrap(); } let elapsed = start.elapsed(); - tracing::info!("{} reqs in {:?}", n_reqs, elapsed); + info!("{} reqs in {:?}", n_reqs, elapsed); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index fa2bbee..bc7f2b4 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -17,8 +17,7 @@ use tokio_stream::StreamMap; use tracing::{debug, warn}; use crate::{ - rep::{driver::RepDriver, DEFAULT_BUFFER_SIZE}, - rep::{SocketState, SocketStats}, + rep::{driver::RepDriver, SocketState, SocketStats, DEFAULT_BUFFER_SIZE}, Authenticator, PubError, RepOptions, Request, }; @@ -104,16 +103,13 @@ where pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), PubError> { let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE); - let mut transport = self - .transport - .take() - .expect("Transport has been moved already"); + let mut transport = self.transport.take().expect("Transport has been moved already"); for addr in addresses { match transport.bind(addr.clone()).await { Ok(_) => break, Err(e) => { - warn!("Failed to bind to {:?}, trying next address: {}", addr, e); + warn!(err = ?e, "Failed to bind to {:?}, trying next address", addr); continue; } } @@ -162,10 +158,6 @@ impl + Unpin, A: Address> Stream for RepSocket { type Item = Request; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut() - .from_driver - .as_mut() - .expect("Inactive socket") - .poll_recv(cx) + self.get_mut().from_driver.as_mut().expect("Inactive socket").poll_recv(cx) } } diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index afbae5f..05d133c 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -1,6 +1,3 @@ -use bytes::Bytes; -use futures::{Future, FutureExt, SinkExt, StreamExt}; -use rustc_hash::FxHashMap; use std::{ collections::VecDeque, io, @@ -9,6 +6,10 @@ use std::{ task::{ready, Context, Poll}, time::{Duration, Instant}, }; + +use bytes::Bytes; +use futures::{Future, FutureExt, SinkExt, StreamExt}; +use rustc_hash::FxHashMap; use tokio::{ sync::{mpsc, oneshot}, time::Interval, @@ -26,8 +27,12 @@ use msg_wire::{ reqrep, }; +/// A connection task that connects to a server and returns the underlying IO object. type ConnectionTask = Pin> + Send>>; +/// A connection controller that manages the connection to a server with an exponential backoff. +type ConnectionCtl = ConnectionState, ExponentialBackoff, Addr>; + /// The request socket driver. Endless future that drives /// the the socket forward. pub(crate) struct ReqDriver, A: Address> { @@ -48,7 +53,7 @@ pub(crate) struct ReqDriver, A: Address> { pub(crate) conn_task: Option>, /// The transport controller, wrapped in a [`ConnectionState`] for backoff. /// The [`Framed`] object can send and receive messages from the socket. - pub(crate) conn_state: ConnectionState, ExponentialBackoff, A>, + pub(crate) conn_state: ConnectionCtl, /// The outgoing message queue. pub(crate) egress_queue: VecDeque, /// The currently pending requests, if any. Uses [`FxHashMap`] for performance. @@ -89,7 +94,7 @@ where let mut io = match connect.await { Ok(io) => io, Err(e) => { - error!("Failed to connect to {:?}: {:?}", addr, e); + error!(err = ?e, "Failed to connect to {:?}", addr); return Err(e); } }; @@ -112,17 +117,15 @@ where Ok(io) } Ok(msg) => { - error!("Unexpected auth ACK result: {:?}", msg); + error!(?msg, "Unexpected auth ACK result"); Err(io::Error::new(io::ErrorKind::PermissionDenied, "rejected").into()) } Err(e) => Err(io::Error::new(io::ErrorKind::PermissionDenied, e).into()), }, None => { error!("Connection closed while waiting for ACK"); - Err( - io::Error::new(io::ErrorKind::UnexpectedEof, "Connection closed") - .into(), - ) + Err(io::Error::new(io::ErrorKind::UnexpectedEof, "Connection closed") + .into()) } } } else { @@ -144,7 +147,7 @@ where match try_decompress_payload(compression_type, payload) { Ok(decompressed) => payload = decompressed, Err(e) => { - error!("Failed to decompress response payload: {:?}", e); + error!(err = ?e, "Failed to decompress response payload"); let _ = pending.sender.send(Err(ReqError::Wire(reqrep::Error::Io( io::Error::new(io::ErrorKind::Other, "Failed to decompress response"), )))); @@ -163,17 +166,14 @@ where /// Handle an incoming command from the socket frontend. fn on_command(&mut self, cmd: Command) { match cmd { - Command::Send { - mut message, - response, - } => { + Command::Send { mut message, response } => { let start = std::time::Instant::now(); let len_before = message.payload().len(); if len_before > self.options.min_compress_size { if let Some(ref compressor) = self.compressor { if let Err(e) = message.compress(compressor.as_ref()) { - error!("Failed to compress message: {:?}", e); + error!(err = ?e, "Failed to compress message"); } debug!( @@ -188,13 +188,7 @@ where let msg_id = msg.id(); self.id_counter = self.id_counter.wrapping_add(1); self.egress_queue.push_back(msg); - self.pending_requests.insert( - msg_id, - PendingRequest { - start, - sender: response, - }, - ); + self.pending_requests.insert(msg_id, PendingRequest { start, sender: response }); } } } @@ -286,11 +280,7 @@ where // If the connection is inactive, try to connect to the server // or poll the backoff timer if we're already trying to connect. - if let ConnectionState::Inactive { - ref mut backoff, - ref addr, - } = this.conn_state - { + if let ConnectionState::Inactive { ref mut backoff, ref addr } = this.conn_state { if let Poll::Ready(item) = backoff.poll_next_unpin(cx) { if let Some(duration) = item { if this.conn_task.is_none() { @@ -326,9 +316,9 @@ where } Poll::Ready(Some(Err(err))) => { if let reqrep::Error::Io(e) = err { - error!("Socket error: {:?}", e); + error!(err = ?e, "Socket error"); if e.kind() == std::io::ErrorKind::Other { - error!("Other error: {:?}", e); + error!(err = ?e, "Other error"); } } @@ -358,7 +348,7 @@ where this.should_flush = true; } Err(e) => { - error!("Failed to send message to socket: {:?}", e); + error!(err = ?e, "Failed to send message to socket"); // set the connection to inactive, so that it will be re-tried this.reset_connection(); diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index c4e10d6..e53bc25 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -35,10 +35,7 @@ pub enum ReqError { } pub enum Command { - Send { - message: ReqMessage, - response: oneshot::Sender>, - }, + Send { message: ReqMessage, response: oneshot::Sender> }, } #[derive(Debug, Clone)] @@ -88,29 +85,32 @@ impl ReqOptions { self } - /// Sets the flush interval for the socket. A higher flush interval will result in higher throughput, - /// but at the cost of higher latency. Note that this behaviour can be completely useless if the - /// `backpressure_boundary` is set too low (which will trigger a flush before the interval is reached). + /// Sets the flush interval for the socket. A higher flush interval will result in higher + /// throughput, but at the cost of higher latency. Note that this behaviour can be + /// completely useless if the `backpressure_boundary` is set too low (which will trigger a + /// flush before the interval is reached). pub fn flush_interval(mut self, flush_interval: Duration) -> Self { self.flush_interval = Some(flush_interval); self } - /// Sets the backpressure boundary for the socket. This is the maximum number of bytes that can be buffered - /// in the session before being flushed. This internally sets [`Framed::set_backpressure_boundary`](tokio_util::codec::Framed). + /// Sets the backpressure boundary for the socket. This is the maximum number of bytes that can + /// be buffered in the session before being flushed. This internally sets + /// [`Framed::set_backpressure_boundary`](tokio_util::codec::Framed). pub fn backpressure_boundary(mut self, backpressure_boundary: usize) -> Self { self.backpressure_boundary = backpressure_boundary; self } - /// Sets the maximum number of retry attempts. If `None`, all connections will be retried indefinitely. + /// Sets the maximum number of retry attempts. If `None`, all connections will be retried + /// indefinitely. pub fn retry_attempts(mut self, retry_attempts: usize) -> Self { self.retry_attempts = Some(retry_attempts); self } - /// Sets the minimum payload size in bytes for compression to be used. If the payload is smaller than - /// this threshold, it will not be compressed. + /// Sets the minimum payload size in bytes for compression to be used. If the payload is smaller + /// than this threshold, it will not be compressed. pub fn min_compress_size(mut self, min_compress_size: usize) -> Self { self.min_compress_size = min_compress_size; self diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 8d61a0e..75c39eb 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -1,11 +1,10 @@ use bytes::Bytes; use rustc_hash::FxHashMap; -use std::marker::PhantomData; -use std::net::SocketAddr; -use std::path::PathBuf; -use std::{io, sync::Arc, time::Duration}; -use tokio::net::{lookup_host, ToSocketAddrs}; -use tokio::sync::{mpsc, oneshot}; +use std::{io, marker::PhantomData, net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; +use tokio::{ + net::{lookup_host, ToSocketAddrs}, + sync::{mpsc, oneshot}, +}; use msg_transport::{Address, Transport}; use msg_wire::compression::Compressor; @@ -42,10 +41,7 @@ where pub async fn connect(&mut self, addr: impl ToSocketAddrs) -> Result<(), ReqError> { let mut addrs = lookup_host(addr).await?; let endpoint = addrs.next().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "could not find any valid address", - ) + io::Error::new(io::ErrorKind::InvalidInput, "could not find any valid address") })?; self.try_connect(endpoint).await @@ -100,10 +96,7 @@ where self.to_driver .as_ref() .ok_or(ReqError::SocketClosed)? - .send(Command::Send { - message: msg, - response: response_tx, - }) + .send(Command::Send { message: msg, response: response_tx }) .await .map_err(|_| ReqError::SocketClosed)?; @@ -116,10 +109,7 @@ where // Initialize communication channels let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); - let transport = self - .transport - .take() - .expect("Transport has been moved already"); + let transport = self.transport.take().expect("Transport has been moved already"); // We initialize the connection as inactive, and let it be activated // by the backend task as soon as the driver is spawned. @@ -132,8 +122,8 @@ where let flush_interval = self.options.flush_interval.map(tokio::time::interval); - // TODO: we should limit the amount of active outgoing requests, and that should be the capacity. - // If we do this, we'll never have to re-allocate. + // TODO: we should limit the amount of active outgoing requests, and that should be the + // capacity. If we do this, we'll never have to re-allocate. let pending_requests = FxHashMap::default(); // Create the socket backend diff --git a/msg-socket/src/sub/driver.rs b/msg-socket/src/sub/driver.rs index b450fa1..4462f79 100644 --- a/msg-socket/src/sub/driver.rs +++ b/msg-socket/src/sub/driver.rs @@ -12,15 +12,14 @@ use tokio::sync::mpsc::{self, error::TrySendError}; use tokio_util::codec::Framed; use tracing::{debug, error, info, warn}; -use super::session::SessionCommand; use super::{ - session::PublisherSession, + session::{PublisherSession, SessionCommand}, stream::{PublisherStream, TopicMessage}, Command, PubMessage, SocketState, SubOptions, }; use crate::{ConnectionState, ExponentialBackoff}; -use msg_common::{channel, task::JoinMap, Channel}; +use msg_common::{channel, Channel, JoinMap}; use msg_transport::{Address, Transport}; use msg_wire::{auth, compression::try_decompress_payload, pubsub}; @@ -79,7 +78,7 @@ where this.on_connection(addr, io); } Err(e) => { - error!(?addr, "Error connecting to publisher: {:?}", e); + error!(err = ?e, ?addr, "Error connecting to publisher"); } } @@ -202,10 +201,7 @@ where } Command::Connect { endpoint } => { if self.is_known(&endpoint) { - debug!( - ?endpoint, - "Publisher already known, ignoring connect command" - ); + debug!(?endpoint, "Publisher already known, ignoring connect command"); return; } @@ -269,10 +265,8 @@ where None => { return ( addr, - Err( - io::Error::new(io::ErrorKind::UnexpectedEof, "Connection closed") - .into(), - ), + Err(io::Error::new(io::ErrorKind::UnexpectedEof, "Connection closed") + .into()), ) } }; @@ -318,26 +312,20 @@ where tokio::spawn(publisher_session); for topic in self.subscribed_topics.iter() { - if publisher_channel - .try_send(SessionCommand::Subscribe(topic.clone())) - .is_err() - { + if publisher_channel.try_send(SessionCommand::Subscribe(topic.clone())).is_err() { error!(publisher = ?addr, "Error trying to subscribe to topic {topic} on startup: publisher channel closed / full"); } } - self.publishers.insert( - addr.clone(), - ConnectionState::Active { - channel: publisher_channel, - }, - ); + self.publishers + .insert(addr.clone(), ConnectionState::Active { channel: publisher_channel }); self.state.stats.insert(addr, session_stats); } - /// Polls all the publisher channels for new messages. On new messages, forwards them to the socket. - /// If a publisher channel is closed, it will be removed from the list of publishers. + /// Polls all the publisher channels for new messages. On new messages, forwards them to the + /// socket. If a publisher channel is closed, it will be removed from the list of + /// publishers. /// /// Returns `Poll::Ready` if any progress was made and this method should be called again. /// Returns `Poll::Pending` if no progress was made. @@ -357,14 +345,14 @@ where match try_decompress_payload(msg.compression_type, msg.payload) { Ok(decompressed) => msg.payload = decompressed, Err(e) => { - error!("Failed to decompress message: {:?}", e); + error!(err = ?e, "Failed to decompress message"); continue; } }; let msg = PubMessage::new(addr.clone(), msg.topic, msg.payload); - debug!(source = ?msg.source, "New message: {:?}", msg); + debug!(source = ?msg.source, ?msg, "New message"); // TODO: queuing if let Err(TrySendError::Full(msg)) = self.to_socket.try_send(msg) { error!( diff --git a/msg-socket/src/sub/mod.rs b/msg-socket/src/sub/mod.rs index b6d7a7d..ccf867e 100644 --- a/msg-socket/src/sub/mod.rs +++ b/msg-socket/src/sub/mod.rs @@ -1,19 +1,23 @@ +use std::{fmt, time::Duration}; + use bytes::Bytes; -use core::fmt; -use msg_transport::Address; -use msg_wire::pubsub; -use std::time::Duration; use thiserror::Error; mod driver; +use driver::SubDriver; + mod session; + mod socket; +pub use socket::*; + mod stats; +use stats::SocketStats; + mod stream; -use driver::SubDriver; -pub use socket::*; -use stats::SocketStats; +use msg_transport::Address; +use msg_wire::pubsub; const DEFAULT_BUFFER_SIZE: usize = 1024; @@ -68,8 +72,9 @@ impl SubOptions { self } - /// Sets the ingress buffer size. This is the maximum amount of incoming messages that will be buffered. - /// If the consumer cannot keep up with the incoming messages, messages will start being dropped. + /// Sets the ingress buffer size. This is the maximum amount of incoming messages that will be + /// buffered. If the consumer cannot keep up with the incoming messages, messages will start + /// being dropped. pub fn ingress_buffer_size(mut self, ingress_buffer_size: usize) -> Self { self.ingress_buffer_size = ingress_buffer_size; self @@ -124,11 +129,7 @@ impl fmt::Debug for PubMessage { impl PubMessage { pub fn new(source: A, topic: String, payload: Bytes) -> Self { - Self { - source, - topic, - payload, - } + Self { source, topic, payload } } #[inline] @@ -160,9 +161,7 @@ pub(crate) struct SocketState { impl SocketState { pub fn new() -> Self { - Self { - stats: SocketStats::new(), - } + Self { stats: SocketStats::new() } } } @@ -176,7 +175,7 @@ mod tests { net::TcpListener, }; use tokio_stream::StreamExt; - use tracing::Instrument; + use tracing::{info, info_span, Instrument}; use super::*; @@ -193,11 +192,11 @@ mod tests { let b = socket.read(&mut buf).await.unwrap(); let read = &buf[..b]; - tracing::info!("Received bytes: {:?}", read); + info!("Received bytes: {:?}", read); socket.write_all(read).await.unwrap(); socket.flush().await.unwrap(); } - .instrument(tracing::info_span!("listener")), + .instrument(info_span!("listener")), ); addr diff --git a/msg-socket/src/sub/session.rs b/msg-socket/src/sub/session.rs index e645338..8c81a80 100644 --- a/msg-socket/src/sub/session.rs +++ b/msg-socket/src/sub/session.rs @@ -35,8 +35,8 @@ pub(super) struct PublisherSession { stream: PublisherStream, /// The session stats stats: Arc, - /// Channel for bi-directional communication with the driver. Sends new messages from the associated - /// publisher and receives subscribe / unsubscribe commands. + /// Channel for bi-directional communication with the driver. Sends new messages from the + /// associated publisher and receives subscribe / unsubscribe commands. driver_channel: Channel, } @@ -63,19 +63,17 @@ impl PublisherSession { /// Queues a subscribe message for this publisher. /// On the next poll, the message will be attempted to be sent. fn subscribe(&mut self, topic: String) { - self.egress - .push_back(pubsub::Message::new_sub(Bytes::from(topic))); + self.egress.push_back(pubsub::Message::new_sub(Bytes::from(topic))); } /// Queues an unsubscribe message for this publisher. /// On the next poll, the message will be attempted to be sent. fn unsubscribe(&mut self, topic: String) { - self.egress - .push_back(pubsub::Message::new_unsub(Bytes::from(topic))); + self.egress.push_back(pubsub::Message::new_unsub(Bytes::from(topic))); } - /// Handles incoming messages. On a successful message, the session stats are updated and the message - /// is forwarded to the driver. + /// Handles incoming messages. On a successful message, the session stats are updated and the + /// message is forwarded to the driver. fn on_incoming(&mut self, incoming: Result) { match incoming { Ok(msg) => { @@ -84,12 +82,12 @@ impl PublisherSession { self.stats.increment_rx(msg.payload.len()); self.stats.update_latency(now.saturating_sub(msg.timestamp)); - if self.driver_channel.try_send(msg).is_err() { - warn!(addr = ?self.addr, "Failed to send message to driver"); + if let Err(e) = self.driver_channel.try_send(msg) { + warn!(err = ?e, addr = ?self.addr, "Failed to send message to driver"); } } Err(e) => { - error!(addr = ?self.addr, "Error receiving message: {:?}", e); + error!(err = ?e, addr = ?self.addr, "Error receiving message"); } } } diff --git a/msg-socket/src/sub/socket.rs b/msg-socket/src/sub/socket.rs index b767a32..d001b14 100644 --- a/msg-socket/src/sub/socket.rs +++ b/msg-socket/src/sub/socket.rs @@ -1,5 +1,3 @@ -use futures::Stream; -use rustc_hash::FxHashMap; use std::{ collections::HashSet, io, @@ -9,12 +7,15 @@ use std::{ sync::Arc, task::{Context, Poll}, }; + +use futures::Stream; +use rustc_hash::FxHashMap; use tokio::{ net::{lookup_host, ToSocketAddrs}, sync::mpsc, }; -use msg_common::task::JoinMap; +use msg_common::JoinMap; use msg_transport::{Address, Transport}; use super::{ @@ -253,10 +254,7 @@ where /// Sends a command to the driver, returning [`SubError::SocketClosed`] if the /// driver has been dropped. async fn send_command(&self, command: Command) -> Result<(), SubError> { - self.to_driver - .send(command) - .await - .map_err(|_| SubError::SocketClosed)?; + self.to_driver.send(command).await.map_err(|_| SubError::SocketClosed)?; Ok(()) } diff --git a/msg-socket/src/sub/stats.rs b/msg-socket/src/sub/stats.rs index 8f76bf5..acd438f 100644 --- a/msg-socket/src/sub/stats.rs +++ b/msg-socket/src/sub/stats.rs @@ -19,9 +19,7 @@ pub struct SocketStats { impl SocketStats { pub fn new() -> Self { - Self { - session_stats: RwLock::new(HashMap::new()), - } + Self { session_stats: RwLock::new(HashMap::new()) } } } @@ -38,19 +36,13 @@ impl SocketStats { #[inline] pub fn bytes_rx(&self, session_addr: &A) -> Option { - self.session_stats - .read() - .get(session_addr) - .map(|stats| stats.bytes_rx()) + self.session_stats.read().get(session_addr).map(|stats| stats.bytes_rx()) } /// Returns the average latency in microseconds for the given session. #[inline] pub fn avg_latency(&self, session_addr: &A) -> Option { - self.session_stats - .read() - .get(session_addr) - .map(|stats| stats.avg_latency()) + self.session_stats.read().get(session_addr).map(|stats| stats.avg_latency()) } } diff --git a/msg-socket/src/sub/stream.rs b/msg-socket/src/sub/stream.rs index 60fc450..2792403 100644 --- a/msg-socket/src/sub/stream.rs +++ b/msg-socket/src/sub/stream.rs @@ -76,12 +76,7 @@ impl Stream for PublisherStream { // TODO: this will allocate. Can we just return the `Cow`? let topic = String::from_utf8_lossy(&topic).to_string(); - TopicMessage { - compression_type, - timestamp, - topic, - payload, - } + TopicMessage { compression_type, timestamp, topic, payload } }))); } diff --git a/msg-socket/tests/README.md b/msg-socket/tests/README.md index 7558fd1..91caa51 100644 --- a/msg-socket/tests/README.md +++ b/msg-socket/tests/README.md @@ -1,6 +1,7 @@ # `msg-socket` tests ## Integration tests -| Test | Description | Status | -| ----------- | ---------------------------- | ------ | -| [`pubsub`](./it/pubsub.rs) | Different messaging patterns with different transports, chaos through simulated network links & random delay injection. | ✅ | + +| Test | Description | Status | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------- | ------ | +| [`pubsub`](./it/pubsub.rs) | Different messaging patterns with different transports, chaos through simulated network links & random delay injection. | ✅ | diff --git a/msg-socket/tests/it/pubsub.rs b/msg-socket/tests/it/pubsub.rs index f89d245..5d517c3 100644 --- a/msg-socket/tests/it/pubsub.rs +++ b/msg-socket/tests/it/pubsub.rs @@ -1,9 +1,11 @@ +use std::{collections::HashSet, net::IpAddr, time::Duration}; + use bytes::Bytes; use msg_sim::{Protocol, SimulationConfig, Simulator}; use rand::Rng; -use std::{collections::HashSet, net::IpAddr, time::Duration}; use tokio::{sync::mpsc, task::JoinSet}; use tokio_stream::StreamExt; +use tracing::info; use msg_socket::{PubSocket, SubSocket}; use msg_transport::{quic::Quic, tcp::Tcp, Address, Transport}; @@ -65,19 +67,17 @@ where publisher.try_bind(vec![addr]).await?; - // Spawn a task to keep sending messages until the subscriber receives one (after connection process) + // Spawn a task to keep sending messages until the subscriber receives one (after connection + // process) tokio::spawn(async move { loop { tokio::time::sleep(Duration::from_millis(500)).await; - publisher - .publish(TOPIC, Bytes::from("WORLD")) - .await - .unwrap(); + publisher.publish(TOPIC, Bytes::from("WORLD")).await.unwrap(); } }); let msg = subscriber.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!(TOPIC, msg.topic()); assert_eq!("WORLD", msg.payload()); @@ -137,7 +137,7 @@ async fn pubsub_fan_out_transport< subscriber.subscribe(TOPIC).await.unwrap(); let msg = subscriber.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!(TOPIC, msg.topic()); assert_eq!("WORLD", msg.payload()); }); @@ -147,14 +147,12 @@ async fn pubsub_fan_out_transport< publisher.try_bind(vec![addr]).await?; - // Spawn a task to keep sending messages until the subscriber receives one (after connection process) + // Spawn a task to keep sending messages until the subscriber receives one (after connection + // process) tokio::spawn(async move { loop { tokio::time::sleep(Duration::from_millis(500)).await; - publisher - .publish(TOPIC, Bytes::from("WORLD")) - .await - .unwrap(); + publisher.publish(TOPIC, Bytes::from("WORLD")).await.unwrap(); } }); @@ -219,14 +217,12 @@ async fn pubsub_fan_in_transport< let local_addr = publisher.local_addr().unwrap().clone(); tx.send(local_addr).await.unwrap(); - // Spawn a task to keep sending messages until the subscriber receives one (after connection process) + // Spawn a task to keep sending messages until the subscriber receives one (after + // connection process) tokio::spawn(async move { loop { tokio::time::sleep(Duration::from_millis(500)).await; - publisher - .publish(TOPIC, Bytes::from("WORLD")) - .await - .unwrap(); + publisher.publish(TOPIC, Bytes::from("WORLD")).await.unwrap(); } }); }); @@ -254,7 +250,7 @@ async fn pubsub_fan_in_transport< } let msg = subscriber.next().await.unwrap(); - tracing::info!("Received message: {:?}", msg); + info!("Received message: {:?}", msg); assert_eq!(TOPIC, msg.topic()); assert_eq!("WORLD", msg.payload()); diff --git a/msg-transport/src/ipc/mod.rs b/msg-transport/src/ipc/mod.rs index ad0ccce..1a72cb8 100644 --- a/msg-transport/src/ipc/mod.rs +++ b/msg-transport/src/ipc/mod.rs @@ -42,11 +42,7 @@ pub struct Ipc { impl Ipc { pub fn new(config: Config) -> Self { - Self { - config, - listener: None, - path: None, - } + Self { config, listener: None, path: None } } } @@ -115,7 +111,7 @@ impl Transport for Ipc { if let Err(e) = std::fs::remove_file(&addr) { return Err(io::Error::new( io::ErrorKind::Other, - format!("Failed to remove existing socket file: {}", e), + format!("Failed to remove existing socket file, {:?}", e), )); } } diff --git a/msg-transport/src/lib.rs b/msg-transport/src/lib.rs index 138c003..f817ab0 100644 --- a/msg-transport/src/lib.rs +++ b/msg-transport/src/lib.rs @@ -59,7 +59,8 @@ pub trait Transport { /// Binds to the given address. async fn bind(&mut self, addr: A) -> Result<(), Self::Error>; - /// Connects to the given address, returning a future representing a pending outbound connection. + /// Connects to the given address, returning a future representing a + /// pending outbound connection. fn connect(&mut self, addr: A) -> Self::Connect; /// Poll for incoming connections. If an inbound connection is received, a future representing @@ -84,10 +85,7 @@ pub struct Acceptor<'a, T, A> { impl<'a, T, A> Acceptor<'a, T, A> { fn new(inner: &'a mut T) -> Self { - Self { - inner, - _marker: PhantomData, - } + Self { inner, _marker: PhantomData } } } diff --git a/msg-transport/src/quic/config.rs b/msg-transport/src/quic/config.rs index 0cce0ff..396668a 100644 --- a/msg-transport/src/quic/config.rs +++ b/msg-transport/src/quic/config.rs @@ -113,11 +113,7 @@ where client_config.transport_config(transport); - Config { - endpoint_config: quinn::EndpointConfig::default(), - client_config, - server_config, - } + Config { endpoint_config: quinn::EndpointConfig::default(), client_config, server_config } } } @@ -170,10 +166,6 @@ impl Default for Config { client_config.transport_config(transport); - Self { - endpoint_config: quinn::EndpointConfig::default(), - client_config, - server_config, - } + Self { endpoint_config: quinn::EndpointConfig::default(), client_config, server_config } } } diff --git a/msg-transport/src/quic/mod.rs b/msg-transport/src/quic/mod.rs index 120f3ce..293ac1e 100644 --- a/msg-transport/src/quic/mod.rs +++ b/msg-transport/src/quic/mod.rs @@ -9,7 +9,7 @@ use std::{ use futures::future::BoxFuture; use thiserror::Error; use tokio::sync::mpsc::{self, Receiver}; -use tracing::error; +use tracing::{debug, error}; use crate::{Acceptor, Transport, TransportExt}; @@ -36,14 +36,16 @@ pub enum Error { ClosedEndpoint, } -/// A QUIC implementation built with [quinn] that implements the [`Transport`] and [`TransportExt`] traits. +/// A QUIC implementation built with [quinn] that implements the [`Transport`] and [`TransportExt`] +/// traits. /// /// # Note on multiplexing -/// This implementation does not yet support multiplexing. This means that each connection only supports a single -/// bi-directional stream, which is returned as the I/O object when connecting or accepting. +/// This implementation does not yet support multiplexing. This means that each connection only +/// supports a single bi-directional stream, which is returned as the I/O object when connecting or +/// accepting. /// -/// In a future release, we will add support for multiplexing, which will allow multiple streams per connection based on -/// socket requirements / semantics. +/// In a future release, we will add support for multiplexing, which will allow multiple streams per +/// connection based on socket requirements / semantics. #[derive(Debug, Default)] pub struct Quic { config: Config, @@ -56,15 +58,11 @@ pub struct Quic { impl Quic { /// Creates a new QUIC transport with the given configuration. pub fn new(config: Config) -> Self { - Self { - config, - endpoint: None, - incoming: None, - } + Self { config, endpoint: None, incoming: None } } - /// Creates a new [`quinn::Endpoint`] with the given configuration and a Tokio runtime. If no `addr` is given, - /// the endpoint will be bound to the default address. + /// Creates a new [`quinn::Endpoint`] with the given configuration and a Tokio runtime. If no + /// `addr` is given, the endpoint will be bound to the default address. fn new_endpoint( &self, addr: Option, @@ -106,8 +104,8 @@ impl Transport for Quic { Ok(()) } - /// Connects to the given address, returning a future representing a pending outbound connection. - /// If the endpoint is not bound, it will be bound to the default address. + /// Connects to the given address, returning a future representing a pending outbound + /// connection. If the endpoint is not bound, it will be bound to the default address. fn connect(&mut self, addr: SocketAddr) -> Self::Connect { // If we have an endpoint, use it. Otherwise, create a new one. let endpoint = if let Some(endpoint) = self.endpoint.clone() { @@ -128,22 +126,17 @@ impl Transport for Quic { // This `"l"` seems necessary because an empty string is an invalid domain // name. While we don't use domain names, the underlying rustls library // is based upon the assumption that we do. - let connection = endpoint - .connect_with(client_config, addr, "l")? - .await - .map_err(Error::from)?; + let connection = + endpoint.connect_with(client_config, addr, "l")?.await.map_err(Error::from)?; - tracing::debug!("Connected to {}, opening stream", addr); + debug!("Connected to {}, opening stream", addr); - // Open a bi-directional stream and return it. We'll think about multiplexing per topic later. + // Open a bi-directional stream and return it. We'll think about multiplexing per topic + // later. connection .open_bi() .await - .map(|(send, recv)| QuicStream { - peer: addr, - send, - recv, - }) + .map(|(send, recv)| QuicStream { peer: addr, send, recv }) .map_err(Error::from) }) } @@ -159,17 +152,18 @@ impl Transport for Quic { Some(Ok(connecting)) => { let peer = connecting.remote_address(); - tracing::debug!("New incoming connection from {}", peer); + debug!("New incoming connection from {}", peer); // Return a future that resolves to the output. return Poll::Ready(Box::pin(async move { let connection = connecting.await.map_err(Error::from)?; - tracing::debug!( + debug!( "Accepted connection from {}, opening stream", connection.remote_address() ); - // Accept a bi-directional stream and return it. We'll think about multiplexing per topic later. + // Accept a bi-directional stream and return it. We'll think about + // multiplexing per topic later. connection .accept_bi() .await @@ -185,8 +179,8 @@ impl Transport for Quic { } } } else { - // We need to set the incoming channel and spawn a task to accept incoming connections - // on the endpoint. + // We need to set the incoming channel and spawn a task to accept incoming + // connections on the endpoint. // Check if there's an endpoint bound. let Some(endpoint) = this.endpoint.clone() else { @@ -235,6 +229,7 @@ mod tests { io::{AsyncReadExt, AsyncWriteExt}, sync::oneshot, }; + use tracing::info; use super::*; @@ -245,13 +240,10 @@ mod tests { let config = Config::default(); let mut server = Quic::new(config.clone()); - server - .bind(SocketAddr::from(([127, 0, 0, 1], 0))) - .await - .unwrap(); + server.bind(SocketAddr::from(([127, 0, 0, 1], 0))).await.unwrap(); let server_addr = server.local_addr().unwrap(); - tracing::info!("Server bound on {:?}", server_addr); + info!("Server bound on {:?}", server_addr); let (tx, rx) = oneshot::channel(); @@ -260,7 +252,7 @@ mod tests { let mut stream = server.accept().await.unwrap(); - tracing::info!("Accepted connection"); + info!("Accepted connection"); let mut dst = [0u8; 5]; @@ -273,13 +265,13 @@ mod tests { let mut client = Quic::new(config); let mut stream = client.connect(server_addr).await.unwrap(); - tracing::info!("Connected to remote"); + info!("Connected to remote"); let item = b"Hello"; stream.write_all(item).await.unwrap(); stream.flush().await.unwrap(); - tracing::info!("Wrote to remote"); + info!("Wrote to remote"); let rcv = rx.await.unwrap(); assert_eq!(rcv, *item); @@ -299,7 +291,7 @@ mod tests { tokio::spawn(async move { let _stream = client.connect(addr).await.unwrap(); - tracing::info!("Connected to remote"); + info!("Connected to remote"); }); tokio::time::sleep(Duration::from_secs(17)).await; diff --git a/msg-transport/src/tcp/mod.rs b/msg-transport/src/tcp/mod.rs index 79ee39d..28a3ecf 100644 --- a/msg-transport/src/tcp/mod.rs +++ b/msg-transport/src/tcp/mod.rs @@ -5,6 +5,7 @@ use std::{ task::{Context, Poll}, }; use tokio::net::{TcpListener, TcpStream}; +use tracing::debug; use msg_common::async_error; @@ -22,10 +23,7 @@ pub struct Tcp { impl Tcp { pub fn new(config: Config) -> Self { - Self { - config, - listener: None, - } + Self { config, listener: None } } } @@ -74,7 +72,7 @@ impl Transport for Tcp { match listener.poll_accept(cx) { Poll::Ready(Ok((io, addr))) => { - tracing::debug!("Accepted connection from {}", addr); + debug!("Accepted connection from {}", addr); Poll::Ready(Box::pin(async move { io.set_nodelay(true)?; diff --git a/msg-wire/src/auth.rs b/msg-wire/src/auth.rs index 1de19b4..fa53a9f 100644 --- a/msg-wire/src/auth.rs +++ b/msg-wire/src/auth.rs @@ -31,9 +31,7 @@ impl Codec { /// codec in the `AuthReceive` state since it will be waiting for the /// client to send its ID. pub fn new_server() -> Self { - Self { - state: State::AuthReceive, - } + Self { state: State::AuthReceive } } } diff --git a/msg-wire/src/compression/gzip.rs b/msg-wire/src/compression/gzip.rs index afdb4b3..b26361c 100644 --- a/msg-wire/src/compression/gzip.rs +++ b/msg-wire/src/compression/gzip.rs @@ -23,10 +23,8 @@ impl Compressor for GzipCompressor { fn compress(&self, data: &[u8]) -> Result { // Optimistically allocate the compressed buffer to 1/4 of the original size. - let mut encoder = GzEncoder::new( - Vec::with_capacity(data.len() / 4), - Compression::new(self.level), - ); + let mut encoder = + GzEncoder::new(Vec::with_capacity(data.len() / 4), Compression::new(self.level)); encoder.write_all(data)?; diff --git a/msg-wire/src/compression/lz4.rs b/msg-wire/src/compression/lz4.rs index 36f8111..ca86238 100644 --- a/msg-wire/src/compression/lz4.rs +++ b/msg-wire/src/compression/lz4.rs @@ -26,10 +26,7 @@ pub struct Lz4Decompressor; impl Decompressor for Lz4Decompressor { fn decompress(&self, data: &[u8]) -> Result { let bytes = decompress_size_prepended(data).map_err(|e| { - io::Error::new( - io::ErrorKind::InvalidData, - format!("Lz4 decompression failed: {}", e), - ) + io::Error::new(io::ErrorKind::InvalidData, format!("Lz4 decompression failed: {}", e)) })?; Ok(Bytes::from(bytes)) diff --git a/msg-wire/src/compression/mod.rs b/msg-wire/src/compression/mod.rs index e8f0655..5934524 100644 --- a/msg-wire/src/compression/mod.rs +++ b/msg-wire/src/compression/mod.rs @@ -170,31 +170,19 @@ mod tests { let gzip = GzipCompressor::new(6); let (gzip_time, gzip_perf, gzip_comp) = compression_test(&data, gzip); - println!( - "gzip compression shrank the data by {:.2}% in {:?}", - gzip_perf, gzip_time - ); + println!("gzip compression shrank the data by {:.2}% in {:?}", gzip_perf, gzip_time); let zstd = ZstdCompressor::new(6); let (zstd_time, zstd_perf, zstd_comp) = compression_test(&data, zstd); - println!( - "zstd compression shrank the data by {:.2}% in {:?}", - zstd_perf, zstd_time - ); + println!("zstd compression shrank the data by {:.2}% in {:?}", zstd_perf, zstd_time); let snappy = SnappyCompressor; let (snappy_time, snappy_perf, snappy_comp) = compression_test(&data, snappy); - println!( - "snappy compression shrank the data by {:.2}% in {:?}", - snappy_perf, snappy_time - ); + println!("snappy compression shrank the data by {:.2}% in {:?}", snappy_perf, snappy_time); let lz4 = Lz4Compressor; let (lz4_time, lz4_perf, lz4_comp) = compression_test(&data, lz4); - println!( - "lz4 compression shrank the data by {:.2}% in {:?}", - lz4_perf, lz4_time - ); + println!("lz4 compression shrank the data by {:.2}% in {:?}", lz4_perf, lz4_time); println!("------ SSZ BLOCK -------"); @@ -225,31 +213,19 @@ mod tests { let gzip = GzipCompressor::new(6); let (gzip_time, gzip_perf, gzip_comp) = compression_test(&data, gzip); - println!( - "gzip compression shrank the data by {:.2}% in {:?}", - gzip_perf, gzip_time - ); + println!("gzip compression shrank the data by {:.2}% in {:?}", gzip_perf, gzip_time); let zstd = ZstdCompressor::new(6); let (zstd_time, zstd_perf, zstd_comp) = compression_test(&data, zstd); - println!( - "zstd compression shrank the data by {:.2}% in {:?}", - zstd_perf, zstd_time - ); + println!("zstd compression shrank the data by {:.2}% in {:?}", zstd_perf, zstd_time); let snappy = SnappyCompressor; let (snappy_time, snappy_perf, snappy_comp) = compression_test(&data, snappy); - println!( - "snappy compression shrank the data by {:.2}% in {:?}", - snappy_perf, snappy_time - ); + println!("snappy compression shrank the data by {:.2}% in {:?}", snappy_perf, snappy_time); let lz4 = Lz4Compressor; let (lz4_time, lz4_perf, lz4_comp) = compression_test(&data, lz4); - println!( - "lz4 compression shrank the data by {:.2}% in {:?}", - lz4_perf, lz4_time - ); + println!("lz4 compression shrank the data by {:.2}% in {:?}", lz4_perf, lz4_time); println!("------ BLOB TX ------"); diff --git a/msg-wire/src/pubsub.rs b/msg-wire/src/pubsub.rs index 41a0633..47842e8 100644 --- a/msg-wire/src/pubsub.rs +++ b/msg-wire/src/pubsub.rs @@ -214,7 +214,8 @@ impl Decoder for Codec { cursor += 2; - // We don't have enough bytes to read the topic and the rest of the data (timestamp u64, seq u32, size u32) + // We don't have enough bytes to read the topic and the rest of the data + // (timestamp u64, seq u32, size u32) if src.len() < cursor + topic_size as usize + 8 + 8 { return Ok(None); } @@ -244,10 +245,7 @@ impl Decoder for Codec { let header = header.take().unwrap(); let payload = src.split_to(header.size as usize); - let message = Message { - header, - payload: payload.freeze(), - }; + let message = Message { header, payload: payload.freeze() }; self.state = State::Header; return Ok(Some(message)); diff --git a/msg-wire/src/reqrep.rs b/msg-wire/src/reqrep.rs index fb11cd3..77c4337 100644 --- a/msg-wire/src/reqrep.rs +++ b/msg-wire/src/reqrep.rs @@ -23,14 +23,7 @@ pub struct Message { impl Message { #[inline] pub fn new(id: u32, compression_type: u8, payload: Bytes) -> Self { - Self { - header: Header { - id, - compression_type, - size: payload.len() as u32, - }, - payload, - } + Self { header: Header { id, compression_type, size: payload.len() as u32 }, payload } } #[inline] @@ -151,11 +144,8 @@ impl Decoder for Codec { src.advance(cursor); // Construct the header - let header = Header { - compression_type, - id: src.get_u32(), - size: src.get_u32(), - }; + let header = + Header { compression_type, id: src.get_u32(), size: src.get_u32() }; self.state = State::Payload(header); } @@ -165,10 +155,7 @@ impl Decoder for Codec { } let payload = src.split_to(header.size as usize); - let message = Message { - header, - payload: payload.freeze(), - }; + let message = Message { header, payload: payload.freeze() }; self.state = State::Header; return Ok(Some(message)); diff --git a/msg/benches/pubsub.rs b/msg/benches/pubsub.rs index af5ff7e..92076b8 100644 --- a/msg/benches/pubsub.rs +++ b/msg/benches/pubsub.rs @@ -43,10 +43,7 @@ impl + Send + Sync + Unpin + 'static, A: Address> PairBenchmark< let addr = self.publisher.local_addr().unwrap(); self.subscriber.connect_inner(addr.clone()).await.unwrap(); - self.subscriber - .subscribe("HELLO".to_string()) - .await - .unwrap(); + self.subscriber.subscribe("HELLO".to_string()).await.unwrap(); // Give some time for the background connection process to run tokio::time::sleep(Duration::from_millis(10)).await; @@ -146,10 +143,7 @@ fn generate_messages(n_reqs: usize, msg_size: usize) -> Vec { fn pubsub_single_thread_tcp(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); let buffer_size = 1024 * 64; @@ -163,9 +157,7 @@ fn pubsub_single_thread_tcp(c: &mut Criterion) { let subscriber = SubSocket::with_options( Tcp::default(), - SubOptions::default() - .read_buffer_size(buffer_size) - .ingress_buffer_size(N_REQS * 2), + SubOptions::default().read_buffer_size(buffer_size).ingress_buffer_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -190,11 +182,8 @@ fn pubsub_single_thread_tcp(c: &mut Criterion) { fn pubsub_multi_thread_tcp(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_multi_thread() - .worker_threads(4) - .enable_all() - .build() - .unwrap(); + let rt = + tokio::runtime::Builder::new_multi_thread().worker_threads(4).enable_all().build().unwrap(); let buffer_size = 1024 * 64; @@ -208,9 +197,7 @@ fn pubsub_multi_thread_tcp(c: &mut Criterion) { let subscriber = SubSocket::with_options( Tcp::default(), - SubOptions::default() - .read_buffer_size(buffer_size) - .ingress_buffer_size(N_REQS * 2), + SubOptions::default().read_buffer_size(buffer_size).ingress_buffer_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -235,10 +222,7 @@ fn pubsub_multi_thread_tcp(c: &mut Criterion) { fn pubsub_single_thread_quic(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); let buffer_size = 1024 * 64; @@ -252,9 +236,7 @@ fn pubsub_single_thread_quic(c: &mut Criterion) { let subscriber = SubSocket::with_options( Quic::default(), - SubOptions::default() - .read_buffer_size(buffer_size) - .ingress_buffer_size(N_REQS * 2), + SubOptions::default().read_buffer_size(buffer_size).ingress_buffer_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -279,11 +261,8 @@ fn pubsub_single_thread_quic(c: &mut Criterion) { fn pubsub_multi_thread_quic(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_multi_thread() - .worker_threads(4) - .enable_all() - .build() - .unwrap(); + let rt = + tokio::runtime::Builder::new_multi_thread().worker_threads(4).enable_all().build().unwrap(); let buffer_size = 1024 * 64; @@ -297,9 +276,7 @@ fn pubsub_multi_thread_quic(c: &mut Criterion) { let subscriber = SubSocket::with_options( Quic::default(), - SubOptions::default() - .read_buffer_size(buffer_size) - .ingress_buffer_size(N_REQS * 2), + SubOptions::default().read_buffer_size(buffer_size).ingress_buffer_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -324,10 +301,7 @@ fn pubsub_multi_thread_quic(c: &mut Criterion) { fn pubsub_single_thread_ipc(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); let buffer_size = 1024 * 64; @@ -341,9 +315,7 @@ fn pubsub_single_thread_ipc(c: &mut Criterion) { let subscriber = SubSocket::with_options( Ipc::default(), - SubOptions::default() - .read_buffer_size(buffer_size) - .ingress_buffer_size(N_REQS * 2), + SubOptions::default().read_buffer_size(buffer_size).ingress_buffer_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -368,11 +340,8 @@ fn pubsub_single_thread_ipc(c: &mut Criterion) { fn pubsub_multi_thread_ipc(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_multi_thread() - .worker_threads(4) - .enable_all() - .build() - .unwrap(); + let rt = + tokio::runtime::Builder::new_multi_thread().worker_threads(4).enable_all().build().unwrap(); let buffer_size = 1024 * 64; @@ -386,9 +355,7 @@ fn pubsub_multi_thread_ipc(c: &mut Criterion) { let subscriber = SubSocket::with_options( Ipc::default(), - SubOptions::default() - .read_buffer_size(buffer_size) - .ingress_buffer_size(N_REQS * 2), + SubOptions::default().read_buffer_size(buffer_size).ingress_buffer_size(N_REQS * 2), ); let mut bench = PairBenchmark { diff --git a/msg/benches/reqrep.rs b/msg/benches/reqrep.rs index 981b77a..8cd2d60 100644 --- a/msg/benches/reqrep.rs +++ b/msg/benches/reqrep.rs @@ -39,10 +39,7 @@ impl + Send + Sync + Unpin + 'static, A: Address> PairBenchmark< self.rt.block_on(async { rep.try_bind(vec![addr]).await.unwrap(); - self.req - .try_connect(rep.local_addr().unwrap().clone()) - .await - .unwrap(); + self.req.try_connect(rep.local_addr().unwrap().clone()).await.unwrap(); tokio::spawn(async move { rep.map(|req| async move { @@ -117,10 +114,7 @@ fn generate_requests(n_reqs: usize, msg_size: usize) -> Vec { fn reqrep_single_thread_tcp(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); let req = ReqSocket::with_options( Tcp::default(), @@ -150,11 +144,8 @@ fn reqrep_single_thread_tcp(c: &mut Criterion) { fn reqrep_multi_thread_tcp(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_multi_thread() - .worker_threads(4) - .enable_all() - .build() - .unwrap(); + let rt = + tokio::runtime::Builder::new_multi_thread().worker_threads(4).enable_all().build().unwrap(); let req = ReqSocket::with_options( Tcp::default(), @@ -184,10 +175,7 @@ fn reqrep_multi_thread_tcp(c: &mut Criterion) { fn reqrep_single_thread_ipc(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); + let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap(); let req = ReqSocket::new(Ipc::default()); let rep = RepSocket::new(Ipc::default()); @@ -213,11 +201,8 @@ fn reqrep_single_thread_ipc(c: &mut Criterion) { fn reqrep_multi_thread_ipc(c: &mut Criterion) { let _ = tracing_subscriber::fmt::try_init(); - let rt = tokio::runtime::Builder::new_multi_thread() - .worker_threads(4) - .enable_all() - .build() - .unwrap(); + let rt = + tokio::runtime::Builder::new_multi_thread().worker_threads(4).enable_all().build().unwrap(); let req = ReqSocket::new(Ipc::default()); let rep = RepSocket::new(Ipc::default()); diff --git a/msg/examples/durable.rs b/msg/examples/durable.rs index 42cff7e..088a352 100644 --- a/msg/examples/durable.rs +++ b/msg/examples/durable.rs @@ -5,27 +5,27 @@ use tokio::sync::oneshot; use tokio_stream::StreamExt; use msg::{tcp::Tcp, Authenticator, RepSocket, ReqOptions, ReqSocket}; -use tracing::Instrument; +use tracing::{error, info, info_span, instrument, warn, Instrument}; #[derive(Default)] struct Auth; impl Authenticator for Auth { fn authenticate(&self, id: &Bytes) -> bool { - tracing::info!("Auth request from: {:?}, authentication passed.", id); + info!("Auth request from: {:?}, authentication passed.", id); // Custom authentication logic true } } -#[tracing::instrument(name = "RepSocket")] +#[instrument(name = "RepSocket")] async fn start_rep() { // Initialize the reply socket (server side) with a transport // and an authenticator. let mut rep = RepSocket::new(Tcp::default()).with_auth(Auth); while rep.bind("0.0.0.0:4444").await.is_err() { rep = RepSocket::new(Tcp::default()).with_auth(Auth); - tracing::warn!("Failed to bind rep socket, retrying..."); + warn!("Failed to bind rep socket, retrying..."); tokio::time::sleep(Duration::from_secs(1)).await; } @@ -35,18 +35,13 @@ async fn start_rep() { loop { let req = rep.next().await.unwrap(); n_reqs += 1; - tracing::info!("Message: {:?}", req.msg()); + info!("Message: {:?}", req.msg()); let msg = String::from_utf8_lossy(req.msg()).to_string(); - let msg_id = msg - .split_whitespace() - .nth(1) - .unwrap() - .parse::() - .unwrap(); + let msg_id = msg.split_whitespace().nth(1).unwrap().parse::().unwrap(); if n_reqs == 5 { - tracing::warn!( + warn!( "RepSocket received the 5th request, dropping the request to trigger a timeout..." ); @@ -66,9 +61,7 @@ async fn main() { // and an identifier. This will implicitly turn on client authentication. let mut req = ReqSocket::with_options( Tcp::default(), - ReqOptions::default() - .timeout(Duration::from_secs(4)) - .auth_token(Bytes::from("client1")), + ReqOptions::default().timeout(Duration::from_secs(4)).auth_token(Bytes::from("client1")), ); let (tx, rx) = oneshot::channel(); @@ -79,9 +72,9 @@ async fn main() { req.connect("0.0.0.0:4444").await.unwrap(); for i in 0..10 { - tracing::info!("Sending request {i}..."); + info!("Sending request {i}..."); if i == 0 { - tracing::warn!("At this point the RepSocket is not running yet, so the request will block while \ + warn!("At this point the RepSocket is not running yet, so the request will block while \ the ReqSocket continues to establish a connection. The RepSocket will be started in 3 seconds."); } @@ -91,30 +84,30 @@ async fn main() { match req.request(Bytes::from(msg.clone())).await { Ok(res) => break res, Err(e) => { - tracing::error!("Request failed: {:?}, retrying...", e); + error!(err = ?e, "Request failed, retrying..."); tokio::time::sleep(Duration::from_millis(1000)).await; } } }; - tracing::info!("Response: {:?}", res); + info!("Response: {:?}", res); tokio::time::sleep(Duration::from_millis(1000)).await; } tx.send(true).unwrap(); } - .instrument(tracing::info_span!("ReqSocket")), + .instrument(info_span!("ReqSocket")), ); tokio::time::sleep(Duration::from_secs(3)).await; - tracing::info!("=========================="); - tracing::info!("Starting the RepSocket now"); - tracing::info!("=========================="); + info!("=========================="); + info!("Starting the RepSocket now"); + info!("=========================="); tokio::spawn(start_rep()); // Wait for the client to finish rx.await.unwrap(); - tracing::info!("DONE. Sent all 10 PINGS and received 10 PONGS."); + info!("DONE. Sent all 10 PINGS and received 10 PONGS."); } diff --git a/msg/examples/pubsub.rs b/msg/examples/pubsub.rs index 6b8c54c..13001aa 100644 --- a/msg/examples/pubsub.rs +++ b/msg/examples/pubsub.rs @@ -2,7 +2,7 @@ use bytes::Bytes; use futures::StreamExt; use std::time::Duration; use tokio::time::timeout; -use tracing::Instrument; +use tracing::{info, info_span, warn, Instrument}; use msg::{tcp::Tcp, PubOptions, PubSocket, SubOptions, SubSocket}; @@ -19,10 +19,8 @@ async fn main() { ); // Configure the subscribers with options - let mut sub1 = SubSocket::with_options( - Tcp::default(), - SubOptions::default().ingress_buffer_size(1024), - ); + let mut sub1 = + SubSocket::with_options(Tcp::default(), SubOptions::default().ingress_buffer_size(1024)); let mut sub2 = SubSocket::with_options( // TCP transport with blocking connect, usually connection happens in the background. @@ -34,59 +32,56 @@ async fn main() { pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); - tracing::info!("Publisher listening on: {}", pub_addr); + info!("Publisher listening on: {}", pub_addr); sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); - tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); + info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); sub2.connect(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); - tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); + info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); let t1 = tokio::spawn( async move { loop { - // Wait for a message to arrive, or timeout after 2 seconds. If the unsubscription was succesful, - // we should time out after the 10th message. + // Wait for a message to arrive, or timeout after 2 seconds. If the unsubscription + // was succesful, we should time out after the 10th message. let Ok(Some(recv)) = timeout(Duration::from_millis(2000), sub1.next()).await else { - tracing::warn!("Timeout waiting for message, stopping sub1"); + warn!("Timeout waiting for message, stopping sub1"); break; }; let string = bytes_to_string(recv.clone().into_payload()); - tracing::info!("Received message: {}", string); + info!("Received message: {}", string); if string.contains("10") { - tracing::warn!("Received message 10, unsubscribing..."); + warn!("Received message 10, unsubscribing..."); sub1.unsubscribe("HELLO_TOPIC".to_string()).await.unwrap(); } } } - .instrument(tracing::info_span!("sub1")), + .instrument(info_span!("sub1")), ); let t2 = tokio::spawn( async move { loop { let Ok(Some(recv)) = timeout(Duration::from_millis(1000), sub2.next()).await else { - tracing::warn!("Timeout waiting for message, stopping sub2"); + warn!("Timeout waiting for message, stopping sub2"); break; }; let string = bytes_to_string(recv.clone().into_payload()); - tracing::info!("Received message: {}", string); + info!("Received message: {}", string); } } - .instrument(tracing::info_span!("sub2")), + .instrument(info_span!("sub2")), ); for i in 0..20 { tokio::time::sleep(Duration::from_millis(300)).await; - pub_socket - .publish("HELLO_TOPIC".to_string(), format!("Message {i}").into()) - .await - .unwrap(); + pub_socket.publish("HELLO_TOPIC".to_string(), format!("Message {i}").into()).await.unwrap(); } let _ = tokio::join!(t1, t2); diff --git a/msg/examples/pubsub_auth.rs b/msg/examples/pubsub_auth.rs index 59d7208..b29520f 100644 --- a/msg/examples/pubsub_auth.rs +++ b/msg/examples/pubsub_auth.rs @@ -3,7 +3,7 @@ use futures::StreamExt; use msg_socket::SubOptions; use std::time::Duration; use tokio::time::timeout; -use tracing::Instrument; +use tracing::{info, info_span, warn, Instrument}; use msg::{tcp::Tcp, Authenticator, PubSocket, SubSocket}; @@ -12,12 +12,12 @@ struct Auth; impl Authenticator for Auth { fn authenticate(&self, id: &Bytes) -> bool { - tracing::info!("Auth request from: {:?}", id); + info!("Auth request from: {:?}", id); if id.as_ref() == b"client1" { - tracing::info!("Client authenticated: {:?}", id); + info!("Client authenticated: {:?}", id); true } else { - tracing::warn!("Unknown client: {:?}", id); + warn!("Unknown client: {:?}", id); false } } @@ -48,59 +48,56 @@ async fn main() { pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); - tracing::info!("Publisher listening on: {}", pub_addr); + info!("Publisher listening on: {}", pub_addr); sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); - tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); + info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); sub2.connect(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); - tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); + info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); let t1 = tokio::spawn( async move { loop { - // Wait for a message to arrive, or timeout after 2 seconds. If the unsubscription was succesful, - // we should time out after the 10th message. + // Wait for a message to arrive, or timeout after 2 seconds. If the unsubscription + // was succesful, we should time out after the 10th message. let Ok(Some(recv)) = timeout(Duration::from_millis(2000), sub1.next()).await else { - tracing::warn!("Timeout waiting for message, stopping sub1"); + warn!("Timeout waiting for message, stopping sub1"); break; }; let string = bytes_to_string(recv.clone().into_payload()); - tracing::info!("Received message: {}", string); + info!("Received message: {}", string); if string.contains("10") { - tracing::warn!("Received message 10, unsubscribing..."); + warn!("Received message 10, unsubscribing..."); sub1.unsubscribe("HELLO_TOPIC".to_string()).await.unwrap(); } } } - .instrument(tracing::info_span!("sub1")), + .instrument(info_span!("sub1")), ); let t2 = tokio::spawn( async move { loop { let Ok(Some(recv)) = timeout(Duration::from_millis(1000), sub2.next()).await else { - tracing::warn!("Timeout waiting for message, stopping sub2"); + warn!("Timeout waiting for message, stopping sub2"); break; }; let string = bytes_to_string(recv.clone().into_payload()); - tracing::info!("Received message: {}", string); + info!("Received message: {}", string); } } - .instrument(tracing::info_span!("sub2")), + .instrument(info_span!("sub2")), ); for i in 0..20 { tokio::time::sleep(Duration::from_millis(300)).await; - pub_socket - .publish("HELLO_TOPIC".to_string(), format!("Message {i}").into()) - .await - .unwrap(); + pub_socket.publish("HELLO_TOPIC".to_string(), format!("Message {i}").into()).await.unwrap(); } let _ = tokio::join!(t1, t2); diff --git a/msg/examples/pubsub_compression.rs b/msg/examples/pubsub_compression.rs index 37cfaa9..8d8087c 100644 --- a/msg/examples/pubsub_compression.rs +++ b/msg/examples/pubsub_compression.rs @@ -2,7 +2,7 @@ use bytes::Bytes; use std::time::Duration; use tokio::time::timeout; use tokio_stream::StreamExt; -use tracing::Instrument; +use tracing::{info, info_span, warn, Instrument}; use msg::{compression::GzipCompressor, tcp::Tcp, PubSocket, SubSocket}; @@ -24,59 +24,56 @@ async fn main() { pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); - tracing::info!("Publisher listening on: {}", pub_addr); + info!("Publisher listening on: {}", pub_addr); sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); - tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); + info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); sub2.connect(pub_addr).await.unwrap(); sub2.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); - tracing::info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); + info!("Subscriber 2 connected and subscribed to HELLO_TOPIC"); let t1 = tokio::spawn( async move { loop { - // Wait for a message to arrive, or timeout after 2 seconds. If the unsubscription was succesful, - // we should time out after the 10th message. + // Wait for a message to arrive, or timeout after 2 seconds. If the unsubscription + // was succesful, we should time out after the 10th message. let Ok(Some(recv)) = timeout(Duration::from_millis(2000), sub1.next()).await else { - tracing::warn!("Timeout waiting for message, stopping sub1"); + warn!("Timeout waiting for message, stopping sub1"); break; }; let string = bytes_to_string(recv.clone().into_payload()); - tracing::info!("Received message: {}", string); + info!("Received message: {}", string); if string.contains("10") { - tracing::warn!("Received message 10, unsubscribing..."); + warn!("Received message 10, unsubscribing..."); sub1.unsubscribe("HELLO_TOPIC".to_string()).await.unwrap(); } } } - .instrument(tracing::info_span!("sub1")), + .instrument(info_span!("sub1")), ); let t2 = tokio::spawn( async move { loop { let Ok(Some(recv)) = timeout(Duration::from_millis(1000), sub2.next()).await else { - tracing::warn!("Timeout waiting for message, stopping sub2"); + warn!("Timeout waiting for message, stopping sub2"); break; }; let string = bytes_to_string(recv.clone().into_payload()); - tracing::info!("Received message: {}", string); + info!("Received message: {}", string); } } - .instrument(tracing::info_span!("sub2")), + .instrument(info_span!("sub2")), ); for i in 0..20 { tokio::time::sleep(Duration::from_millis(300)).await; - pub_socket - .publish("HELLO_TOPIC".to_string(), format!("Message {i}").into()) - .await - .unwrap(); + pub_socket.publish("HELLO_TOPIC".to_string(), format!("Message {i}").into()).await.unwrap(); } let _ = tokio::join!(t1, t2); diff --git a/msg/examples/quic_vs_tcp.rs b/msg/examples/quic_vs_tcp.rs index f720f70..f307ed6 100644 --- a/msg/examples/quic_vs_tcp.rs +++ b/msg/examples/quic_vs_tcp.rs @@ -2,6 +2,7 @@ use bytes::Bytes; use futures::StreamExt; use msg_transport::{quic::Quic, Transport}; use std::time::{Duration, Instant}; +use tracing::info; use msg::{tcp::Tcp, Address, PubOptions, PubSocket, SubOptions, SubSocket}; @@ -24,21 +25,19 @@ async fn run_tcp() { ); // Configure the subscribers with options - let mut sub1 = SubSocket::with_options( - Tcp::default(), - SubOptions::default().ingress_buffer_size(1024), - ); + let mut sub1 = + SubSocket::with_options(Tcp::default(), SubOptions::default().ingress_buffer_size(1024)); tracing::info!("Setting up the sockets..."); pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); - tracing::info!("Publisher listening on: {}", pub_addr); + info!("Publisher listening on: {}", pub_addr); sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); - tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); + info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); tokio::time::sleep(Duration::from_millis(1000)).await; @@ -66,12 +65,12 @@ async fn run_quic() { pub_socket.bind("127.0.0.1:0").await.unwrap(); let pub_addr = pub_socket.local_addr().unwrap(); - tracing::info!("Publisher listening on: {}", pub_addr); + info!("Publisher listening on: {}", pub_addr); sub1.connect(pub_addr).await.unwrap(); sub1.subscribe("HELLO_TOPIC".to_string()).await.unwrap(); - tracing::info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); + info!("Subscriber 1 connected and subscribed to HELLO_TOPIC"); tokio::time::sleep(Duration::from_millis(1000)).await; @@ -90,14 +89,11 @@ async fn run_transfer + Send + Unpin + 'static, A: Address>( for _ in 0..100 { let start = Instant::now(); - pub_socket - .publish("HELLO_TOPIC".to_string(), data.clone()) - .await - .unwrap(); + pub_socket.publish("HELLO_TOPIC".to_string(), data.clone()).await.unwrap(); let recv = sub_socket.next().await.unwrap(); let elapsed = start.elapsed(); - tracing::info!("{} transfer took {:?}", transport, elapsed); + info!("{} transfer took {:?}", transport, elapsed); assert_eq!(recv.into_payload(), data); tokio::time::sleep(Duration::from_secs(1)).await; diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..68c3c93 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,11 @@ +reorder_imports = true +imports_granularity = "Crate" +use_small_heuristics = "Max" +comment_width = 100 +wrap_comments = true +binop_separator = "Back" +trailing_comma = "Vertical" +trailing_semicolon = false +use_field_init_shorthand = true +format_code_in_doc_comments = true +doc_comment_code_block_width = 100