diff --git a/src/rust/async/await.c++ b/src/rust/async/await.c++ index 5a701535227..d8ad6d5cc90 100644 --- a/src/rust/async/await.c++ +++ b/src/rust/async/await.c++ @@ -40,7 +40,7 @@ kj::Maybe> ArcWakerAwaiter::fire() { // We should only ever receive a WakeInstruction, never an exception. But if we do, propagate // it to the coroutine. KJ_IF_SOME(exception, result.exception) { - coAwaitWaker.internalReject({}, kj::mv(exception)); + coAwaitWaker.getFuturePoller().reject(kj::mv(exception)); return kj::none; } @@ -48,7 +48,7 @@ kj::Maybe> ArcWakerAwaiter::fire() { if (value == WakeInstruction::WAKE) { // This was an actual wakeup. - coAwaitWaker.armDepthFirst(); + coAwaitWaker.getFuturePoller().armDepthFirst(); } else { // All of our Wakers were dropped. We are awaiting the Rust equivalent of kj::NEVER_DONE. } @@ -58,9 +58,9 @@ kj::Maybe> ArcWakerAwaiter::fire() { void ArcWakerAwaiter::traceEvent(kj::_::TraceBuilder& builder) { if (coAwaitWaker.wouldTrace({}, *this)) { - // Our associated CoAwaitWaker's `traceEvent()` implementation would call our `tracePromise()` - // function. Just forward the call to CoAwaitWaker. - coAwaitWaker.traceEvent(builder); + // Our associated FuturePoller's `traceEvent()` implementation would call our `tracePromise()` + // function. Just forward the call to the FuturePoller. + coAwaitWaker.getFuturePoller().traceEvent(builder); } else { // Our CoAwaitWaker would choose a different branch to trace, so just record our own trace // address(es) and stop here. @@ -174,7 +174,7 @@ kj::Maybe> RustPromiseAwaiter::fire() { KJ_DEFER(setDone()); KJ_IF_SOME(coAwaitWaker, linkedGroup().tryGet()) { - coAwaitWaker.armDepthFirst(); + coAwaitWaker.getFuturePoller().armDepthFirst(); linkedGroup().set(kj::none); } else KJ_IF_SOME(waker, rustWaker) { // This call to `waker.wake()` consumes RustWaker's inner Waker. If we call it more than once, @@ -195,7 +195,7 @@ void RustPromiseAwaiter::traceEvent(kj::_::TraceBuilder& builder) { if (coAwaitWaker.wouldTrace({}, *this)) { // We are associated with a CoAwaitWaker, and CoAwaitWaker's `traceEvent()` implementation // would call our `tracePromise()` function. Just forward the call to CoAwaitWaker. - coAwaitWaker.traceEvent(builder); + coAwaitWaker.getFuturePoller().traceEvent(builder); return; } } @@ -292,16 +292,27 @@ void guarded_rust_promise_awaiter_drop_in_place(PtrGuardedRustPromiseAwaiter ptr // ======================================================================================= // CoAwaitWaker -CoAwaitWaker::CoAwaitWaker( - kj::_::ExceptionOrValue& resultRef, - kj::SourceLocation location) - : Event(location), - resultRef(resultRef) {} +void FuturePoller::reject(kj::Exception&& exception) { + maybeException = kj::mv(exception); + armDepthFirst(); +} + +bool FuturePoller::isWaiting() { + return maybeException == kj::none; +} -CoAwaitWaker::~CoAwaitWaker() noexcept(false) { - getCoroutine().clearPromiseNodeForTrace(); +void FuturePoller::throwIfRejected() { + KJ_IF_SOME(exception, maybeException) { + KJ_DEFER(maybeException = kj::none); + kj::throwFatalException(kj::mv(exception)); + } } +// ======================================================================================= +// CoAwaitWaker + +CoAwaitWaker::CoAwaitWaker(FuturePoller& futurePoller): futurePoller(futurePoller) {} + bool CoAwaitWaker::is_current() const { return &kjWaker.getExecutor() == &kj::getCurrentThreadExecutor(); } @@ -355,10 +366,8 @@ void CoAwaitWaker::tracePromise(kj::_::TraceBuilder& builder, bool stopAtNextEve } } -void CoAwaitWaker::traceEvent(kj::_::TraceBuilder& builder) { - // Just defer to our enclosing Coroutine. It will immediately call our `tracePromise` implementation. - auto& coroutine = getCoroutine(); - static_cast(coroutine).traceEvent(builder); +FuturePoller& CoAwaitWaker::getFuturePoller() { + return futurePoller; } bool CoAwaitWaker::wouldTrace(kj::Badge, ArcWakerAwaiter& awaiter) { @@ -382,11 +391,6 @@ bool CoAwaitWaker::wouldTrace(kj::Badge, RustPromiseAwaiter& return false; } -void CoAwaitWaker::internalReject(kj::Badge, kj::Exception exception) { - resultRef.addException(kj::mv(exception)); - getCoroutine().armDepthFirst(); -} - void CoAwaitWaker::awaitBegin() { auto state = kjWaker.reset(); @@ -394,7 +398,7 @@ void CoAwaitWaker::awaitBegin() { // The future returned Pending, but synchronously called `wake_by_ref()` on the KjWaker, // indicating it wants to immediately be polled again. We should arm our event right now, // which will call `await_ready()` again on the event loop. - armDepthFirst(); + futurePoller.armDepthFirst(); } else KJ_IF_SOME(promise, state.cloned) { // The future returned Pending and cloned an ArcWaker to notify us later. We'll arrange for // the ArcWaker's promise to arm our event once it's fulfilled. @@ -404,33 +408,6 @@ void CoAwaitWaker::awaitBegin() { // clone an ArcWaker. Rust is either awaiting a KJ promise, or the Rust equivalent of // kj::NEVER_DONE. } - - // Integrate with our enclosing coroutine's tracing. - getCoroutine().setPromiseNodeForTrace(self); -} - -void CoAwaitWaker::awaitEnd() { - getCoroutine().clearPromiseNodeForTrace(); -} - -void CoAwaitWaker::scheduleResumption() { - getCoroutine().armDepthFirst(); -} - -void CoAwaitWaker::setCoroutine(kj::_::CoroutineBase& coroutine) { - maybeCoroutine = coroutine; -} - -kj::_::CoroutineBase& CoAwaitWaker::getCoroutine() { - return KJ_ASSERT_NONNULL(maybeCoroutine, "CoroutineBase reference should be initialized"); -} - -BoxFutureVoidAwaiter operator co_await(BoxFutureVoid future) { - return BoxFutureVoidAwaiter{kj::mv(future)}; -} - -BoxFutureVoidAwaiter operator co_await(BoxFutureVoid& future) { - return BoxFutureVoidAwaiter{kj::mv(future)}; } } // namespace workerd::rust::async diff --git a/src/rust/async/await.h b/src/rust/async/await.h index 7dd8c90cf21..ee1a4ee6800 100644 --- a/src/rust/async/await.h +++ b/src/rust/async/await.h @@ -5,6 +5,8 @@ #include #include +#include + namespace workerd::rust::async { // TODO(perf): This is only an Event because we need to handle the case where all the Wakers are @@ -147,9 +149,40 @@ void guarded_rust_promise_awaiter_new_in_place( void guarded_rust_promise_awaiter_drop_in_place(PtrGuardedRustPromiseAwaiter); // ======================================================================================= -// CoAwaitWaker +// FuturePoller // Base class for the awaitable created by `co_await` when awaiting a Rust Future in a KJ coroutine. +class FuturePoller: public kj::_::Event, + public kj::PromiseRejector { +public: + using Event::Event; + + // Reject this Future with an exception. This is not an expected code path, and indicates a bug in + // our implementation. Currently it is only required because it is not possible to "disarm" + // CrossThreadPromiseFulfillers. + void reject(kj::Exception&& exception) override final; + + // True if we haven't been rejected yet. + bool isWaiting() override final; + +protected: + // Throw the saved exception if one exists. Helper to implement derived classes. + void throwIfRejected(); + +private: + // We use this only to reject the `co_await` with an exception if there's a bug in our usage of + // our ArcWaker. Specifically, if our ArcWaker promise is rejected, the exception ends up here. + // Since we never reject that promise, this is dead code. Or at least, it should be. + // + // TODO(cleanup): If we can make CrossThreadPromiseFulfillers disarmable -- i.e., convert the + // promise into an effectively NEVER_DONE promise -- then this dead code path can go away. + kj::Maybe maybeException; +}; + +// ======================================================================================= +// CoAwaitWaker + +// A CxxWaker implementation which provides an optimized path for awaiting KJ Promises in Rust. // // The PromiseNode base class is a hack to implement async tracing. That is, we only implement the // `tracePromise()` function, and, instead of calling `coroutine.awaitBegin(p)` with the Promise `p` @@ -159,15 +192,10 @@ void guarded_rust_promise_awaiter_drop_in_place(PtrGuardedRustPromiseAwaiter); // RustPromiseAwaiter LinkedObjects have independent lifetimes from the CoAwaitWaker, so we mustn't // leave references to them, or their members, lying around in the Coroutine class. class CoAwaitWaker: public CxxWaker, - public kj::_::Event, public kj::_::PromiseNode, public LinkedGroup { public: - // Initialize `next` with the enclosing coroutine's `Event`. - CoAwaitWaker( - kj::_::ExceptionOrValue& resultRef, - kj::SourceLocation location = {}); - ~CoAwaitWaker() noexcept(false); + CoAwaitWaker(FuturePoller& futurePoller); // True if the current thread's kj::Executor is the same as the one that was active when this // CoAwaitWaker was constructed. This allows Rust to optimize Promise `.await`s. @@ -183,12 +211,6 @@ class CoAwaitWaker: public CxxWaker, void wake_by_ref() const override; void drop() const override; - // ------------------------------------------------------- - // Event API - - void traceEvent(kj::_::TraceBuilder& builder) override; - // fire() implemented in derived class - // ------------------------------------------------------- // PromiseNode API // @@ -203,113 +225,130 @@ class CoAwaitWaker: public CxxWaker, // ------------------------------------------------------- // Other stuff - // True if `CoAwaitWaker::traceEvent()` would immediately call - // `awaiter.tracePromise()` + FuturePoller& getFuturePoller(); + + // True if our `tracePromise()` implementation would choose the given awaiter's promise for + // tracing. If our wrapped Future is awaiting multiple other Promises and/or Futures, our + // `tracePromise()` implementation might choose a different branch to go down. bool wouldTrace(kj::Badge, ArcWakerAwaiter& awaiter); bool wouldTrace(kj::Badge, RustPromiseAwaiter& awaiter); // TODO(now): Propagate value-or-exception. - // Reject the Future with an exception. Arms the enclosing coroutine's event. The event will - // resume the coroutine, which will then rethrow the exception from `await_resume()`. This is not - // an expected code path, and indicates a bug. - void internalReject(kj::Badge, kj::Exception exception); - -protected: - // API for derived class. + // TODO(now): Can we get rid of this API? void awaitBegin(); - void awaitEnd(); - void scheduleResumption(); // TODO(now): Rename to fulfill()? - - // Called from `await_suspend()`, which is the earliest we get access to the coroutine handle. - void setCoroutine(kj::_::CoroutineBase& coroutine); private: - // Helper to access `maybeCoroutine`, which is effectively always non-none. - kj::_::CoroutineBase& getCoroutine(); - - // The enclosing coroutine, which we will arm once our wrapped Future returns Ready, or an - // internal error occurs. - // - // This member is a Maybe because we don't have access to the coroutine until `await_suspend()` is - // called, which initializes this member by calling `setCoroutine()`. Since our derived classes' - // `await_ready()` implementations do nothing but immediately return false, we can assume that - // this Maybe is non-none effectively everywhere in the implementation of this class. - kj::Maybe maybeCoroutine; + FuturePoller& futurePoller; // This KjWaker is our actual implementation of the CxxWaker interface. We forward all calls here. KjWaker kjWaker; - // HACK: We implement the PromiseNode interface to integrate with the Coroutine class' current - // tracing implementation. - OwnPromiseNode self { this }; - - // Reference to a member of our derived class. We use this only to reject the `co_await` with an - // exception if an internal error occurs. What is an internal error? Any condition which causes - // `internalReject()` to be called. :) - // - // The referee is of course uninitialized in our constructor and destructor. - kj::_::ExceptionOrValue& resultRef; - + // TODO(now): Can this be moved into KjWaker? kj::Maybe arcWakerAwaiter; }; -class BoxFutureVoidAwaiter: public CoAwaitWaker { +template +class BoxFutureAwaiter final: public FuturePoller { public: - BoxFutureVoidAwaiter(BoxFutureVoid&& future, kj::SourceLocation location = {}) - : CoAwaitWaker(result), + BoxFutureAwaiter( + kj::_::CoroutineBase& coroutine, + BoxFuture future, + kj::SourceLocation location = {}) + : FuturePoller(location), + coroutine(coroutine), + coAwaitWaker(*this), future(kj::mv(future)) {} + ~BoxFutureAwaiter() noexcept(false) { + coroutine.clearPromiseNodeForTrace(); + } + KJ_DISALLOW_COPY_AND_MOVE(BoxFutureAwaiter); + + // Poll the wrapped Future, returning false if we should _not_ suspend, true if we should suspend. + bool awaitSuspendImpl() { + // TODO(perf): Check if we already have an ArcWaker from a previous suspension and give it to + // KjWaker for cloning if we have the last reference to it at this point. This could save + // memory allocations, but would depend on making XThreadFulfiller and XThreadPaf resettable + // to really benefit. + + if (future.poll(coAwaitWaker)) { + // Future is ready, we're done. + // TODO(now): Propagate value-or-exception. + return false; + } + + // TODO(now): Get rid of this? + coAwaitWaker.awaitBegin(); - bool await_ready() const { - return false; + // Integrate with our enclosing coroutine's tracing. + coroutine.setPromiseNodeForTrace(promiseNodeForTrace); + + return true; } - template requires (kj::canConvert()) - bool await_suspend(kj::_::stdcoro::coroutine_handle handle) { - setCoroutine(handle.promise()); - return awaitSuspendImpl(); + void awaitResumeImpl() { + coroutine.clearPromiseNodeForTrace(); + throwIfRejected(); } - // Unit futures return void. - void await_resume() { - awaitEnd(); + // ------------------------------------------------------- + // Event API - KJ_IF_SOME(exception, result.exception) { - kj::throwFatalException(kj::mv(exception)); - } + void traceEvent(kj::_::TraceBuilder& builder) override { + // Just defer to our enclosing Coroutine. It will immediately call our CoAwaitWaker's + // `tracePromise()` implementation. + static_cast(coroutine).traceEvent(builder); } +protected: kj::Maybe> fire() override { if (!awaitSuspendImpl()) { - scheduleResumption(); + coroutine.armDepthFirst(); } return kj::none; } private: - // Poll the wrapped Future, returning false if we should _not_ suspend, true if we should suspend. - bool awaitSuspendImpl() { - // TODO(perf): Check if we already have an ArcWaker from a previous suspension and give it to - // KjWaker for cloning if we have the last reference to it at this point. This could save - // memory allocations, but would depend on making XThreadFulfiller and XThreadPaf resettable - // to really benefit. + kj::_::CoroutineBase& coroutine; + CoAwaitWaker coAwaitWaker; + // HACK: CoAwaitWaker implements the PromiseNode interface to integrate with the Coroutine class' + // current tracing implementation. + OwnPromiseNode promiseNodeForTrace { &coAwaitWaker }; - if (future.poll(*this)) { - // Future is ready, we're done. - // TODO(now): Propagate value-or-exception. - return false; - } + BoxFuture future; +}; + +template +class LazyBoxFutureAwaiter { +public: + LazyBoxFutureAwaiter(BoxFuture&& future): impl(kj::mv(future)) {} - awaitBegin(); + bool await_ready() const { return false; } - return true; + template requires (kj::canConvert()) + bool await_suspend(kj::_::stdcoro::coroutine_handle handle) { + auto future = kj::mv(KJ_ASSERT_NONNULL(impl.template tryGet>())); + return impl.template init>(handle.promise(), kj::mv(future)) + .awaitSuspendImpl(); } - BoxFutureVoid future; - kj::_::ExceptionOr result; + // TODO(now): Return non-void T. + void await_resume() { + KJ_ASSERT_NONNULL(impl.template tryGet>()).awaitResumeImpl(); + } + +private: + // TODO(now): Comment. + kj::OneOf, BoxFutureAwaiter> impl; }; -BoxFutureVoidAwaiter operator co_await(BoxFutureVoid future); -BoxFutureVoidAwaiter operator co_await(BoxFutureVoid& future); +template +LazyBoxFutureAwaiter operator co_await(BoxFuture future) { + return kj::mv(future); +} +template +LazyBoxFutureAwaiter operator co_await(BoxFuture& future) { + return kj::mv(future); +} } // namespace workerd::rust::async