Skip to content

Commit

Permalink
Fix race condition in test-utils CloseableCursor implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Diggsey authored and jbr committed Feb 4, 2021
1 parent 4d8bb97 commit 54d8b67
Showing 1 changed file with 44 additions and 34 deletions.
78 changes: 44 additions & 34 deletions tests/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,47 @@ pub struct TestIO {
}

#[derive(Default)]
pub struct CloseableCursor {
data: RwLock<Vec<u8>>,
cursor: RwLock<usize>,
waker: RwLock<Option<Waker>>,
closed: RwLock<bool>,
struct CloseableCursorInner {
data: Vec<u8>,
cursor: usize,
waker: Option<Waker>,
closed: bool,
}

#[derive(Default)]
pub struct CloseableCursor(RwLock<CloseableCursorInner>);

impl CloseableCursor {
fn len(&self) -> usize {
self.data.read().unwrap().len()
pub fn len(&self) -> usize {
self.0.read().unwrap().data.len()
}

pub fn cursor(&self) -> usize {
self.0.read().unwrap().cursor
}

fn cursor(&self) -> usize {
*self.cursor.read().unwrap()
pub fn is_empty(&self) -> bool {
self.len() == 0
}

fn current(&self) -> bool {
self.len() == self.cursor()
pub fn current(&self) -> bool {
let inner = self.0.read().unwrap();
inner.data.len() == inner.cursor
}

fn close(&self) {
*self.closed.write().unwrap() = true;
pub fn close(&self) {
let mut inner = self.0.write().unwrap();
inner.closed = true;
if let Some(waker) = inner.waker.take() {
waker.wake();
}
}
}

impl Display for CloseableCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let data = &*self.data.read().unwrap();
let s = std::str::from_utf8(data).unwrap_or("not utf8");
let inner = self.0.read().unwrap();
let s = std::str::from_utf8(&inner.data).unwrap_or("not utf8");
write!(f, "{}", s)
}
}
Expand Down Expand Up @@ -163,13 +175,14 @@ impl TestIO {

impl Debug for CloseableCursor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.0.read().unwrap();
f.debug_struct("CloseableCursor")
.field(
"data",
&std::str::from_utf8(&self.data.read().unwrap()).unwrap_or("not utf8"),
&std::str::from_utf8(&inner.data).unwrap_or("not utf8"),
)
.field("closed", &*self.closed.read().unwrap())
.field("cursor", &*self.cursor.read().unwrap())
.field("closed", &inner.closed)
.field("cursor", &inner.cursor)
.finish()
}
}
Expand All @@ -180,18 +193,17 @@ impl Read for &CloseableCursor {
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let len = self.len();
let cursor = self.cursor();
if cursor < len {
let data = &*self.data.read().unwrap();
let bytes_to_copy = buf.len().min(len - cursor);
buf[..bytes_to_copy].copy_from_slice(&data[cursor..cursor + bytes_to_copy]);
*self.cursor.write().unwrap() += bytes_to_copy;
let mut inner = self.0.write().unwrap();
if inner.cursor < inner.data.len() {
let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor);
buf[..bytes_to_copy]
.copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]);
inner.cursor += bytes_to_copy;
Poll::Ready(Ok(bytes_to_copy))
} else if *self.closed.read().unwrap() {
} else if inner.closed {
Poll::Ready(Ok(0))
} else {
*self.waker.write().unwrap() = Some(cx.waker().clone());
inner.waker = Some(cx.waker().clone());
Poll::Pending
}
}
Expand All @@ -203,11 +215,12 @@ impl Write for &CloseableCursor {
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if *self.closed.read().unwrap() {
let mut inner = self.0.write().unwrap();
if inner.closed {
Poll::Ready(Ok(0))
} else {
self.data.write().unwrap().extend_from_slice(buf);
if let Some(waker) = self.waker.write().unwrap().take() {
inner.data.extend_from_slice(buf);
if let Some(waker) = inner.waker.take() {
waker.wake();
}
Poll::Ready(Ok(buf.len()))
Expand All @@ -219,10 +232,7 @@ impl Write for &CloseableCursor {
}

fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if let Some(waker) = self.waker.write().unwrap().take() {
waker.wake();
}
*self.closed.write().unwrap() = true;
self.close();
Poll::Ready(Ok(()))
}
}
Expand Down

0 comments on commit 54d8b67

Please sign in to comment.