Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add try_write to TcpStream #176

Merged
merged 5 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,23 @@ impl TcpStream {
Ok(TcpStream::new(pair, rx))
}

/// Try to write a buffer to the stream, returning how many bytes were
/// written.
///
/// The function will attempt to write the entire contents of `buf`, but
/// only part of the buffer may be written.
///
/// This function is usually paired with `writable()`.
///
/// # Return
///
/// If data is successfully written, `Ok(n)` is returned, where `n` is the
/// number of bytes written. If the stream is not ready to write data,
/// `Err(io::ErrorKind::WouldBlock)` is returned.
pub fn try_write(&self, buf: &[u8]) -> Result<usize> {
self.write_half.try_write(buf)
}

/// Returns the local address that this stream is bound to.
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.read_half.pair.local)
Expand All @@ -110,6 +127,21 @@ impl TcpStream {
}
}

/// Waits for the socket to become writable.
///
/// This function is equivalent to `ready(Interest::WRITABLE)` and is usually
/// paired with `try_write()`.
///
/// # Cancel safety
///
/// This method is cancel safe. Once a readiness event occurs, the method
/// will continue to return immediately until the readiness event is
/// consumed by an attempt to write that fails with `WouldBlock` or
/// `Poll::Pending`.
pub async fn writable(&self) -> Result<()> {
Ok(())
}

/// Splits a `TcpStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently.
///
Expand Down Expand Up @@ -214,29 +246,28 @@ pub(crate) struct WriteHalf {
}

impl WriteHalf {
fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
fn try_write(&self, buf: &[u8]) -> Result<usize> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(0));
return Ok(0);
}

if self.is_shutdown {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Broken pipe",
)));
return Err(io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"));
}

let res = World::current(|world| {
World::current(|world| {
let bytes = Bytes::copy_from_slice(buf);
let len = bytes.len();

let seq = self.seq(world)?;
self.send(world, Segment::Data(seq, bytes))?;

Ok(len)
});
})
}

Poll::Ready(res)
fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
Poll::Ready(self.try_write(buf))
}

fn poll_shutdown_priv(&mut self) -> Poll<Result<()>> {
Expand Down
5 changes: 1 addition & 4 deletions tests/async_send_sync.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
//! Copied over from:
//! https://github.com/tokio-rs/tokio/blob/master/tokio/tests/async_send_sync.rs
#![allow(dead_code)]

#[allow(dead_code)]
fn require_send<T: Send>(_t: &T) {}
#[allow(dead_code)]
fn require_sync<T: Sync>(_t: &T) {}
#[allow(dead_code)]
fn require_unpin<T: Unpin>(_t: &T) {}

#[allow(dead_code)]
struct Invalid;

trait AmbiguousIfSend<A> {
Expand Down
24 changes: 24 additions & 0 deletions tests/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1162,3 +1162,27 @@ fn exhaust_ephemeral_ports() {

_ = sim.run()
}

#[test]
fn try_write() -> Result {
let mut sim = Builder::new().build();
sim.client("client", async move {
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 1234)).await?;

tokio::spawn(async move {
let (socket, _) = listener.accept().await.unwrap();

let written = socket.try_write(b"hello!").unwrap();
assert_eq!(written, 6);
});

let mut socket = TcpStream::connect((Ipv4Addr::LOCALHOST, 1234)).await?;
let mut buf: [u8; 6] = [0; 6];
socket.read_exact(&mut buf).await?;
assert_eq!(&buf, b"hello!");

Ok(())
});

sim.run()
}