diff --git a/compile_flags.txt b/compile_flags.txt index ad0368fca5da..4ef92cfb6c67 100644 --- a/compile_flags.txt +++ b/compile_flags.txt @@ -51,6 +51,8 @@ -isystembazel-bin/src/rust/cxx-integration/_virtual_includes/cxx-integration@cxx -isystembazel-bin/src/rust/cxx-integration-test/_virtual_includes/cxx-integration-test@cxx -isystembazel-bin/src/rust/dns/_virtual_includes/dns@cxx +-isystembazel-bin/src/rust/async/_virtual_includes/async@cxx +-isystembazel-bin/src/rust/async/_virtual_includes/async/ -D_FORTIFY_SOURCE=1 -D_LIBCPP_REMOVE_TRANSITIVE_INCLUDES -D_LIBCPP_NO_ABI_TAG diff --git a/src/rust/async/BUILD.bazel b/src/rust/async/BUILD.bazel new file mode 100644 index 000000000000..1b6c7ddb8bd6 --- /dev/null +++ b/src/rust/async/BUILD.bazel @@ -0,0 +1,49 @@ +load("//:build/kj_test.bzl", "kj_test") +load("//:build/wd_cc_library.bzl", "wd_cc_library") +load("//:build/wd_rust_crate.bzl", "wd_rust_crate") + +wd_rust_crate( + name = "async", + cxx_bridge_deps = [ + "@capnp-cpp//src/kj:kj-async", + ], + cxx_bridge_src = "lib.rs", + visibility = ["//visibility:public"], + deps = [ + ], +) + +wd_cc_library( + name = "cxx-bridge", + srcs = [ + "await.c++", + "future.c++", + "promise.c++", + "test-promises.c++", + "waker.c++", + ], + hdrs = [ + "await.h", + "future.h", + "leak.h", + "promise.h", + "test-promises.h", + "waker.h", + ], + implementation_deps = [ + ":async", + ":async@cxx", + ], + visibility = ["//visibility:public"], + deps = [ + "@capnp-cpp//src/kj:kj-async", + ], +) + +kj_test( + src = "cxx-bridge-test.c++", + deps = [ + ":async@cxx", + ":cxx-bridge", + ], +) diff --git a/src/rust/async/await.c++ b/src/rust/async/await.c++ new file mode 100644 index 000000000000..797e0a083027 --- /dev/null +++ b/src/rust/async/await.c++ @@ -0,0 +1,161 @@ +#include + +namespace workerd::rust::async { + +// ======================================================================================= +// ArcWakerAwaiter + +ArcWakerAwaiter::ArcWakerAwaiter(FuturePollerBase& futurePoller, OwnPromiseNode node, kj::SourceLocation location) + : Event(location), + futurePoller(futurePoller), + node(kj::mv(node)) { + this->node->setSelfPointer(&this->node); + this->node->onReady(this); + // TODO(perf): If `this->isNext()` is true, can we immediately resume? Or should we check if + // the enclosing coroutine has suspended at least once? + futurePoller.beginTrace(this->node); +} + +ArcWakerAwaiter::~ArcWakerAwaiter() noexcept(false) { + futurePoller.endTrace(node); + + unwindDetector.catchExceptionsIfUnwinding([this]() { + node = nullptr; + }); +} + +// Validity-check the Promise's result, then fire the BaseFutureAwaiterBase Event to poll the +// wrapped Future again. +kj::Maybe> ArcWakerAwaiter::fire() { + futurePoller.endTrace(node); + + kj::_::ExceptionOr result; + + node->get(result); + KJ_IF_SOME(exception, kj::runCatchingExceptions([this]() { + node = nullptr; + })) { + result.addException(kj::mv(exception)); + } + + // We should only ever receive a WakeInstruction, never an exception. But if we do, propagate + // it to the coroutine. + KJ_IF_SOME(exception, result.exception) { + futurePoller.reject(kj::mv(exception)); + return kj::none; + } + + auto value = KJ_ASSERT_NONNULL(result.value); + + if (value == WakeInstruction::WAKE) { + // This was an actual wakeup. + futurePoller.armDepthFirst(); + } else { + // All of our Wakers were dropped. We are awaiting the Rust equivalent of kj::NEVER_DONE. + } + + return kj::none; +} + +void ArcWakerAwaiter::traceEvent(kj::_::TraceBuilder& builder) { + if (node.get() != nullptr) { + node->tracePromise(builder, true); + } + futurePoller.traceEvent(builder); +} + +// ================================================================================================= +// RustPromiseAwaiter + +RustPromiseAwaiter::RustPromiseAwaiter(const RootWaker& rootWaker, OwnPromiseNode node, kj::SourceLocation location) + : Event(location), + // TODO(now): const cast + futurePoller(const_cast(rootWaker).getFuturePoller()), + node(kj::mv(node)), + done(false) { + this->node->setSelfPointer(&this->node); + this->node->onReady(this); + // TODO(perf): If `this->isNext()` is true, can we immediately resume? Or should we check if + // the enclosing coroutine has suspended at least once? + futurePoller.beginTrace(this->node); +} + +RustPromiseAwaiter::~RustPromiseAwaiter() noexcept(false) { + futurePoller.endTrace(node); + + unwindDetector.catchExceptionsIfUnwinding([this]() { + node = nullptr; + }); +} + +kj::Maybe> RustPromiseAwaiter::fire() { + futurePoller.endTrace(node); + done = true; + futurePoller.armDepthFirst(); + return kj::none; +} + +void RustPromiseAwaiter::traceEvent(kj::_::TraceBuilder& builder) { + if (node.get() != nullptr) { + node->tracePromise(builder, true); + } + futurePoller.traceEvent(builder); +} + +bool RustPromiseAwaiter::poll(const RootWaker& rootWaker) { + // TODO(now): const cast, and can we do something smarter? + KJ_ASSERT(&const_cast(rootWaker).getFuturePoller() == &futurePoller); + return done; +} + +void rust_promise_awaiter_new_in_place(RustPromiseAwaiter* ptr, const RootWaker& rootWaker, OwnPromiseNode node) { + kj::ctor(*ptr, rootWaker, kj::mv(node)); +} +void rust_promise_awaiter_drop_in_place(RustPromiseAwaiter* ptr) { + kj::dtor(*ptr); +} + +// ======================================================================================= +// FuturePollerBase + +FuturePollerBase::FuturePollerBase( + kj::_::Event& next, kj::_::ExceptionOrValue& resultRef, kj::SourceLocation location) + : Event(location), + next(next), + resultRef(resultRef) {} + +void FuturePollerBase::beginTrace(OwnPromiseNode& node) { + if (promiseNodeForTrace == kj::none) { + promiseNodeForTrace = node; + } +} + +void FuturePollerBase::endTrace(OwnPromiseNode& node) { + KJ_IF_SOME(myNode, promiseNodeForTrace) { + if (myNode.get() == node.get()) { + promiseNodeForTrace = kj::none; + } + } +} + +void FuturePollerBase::reject(kj::Exception exception) { + resultRef.addException(kj::mv(exception)); + next.armDepthFirst(); +} + +void FuturePollerBase::traceEvent(kj::_::TraceBuilder& builder) { + KJ_IF_SOME(node, promiseNodeForTrace) { + node->tracePromise(builder, true); + } + next.traceEvent(builder); +} + +BoxFutureVoidAwaiter operator co_await(kj::_::CoroutineBase::Await await) { + return BoxFutureVoidAwaiter{await.coroutine, kj::mv(await.awaitable)}; +} + +BoxFutureVoidAwaiter operator co_await(kj::_::CoroutineBase::Await await) { + return BoxFutureVoidAwaiter{await.coroutine, kj::mv(await.awaitable)}; +} + +} // namespace workerd::rust::async diff --git a/src/rust/async/await.h b/src/rust/async/await.h new file mode 100644 index 000000000000..8c235178870e --- /dev/null +++ b/src/rust/async/await.h @@ -0,0 +1,194 @@ +#pragma once + +#include +#include + +#include +#include + +namespace workerd::rust::async { + +// TODO(cleanup): Code duplication with kj::_::PromiseAwaiterBase. If BaseFutureAwaiterBase could +// somehow implement CoroutineBase's interface, we could fold this into one class. +// TODO(perf): This is only an Event because we need to handle the case where all the Wakers are +// dropped and we receive a WakeInstruction::IGNORE. If we could somehow disarm the +// CrossThreadPromiseFulfillers inside ArcWaker when it's dropped, we could avoid this +// indirection. +class ArcWakerAwaiter final: public kj::_::Event { +public: + ArcWakerAwaiter(FuturePollerBase& futurePoller, OwnPromiseNode node, kj::SourceLocation location = {}); + ~ArcWakerAwaiter() noexcept(false); + + kj::Maybe> fire() override; + void traceEvent(kj::_::TraceBuilder& builder) override; + +private: + FuturePollerBase& futurePoller; + kj::UnwindDetector unwindDetector; + kj::_::OwnPromiseNode node; +}; + +// ======================================================================================= +// RustPromiseAwaiter + +// RustPromiseAwaiter allows Rust `async` blocks to `.await` KJ promises. Rust code creates one in +// the block's storage at the point where the `.await` expression is evaluated, similar to how +// `kj::_::PromiseAwaiter` is created in the KJ coroutine frame when C++ `co_await`s a promise. +// +// To initialize the object, Rust needs to know the size and alignment of RustPromiseAwaiter. To +// that end, I used bindgen to generate an opaque FFI type in await_h.rs using the command below. +// +// TODO(now): Automate this? + +#if 0 + +bindgen \ + --rust-target 1.83.0 \ + --disable-name-namespacing \ + --generate "types" \ + --allowlist-type "workerd::rust::async_::RustPromiseAwaiter" \ + --opaque-type ".*" \ + --no-derive-copy \ + ./await.h \ + -o ./await.h.rs \ + -- \ + -x c++ \ + -std=c++23 \ + -stdlib=libc++ \ + -Wno-pragma-once-outside-header \ + -I $(bazel info bazel-bin)/external/capnp-cpp/src/kj/_virtual_includes/kj \ + -I $(bazel info bazel-bin)/external/capnp-cpp/src/kj/_virtual_includes/kj-async \ + -I $(bazel info bazel-bin)/external/crates_vendor__cxx-1.0.133/_virtual_includes/cxx_cc \ + -I $(bazel info bazel-bin)/src/rust/async/_virtual_includes/async@cxx + +#endif + +class RustPromiseAwaiter final: public kj::_::Event { +public: + RustPromiseAwaiter(const RootWaker& rootWaker, OwnPromiseNode node, kj::SourceLocation location = {}); + ~RustPromiseAwaiter() noexcept(false); + + kj::Maybe> fire() override; + void traceEvent(kj::_::TraceBuilder& builder) override; + + // Called by Rust. + bool poll(const RootWaker& cx); + +private: + FuturePollerBase& futurePoller; + kj::UnwindDetector unwindDetector; + kj::_::OwnPromiseNode node; + bool done; +}; + +using PtrRustPromiseAwaiter = RustPromiseAwaiter*; + +void rust_promise_awaiter_new_in_place(PtrRustPromiseAwaiter, const RootWaker&, OwnPromiseNode); +void rust_promise_awaiter_drop_in_place(PtrRustPromiseAwaiter); + +// ======================================================================================= +// FuturePollerBase + +// Base class for the awaitable created by `co_await` when awaiting a Rust Future in a KJ coroutine. +class FuturePollerBase: public kj::_::Event { +public: + // Initialize `next` with the enclosing coroutine's `Event`. + FuturePollerBase( + kj::_::Event& next, kj::_::ExceptionOrValue& resultRef, kj::SourceLocation location = {}); + + // When we `poll()` a Future, our RootWaker will either be cloned (creating an ArcWaker + // promise), or the Future will `.await` some number of KJ promises itself, or both. The awaiter + // objects which wrap those two kinds of promises, use `beginTrace()` and `endTrace()` to connect + // the promise they're wrapping to the enclosing coroutine for tracing purposes. + void beginTrace(OwnPromiseNode& node); + void endTrace(OwnPromiseNode& node); + + // TODO(now): fulfill() + + // Reject the Future with an exception. Arms the enclosing coroutine's event. The event will + // resume the coroutine, which will then rethrow the exception from `await_resume()`. + void reject(kj::Exception exception); + + virtual kj::Maybe> fire() override = 0; + void traceEvent(kj::_::TraceBuilder& builder) override; + +private: + kj::_::Event& next; + kj::_::ExceptionOrValue& resultRef; + + kj::Maybe promiseNodeForTrace; +}; + +class BoxFutureVoidAwaiter: public FuturePollerBase { +public: + BoxFutureVoidAwaiter(kj::_::CoroutineBase& coroutine, BoxFutureVoid&& future, kj::SourceLocation location = {}) + : FuturePollerBase(coroutine, result), + coroutine(coroutine), + future(kj::mv(future)) {} + ~BoxFutureVoidAwaiter() noexcept(false) { + coroutine.awaitEnd(); + } + + bool await_ready() { + // TODO(perf): Check if we already have an ArcWaker from a previous suspension and give it to + // RootWaker for cloning if we have the last reference to it at this point. This could save + // memory allocations, but would depend on making XThreadFulfiller and XThreadPaf resettable + // to really benefit. + RootWaker waker(*this); + + if (future.poll(waker)) { + // Future is ready, we're done. + // TODO(now): Propagate value-or-exception. + return true; + } + + auto state = waker.reset(); + + if (state.wakeCount > 0) { + // The future returned Pending, but synchronously called `wake_by_ref()` on the RootWaker, + // indicating it wants to immediately be polled again. We should arm our event right now, + // which will call `await_ready()` again on the event loop. + armDepthFirst(); + } else KJ_IF_SOME(cloned, state.cloned) { + // The future returned Pending and cloned an ArcWaker to notify us later. We'll arrange for + // the ArcWaker's promise to arm our event once it's fulfilled. + arcWakerAwaiter.emplace(*this, kj::_::PromiseNode::from(kj::mv(cloned.promise))); + } else { + // The future returned Pending, did not call `wake_by_ref()` on the RootWaker, and did not + // clone an ArcWaker. Rust is either awaiting a KJ promise, or the Rust equivalent of + // kj::NEVER_DONE. + } + + return false; + } + + // We already arranged to be scheduled in await_ready(), nothing to do here. + void await_suspend(kj::_::stdcoro::coroutine_handle<>) {} + + // Unit futures return void. + void await_resume() { + KJ_IF_SOME(exception, result.exception) { + kj::throwFatalException(kj::mv(exception)); + } + } + + kj::Maybe> fire() override { + if (await_ready()) { + // TODO(perf): Call `coroutine.fire()` directly? + coroutine.armDepthFirst(); + } + return kj::none; + } + +private: + kj::_::CoroutineBase& coroutine; + BoxFutureVoid future; + kj::_::ExceptionOr result; + + kj::Maybe arcWakerAwaiter; +}; + +BoxFutureVoidAwaiter operator co_await(kj::_::CoroutineBase::Await await); +BoxFutureVoidAwaiter operator co_await(kj::_::CoroutineBase::Await await); + +} // namespace workerd::rust::async diff --git a/src/rust/async/await.h.rs b/src/rust/async/await.h.rs new file mode 100644 index 000000000000..1bcdf15d6934 --- /dev/null +++ b/src/rust/async/await.h.rs @@ -0,0 +1,13 @@ +/* automatically generated by rust-bindgen 0.71.1 */ + +#[repr(C)] +#[repr(align(8))] +#[derive(Debug)] +pub struct RustPromiseAwaiter { + pub _bindgen_opaque_blob: [u64; 12usize], +} +#[allow(clippy::unnecessary_operation, clippy::identity_op)] +const _: () = { + ["Size of RustPromiseAwaiter"][::std::mem::size_of::() - 96usize]; + ["Alignment of RustPromiseAwaiter"][::std::mem::align_of::() - 8usize]; +}; diff --git a/src/rust/async/await_.rs b/src/rust/async/await_.rs new file mode 100644 index 000000000000..f51f21f484b4 --- /dev/null +++ b/src/rust/async/await_.rs @@ -0,0 +1,113 @@ +use std::future::Future; +use std::future::IntoFuture; + +// use std::mem::MaybeUninit; + +use std::pin::Pin; + +use std::task::Context; +use std::task::Poll; + +use cxx::ExternType; + +use crate::waker::deref_root_waker; + +use crate::lazy_pin_init::LazyPinInit; + +#[path = "await.h.rs"] +mod await_h; +pub use await_h::RustPromiseAwaiter; + +#[repr(transparent)] +pub struct PtrRustPromiseAwaiter(*mut RustPromiseAwaiter); + +use crate::ffi::rust_promise_awaiter_drop_in_place; +use crate::ffi::rust_promise_awaiter_new_in_place; + +use crate::OwnPromiseNode; + +impl Drop for RustPromiseAwaiter { + fn drop(&mut self) { + // The pin crate suggests implementing drop traits for address-sensitive types with an inner + // function which accepts a `Pin<&mut Type>` parameter, to help uphold pinning guarantees. + // However, since our drop function is actually a C++ destructor to which we must pass a raw + // pointer, there is no benefit in creating a Pin from `self`. + unsafe { + rust_promise_awaiter_drop_in_place(PtrRustPromiseAwaiter(self)); + } + } +} + +unsafe impl ExternType for RustPromiseAwaiter { + type Id = cxx::type_id!("workerd::rust::async::RustPromiseAwaiter"); + type Kind = cxx::kind::Opaque; +} + +unsafe impl ExternType for PtrRustPromiseAwaiter { + type Id = cxx::type_id!("workerd::rust::async::PtrRustPromiseAwaiter"); + type Kind = cxx::kind::Trivial; +} + +// ======================================================================================= +// Await syntax for OwnPromiseNode + +impl IntoFuture for OwnPromiseNode { + type Output = (); + type IntoFuture = RustPromiseAwaiterFuture; + + fn into_future(self) -> Self::IntoFuture { + RustPromiseAwaiterFuture::new(self) + } +} + +pub struct RustPromiseAwaiterFuture { + node: Option, + awaiter: LazyPinInit, +} + +impl RustPromiseAwaiterFuture { + fn new(node: OwnPromiseNode) -> Self { + RustPromiseAwaiterFuture { + node: Some(node), + awaiter: LazyPinInit::uninit(), + } + } +} + +impl Future for RustPromiseAwaiterFuture { + type Output = (); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> { + if let Some(root_waker) = deref_root_waker(cx.waker()) { + // On our first invocation, `node` will be Some, and `awaiter.get()`'s callback will + // immediately move pass its contents into the RustPromiseAwaiter constructor. On all + // subsequent invocations, `node` will be None and the `awaiter.get()` callback will + // not fire. + let node = self.node.take(); + + // Our awaiter is structurally pinned. + // TODO(now): Safety comment. + let awaiter = unsafe { self.map_unchecked_mut(|s| &mut s.awaiter) }; + + let awaiter = awaiter.get(move |ptr: *mut RustPromiseAwaiter| unsafe { + rust_promise_awaiter_new_in_place( + PtrRustPromiseAwaiter(ptr), + root_waker, + // `node` is consumed + node.expect("init function only called once"), + ); + }); + + if awaiter.poll(root_waker) { + Poll::Ready(()) + } else { + Poll::Pending + } + } else { + unreachable!("unimplemented"); + // TODO(now): Store a clone of the waker, then replace self.node with the result + // of wake_after(&waker, node), which will be implemented like + // node.attach(kj::defer([&waker]() { waker.wake_by_ref(); })) + // .eagerlyEvaluate(nullptr) + } + } +} diff --git a/src/rust/async/cxx-bridge-test.c++ b/src/rust/async/cxx-bridge-test.c++ new file mode 100644 index 000000000000..ecb7130a60d8 --- /dev/null +++ b/src/rust/async/cxx-bridge-test.c++ @@ -0,0 +1,127 @@ +#include +#include +#include +#include + +#include + +namespace workerd::rust::async { +namespace { + +class TestCoroutineEvent: public kj::_::Event { +public: + TestCoroutineEvent(kj::SourceLocation location = {}) + : Event(location) {} + kj::Maybe> fire() override { + KJ_UNIMPLEMENTED("nope"); + } + void traceEvent(kj::_::TraceBuilder& builder) override {} +}; + +class TestFuturePoller: public FuturePollerBase { +public: + TestFuturePoller(kj::_::Event& next, kj::SourceLocation location = {}) + : FuturePollerBase(next, result, location) {} + + kj::Maybe> fire() override { + fired = true; + return kj::none; + } + + bool fired = false; + kj::_::ExceptionOr result; +}; + +KJ_TEST("BoxFutureVoid: C++ can poll() Rust Futures") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + // Poll a Future which returns Pending. + { + TestCoroutineEvent coroutineEvent; + TestFuturePoller futurePoller{coroutineEvent}; + RootWaker waker{futurePoller}; + + auto pending = new_pending_future_void(); + KJ_EXPECT(!pending.poll(waker)); + + // The pending future never calls Waker::wake() because it has no intention to ever wake us up. + // Additionally, it never even calls `waker.clone()`, so we have no promise at all. + auto state = waker.reset(); + KJ_EXPECT(state.wakeCount == 0); + KJ_EXPECT(state.cloned == kj::none); + } + + // Poll a Future which returns Ready(()). + { + TestCoroutineEvent coroutineEvent; + TestFuturePoller futurePoller{coroutineEvent}; + RootWaker waker{futurePoller}; + + auto ready = new_ready_future_void(); + KJ_EXPECT(ready.poll(waker)); + + // The ready future never calls Waker::wake() because it instead indicates immediate + // readiness by its return value. Additionally, it never even calls `waker.clone()`, so we have no + // promise at all. + auto state = waker.reset(); + KJ_EXPECT(state.wakeCount == 0); + KJ_EXPECT(state.cloned == kj::none); + } + + // Poll a Future which returns Pending and immediately calls the Waker. + { + TestCoroutineEvent coroutineEvent; + TestFuturePoller futurePoller{coroutineEvent}; + RootWaker waker{futurePoller}; + + auto waking = new_waking_future_void(); + KJ_EXPECT(!waking.poll(waker)); + + // The waking future immediately called wake_by_ref() on the LazyWaker. This incremented our + // count, but didn't populate a promise. + auto state = waker.reset(); + KJ_EXPECT(state.wakeCount == 1); + KJ_EXPECT(state.cloned == kj::none); + } + + // Poll a Future which clones the Waker on a different thread, then spawns a new thread to wake + // the waker after a delay. + { + TestCoroutineEvent coroutineEvent; + TestFuturePoller futurePoller{coroutineEvent}; + RootWaker waker{futurePoller}; + + auto waking = new_threaded_delay_future_void(); + KJ_EXPECT(!waking.poll(waker)); + + auto state = waker.reset(); + KJ_EXPECT(state.wakeCount == 0); + KJ_EXPECT(state.cloned != kj::none); + + KJ_EXPECT(!KJ_ASSERT_NONNULL(state.cloned).promise.poll(waitScope)); + KJ_ASSERT_NONNULL(state.cloned).promise.wait(waitScope); + } +} + +KJ_TEST("FutureAwaiter: C++ KJ coroutines can co_await Rust Futures") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + []() -> kj::Promise { + co_await new_ready_future_void(); + co_await new_waking_future_void(); + }().wait(waitScope); +} + +KJ_TEST("OwnPromiseNode: Rust can wait for KJ promises") { + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + []() -> kj::Promise { + co_await new_layered_ready_future_void(); + }().wait(waitScope); +} + +} // namespace +} // namespace workerd::rust::async diff --git a/src/rust/async/future.c++ b/src/rust/async/future.c++ new file mode 100644 index 000000000000..1148a178bd5a --- /dev/null +++ b/src/rust/async/future.c++ @@ -0,0 +1,20 @@ +#include +#include + +namespace workerd::rust::async { + +BoxFutureVoid::BoxFutureVoid(BoxFutureVoid&& other) noexcept: repr(other.repr) { + other.repr = {0, 0}; +} + +BoxFutureVoid::~BoxFutureVoid() noexcept { + if (repr != std::array{0, 0}) { + box_future_void_drop_in_place(this); + } +} + +bool BoxFutureVoid::poll(const RootWaker& waker) noexcept { + return box_future_void_poll(*this, waker); +} + +} // namespace workerd::rust::async diff --git a/src/rust/async/future.h b/src/rust/async/future.h new file mode 100644 index 000000000000..23de8451b176 --- /dev/null +++ b/src/rust/async/future.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include + +namespace workerd::rust::async { + +// A `Pin>>` owned by C++. +// +// The only way to construct a BoxFutureVoid is by returning one from a Rust function. +// +// TODO(now): Figure out how to make this a template, BoxFuture. +class BoxFutureVoid { +public: + BoxFutureVoid(BoxFutureVoid&&) noexcept; + ~BoxFutureVoid() noexcept; + + // This function constructs a `std::task::Context` in Rust wrapping the given `RootWaker`. It + // then calls the future's `Future::poll()` trait function. + // + // The reason we pass a `const RootWaker&`, and not the more generic `const CxxWaker&`, is + // because `RootWaker` exposes an API which Rust can use to optimize awaiting KJ Promises inside + // of this future. + // + // Returns true if the future returned `Poll::Ready`, false if the future returned + // `Poll::Pending`. + // + // TODO(now): Figure out how to return non-unit/void values and exceptions. + bool poll(const RootWaker& waker) noexcept; + + // Tell cxx-rs that this type follows Rust's move semantics, and can thus be passed across the FFI + // boundary. + using IsRelocatable = std::true_type; + +private: + // Match Rust's representation of a `Box`. + std::array repr; +}; + +// We define this the pointer typedef so that cxx-rs can associate it with the same pointer type our +// drop function uses. +using PtrBoxFutureVoid = BoxFutureVoid*; + +} // namespace workerd::rust::async diff --git a/src/rust/async/future.rs b/src/rust/async/future.rs new file mode 100644 index 000000000000..5ea33001d147 --- /dev/null +++ b/src/rust/async/future.rs @@ -0,0 +1,56 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::Context; +use std::task::Waker; + +use cxx::ExternType; + +use crate::ffi::RootWaker; + +// Expose Pin> to C++ as BoxFutureVoid. +// +// We want to allow C++ to own Rust Futures in a Box. At present, cxx-rs can easily expose Box +// directly to C++ only if T implements Sized and Unpin. Dynamic trait types like `dyn Future` don't +// meet these requirements. One workaround is to pass Box> around. With a few more +// lines of boilerplate, we can avoid the extra Box:, as dtolnay showed in this demo PR: +// https://github.com/dtolnay/cxx/pull/672/files + +pub struct BoxFuture(Pin>>); + +#[repr(transparent)] +pub struct PtrBoxFuture(*mut BoxFuture); + +// A From implementation to make it easier to convert from an arbitrary Future +// type into a BoxFuture. +// +// TODO(now): Understand why 'static is needed. +impl + 'static> From>> for BoxFuture { + fn from(value: Pin>) -> Self { + BoxFuture(value) + } +} + +// We must manually implement the ExternType trait, poll, and drop functions for each possible T of +// BoxFuture and PtrBoxFuture. +// +// TODO(now): Make this a macro so we can define them easier? +unsafe impl ExternType for BoxFuture<()> { + type Id = cxx::type_id!("workerd::rust::async::BoxFutureVoid"); + type Kind = cxx::kind::Trivial; +} + +unsafe impl ExternType for PtrBoxFuture<()> { + type Id = cxx::type_id!("workerd::rust::async::PtrBoxFutureVoid"); + type Kind = cxx::kind::Trivial; +} + +pub fn box_future_void_poll(future: &mut BoxFuture<()>, waker: &RootWaker) -> bool { + let waker = Waker::from(waker); + let mut cx = Context::from_waker(&waker); + // TODO(now): Figure out how to propagate value-or-exception. + future.0.as_mut().poll(&mut cx).is_ready() +} + +pub unsafe fn box_future_void_drop_in_place(ptr: PtrBoxFuture<()>) { + std::ptr::drop_in_place(ptr.0); +} diff --git a/src/rust/async/lazy_pin_init.rs b/src/rust/async/lazy_pin_init.rs new file mode 100644 index 000000000000..ce5734d5dfe6 --- /dev/null +++ b/src/rust/async/lazy_pin_init.rs @@ -0,0 +1,54 @@ +use std::mem::MaybeUninit; +use std::pin::Pin; + +// Based on StackInit from the `pinned-init` crate: +// https://github.com/Rust-for-Linux/pinned-init/blob/67c0a0c35bf23b8584f8e7792f9098de5fe0c8b0/src/__internal.rs#L142 + +/// # Invariants +/// +/// If `self.is_init` is true, then `self.value` is initialized. +pub struct LazyPinInit { + value: MaybeUninit, + is_init: bool, +} + +impl Drop for LazyPinInit { + #[inline] + fn drop(&mut self) { + if self.is_init { + // SAFETY: As we are being dropped, we only call this once. And since `self.is_init` is + // true, `self.value` is initialized. + unsafe { self.value.assume_init_drop() }; + } + } +} + +impl LazyPinInit { + /// Creates a new `LazyPinInit` that is uninitialized. + #[inline] + pub fn uninit() -> Self { + Self { + value: MaybeUninit::uninit(), + is_init: false, + } + } + + /// Initializes the contents and returns the result. + #[inline] + pub fn get(self: Pin<&mut Self>, init: impl FnOnce(*mut T)) -> Pin<&mut T> { + // SAFETY: We never move out of `this`. + let this = unsafe { Pin::into_inner_unchecked(self) }; + // The value is currently initialized, so it needs to be dropped before we can reuse + // the memory (this is a safety guarantee of `Pin`). + if !this.is_init { + // SAFETY: The memory slot is valid and this type ensures that it will stay pinned. + init(this.value.as_mut_ptr()); + // INVARIANT: `this.value` is initialized above. + this.is_init = true; + } + // SAFETY: The slot is now pinned, since we will never give access to `&mut T`. + unsafe { Pin::new_unchecked(this.value.assume_init_mut()) } + } +} + +// TODO(now): Test, exception-handling diff --git a/src/rust/async/leak.h b/src/rust/async/leak.h new file mode 100644 index 000000000000..24ec79577c0e --- /dev/null +++ b/src/rust/async/leak.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +namespace workerd::rust::async { + +// Consume a `kj::Arc` such that its destructor never runs, then return the pointer it owned. You +// must arrange to call `unleak()` on the returned pointer later if you want to destroy it. +template +T* leak(kj::Arc value); + +// Given a pointer to an `kj::AtomicRefcounted` which has previously been returned by `leak()`, +// reassume ownership by wrapping the pointer in a `kj::Arc`. +template +kj::Arc unleak(T* ptr); + +// ======================================================================================= +// Implementation details + +template +T* leak(kj::Arc value) { + // Unions do not run non-trivial constructors or destructors for their members, unless the union's + // own constructor/destructor are explicitly written to do so. Here, we run kj::Arc's move + // constructor, but not its destructor, causing it to leak. + // + // TODO(cleanup): libkj _almost_ has a way to do this: if it were possible to cast + // AtomicRefcounted objects to their private Disposer base class, we could use + // `value.toOwn().disown(disposer)`. + union Leak { + kj::Arc value; + Leak(kj::Arc value): value(kj::mv(value)) {} + ~Leak() {} + } leak{kj::mv(value)}; + return leak.value.get(); +} + +// HACK: We partially specialize kj::Arc> below in order to gain fraudulent friend access +// kj::Arc's ownership-assuming constructor. +template +struct Unleak {}; + +// Take a pointer to a kj::AtomicRefcounted value and assume ownership of it. +template +kj::Arc unleak(T* ptr) { + return kj::Arc>::unleak(ptr); +} + +} // namespace workerd::rust::async + +namespace kj { + +template +class Arc> { +public: + static_assert(kj::canConvert()); + static Arc unleak(T* ptr) { + // This unary pointer-accepting constructor seems to be the easiest way to assume ownership of a + // raw kj::AtomicRefcounted pointer. It is private, which is why we specialized Arc> + // in order to gain friend access. + return Arc{ptr}; + } +}; + +} // namespace kj diff --git a/src/rust/async/lib.rs b/src/rust/async/lib.rs new file mode 100644 index 000000000000..1ed6c6fa0a90 --- /dev/null +++ b/src/rust/async/lib.rs @@ -0,0 +1,99 @@ +mod await_; +use await_::PtrRustPromiseAwaiter; +pub use await_::RustPromiseAwaiter; + +mod future; +use future::box_future_void_drop_in_place; +use future::box_future_void_poll; +pub use future::BoxFuture; +use future::PtrBoxFuture; + +mod lazy_pin_init; + +mod promise; +pub use promise::OwnPromiseNode; +use promise::PtrOwnPromiseNode; + +mod test_futures; +use test_futures::new_layered_ready_future_void; +use test_futures::new_pending_future_void; +use test_futures::new_ready_future_void; +use test_futures::new_threaded_delay_future_void; +use test_futures::new_waking_future_void; + +mod waker; + +#[cxx::bridge(namespace = "workerd::rust::async")] +mod ffi { + unsafe extern "C++" { + include!("workerd/rust/async/waker.h"); + + // Match the definition of the abstract virtual class in the C++ header. + type CxxWaker; + fn clone(&self) -> *const CxxWaker; + fn wake(&self); + fn wake_by_ref(&self); + fn drop(&self); + } + + unsafe extern "C++" { + include!("workerd/rust/async/waker.h"); + + type RootWaker; + fn is_current(&self) -> bool; + } + + unsafe extern "C++" { + include!("workerd/rust/async/future.h"); + + type BoxFutureVoid = crate::BoxFuture<()>; + type PtrBoxFutureVoid = crate::PtrBoxFuture<()>; + } + + extern "Rust" { + fn box_future_void_poll(future: &mut BoxFutureVoid, cx: &RootWaker) -> bool; + unsafe fn box_future_void_drop_in_place(ptr: PtrBoxFutureVoid); + } + + unsafe extern "C++" { + include!("workerd/rust/async/promise.h"); + + type OwnPromiseNode = crate::OwnPromiseNode; + type PtrOwnPromiseNode = crate::PtrOwnPromiseNode; + + unsafe fn own_promise_node_drop_in_place(node: PtrOwnPromiseNode); + } + + unsafe extern "C++" { + include!("workerd/rust/async/await.h"); + + type RustPromiseAwaiter = crate::RustPromiseAwaiter; + type PtrRustPromiseAwaiter = crate::PtrRustPromiseAwaiter; + + unsafe fn rust_promise_awaiter_new_in_place( + ptr: PtrRustPromiseAwaiter, + root_waker: &RootWaker, + node: OwnPromiseNode, + ); + unsafe fn rust_promise_awaiter_drop_in_place(ptr: PtrRustPromiseAwaiter); + + fn poll(self: Pin<&mut RustPromiseAwaiter>, cx: &RootWaker) -> bool; + } + + // Helper functions to create OwnPromiseNodes for testing purposes. + unsafe extern "C++" { + include!("workerd/rust/async/test-promises.h"); + + fn new_ready_promise_node() -> OwnPromiseNode; + fn new_coroutine_promise_node() -> OwnPromiseNode; + } + + // Helper functions to create BoxFutureVoids for testing purposes. + extern "Rust" { + fn new_pending_future_void() -> BoxFutureVoid; + fn new_ready_future_void() -> BoxFutureVoid; + fn new_waking_future_void() -> BoxFutureVoid; + fn new_threaded_delay_future_void() -> BoxFutureVoid; + fn new_layered_ready_future_void() -> BoxFutureVoid; + } +} diff --git a/src/rust/async/promise.c++ b/src/rust/async/promise.c++ new file mode 100644 index 000000000000..07497f69b62a --- /dev/null +++ b/src/rust/async/promise.c++ @@ -0,0 +1,11 @@ +#include + +#include + +namespace workerd::rust::async { + +void own_promise_node_drop_in_place(OwnPromiseNode* node) { + node->~OwnPromiseNode(); +} + +} // namespace workerd::rust::async diff --git a/src/rust/async/promise.h b/src/rust/async/promise.h new file mode 100644 index 000000000000..291f534760dd --- /dev/null +++ b/src/rust/async/promise.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace workerd::rust::async { + +using OwnPromiseNode = kj::_::OwnPromiseNode; +using PtrOwnPromiseNode = OwnPromiseNode*; + +void own_promise_node_drop_in_place(OwnPromiseNode*); + +} // namespace workerd::rust::async + +namespace rust { + +// OwnPromiseNodes happen to follow Rust move semantics. +template <> +struct IsRelocatable<::workerd::rust::async::OwnPromiseNode>: std::true_type {}; + +} // namespace rust diff --git a/src/rust/async/promise.rs b/src/rust/async/promise.rs new file mode 100644 index 000000000000..3be68f34d6b7 --- /dev/null +++ b/src/rust/async/promise.rs @@ -0,0 +1,34 @@ +use cxx::ExternType; + +use crate::ffi::own_promise_node_drop_in_place; + +// OwnPromiseNode + +// The inner pointer is never read on Rust's side, so Rust thinks it's dead code. +#[allow(dead_code)] +pub struct OwnPromiseNode(*const ()); + +#[repr(transparent)] +pub struct PtrOwnPromiseNode(*mut OwnPromiseNode); + +unsafe impl ExternType for OwnPromiseNode { + type Id = cxx::type_id!("workerd::rust::async::OwnPromiseNode"); + type Kind = cxx::kind::Trivial; +} + +unsafe impl ExternType for PtrOwnPromiseNode { + type Id = cxx::type_id!("workerd::rust::async::PtrOwnPromiseNode"); + type Kind = cxx::kind::Trivial; +} + +impl Drop for OwnPromiseNode { + fn drop(&mut self) { + // The pin crate suggests implementing drop traits for address-sensitive types with an inner + // function which accepts a `Pin<&mut Type>` parameter, to help uphold pinning guarantees. + // However, since our drop function is actually a C++ destructor to which we must pass a raw + // pointer, there is no benefit in creating a Pin from `self`. + unsafe { + own_promise_node_drop_in_place(PtrOwnPromiseNode(self)); + } + } +} diff --git a/src/rust/async/test-promises.c++ b/src/rust/async/test-promises.c++ new file mode 100644 index 000000000000..c444cd7c5881 --- /dev/null +++ b/src/rust/async/test-promises.c++ @@ -0,0 +1,17 @@ +#include + +namespace workerd::rust::async { + +OwnPromiseNode new_ready_promise_node() { + return kj::_::PromiseNode::from(kj::Promise(kj::READY_NOW)); +} + +OwnPromiseNode new_coroutine_promise_node() { + return kj::_::PromiseNode::from([]() -> kj::Promise { + co_await kj::Promise(kj::READY_NOW); + co_await kj::Promise(kj::READY_NOW); + co_await kj::Promise(kj::READY_NOW); + }()); +} + +} // namespace workerd::rust::async diff --git a/src/rust/async/test-promises.h b/src/rust/async/test-promises.h new file mode 100644 index 000000000000..92fdc0154e43 --- /dev/null +++ b/src/rust/async/test-promises.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace workerd::rust::async { + +OwnPromiseNode new_ready_promise_node(); +OwnPromiseNode new_coroutine_promise_node(); + +} // namespace workerd::rust::async diff --git a/src/rust/async/test_futures.rs b/src/rust/async/test_futures.rs new file mode 100644 index 000000000000..989da1b40f14 --- /dev/null +++ b/src/rust/async/test_futures.rs @@ -0,0 +1,85 @@ +use std::pin::Pin; +use std::future::Future; +use std::task::Poll; + +use crate::BoxFuture; + +pub fn new_pending_future_void() -> BoxFuture<()> { + Box::pin(std::future::pending()).into() +} +pub fn new_ready_future_void() -> BoxFuture<()> { + Box::pin(std::future::ready(())).into() +} + +// TODO(now): Make configurable: +// - Synchronous wake_by_ref() +// - Synchronous clone().wake() +// - wake_by_ref() on different thread +// - clone().wake() on different thread +// - List of different wake styles +struct WakingFuture { + done: bool, +} + +impl WakingFuture { + fn new() -> Self { + Self { done: false } + } +} + +impl Future for WakingFuture { + type Output = (); + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<()> { + let waker = cx.waker(); + // waker.clone().wake(); + waker.wake_by_ref(); + if self.done { + Poll::Ready(()) + } else { + self.done = true; + Poll::Pending + } + } +} + +pub fn new_waking_future_void() -> BoxFuture<()> { + Box::pin(WakingFuture::new()).into() +} + +struct ThreadedDelayFuture { + handle: Option>, +} + +impl ThreadedDelayFuture { + fn new() -> Self { + Self { handle: None } + } +} + +impl Future for ThreadedDelayFuture { + type Output = (); + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<()> { + if let Some(handle) = self.handle.take() { + let _ = handle.join(); + return Poll::Ready(()); + } + let waker = cx.waker(); + let waker = std::thread::scope(|scope| scope.spawn(|| waker.clone()).join().unwrap()); + self.handle = Some(std::thread::spawn(|| { + std::thread::sleep(std::time::Duration::from_millis(100)); + waker.wake(); + })); + Poll::Pending + } +} + +pub fn new_threaded_delay_future_void() -> BoxFuture<()> { + Box::pin(ThreadedDelayFuture::new()).into() +} + +pub fn new_layered_ready_future_void() -> BoxFuture<()> { + Box::pin(async { + crate::ffi::new_ready_promise_node().await; + crate::ffi::new_coroutine_promise_node().await; + }).into() +} diff --git a/src/rust/async/waker.c++ b/src/rust/async/waker.c++ new file mode 100644 index 000000000000..d29e70484bd2 --- /dev/null +++ b/src/rust/async/waker.c++ @@ -0,0 +1,106 @@ +#include +#include + +#include + +namespace workerd::rust::async { + +// ======================================================================================= +// ArcWaker + +ArcWaker::ArcWaker(kj::Own> fulfiller) + : fulfiller(kj::mv(fulfiller)) {} +ArcWaker::~ArcWaker() noexcept(false) { + // We can't leave the promise hanging or else the fulfiller's destructor will reject it for us. + // So, settle the promise with our no-op ignore value in case we're still waiting here. + fulfiller->fulfill(WakeInstruction::IGNORE); +} + +const CxxWaker* ArcWaker::clone() const { + return leak(addRefToThis()); +} +void ArcWaker::wake() const { + wake_by_ref(); + drop(); +} +void ArcWaker::wake_by_ref() const { + fulfiller->fulfill(WakeInstruction::WAKE); +} +void ArcWaker::drop() const { + auto drop = unleak(this); +} + +PromiseArcWakerPair newPromiseAndArcWaker(const kj::Executor& executor) { + // TODO(perf): newPromiseAndCrossThreadFulfiller() makes two heap allocations, but it is probably + // optimizable to one. + auto [promise, fulfiller] = kj::newPromiseAndCrossThreadFulfiller(executor); + return { + .promise = kj::mv(promise), + // TODO(perf): This heap allocation could also probably be collapsed into the fulfiller's. + .waker = kj::arc(kj::mv(fulfiller)), + }; +} + +// ======================================================================================= +// RootWaker + +RootWaker::RootWaker(FuturePollerBase& futurePoller): futurePoller(futurePoller) {} + +const CxxWaker* RootWaker::clone() const { + // Rust code wants to suspend and wait for something other than an OwnPromiseNode from the same + // thread as this RootWaker. We'll start handing out ArcWakers if we haven't already been woken + // synchronously. + + if (wakeCount.load(std::memory_order_relaxed) > 0) { + // We were already woken synchronously, so there's no point handing out more wakers for the + // current call to `Future::poll()`. We can hand out a noop waker by returning nullptr. + return nullptr; + } + + auto lock = cloned.lockExclusive(); + + if (*lock == kj::none) { + // We haven't been cloned before, so make a new ArcWaker. + *lock = newPromiseAndArcWaker(executor); + } + + return KJ_ASSERT_NONNULL(*lock).waker->clone(); +} + +void RootWaker::wake() const { + // RootWakers are only exposed to Rust by const borrow, meaning Rust can never arrange to call + // `wake()`, which drops `self`, on this object. + KJ_UNIMPLEMENTED("Rust user code should never have a consumable reference to RootWaker"); +} + +void RootWaker::wake_by_ref() const { + // Woken synchronously during a call to `future.poll(awaitWaker)`. + wakeCount.fetch_add(1, std::memory_order_relaxed); +} + +void RootWaker::drop() const { + ++dropCount; +} + +bool RootWaker::is_current() const { + return &executor == &kj::getCurrentThreadExecutor(); +} + +FuturePollerBase& RootWaker::getFuturePoller() { + return futurePoller; +} + +RootWaker::State RootWaker::reset() { + // Getting the state without a lock is safe, because this function is only called after + // `future.poll(awaitWaker)` has returned, meaning Rust has dropped its reference. + KJ_ASSERT(dropCount == 1); + KJ_DEFER(dropCount = 0); + KJ_DEFER(wakeCount.store(0, std::memory_order_relaxed)); + KJ_DEFER(cloned.getWithoutLock() = kj::none); + return { + .wakeCount = wakeCount.load(std::memory_order_relaxed), + .cloned = kj::mv(cloned.getWithoutLock()), + }; +} + +} // namespace workerd::rust::async diff --git a/src/rust/async/waker.h b/src/rust/async/waker.h new file mode 100644 index 000000000000..6826c4aa91ef --- /dev/null +++ b/src/rust/async/waker.h @@ -0,0 +1,178 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace workerd::rust::async { + +// ======================================================================================= +// CxxWaker + +// CxxWaker is an abstract base class which defines an interface mirroring Rust's RawWakerVTable +// struct. Rust has four trampoline functions, defined in waker.rs, which translate Waker::clone(), +// Waker::wake(), etc. calls to the virtual member functions on this class. +// +// Rust requires Wakers to be Send and Sync, meaning all of the functions defined here may be called +// concurrently by any thread. Derived class implementations of these functions must handle this, +// which is why all of the virtual member functions are `const`-qualified. +class CxxWaker { +public: + // Return a pointer to a new strong ref to a CxxWaker. Note that `clone()` may return nullptr, + // in which case the Rust implementation in waker.rs will treat it as a no-op Waker. Rust + // immediately wraps this pointer in its own Waker object, which is responsible for later + // releasing the strong reference. + // + // TODO(cleanup): Build kj::Arc into cxx-rs so we can return one instead of a raw pointer. + virtual const CxxWaker* clone() const = 0; + + // Wake and drop this waker. + virtual void wake() const = 0; + + // Wake this waker, but do not drop it. + virtual void wake_by_ref() const = 0; + + // Drop this waker. + virtual void drop() const = 0; +}; + +// ======================================================================================= +// ArcWaker + +// The result type for ArcWaker's Promise. +enum class WakeInstruction { + // The `IGNORE` instruction means the Waker was dropped without ever being used. + IGNORE, + // The `WAKE` instruction means `wake()` was called on the Waker. + WAKE, +}; + +// ArcWaker is an atomic-refcounted wrapper around a `CrossThreadPromiseFulfiller`. +// The atomic-refcounted aspect makes it safe to call `clone()` and `drop()` concurrently, while the +// `CrossThreadPromiseFulfiller` aspect makes it safe to call `wake_by_ref()` concurrently. Finally, +// `wake()` is implemented in terms of `wake_by_ref()` and `drop()`. +// +// This class is mostly an implementation detail of RootWaker. +class ArcWaker: public kj::AtomicRefcounted, + public kj::EnableAddRefToThis, + public CxxWaker { +public: + ArcWaker(kj::Own> fulfiller); + ~ArcWaker() noexcept(false); + KJ_DISALLOW_COPY_AND_MOVE(ArcWaker); + + const CxxWaker* clone() const override; + void wake() const override; + void wake_by_ref() const override; + void drop() const override; + +private: + kj::Own> fulfiller; +}; + +struct PromiseArcWakerPair { + kj::Promise promise; + kj::Arc waker; +}; + +// TODO(now): Doc comment. +PromiseArcWakerPair newPromiseAndArcWaker(const kj::Executor& executor); + +// ======================================================================================= +// RootWaker + +class FuturePollerBase; + +// RootWaker is the waker passed to Rust's `Future::poll()` function. RootWaker itself is not +// refcounted -- instead it is intended to live locally on the stack or in a coroutine frame, and +// trying to `clone()` it will cause it to allocate an ArcWaker for the caller. +// +// This class is mostly an implementation detail of our `co_await` operator implementation for Rust +// Futures. RootWaker exists in order to optimize the case where Rust async code awaits a KJ +// promise, in which case we can make the outer KJ coroutine wait more or less directly on the inner +// KJ promise which Rust owns. +class RootWaker: public CxxWaker { +public: + // Saves a reference to the FuturePoller which is using this RootWaker. The FuturePoller creates + // RootWakers on the stack in `await_ready()`, so its lifetime always encloses RootWakers. + explicit RootWaker(FuturePollerBase& futurePoller); + + // Create a new or clone an existing ArcWaker, leak its pointer, and return it. This may be called + // by any thread. + const CxxWaker* clone() const override; + + // Unimplemented, because Rust user code cannot consume the `std::task::Waker` we create which + // wraps this RootWaker. + void wake() const override; + + // Rust user code can wake us synchronously during the execution of `future.poll()` using this + // function. This may be called by any thread. + void wake_by_ref() const override; + + // Does not actually destroy this object. Instead, we increment a counter so we can assert that it + // was dropped exactly once before `future.poll()` returned. This can only be called on the thread + // which is doing the awaiting, because our implementation of `future.poll()` never transfers the + // Waker object to a different thread. + void drop() const override; + + // In addition to the above functions, Rust may invoke two more functions, `is_current()` and + // `wake_after()`, during `future.poll(awaitWaker)` execution. It uses this to implement a short- + // circuit optimization when it awaits a KJ promise. + + // True if the current thread's kj::Executor is the same as the RootWaker's. + bool is_current() const; + + // Called by RustPromiseAwaiter's constructor to get a reference to an Event which will call + // the current Future's `poll()` function. This is used to `.await` OwnPromiseNodes in Rust + // without having to clone an ArcWaker. + FuturePollerBase& getFuturePoller(); + + struct State { + // Number of times this Waker was synchronously woken during `future.poll(awaitWaker)`. + // Incremented by `wake_by_ref()`. + uint wakeCount = 0; + + // Filled in lazily by `clone()`. If `clone()` is never called, this will remain kj::none. + kj::Maybe cloned; + }; + + // Used by the owner of RootWaker after `future.poll()` has returned, to retrieve the + // RootWaker's state for further processing. This is non-const, because by the time this is + // called, Rust has dropped all of its borrows to this class, meaning we no longer have to worry + // about thread safety. + // + // This function will assert if `drop()` has not been called since RootWaker was constructed, or + // since the last call to `reset()`. + State reset(); + +private: + FuturePollerBase& futurePoller; + + // We store the kj::Executor for the constructing thread so that we can lazily instantiate a + // CrossThreadPromiseFulfiller from any thread in our `clone()` implementation. This also allows + // us to guarantee that `wake_after()` will only be called from the awaiting thread, allowing us + // to ignore thread-safety for the `wakeAfter` promise. + const kj::Executor& executor = kj::getCurrentThreadExecutor(); + + // Initialized by `clone()`, which may be called by any thread. + kj::MutexGuarded> cloned; + + // Incremented by `wake_by_ref()`, which may be called by any thread. All operations use relaxed + // memory order, because this counter doesn't guard any memory. + mutable std::atomic wakeCount { 0 }; + + // Incremented by `drop()`, so we can validate that `drop()` is only called once on this object. + // + // Rust requires that Wakers be droppable by any thread. However, we own the implementation of + // `poll()` to which `RootWaker&` is passed, and those implementations store the Rust + // `std::task::Waker` object on the stack,, and never move it elsewhere. Since that object is + // responsible for calling `RootWaker::drop()`, we know for sure that `drop()` will only ever be + // called on the thread which constructed it. Therefore, there is no need to make `dropCount` + // thread-safe. + mutable uint dropCount = 0; +}; + +} // namespace workerd::rust::async diff --git a/src/rust/async/waker.rs b/src/rust/async/waker.rs new file mode 100644 index 000000000000..1165d088bdb0 --- /dev/null +++ b/src/rust/async/waker.rs @@ -0,0 +1,87 @@ +use std::task::RawWaker; +use std::task::RawWakerVTable; +use std::task::Waker; + +use crate::ffi::CxxWaker; + +unsafe impl Send for CxxWaker {} +unsafe impl Sync for CxxWaker {} + +fn deref_cxx_waker<'a>(data: *const ()) -> Option<&'a CxxWaker> { + if !data.is_null() { + let p = data as *const CxxWaker; + Some(unsafe { &*p }) + } else { + None + } +} + +pub fn cxx_waker_clone(data: *const ()) -> RawWaker { + let new_data = if let Some(cxx_waker) = deref_cxx_waker(data) { + cxx_waker.clone() as *const () + } else { + std::ptr::null() as *const () + }; + RawWaker::new(new_data, &CXX_WAKER_VTABLE) +} + +pub fn cxx_waker_wake(data: *const ()) { + if let Some(cxx_waker) = deref_cxx_waker(data) { + cxx_waker.wake(); + } +} + +pub fn cxx_waker_wake_by_ref(data: *const ()) { + if let Some(cxx_waker) = deref_cxx_waker(data) { + cxx_waker.wake_by_ref(); + } +} + +pub fn cxx_waker_drop(data: *const ()) { + if let Some(cxx_waker) = deref_cxx_waker(data) { + cxx_waker.drop(); + } +} + +static CXX_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( + cxx_waker_clone, + cxx_waker_wake, + cxx_waker_wake_by_ref, + cxx_waker_drop, +); + +use crate::ffi::RootWaker; + +unsafe impl Send for RootWaker {} +unsafe impl Sync for RootWaker {} + +impl From<&RootWaker> for Waker { + fn from(waker: &RootWaker) -> Self { + let waker = RawWaker::new(waker as *const RootWaker as *const (), &ROOT_WAKER_VTABLE); + unsafe { Waker::from_raw(waker) } + } +} + +pub fn deref_root_waker<'a>(waker: &Waker) -> Option<&'a RootWaker> { + if waker.vtable() == &ROOT_WAKER_VTABLE { + let data = waker.data(); + assert!(!data.is_null()); + let p = data as *const RootWaker; + let root_waker = unsafe { &*p }; + + if root_waker.is_current() { + Some(root_waker) + } else { + None + } + } else { + None + } +} + +static ROOT_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( + cxx_waker_clone, + cxx_waker_wake, + cxx_waker_wake_by_ref, + cxx_waker_drop, +);