From b6eb058d9d2f92eedae0e6cf70aeb36aa728923d Mon Sep 17 00:00:00 2001 From: Innes Anderson-Morrison Date: Tue, 25 Feb 2025 08:00:37 +0000 Subject: [PATCH] getting the fsm approach to work without async-trait --- Cargo.lock | 12 - crates/ninep/Cargo.toml | 9 +- crates/ninep/examples/async_fsm.rs | 42 +- crates/ninep/src/lib.rs | 2 + crates/ninep/src/sansio/mod.rs | 15 + crates/ninep/src/sansio/protocol.rs | 719 ++++++---------------------- crates/ninep/src/sync/mod.rs | 34 +- crates/ninep/src/tokio/mod.rs | 93 ++-- 8 files changed, 286 insertions(+), 640 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c00f715..7eb56d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,17 +86,6 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" -[[package]] -name = "async-trait" -version = "0.1.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -457,7 +446,6 @@ dependencies = [ name = "ninep" version = "0.3.0" dependencies = [ - "async-trait", "bitflags 2.8.0", "simple_test_case", "tokio", diff --git a/crates/ninep/Cargo.toml b/crates/ninep/Cargo.toml index c71a440..262c2cc 100644 --- a/crates/ninep/Cargo.toml +++ b/crates/ninep/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ninep" -version = "0.3.0" +version = "0.4.0" edition = "2021" authors = ["sminez "] license = "MIT" @@ -15,13 +15,12 @@ include = [ [features] default = ["tokio"] -tokio = ["dep:tokio", "dep:async-trait"] +tokio = ["dep:tokio"] [dependencies] -async-trait = { version = "0.1.86", optional = true } bitflags = "2.6" -tokio = { version = "1.43.0", features = ["net", "io-util", "sync"], optional = true } +tokio = { version = "1.43.0", features = ["macros", "net", "io-util", "rt", "sync"], optional = true } [dev-dependencies] simple_test_case = "1" -tokio = { version = "1.43.0", features = ["time", "macros", "rt-multi-thread", "net", "io-util", "sync"] } +tokio = { version = "1.43.0", features = ["macros", "net", "io-util", "rt", "sync", "time", "rt-multi-thread", "net", "io-util"] } diff --git a/crates/ninep/examples/async_fsm.rs b/crates/ninep/examples/async_fsm.rs index 9179f03..ce0b76d 100644 --- a/crates/ninep/examples/async_fsm.rs +++ b/crates/ninep/examples/async_fsm.rs @@ -1,11 +1,10 @@ //! This is a little exploration of using async/await + a dummy Waker to simplify writing sans-io //! state machine code. use std::{ - cell::RefCell, + cell::UnsafeCell, future::{Future, IntoFuture}, io::{Cursor, Read}, pin::{pin, Pin}, - rc::Rc, sync::Arc, task::{Context, Poll, Wake, Waker}, }; @@ -24,17 +23,17 @@ async fn main() { ]); println!(">> reading using std::io::Read"); - let s: String = read_9p_sync_from_bytes(&mut cur); + let s: String = read_9p_sync(&mut cur); println!(" got val: {s:?}\n"); cur.set_position(0); println!(">> reading using tokio::io::AsyncRead"); - let s: String = read_9p_async_from_bytes(&mut cur).await; + let s: String = read_9p_async(&mut cur).await; println!(" got val: {s:?}"); } -fn read_9p_sync_from_bytes(r: &mut R) -> T +fn read_9p_sync(r: &mut R) -> T where T: Read9p, R: Read, @@ -48,18 +47,18 @@ where loop { match fut.as_mut().poll(&mut context) { Poll::Ready(val) => return val, - Poll::Pending => { - let n = s.0.borrow().n; + Poll::Pending => unsafe { + let n = (*s.0.get()).n; println!("{n} bytes requested"); let mut buf = vec![0; n]; r.read_exact(&mut buf).unwrap(); - s.0.borrow_mut().buf = Some(buf); - } + (*s.0.get()).buf = Some(buf); + }, } } } -async fn read_9p_async_from_bytes(r: &mut R) -> T +async fn read_9p_async(r: &mut R) -> T where T: Read9p, R: AsyncRead + Unpin, @@ -73,13 +72,13 @@ where loop { match fut.as_mut().poll(&mut context) { Poll::Ready(val) => return val, - Poll::Pending => { - let n = s.0.borrow().n; + Poll::Pending => unsafe { + let n = (*s.0.get()).n; println!("{n} bytes requested"); let mut buf = vec![0; n]; r.read_exact(&mut buf).await.unwrap(); - s.0.borrow_mut().buf = Some(buf); - } + (*s.0.get()).buf = Some(buf); + }, } } } @@ -106,8 +105,14 @@ impl Future for Yield { } } +/// Shared state between a [NineP] impl and a parent read loop that is performing IO. #[derive(Default, Debug, Clone)] -struct State(Rc>); +pub struct State(pub(crate) Arc>); + +// SAFETY: StateInner is only accessable in this crate +unsafe impl Send for State {} +// SAFETY: StateInner is only accessable in this crate +unsafe impl Sync for State {} #[derive(Default, Debug)] struct StateInner { @@ -119,20 +124,19 @@ struct StateInner { /// so it can perform IO and provide the requested data. macro_rules! request_bytes { ($s:expr, $n:expr) => {{ - $s.0.borrow_mut().n = $n; + (*$s.0.get()).n = $n; Yield(false).await; - $s.0.borrow_mut().buf.take().unwrap() + (*$s.0.get()).buf.take().unwrap() }}; } /// # Safety /// The read method of this trait requires that you only yield view the [request_bytes] macro. -#[allow(async_fn_in_trait)] unsafe trait Read9p { /// # Safety /// Implementations of `read` need to ensure that the only await points they contain are /// from calls to the [request_bytes] macro. - async unsafe fn read(state: State) -> Self; + unsafe fn read(state: State) -> impl Future + Send; } #[allow(async_fn_in_trait)] diff --git a/crates/ninep/src/lib.rs b/crates/ninep/src/lib.rs index 751cc11..07874ed 100644 --- a/crates/ninep/src/lib.rs +++ b/crates/ninep/src/lib.rs @@ -14,6 +14,8 @@ pub mod fs; pub mod sansio; pub mod sync; + +#[cfg(feature = "tokio")] pub mod tokio; /// A simple result type for errors returned from this crate diff --git a/crates/ninep/src/sansio/mod.rs b/crates/ninep/src/sansio/mod.rs index e71a6aa..1dd296f 100644 --- a/crates/ninep/src/sansio/mod.rs +++ b/crates/ninep/src/sansio/mod.rs @@ -5,6 +5,10 @@ use crate::{ sansio::protocol::{Rdata, Rmessage}, Result, }; +use std::{ + sync::Arc, + task::{Wake, Waker}, +}; pub mod protocol; pub mod server; @@ -17,3 +21,14 @@ impl From<(u16, Result)> for Rmessage { } } } + +struct StubWaker; +impl Wake for StubWaker { + fn wake(self: Arc) {} + fn wake_by_ref(self: &Arc) {} +} + +/// A no-op waker that is just used to create a context for driving a NineP read loop. +pub(crate) fn stub_waker() -> Waker { + Waker::from(Arc::new(StubWaker)) +} diff --git a/crates/ninep/src/sansio/protocol.rs b/crates/ninep/src/sansio/protocol.rs index 1fa997a..1419d77 100644 --- a/crates/ninep/src/sansio/protocol.rs +++ b/crates/ninep/src/sansio/protocol.rs @@ -1,10 +1,16 @@ //! Sans-io 9p protocol implementation //! //! http://man.cat-v.org/plan_9/5/ +use crate::sync::SyncNineP; use std::{ + cell::UnsafeCell, fmt, - io::{self, ErrorKind, Read}, + future::Future, + io::{self, Cursor, ErrorKind}, mem::size_of, + pin::Pin, + sync::Arc, + task::{Context, Poll}, }; /// The size of variable length data is denoted using a u16 so anything longer @@ -45,9 +51,6 @@ impl fmt::Display for WriteError { /// Each message consists of a sequence of bytes. Two-, four-, and eight-byte fields hold /// unsigned integers represented in little-endian order (least significant byte first). pub trait NineP: Sized { - /// The [Read9p] implementation used to decode an instance of this type from a bytestream - type Reader: Read9p; - /// Number of bytes required to encode fn n_bytes(&self) -> usize; @@ -55,20 +58,61 @@ pub trait NineP: Sized { /// ensure is sized to be at least [NineP::n_bytes]. fn write_bytes(&self, buf: &mut [u8]) -> Result<(), WriteError>; - /// Construct a new reader for parsing this type from a source of bytes - fn reader() -> Self::Reader; + /// Serialize into a byte buffer ready for transmission. + fn write_9p_bytes(&self) -> Result, WriteError> { + let mut buf = vec![0; self.n_bytes()]; + self.write_bytes(&mut buf)?; + + Ok(buf) + } + + /// This is not a normal async function. It is used to set up a sans-io state machine that + /// can be driven by a concrete implementation. + /// + /// # Safety + /// Implementations of `read` need to ensure that the only await points they contain are + /// from calls to the [request_bytes] macro. + unsafe fn read(state: &State) -> impl Future> + Send; } -/// A paired helper type for decoding a [Format9p] type from a bytestream. -pub trait Read9p: Sized + Send { - /// The parent [Format9p] type being decoded into - type T: NineP; +/// Helper struct for awaiting a Future that returns pending once so we can return control to the +/// poll loop and perform IO. +struct Yield(bool); +impl Future for Yield { + type Output = (); + fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<()> { + if self.0 { + Poll::Ready(()) + } else { + self.0 = true; + Poll::Pending + } + } +} + +/// Shared state between a [NineP] impl and a parent read loop that is performing IO. +#[derive(Default, Debug, Clone)] +pub struct State(pub(crate) Arc>); - /// The number of bytes that need to be passed to the next call to [Read9p::accept_bytes]. - fn needs_bytes(&self) -> usize; +// SAFETY: StateInner is only accessable in this crate +unsafe impl Send for State {} +// SAFETY: StateInner is only accessable in this crate +unsafe impl Sync for State {} + +#[derive(Default, Debug)] +pub(crate) struct StateInner { + pub(crate) n: usize, + pub(crate) buf: Option>, +} - /// Accept the requested number of bytes and return a new [NinepReader] state machine. - fn accept_bytes(self, bytes: &[u8]) -> io::Result>; +/// Request a specific number of bytes from the parent poll loop and then yield to that poll loop +/// so it can perform IO and provide the requested data. +macro_rules! request_bytes { + ($s:expr, $n:expr) => {{ + (*$s.0.get()).n = $n; + Yield(false).await; + (*$s.0.get()).buf.take().unwrap() + }}; } /// wrapper around uX::from_le_bytes that accepts a slice rather than a fixed size array @@ -79,52 +123,12 @@ macro_rules! from_le_bytes { }; } -/// Attempt to read a [NineP] value from a byte buffer that contains sufficient data without -/// requiring further IO. -pub fn try_read_9p_bytes(mut bytes: &[u8]) -> io::Result { - let mut nr = NinepReader::Pending(T::reader()); - - loop { - match nr { - NinepReader::Pending(r) => { - let n = T::Reader::needs_bytes(&r); - nr = r.accept_bytes(&bytes[0..n])?; - bytes = &bytes[n..]; - } - - NinepReader::Complete(t) => return Ok(t), - } - } -} - -/// Serialize `t` into a byte buffer ready for transmission. -pub fn write_9p_bytes(t: &T) -> Result, WriteError> { - let mut buf = vec![0; t.n_bytes()]; - t.write_bytes(&mut buf)?; - - Ok(buf) -} - -/// State machine enum for writing concrete readers from a [Format9p] type. -#[derive(Debug)] -pub enum NinepReader -where - T: NineP, -{ - /// A reader that requires more input to complete - Pending(T::Reader), - /// A reader that completed successfully - Complete(T), -} - // Unsigned integer types can all be treated the same way so we stamp them out using a macro. // They are written and read in their little-endian byte form. macro_rules! impl_u { ($($ty:ty),+) => { $( impl NineP for $ty { - type Reader = $ty; - fn n_bytes(&self) -> usize { size_of::<$ty>() } @@ -134,20 +138,9 @@ macro_rules! impl_u { Ok(()) } - fn reader() -> $ty { - 0 - } - } - - impl Read9p for $ty { - type T = $ty; - - fn needs_bytes(&self) -> usize { - size_of::<$ty>() - } - - fn accept_bytes(self, bytes: &[u8]) -> io::Result> { - Ok(NinepReader::Complete(from_le_bytes!($ty, bytes))) + async unsafe fn read(state: &State) -> io::Result<$ty> { + let buf = request_bytes!(state, size_of::<$ty>()); + Ok(from_le_bytes!($ty, buf)) } } )+ @@ -167,8 +160,6 @@ impl_u!(u8, u16, u32, u64); // which include no final zero byte. The NUL character is illegal in all text strings // in 9P, and is therefore excluded from file names, user names, and so on. impl NineP for String { - type Reader = StringReader; - fn n_bytes(&self) -> usize { size_of::() + self.len() } @@ -185,50 +176,12 @@ impl NineP for String { Ok(()) } - fn reader() -> StringReader { - StringReader::Start - } -} - -#[derive(Debug)] -#[allow(missing_docs)] -pub enum StringReader { - Start, - WithLen(usize), -} - -impl Read9p for StringReader { - type T = String; - - fn needs_bytes(&self) -> usize { - match self { - Self::Start => size_of::(), - Self::WithLen(len) => *len, - } - } - - fn accept_bytes(self, mut bytes: &[u8]) -> io::Result> { - match self { - Self::Start => { - let len = from_le_bytes!(u16, bytes) as usize; - Ok(NinepReader::Pending(Self::WithLen(len))) - } - - Self::WithLen(len) => { - let mut s = String::with_capacity(len); - bytes.read_to_string(&mut s)?; - let actual = s.len(); + async unsafe fn read(state: &State) -> io::Result { + let buf = request_bytes!(state, size_of::()); + let len = from_le_bytes!(u16, buf) as usize; + let buf = request_bytes!(state, len); - if actual < len { - return Err(io::Error::new( - ErrorKind::UnexpectedEof, - format!("unexpected end of string: wanted {len}, got {actual}"), - )); - } - - Ok(NinepReader::Complete(s)) - } - } + String::from_utf8(buf).map_err(|e| io::Error::new(ErrorKind::InvalidData, e.to_string())) } } @@ -238,8 +191,6 @@ impl Read9p for StringReader { // Data items of larger or variable lengths are represented by a two-byte field specifying // a count, n, followed by n bytes of data. impl NineP for Vec { - type Reader = VecReader; - fn n_bytes(&self) -> usize { size_of::() + self.iter().map(|t| t.n_bytes()).sum::() } @@ -261,55 +212,16 @@ impl NineP for Vec { Ok(()) } - fn reader() -> VecReader { - VecReader::Start - } -} - -#[derive(Debug)] -#[allow(missing_docs)] -pub enum VecReader -where - T: NineP + fmt::Debug + Send, -{ - Start, - Reading(usize, T::Reader, Vec), -} + async unsafe fn read(state: &State) -> io::Result { + let buf = request_bytes!(state, size_of::()); + let len = from_le_bytes!(u16, buf) as usize; -impl Read9p for VecReader { - type T = Vec; - - fn needs_bytes(&self) -> usize { - match self { - Self::Start => size_of::(), - Self::Reading(_, r, _) => r.needs_bytes(), + let mut buf = Vec::with_capacity(len); + for _ in 0..len { + buf.push(T::read(state).await?); } - } - - fn accept_bytes(self, bytes: &[u8]) -> io::Result>> { - match self { - Self::Start => { - let len = from_le_bytes!(u16, bytes) as usize; - let buf = Vec::with_capacity(len); - let r = T::reader(); - Ok(NinepReader::Pending(VecReader::Reading(len, r, buf))) - } - - Self::Reading(n, r, mut buf) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(VecReader::Reading(n, r, buf))), - - NinepReader::Complete(t) => { - buf.push(t); - if n == 1 { - Ok(NinepReader::Complete(buf)) - } else { - let r = T::reader(); - Ok(NinepReader::Pending(VecReader::Reading(n - 1, r, buf))) - } - } - }, - } + Ok(buf) } } @@ -350,7 +262,7 @@ impl TryFrom for Vec { let n = size_of::(); loop { - match try_read_9p_bytes::(bytes) { + match RawStat::read_from(&mut bytes) { Ok(rs) => { buf.push(rs); bytes = &bytes[n..]; @@ -365,8 +277,6 @@ impl TryFrom for Vec { } impl NineP for Data { - type Reader = DataReader; - fn n_bytes(&self) -> usize { size_of::() + self.0.len() } @@ -383,54 +293,17 @@ impl NineP for Data { Ok(()) } - fn reader() -> DataReader { - DataReader::Start - } -} - -#[derive(Debug)] -#[allow(missing_docs)] -pub enum DataReader { - Start, - WithLen(usize), -} - -impl Read9p for DataReader { - type T = Data; - - fn needs_bytes(&self) -> usize { - match self { - Self::Start => size_of::(), - Self::WithLen(len) => *len, + async unsafe fn read(state: &State) -> io::Result { + let buf = request_bytes!(state, size_of::()); + let len = from_le_bytes!(u32, buf) as usize; + if len > MAX_DATA_LEN { + return Err(io::Error::new( + ErrorKind::InvalidData, + format!("data field too long: max={MAX_DATA_LEN} len={len}"), + )); } - } - - fn accept_bytes(self, bytes: &[u8]) -> io::Result> { - match self { - DataReader::Start => { - let len = from_le_bytes!(u32, bytes) as usize; - if len > MAX_DATA_LEN { - return Err(io::Error::new( - ErrorKind::InvalidData, - format!("data field too long: max={MAX_DATA_LEN} len={len}"), - )); - } - - Ok(NinepReader::Pending(DataReader::WithLen(len))) - } - - DataReader::WithLen(len) => { - let actual = bytes.len(); - if actual < len { - return Err(io::Error::new( - ErrorKind::UnexpectedEof, - format!("unexpected end of data: wanted {len}, got {actual}"), - )); - } - Ok(NinepReader::Complete(Data(bytes.to_vec()))) - } - } + Ok(Data(request_bytes!(state, len))) } } @@ -480,8 +353,6 @@ macro_rules! write_fields { } impl NineP for RawStat { - type Reader = RawStatReader; - fn n_bytes(&self) -> usize { // 2 2 4 13 4 4 4 8 -> 41 41 + self.name.n_bytes() + self.uid.n_bytes() + self.gid.n_bytes() + self.muid.n_bytes() @@ -493,95 +364,37 @@ impl NineP for RawStat { ) } - fn reader() -> RawStatReader { - RawStatReader::Start - } -} - -#[derive(Debug)] -#[allow(missing_docs)] -pub enum RawStatReader { - Start, - Name(RawStat, StringReader), - Uid(RawStat, StringReader), - Gid(RawStat, StringReader), - Muid(RawStat, StringReader), -} - -impl Read9p for RawStatReader { - type T = RawStat; - - fn needs_bytes(&self) -> usize { - match self { - Self::Start => 41, - Self::Name(_, r) => r.needs_bytes(), - Self::Uid(_, r) => r.needs_bytes(), - Self::Gid(_, r) => r.needs_bytes(), - Self::Muid(_, r) => r.needs_bytes(), - } - } - - fn accept_bytes(self, bytes: &[u8]) -> io::Result> { - match self { - Self::Start => { - let size = from_le_bytes!(u16, bytes); - let ty = from_le_bytes!(u16, &bytes[2..]); - let dev = from_le_bytes!(u32, &bytes[4..]); - let qid: Qid = try_read_9p_bytes(&bytes[8..])?; - let mode = from_le_bytes!(u32, &bytes[21..]); - let atime = from_le_bytes!(u32, &bytes[25..]); - let mtime = from_le_bytes!(u32, &bytes[29..]); - let length = from_le_bytes!(u64, &bytes[33..]); - let rs = RawStat { - size, - ty, - dev, - qid, - mode, - atime, - mtime, - length, - name: String::default(), - uid: String::default(), - gid: String::default(), - muid: String::default(), - }; - - Ok(NinepReader::Pending(Self::Name(rs, StringReader::Start))) - } - - Self::Name(mut rs, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::Name(rs, r))), - NinepReader::Complete(s) => { - rs.name = s; - Ok(NinepReader::Pending(Self::Uid(rs, StringReader::Start))) - } - }, - - Self::Uid(mut rs, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::Uid(rs, r))), - NinepReader::Complete(s) => { - rs.uid = s; - Ok(NinepReader::Pending(Self::Gid(rs, StringReader::Start))) - } - }, - - Self::Gid(mut rs, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::Gid(rs, r))), - NinepReader::Complete(s) => { - rs.gid = s; - Ok(NinepReader::Pending(Self::Muid(rs, StringReader::Start))) - } - }, - - Self::Muid(mut rs, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::Muid(rs, r))), - NinepReader::Complete(s) => { - rs.muid = s; - Ok(NinepReader::Complete(rs)) - } - }, - } + async unsafe fn read(state: &State) -> io::Result { + let buf = request_bytes!(state, 41); + let bytes = buf.as_slice(); + + let size = from_le_bytes!(u16, bytes); + let ty = from_le_bytes!(u16, &bytes[2..]); + let dev = from_le_bytes!(u32, &bytes[4..]); + let qid = Qid::read_from(&mut &bytes[8..])?; + let mode = from_le_bytes!(u32, &bytes[21..]); + let atime = from_le_bytes!(u32, &bytes[25..]); + let mtime = from_le_bytes!(u32, &bytes[29..]); + let length = from_le_bytes!(u64, &bytes[33..]); + let name = String::read(state).await?; + let uid = String::read(state).await?; + let gid = String::read(state).await?; + let muid = String::read(state).await?; + + Ok(RawStat { + size, + ty, + dev, + qid, + mode, + atime, + mtime, + length, + name, + uid, + gid, + muid, + }) } } @@ -606,201 +419,22 @@ macro_rules! impl_message_datatype { pub $field: $ty, )* } - impl_message_datatype!(@tuple $struct $reader $($field: $ty),*); - }; - - // No fields - (@tuple $struct:ident $reader:ident) => { - impl NineP for $struct { - type Reader = $reader; - fn n_bytes(&self) -> usize { 0 } - fn write_bytes(&self, buf: &mut [u8]) -> Result<(), WriteError> { Ok(()) } - fn reader() -> $reader { $reader($ty::reader()) } - } - - #[derive(Debug)] - #[allow(missing_docs)] - pub struct $reader; - impl Read9p for $reader { - type T = $struct; - fn needs_bytes(&self) -> usize { 0 } - fn accept_bytes(self, bytes: &[u8]) -> io::Result> { - debug_assert!(bytes.is_empty()); - Ok(NinepReader::Complete($struct { })) - } - } - }; - - // Single field - (@tuple $struct:ident $reader:ident $field:ident: $ty:ty) => { - impl NineP for $struct { - type Reader = $reader; - fn n_bytes(&self) -> usize { self.$field.n_bytes() } - fn write_bytes(&self, buf: &mut [u8]) -> Result<(), WriteError> { self.$field.write_bytes(buf) } - fn reader() -> $reader { $reader(<$ty as NineP>::reader()) } - } - - #[derive(Debug)] - #[allow(missing_docs)] - pub struct $reader($ty::Reader); - impl Read9p for $reader { - type T = $struct; - fn needs_bytes(&self) -> usize { self.0.needs_bytes() } - fn accept_bytes(self, bytes: &[u8]) -> io::Result> { - match self.0.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self(r))), - NinepReader::Complete($field) => Ok(NinepReader::Complete($struct { $field })) - } - } - } - }; - - // Two fields - (@tuple $struct:ident $reader:ident $f1:ident: $t1:ty, $f2:ident: $t2:ty) => { - impl NineP for $struct { - type Reader = $reader; - fn n_bytes(&self) -> usize { - self.$f1.n_bytes() + self.$f2.n_bytes() - } - fn write_bytes(&self, mut buf: &mut [u8]) -> Result<(), WriteError> { - write_fields!(buf, self, $f1, $f2) - } - fn reader() -> $reader { $reader::T1(<$t1 as NineP>::reader()) } - } - - #[derive(Debug)] - #[allow(missing_docs)] - pub enum $reader { - T1($t1::Reader), - T2($t1, $t2::Reader), - } - impl Read9p for $reader { - type T = $struct; - fn needs_bytes(&self) -> usize { - match self { - Self::T1(r) => r.needs_bytes(), - Self::T2(_, r) => r.needs_bytes(), - } - } - fn accept_bytes(self, bytes: &[u8]) -> io::Result { - match self { - Self::T1(r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T1(r))), - NinepReader::Complete(t1) => Ok(NinepReader::Pending(Self::T2(t1, <$t2 as NineP>::reader()))), - }, - Self::T2($f1, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T2($f1, r))), - NinepReader::Complete($f2) => Ok(NinepReader::Complete($struct { $f1, $f2 })), - }, - } - } - } - }; - // Three fields - (@tuple $struct:ident $reader:ident $f1:ident: $t1:ty, $f2:ident: $t2:ty, $f3:ident: $t3:ty) => { impl NineP for $struct { - type Reader = $reader; fn n_bytes(&self) -> usize { - self.$f1.n_bytes() + self.$f2.n_bytes() + self.$f3.n_bytes() - } - fn write_bytes(&self, mut buf: &mut [u8]) -> Result<(), WriteError> { - write_fields!(buf, self, $f1, $f2, $f3) - } - fn reader() -> $reader { $reader::T1(<$t1 as NineP>::reader()) } - } - - #[derive(Debug)] - #[allow(missing_docs)] - pub enum $reader { - T1(<$t1 as NineP>::Reader), - T2($t1, <$t2 as NineP>::Reader), - T3($t1, $t2, <$t3 as NineP>::Reader), - } - impl Read9p for $reader { - type T = $struct; - fn needs_bytes(&self) -> usize { - match self { - Self::T1(r) => r.needs_bytes(), - Self::T2(_, r) => r.needs_bytes(), - Self::T3(_, _, r) => r.needs_bytes(), - } - } - fn accept_bytes(self, bytes: &[u8]) -> io::Result> { - match self { - Self::T1(r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T1(r))), - NinepReader::Complete(t1) => Ok(NinepReader::Pending(Self::T2(t1, <$t2 as NineP>::reader()))), - }, - Self::T2(t1, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T2(t1, r))), - NinepReader::Complete(t2) => { - Ok(NinepReader::Pending(Self::T3(t1, t2, <$t3 as NineP>::reader()))) - } - }, - Self::T3($f1, $f2, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T3($f1, $f2, r))), - NinepReader::Complete($f3) => Ok(NinepReader::Complete($struct { $f1, $f2, $f3 })), - }, - } + #[allow(unused_mut)] + let mut n = 0; + $(n += self.$field.n_bytes();)* + n } - } - }; - // Four fields - (@tuple $struct:ident $reader:ident $f1:ident: $t1:ty, $f2:ident: $t2:ty, $f3:ident: $t3:ty, $f4:ident: $t4:ty) => { - impl NineP for $struct { - type Reader = $reader; - fn n_bytes(&self) -> usize { - self.$f1.n_bytes() + self.$f2.n_bytes() + self.$f3.n_bytes() + self.$f4.n_bytes() - } fn write_bytes(&self, mut buf: &mut [u8]) -> Result<(), WriteError> { - write_fields!(buf, self, $f1, $f2, $f3, $f4) + write_fields!(buf, self, $($field),*) } - fn reader() -> $reader { $reader::T1(<$t1 as NineP>::reader()) } - } - #[derive(Debug)] - #[allow(missing_docs)] - pub enum $reader { - T1(T1::Reader), - T2(T1, T2::Reader), - T3(T1, T2, T3::Reader), - T4(T1, T2, T3, T4::Reader), - } - impl Read9p for $reader { - type T = $struct; - fn needs_bytes(&self) -> usize { - match self { - Self::T1(r) => r.needs_bytes(), - Self::T2(_, r) => r.needs_bytes(), - Self::T3(_, _, r) => r.needs_bytes(), - Self::T4(_, _, _, r) => r.needs_bytes(), - } - } - fn accept_bytes(self, bytes: &[u8]) -> io::Result> { - match self { - Self::T1(r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T1(r))), - NinepReader::Complete(t1) => Ok(NinepReader::Pending(Self::T2(t1, <$t2 as NineP>::reader()))), - }, - Self::T2(t1, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T2(t1, r))), - NinepReader::Complete(t2) => { - Ok(NinepReader::Pending(Self::T3(t1, t2, <$t3 as NineP>::reader()))) - } - }, - Self::T3(t1, t2, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T3(t1, t2, r))), - NinepReader::Complete(t3) => { - Ok(NinepReader::Pending(Self::T4(t1, t2, t3, <$t4 as NineP>::reader()))) - } - }, - Self::T4($f1, $f2, $f3, r) => match r.accept_bytes(bytes)? { - NinepReader::Pending(r) => Ok(NinepReader::Pending(Self::T4($f1, $f2, $f3, r))), - NinepReader::Complete($f4) => Ok(NinepReader::Complete($struct { $f1, $f2, $f3, $f4 })), - }, - } + async unsafe fn read(state: &State) -> io::Result { + $(let $field = <$ty>::read(&state).await?;)* + Ok($struct { $($field),* }) } } }; @@ -883,14 +517,12 @@ impl MessageType { /// Helper for implementing Tmessage and Rmessage macro_rules! impl_message_format { ( - $message_ty:ident, $reader:ident, $enum_ty:ident, $err:expr; + $message_ty:ident, $enum_ty:ident, $err:expr; $($enum_variant:ident => $message_variant:ident { $($field:ident: $ty:ty,)* })+ ) => { impl NineP for $message_ty { - type Reader = $reader; - fn n_bytes(&self) -> usize { let content_size = match &self.content { $( @@ -932,69 +564,31 @@ macro_rules! impl_message_format { Ok(()) } - fn reader() -> $reader { - $reader::Start - } - } + #[allow(unused_assignments)] + async unsafe fn read(state: &State) -> io::Result { + let buf = request_bytes!(state, size_of::()); + let len = from_le_bytes!(u32, buf) as usize; -// pub trait Read9p: Sized { -// type T: NineP; -// fn needs_bytes(&self) -> usize; -// fn accept_bytes(self, bytes: &[u8]) -> io::Result>; -// } - - #[derive(Debug)] - #[allow(missing_docs)] - pub enum $reader { - Start, - WithSize(usize), - } + let bytes = request_bytes!(state, len-4); + let ty = from_le_bytes!(u8, &bytes); + let tag = from_le_bytes!(u16, &bytes[1..]); + let mut cur = Cursor::new(bytes); + cur.set_position(3); - impl Read9p for $reader { - type T = $message_ty; + let content = match MessageType(ty) { + $( + MessageType::$message_variant => $enum_ty::$enum_variant { + $($field: <$ty>::read_from(&mut cur)?),* + }, + )+ - fn needs_bytes(&self) -> usize { - match self { - Self::Start => size_of::(), - // the size field includes the number of bytes for the field itself so we - // trim that off before decoding the rest of the message - Self::WithSize(n) => n - 4, - } - } + MessageType(ty) => return Err(io::Error::new( + ErrorKind::InvalidData, + format!($err, ty), + )), + }; - #[allow(unused_assignments)] - fn accept_bytes(self, bytes: &[u8]) -> io::Result> { - match self { - Self::Start => { - let size = from_le_bytes!(u32, bytes) as usize; - Ok(NinepReader::Pending($reader::WithSize(size))) - } - Self::WithSize(_) => { - let ty = from_le_bytes!(u8, bytes); - let tag = from_le_bytes!(u16, &bytes[1..]); - let mut offset = 3; - let content = match MessageType(ty) { - $( - MessageType::$message_variant => $enum_ty::$enum_variant { - $( - $field: { - let val: $ty = try_read_9p_bytes(&bytes[offset..])?; - offset += val.n_bytes(); - val - }, - )* - }, - )+ - - MessageType(ty) => return Err(io::Error::new( - ErrorKind::InvalidData, - format!($err, ty), - )), - }; - - Ok(NinepReader::Complete($message_ty { tag, content })) - } - } + Ok($message_ty { tag, content }) } } }; @@ -1043,7 +637,7 @@ macro_rules! impl_tdata { } impl_message_format!( - Tmessage, TmessageReader, Tdata, "invalid message type for t-message: {}"; + Tmessage, Tdata, "invalid message type for t-message: {}"; $($enum_variant => $message_variant { $($field: $ty,)* })+ @@ -1224,7 +818,7 @@ macro_rules! impl_rdata { } impl_message_format!( - Rmessage, RmessageReader, Rdata, "invalid message type for r-message: {}"; + Rmessage, Rdata, "invalid message type for r-message: {}"; $($enum_variant => $message_variant { $($field: $ty,)* })+ @@ -1336,12 +930,15 @@ mod tests { #[test] fn uint_decode() { - let buf: [u8; 8] = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef]; - - assert_eq!(0x01, try_read_9p_bytes::(&buf).unwrap()); - assert_eq!(0x2301, try_read_9p_bytes::(&buf).unwrap()); - assert_eq!(0x67452301, try_read_9p_bytes::(&buf).unwrap()); - assert_eq!(0xefcdab8967452301, try_read_9p_bytes::(&buf).unwrap()); + let buf: Vec = vec![0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef]; + + assert_eq!(0x01, u8::read_from(&mut buf.as_slice()).unwrap()); + assert_eq!(0x2301, u16::read_from(&mut buf.as_slice()).unwrap()); + assert_eq!(0x67452301, u32::read_from(&mut buf.as_slice()).unwrap()); + assert_eq!( + 0xefcdab8967452301, + u64::read_from(&mut buf.as_slice()).unwrap() + ); } #[test_case("test", &[0x04, 0x00, 0x74, 0x65, 0x73, 0x74]; "single byte chars only")] @@ -1354,7 +951,7 @@ mod tests { #[test] fn string_encode(s: &str, bytes: &[u8]) { let s = s.to_string(); - let buf = write_9p_bytes(&s).unwrap(); + let buf = s.write_9p_bytes().unwrap(); assert_eq!(&buf, bytes); } @@ -1378,8 +975,8 @@ mod tests { where T: NineP + PartialEq + fmt::Debug, { - let buf = write_9p_bytes(&t1).unwrap(); - let t2 = try_read_9p_bytes::(&buf).unwrap(); + let buf = t1.write_9p_bytes().unwrap(); + let t2 = T::read_from(&mut buf.as_slice()).unwrap(); assert_eq!(t1, t2); } diff --git a/crates/ninep/src/sync/mod.rs b/crates/ninep/src/sync/mod.rs index 82c6b28..8a3f40e 100644 --- a/crates/ninep/src/sync/mod.rs +++ b/crates/ninep/src/sync/mod.rs @@ -1,12 +1,18 @@ //! A synchronous implementation of 9p Servers and Clients use crate::{ - sansio::protocol::{NineP, NinepReader, Rdata, Read9p, Rmessage}, + sansio::{ + protocol::{NineP, Rdata, Rmessage, State}, + stub_waker, + }, Result, }; use std::{ + future::Future, io::{self, Read, Write}, net::TcpStream, os::unix::net::UnixStream, + pin::pin, + task::{Context, Poll}, }; pub mod client; @@ -24,23 +30,23 @@ pub trait SyncNineP: NineP { } /// Decode self from 9p protocol bytes coming from the given [SyncStream]. - #[allow(clippy::uninit_vec)] fn read_from(r: &mut R) -> io::Result { - let mut nr = NinepReader::Pending(Self::reader()); - let mut buf = Vec::new(); + let waker = stub_waker(); + let mut context = Context::from_waker(&waker); + let s = State::default(); + // SAFETY: assumes the impl of Read9p is a valid future for us to poll + let mut fut = unsafe { pin!(Self::read(&s)) }; loop { - match nr { - NinepReader::Pending(r9) => { - let n = Self::Reader::needs_bytes(&r9); - buf.reserve(n.saturating_sub(buf.len())); - // SAFETY: we've just reserved sufficient capacity - unsafe { buf.set_len(n) }; + match fut.as_mut().poll(&mut context) { + Poll::Ready(val) => return val, + // SAFETY: s is only shared with the future we're polling + Poll::Pending => unsafe { + let n = (*s.0.get()).n; + let mut buf = vec![0; n]; r.read_exact(&mut buf)?; - nr = r9.accept_bytes(&buf[0..n])?; - } - - NinepReader::Complete(t) => return Ok(t), + (*s.0.get()).buf = Some(buf); + }, } } } diff --git a/crates/ninep/src/tokio/mod.rs b/crates/ninep/src/tokio/mod.rs index 16d2c53..c678c01 100644 --- a/crates/ninep/src/tokio/mod.rs +++ b/crates/ninep/src/tokio/mod.rs @@ -1,9 +1,18 @@ //! Tokio based asynchronous implementation of 9p Servers and Clients use crate::{ - sansio::protocol::{NineP, NinepReader, Rdata, Read9p, Rmessage}, + sansio::{ + protocol::{NineP, Rdata, Rmessage, State}, + stub_waker, + }, Result, }; -use std::{io, marker::Unpin}; +use std::{ + future::Future, + io, + marker::Unpin, + pin::pin, + task::{Context, Poll}, +}; use tokio::{ io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::{TcpStream, UnixStream}, @@ -12,43 +21,69 @@ use tokio::{ pub mod client; pub mod server; -/// Synchronous IO support for reading and writing 9p messages -#[async_trait::async_trait] -pub trait AsyncNineP: NineP { +/// Asynchronous IO support for reading and writing 9p messages +pub trait AsyncNineP: NineP + Send + Sync { /// Encode self as bytes for the 9p protocol and write to the given [SyncStream]. - async fn write_to(&self, w: &mut W) -> io::Result<()> { - let mut buf = vec![0; self.n_bytes()]; - self.write_bytes(&mut buf) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; - - w.write_all(&buf).await + fn write_to(&self, w: &mut W) -> impl Future> + Send + where + W: AsyncWrite + Unpin + Send, + { + write_to(self, w) } /// Decode self from 9p protocol bytes coming from the given [SyncStream]. - #[allow(clippy::uninit_vec)] - async fn read_from(r: &mut R) -> io::Result { - let mut nr = NinepReader::Pending(Self::reader()); - let mut buf = Vec::new(); + fn read_from(r: &mut R) -> impl Future> + Send + where + R: AsyncRead + Unpin + Send, + { + read_from(r) + } +} + +impl AsyncNineP for T where T: NineP + Send + Sync {} - loop { - match nr { - NinepReader::Pending(r9) => { - let n = Self::Reader::needs_bytes(&r9); - buf.reserve(n.saturating_sub(buf.len())); - // SAFETY: we've just reserved sufficient capacity - unsafe { buf.set_len(n) }; - r.read_exact(&mut buf).await?; - nr = r9.accept_bytes(&buf[0..n])?; - } +// write_to and read_from are written as free functions so we can use async/await here while also +// explicitly requiring a Send bound on the methods of the AsyncNineP trait above. - NinepReader::Complete(t) => return Ok(t), - } +#[inline(always)] +async fn write_to(t: &T, w: &mut W) -> io::Result<()> +where + T: NineP + Sync, + W: AsyncWrite + Unpin + Send, +{ + let mut buf = vec![0; t.n_bytes()]; + t.write_bytes(&mut buf) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?; + + w.write_all(&buf).await +} + +#[inline(always)] +async fn read_from(r: &mut R) -> io::Result +where + T: NineP + Send, + R: AsyncRead + Unpin + Send, +{ + let waker = stub_waker(); + let s = State::default(); + + // SAFETY: assumes the impl of Read9p is a valid future for us to poll + let mut fut = unsafe { pin!(T::read(&s)) }; + loop { + let poll = fut.as_mut().poll(&mut Context::from_waker(&waker)); + match poll { + Poll::Ready(val) => return val, + // SAFETY: s is only shared with the future we're polling + Poll::Pending => unsafe { + let n = (*s.0.get()).n; + let mut buf = vec![0; n]; + r.read_exact(&mut buf).await?; + (*s.0.get()).buf = Some(buf); + }, } } } -impl AsyncNineP for T where T: NineP {} - /// A [Stream] that makes use of the standard library [Read] and [Write] traits to perform IO #[allow(async_fn_in_trait)] pub trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send + Sized + 'static {