From 108cab61b50acec32d1a621c74591026d52435fc Mon Sep 17 00:00:00 2001 From: Evan Schwartz <3262610+emschwartz@users.noreply.github.com> Date: Fri, 19 Jul 2024 14:48:54 -0400 Subject: [PATCH 1/2] Support non-Send Futures --- examples/non_send_future.rs | 25 ++++++ src/body.rs | 90 +++++++++++++++++++++- src/lib.rs | 60 ++++++++++++++- src/scope.rs | 146 +++++++++++++++++++++++++++++++++++- src/scope_body.rs | 34 ++++++++- 5 files changed, 349 insertions(+), 6 deletions(-) create mode 100644 examples/non_send_future.rs diff --git a/examples/non_send_future.rs b/examples/non_send_future.rs new file mode 100644 index 0000000..bdabf10 --- /dev/null +++ b/examples/non_send_future.rs @@ -0,0 +1,25 @@ +#![feature(async_closure)] + +use std::{cell::RefCell, rc::Rc}; + +// We are using the current_thread runtime because this version +// of the macro returns a non-Send Future. +#[tokio::main(flavor = "current_thread")] +pub async fn main() { + // This value is not thread-safe + let value = Rc::new(RefCell::new(22)); + + moro::scope_local(async |scope| { + scope.spawn(async { + scope.spawn(async { + *value.borrow_mut() *= 2; // mutate shared state + }); + + *value.borrow_mut() *= 2; + }); + + *value.borrow_mut() *= 2; + }) + .await; + println!("{value:?}"); +} diff --git a/src/body.rs b/src/body.rs index 81959f3..55fd118 100644 --- a/src/body.rs +++ b/src/body.rs @@ -1,9 +1,9 @@ -use std::{pin::Pin, sync::Arc, task::Poll}; +use std::{pin::Pin, rc::Rc, sync::Arc, task::Poll}; -use futures::{Future, FutureExt}; +use futures::Future; use pin_project::{pin_project, pinned_drop}; -use crate::scope::Scope; +use crate::scope::{Scope, ScopeLocal}; /// The future for a scope's "body". /// @@ -96,3 +96,87 @@ where } } } + +/// The future for a scope's "body". +/// +/// It is not considered complete until (a) the body is done and (b) any spawned futures are done. +/// Its result is whatever the body returned. +/// +/// # Unsafe contract +/// +/// - `body_future` and `result` will be dropped BEFORE `scope`. +#[pin_project(PinnedDrop)] +pub(crate) struct BodyLocal<'scope, 'env: 'scope, R, F> +where + R: 'env, +{ + #[pin] + body_future: Option, + result: Option, + scope: Rc>, +} + +impl<'scope, 'env, R, F> BodyLocal<'scope, 'env, R, F> { + /// # Unsafe contract + /// + /// - `future` will be dropped BEFORE `scope` + pub(crate) fn new(future: F, scope: Rc>) -> Self { + Self { + body_future: Some(future), + result: None, + scope, + } + } + + fn clear(self: Pin<&mut Self>) { + let mut this = self.project(); + this.body_future.set(None); + this.result.take(); + this.scope.clear(); + } +} + +#[pinned_drop] +impl<'scope, 'env, R, F> PinnedDrop for BodyLocal<'scope, 'env, R, F> { + fn drop(self: Pin<&mut Self>) { + // Fulfill our unsafe contract and ensure we drop other fields + // before we drop scope. + self.clear(); + } +} + +impl<'scope, 'env, R, F> Future for BodyLocal<'scope, 'env, R, F> +where + F: Future, +{ + type Output = R; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + + // If the body is not yet finished, poll that. Once it becomes finished, + // we will update `this.result. + if let Some(body_future) = this.body_future.as_mut().as_pin_mut() { + match body_future.poll(cx) { + Poll::Ready(r) => { + *this.result = Some(r); + this.body_future.set(None); + } + Poll::Pending => {} + } + } + + // Check if the scope is ready. + // + // If polling the scope returns `Some`, then the scope was early terminated, + // so forward that result. Otherwise, the `result` from our body future + // should be available, so return that. + match ready!(this.scope.poll_jobs(cx)) { + Some(v) => return Poll::Ready(v), + None => match this.result.take() { + None => Poll::Pending, + Some(v) => Poll::Ready(v), + }, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 63a7438..e336150 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,8 @@ mod spawned; mod stream; pub use async_iter::{AsyncIterator, IntoAsyncIter}; +use scope::ScopeLocal; +use scope_body::ScopeBodyLocal; pub use stream::Stream; /// Creates an async scope within which you can spawn jobs. @@ -144,7 +146,24 @@ macro_rules! async_scope { }}; } -use futures::future::BoxFuture; +/// See [`async_scope`] for details. +#[macro_export] +macro_rules! async_scope_local { + (|$scope:ident| -> $result:ty { $($body:tt)* }) => {{ + $crate::scope_fn_local::<$result, _>(|$scope| { + let future = async { $($body)* }; + Box::pin(future) + }) + }}; + (|$scope:ident| $body:expr) => {{ + $crate::scope_fn_local(|$scope| { + let future = async { $body }; + Box::pin(future) + }) + }}; +} + +use futures::future::{BoxFuture, LocalBoxFuture}; pub use self::scope::Scope; pub use self::scope_body::ScopeBody; @@ -168,6 +187,25 @@ where ScopeBody::new(body::Body::new(body_future, scope)) } +/// Creates a new moro scope that is not thread-safe. +/// Normally, you invoke this through `moro::async_scope_local!`. +pub fn scope_fn_local<'env, R, B>(body: B) -> ScopeBodyLocal<'env, R, LocalBoxFuture<'env, R>> +where + R: 'env, + for<'scope> B: FnOnce(&'scope ScopeLocal<'scope, 'env, R>) -> LocalBoxFuture<'scope, R>, +{ + let scope = ScopeLocal::new(); + + // Unsafe: We are letting the body use the `Rc` without reference + // counting. The reference is held by `BodyLocal` below. `BodyLocal` will not drop + // the `Rc` until the body_future is dropped, and the output `T` has to outlive + // `'env` so it can't reference `scope`, so this should be ok. + let scope_ref: *const ScopeLocal<'_, '_, R> = &*scope; + let body_future = body(unsafe { &*scope_ref }); + + ScopeBodyLocal::new(body::BodyLocal::new(body_future, scope)) +} + /// Creates a new moro scope. pub fn scope<'env, R, B>( body: B, @@ -187,3 +225,23 @@ where ScopeBody::new(body::Body::new(body_future, scope)) } + +/// Creates a new moro scope that is not thread-safe. +pub fn scope_local<'env, R, B>( + body: B, +) -> ScopeBodyLocal<'env, R, ,)>>::CallOnceFuture> +where + R: 'env, + for<'scope> B: async FnOnce(&'scope ScopeLocal<'scope, 'env, R>) -> R, +{ + let scope = ScopeLocal::new(); + + // Unsafe: We are letting the body use the `Rc` without reference + // counting. The reference is held by `BodyLocal` below. `BodyLocal` will not drop + // the `Rc` until the body_future is dropped, and the output `T` has to outlive + // `'env` so it can't reference `scope`, so this should be ok. + let scope_ref: *const ScopeLocal<'_, '_, R> = &*scope; + let body_future = body(unsafe { &*scope_ref }); + + ScopeBodyLocal::new(body::BodyLocal::new(body_future, scope)) +} diff --git a/src/scope.rs b/src/scope.rs index 2178802..a84cdc5 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -1,11 +1,17 @@ use std::{ + cell::RefCell, marker::PhantomData, pin::Pin, + rc::Rc, sync::{Arc, Mutex}, task::Poll, }; -use futures::{future::BoxFuture, stream::FuturesUnordered, Future, Stream}; +use futures::{ + future::{BoxFuture, LocalBoxFuture}, + stream::FuturesUnordered, + Future, Stream, +}; use crate::Spawned; @@ -158,3 +164,141 @@ impl<'scope, 'env, R: Send> Scope<'scope, 'env, R> { }) } } + +/// Represents a moro "async scope" that is not thread safe. See the [`async_scope`][crate::async_scope] macro for details. +pub struct ScopeLocal<'scope, 'env: 'scope, R: 'env> { + /// Stores the set of futures that have been spawned. + futures: RefCell>>>>, + enqueued: RefCell>>, + terminated: RefCell>, + phantom: PhantomData<&'scope &'env ()>, +} + +impl<'scope, 'env, R> ScopeLocal<'scope, 'env, R> { + /// Create a scope. + pub(crate) fn new() -> Rc { + Rc::new(Self { + futures: RefCell::new(Box::pin(FuturesUnordered::new())), + enqueued: Default::default(), + terminated: Default::default(), + phantom: Default::default(), + }) + } + + /// Polls the jobs that were spawned thus far. Returns: + /// + /// * `Pending` if there are jobs that cannot complete + /// * `Ready(Ok(()))` if all jobs are completed + /// * `Ready(Err(c))` if the scope has been canceled + /// + /// Should not be invoked again once `Ready(Err(c))` is returned. + /// + /// It is ok to invoke it again after `Ready(Ok(()))` has been returned; + /// if any new jobs have been spawned, they will execute. + pub(crate) fn poll_jobs(&self, cx: &mut std::task::Context<'_>) -> Poll> { + 'outer: loop { + // once we are terminated, we do no more work. + if let Some(r) = self.terminated.take().take() { + return Poll::Ready(Some(r)); + } + + self.futures.borrow_mut().extend(self.enqueued.take()); + + while let Some(()) = ready!(self.futures.borrow_mut().as_mut().poll_next(cx)) { + // once we are terminated, we do no more work. + if self.terminated.borrow().is_some() { + continue 'outer; + } + } + + if self.enqueued.borrow().is_empty() { + return Poll::Ready(None); + } + } + } + + /// Clear out all pending jobs. This is used when dropping the + /// scope body to ensure that any possible references to `Scope` + /// are removed before we drop it. + /// + /// # Unsafe contract + /// + /// Once this returns, there are no more pending tasks. + pub(crate) fn clear(&self) { + self.futures.borrow_mut().clear(); + self.enqueued.borrow_mut().clear(); + } + + /// Terminate the scope immediately -- all existing jobs will stop at their next await point + /// and never wake up again. Anything on their stacks will be dropped. This is most useful + /// for propagating errors, but it can be used to propagate any kind of final value (e.g., + /// perhaps you are searching for something and want to stop once you find it.) + /// + /// This returns a future that you should await, but it will never complete + /// (because you will never be reawoken). Since termination takes effect at the next + /// await point, awaiting the returned future ensures that your current future stops + /// immediately. + /// + /// # Examples + /// + /// ```rust + /// # futures::executor::block_on(async { + /// let result = moro::async_scope!(|scope| { + /// scope.spawn(async { /* ... */ }); + /// + /// // Calling `scope.terminate` here will terminate the async + /// // scope and use the string `"cancellation-value"` as + /// // the final value. + /// let result: () = scope.terminate("cancellation-value").await; + /// unreachable!() // this code never executes + /// }).await; + /// + /// assert_eq!(result, "cancellation-value"); + /// # }); + /// ``` + pub fn terminate(&'scope self, value: R) -> impl Future + 'scope + where + T: 'scope, + { + if self.terminated.borrow().is_none() { + self.terminated.replace(Some(value.into())); + } + + // The code below will never run + self.spawn(async { panic!() }) + } + + /// Spawn a job that will run concurrently with everything else in the scope. + /// The job may access stack fields defined outside the scope. + /// The scope will not terminate until this job completes or the scope is cancelled. + pub fn spawn( + &'scope self, + future: impl Future + 'scope, + ) -> Spawned> + where + T: 'scope, + { + // Use a channel to communicate result from the *actual* future + // (which lives in the futures-unordered) and the caller. + // This is kind of crappy because, ideally, the caller expressing interest + // in the result of the future would let it run, but that would require + // more clever coding and I'm just trying to stand something up quickly + // here. What will happen when caller expresses an interest in result + // now is that caller will block which should (eventually) allow the + // futures-unordered to be polled and make progress. Good enough. + + let (tx, rx) = async_channel::bounded(1); + + self.enqueued.borrow_mut().push(Box::pin(async move { + let v = future.await; + let _ = tx.send(v).await; + })); + + Spawned::new(async move { + match rx.recv().await { + Ok(v) => v, + Err(e) => panic!("unexpected error: {e:?}"), + } + }) + } +} diff --git a/src/scope_body.rs b/src/scope_body.rs index bb957ac..b02a3b5 100644 --- a/src/scope_body.rs +++ b/src/scope_body.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use futures::Future; use pin_project::pin_project; -use crate::body::Body; +use crate::body::{Body, BodyLocal}; #[pin_project] pub struct ScopeBody<'env, R: 'env, F> @@ -39,3 +39,35 @@ where Pin::new(&mut self.project().body).poll(cx) } } + +#[pin_project] +pub struct ScopeBodyLocal<'env, R: 'env, F> +where + F: Future, +{ + #[pin] + body: BodyLocal<'env, 'env, R, F>, +} + +impl<'env, R, F> ScopeBodyLocal<'env, R, F> +where + F: Future, +{ + pub(crate) fn new(body: BodyLocal<'env, 'env, R, F>) -> Self { + Self { body } + } +} + +impl<'env, R, F> Future for ScopeBodyLocal<'env, R, F> +where + F: Future, +{ + type Output = R; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + Pin::new(&mut self.project().body).poll(cx) + } +} From fc39ee3430cfeffcf463434edd714b7551883af3 Mon Sep 17 00:00:00 2001 From: Evan Schwartz <3262610+emschwartz@users.noreply.github.com> Date: Wed, 24 Jul 2024 11:47:38 -0400 Subject: [PATCH 2/2] Value doesn't need to be wrapped in an Rc --- examples/non_send_future.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/non_send_future.rs b/examples/non_send_future.rs index bdabf10..5a89a99 100644 --- a/examples/non_send_future.rs +++ b/examples/non_send_future.rs @@ -1,13 +1,13 @@ #![feature(async_closure)] -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; // We are using the current_thread runtime because this version // of the macro returns a non-Send Future. #[tokio::main(flavor = "current_thread")] pub async fn main() { // This value is not thread-safe - let value = Rc::new(RefCell::new(22)); + let value = RefCell::new(22); moro::scope_local(async |scope| { scope.spawn(async {