Skip to content

Commit

Permalink
Efficiency improvements
Browse files Browse the repository at this point in the history
- reads now require less copying in most situations
- internal code quality improvements
  • Loading branch information
alfred-hodler committed May 3, 2024
1 parent 9b123d1 commit 0f32170
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 122 deletions.
146 changes: 96 additions & 50 deletions src/message_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ pub struct MessageStream<T: Read + Write> {
last_write: Instant,
}

#[derive(Debug)]
pub enum ReadError {
/// A malformed message was received.
MalformedMessage,
/// End of stream has been reached (closed stream).
EndOfStream,
/// The stream produced an I/O error.
Error(io::Error),
}

impl<T: Read + Write> MessageStream<T> {
pub fn new(stream: T, config: StreamConfig) -> Self {
Self {
Expand All @@ -70,23 +80,65 @@ impl<T: Read + Write> MessageStream<T> {
}
}

/// Reads some bytes from the underlying reader and places them into the internal buffer for
/// future reassembly. Returns the number of bytes read.
pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let read = self.stream.read(buf)?;
self.rx_msg_buf.extend_from_slice(&buf[..read]);
Ok(read)
}
/// Receives as many messages as possible, either until reads would start blocking, or until
/// an error is encountered. Encountering an error means that the stream should be discarded.
pub fn read<M: Message, F: Fn(M)>(
&mut self,
rx_buf: &mut [u8],
on_msg: F,
) -> Result<(), ReadError> {
'read: loop {
match self.stream.read(rx_buf).map(|read| &rx_buf[..read]) {
Ok(&[]) => break 'read Err(ReadError::EndOfStream),

Ok(received) => {
if !self.rx_msg_buf.is_empty() {
self.rx_msg_buf.extend_from_slice(received);
'decode: loop {
if !self.rx_msg_buf.is_empty() {
match M::decode(&self.rx_msg_buf) {
Ok((message, consumed)) => {
self.rx_msg_buf.drain(..consumed);
on_msg(message);
}
Err(DecodeError::NotEnoughData) => break 'decode,
Err(DecodeError::MalformedMessage) => {
break 'read Err(ReadError::MalformedMessage)
}
}
} else {
break 'decode;
}
}
} else {
let mut next_from = 0;
'decode: loop {
let next = &received[next_from..];
if !next.is_empty() {
match M::decode(next) {
Ok((message, consumed)) => {
on_msg(message);
next_from += consumed;
}
Err(DecodeError::NotEnoughData) => {
self.rx_msg_buf.extend_from_slice(next);
break 'decode;
}
Err(DecodeError::MalformedMessage) => {
break 'read Err(ReadError::MalformedMessage);
}
}
} else {
break 'decode;
}
}
}
}

Err(err) if err.kind() == io::ErrorKind::WouldBlock => break 'read Ok(()),

/// Reassembles the next message from the internal buffer and returns it. Reassembly can fail
/// for several reasons: the buffer contains only a partial message, the message is malformed etc.
pub fn receive_message<M: Message>(&mut self) -> Result<M, DecodeError> {
match M::decode(&self.rx_msg_buf) {
Ok((message, consumed)) => {
self.rx_msg_buf.drain(..consumed);
Ok(message)
Err(err) => break 'read Err(ReadError::Error(err)),
}
Err(err) => Err(err),
}
}

Expand Down Expand Up @@ -292,6 +344,7 @@ mod queue_points {

#[cfg(test)]
mod test {
use std::cell::RefCell;
use std::io::Cursor;

use super::*;
Expand Down Expand Up @@ -323,50 +376,43 @@ mod test {
cursor.set_position(0);

let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
let read = conn.read(&mut buf).unwrap();

assert_eq!(read, 16);
assert_eq!(conn.receive_message(), Ok(Ping(0)));
assert_eq!(conn.receive_message(), Ok(Ping(1)));
assert_eq!(
conn.receive_message::<Ping>(),
Err(DecodeError::NotEnoughData)
);

let received: RefCell<Vec<Ping>> = Default::default();
let err = conn.read(&mut buf, |message| {
received.borrow_mut().push(message);
});

assert_eq!(received.borrow()[0], Ping(0));
assert_eq!(received.borrow()[1], Ping(1));
assert!(matches!(err, Err(ReadError::EndOfStream)));
assert_eq!(conn.stream.position(), 16);
assert!(conn.rx_msg_buf.is_empty());
}

#[test]
fn reassemble_message_partial_reads() {
let mut buf = [0; 1024];
let mut cursor = Cursor::new(Vec::<u8>::new());
let mut cursor = Cursor::new(Vec::new());
let mut conn = MessageStream::new(&mut cursor, StreamConfig::default());
let mut serialized = Vec::new();
Ping(0).encode(&mut serialized);

let pos = conn.stream.position();
conn.stream.write_all(&serialized[..2]).unwrap();
conn.stream.set_position(pos);
assert_eq!(2, conn.read(&mut buf).unwrap());
assert_eq!(
conn.receive_message::<Ping>(),
Err(DecodeError::NotEnoughData)
);

let pos = conn.stream.position();
conn.stream.write_all(&serialized[2..5]).unwrap();
conn.stream.set_position(pos);
assert_eq!(3, conn.read(&mut buf).unwrap());
assert_eq!(
conn.receive_message::<Ping>(),
Err(DecodeError::NotEnoughData)
);

let pos = conn.stream.position();
conn.stream.write_all(&serialized[5..]).unwrap();
conn.stream.set_position(pos);
assert_eq!(3, conn.read(&mut buf).unwrap());
assert_eq!(conn.receive_message(), Ok(Ping(0)));
Ping(u64::MAX - 1).encode(&mut serialized);
Ping(u64::MAX).encode(&mut serialized);

let received: RefCell<Vec<Ping>> = Default::default();

conn.stream.get_mut().extend_from_slice(&serialized[..4]);
let _ = conn.read(&mut buf, |message| {
received.borrow_mut().push(message);
});
assert!(received.borrow().is_empty());
assert_eq!(conn.rx_msg_buf.len(), 4);

conn.stream.get_mut().extend_from_slice(&serialized[4..]);
let _ = conn.read(&mut buf, |message| {
received.borrow_mut().push(message);
});
assert_eq!(received.borrow()[0], Ping(u64::MAX - 1));
assert_eq!(received.borrow()[1], Ping(u64::MAX));
}

#[test]
Expand Down
92 changes: 20 additions & 72 deletions src/reactor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use slab::Slab;

use crate::connector::{self, Connector, IntoTarget};
use crate::message_stream::{self, MessageStream};
use crate::{Config, DecodeError, Message, PeerId};
use crate::{Config, Message, PeerId};

#[cfg(not(feature = "async"))]
use crossbeam_channel::{Receiver, Sender};
Expand Down Expand Up @@ -304,7 +304,7 @@ where

let mut connections: Slab<Connection<T>> = Slab::with_capacity(16);
let mut events = Events::with_capacity(1024);
let mut read_buf = [0; 1024 * 1024];
let mut read_buf = vec![0; 1024 * 1024];
let mut last_maintenance = Instant::now();
let mut token_map: IntMap<Token> = IntMap::new();
let mut next_peer_id: u64 = 0;
Expand Down Expand Up @@ -485,80 +485,30 @@ where
if event.is_readable() {
log::trace!("readable: peer {peer}");

'read: loop {
let read_result = connection.stream.read(&mut read_buf);

'decode: loop {
match connection.stream.receive_message::<M>() {
Ok(message) => {
log::debug!("read: peer {peer}: message={:?}", message);

sender.send(Event::Message { peer, message });
}

Err(DecodeError::MalformedMessage) => {
log::info!("read: peer {peer}: codec violation");

remove_stream(
poll.registry(),
&mut connections,
&mut token_map,
peer,
)?;

sender.send(Event::Disconnected {
peer,
reason: DisconnectReason::CodecViolation,
});

continue 'events;
}

Err(DecodeError::NotEnoughData) => break 'decode,
if let Err(err) = connection.stream.read(&mut read_buf, |message| {
log::debug!("read: peer {peer}: message={:?}", message);
sender.send(Event::Message { peer, message });
}) {
let reason = match err {
message_stream::ReadError::MalformedMessage => {
log::info!("read: peer {peer}: codec violation");
DisconnectReason::CodecViolation
}
}

match read_result {
Ok(0) => {
message_stream::ReadError::EndOfStream => {
log::debug!("peer {peer}: peer left");

remove_stream(
poll.registry(),
&mut connections,
&mut token_map,
peer,
)?;

sender.send(Event::Disconnected {
peer,
reason: DisconnectReason::Left,
});

continue 'events;
DisconnectReason::Left
}
message_stream::ReadError::Error(err) => {
log::debug!("write: peer {peer}: IO error: {err}");
DisconnectReason::Error(err)
}
};

Ok(_) => continue 'read,

Err(err) if would_block(&err) => break 'read,

Err(err) => {
log::debug!("peer {peer}: IO error: {err}");

remove_stream(
poll.registry(),
&mut connections,
&mut token_map,
peer,
)?;
remove_stream(poll.registry(), &mut connections, &mut token_map, peer)?;

sender.send(Event::Disconnected {
peer,
reason: DisconnectReason::Error(err),
});
sender.send(Event::Disconnected { peer, reason });

continue 'events;
}
}
continue 'events;
}
}

Expand All @@ -575,8 +525,6 @@ where
)?;
}

Err(err) if would_block(&err) => {}

Err(err) => {
log::debug!("write: peer {peer}: IO error: {err}");

Expand Down

0 comments on commit 0f32170

Please sign in to comment.