Skip to content

Commit

Permalink
Refactor in preparation for exception propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
harrishancock committed Jan 27, 2025
1 parent 11b6118 commit 8fb92fd
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 130 deletions.
79 changes: 32 additions & 47 deletions src/rust/async/await.c++
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ kj::Maybe<kj::Own<kj::_::Event>> ArcWakerAwaiter::fire() {
// We should only ever receive a WakeInstruction, never an exception. But if we do, propagate
// it to the coroutine.
KJ_IF_SOME(exception, result.exception) {
coAwaitWaker.internalReject({}, kj::mv(exception));
coAwaitWaker.getFuturePoller().reject(kj::mv(exception));
return kj::none;
}

auto value = KJ_ASSERT_NONNULL(result.value);

if (value == WakeInstruction::WAKE) {
// This was an actual wakeup.
coAwaitWaker.armDepthFirst();
coAwaitWaker.getFuturePoller().armDepthFirst();
} else {
// All of our Wakers were dropped. We are awaiting the Rust equivalent of kj::NEVER_DONE.
}
Expand All @@ -58,9 +58,9 @@ kj::Maybe<kj::Own<kj::_::Event>> ArcWakerAwaiter::fire() {

void ArcWakerAwaiter::traceEvent(kj::_::TraceBuilder& builder) {
if (coAwaitWaker.wouldTrace({}, *this)) {
// Our associated CoAwaitWaker's `traceEvent()` implementation would call our `tracePromise()`
// function. Just forward the call to CoAwaitWaker.
coAwaitWaker.traceEvent(builder);
// Our associated FuturePoller's `traceEvent()` implementation would call our `tracePromise()`
// function. Just forward the call to the FuturePoller.
coAwaitWaker.getFuturePoller().traceEvent(builder);
} else {
// Our CoAwaitWaker would choose a different branch to trace, so just record our own trace
// address(es) and stop here.
Expand Down Expand Up @@ -174,7 +174,7 @@ kj::Maybe<kj::Own<kj::_::Event>> RustPromiseAwaiter::fire() {
KJ_DEFER(setDone());

KJ_IF_SOME(coAwaitWaker, linkedGroup().tryGet()) {
coAwaitWaker.armDepthFirst();
coAwaitWaker.getFuturePoller().armDepthFirst();
linkedGroup().set(kj::none);
} else KJ_IF_SOME(waker, rustWaker) {
// This call to `waker.wake()` consumes RustWaker's inner Waker. If we call it more than once,
Expand All @@ -195,7 +195,7 @@ void RustPromiseAwaiter::traceEvent(kj::_::TraceBuilder& builder) {
if (coAwaitWaker.wouldTrace({}, *this)) {
// We are associated with a CoAwaitWaker, and CoAwaitWaker's `traceEvent()` implementation
// would call our `tracePromise()` function. Just forward the call to CoAwaitWaker.
coAwaitWaker.traceEvent(builder);
coAwaitWaker.getFuturePoller().traceEvent(builder);
return;
}
}
Expand Down Expand Up @@ -292,16 +292,27 @@ void guarded_rust_promise_awaiter_drop_in_place(PtrGuardedRustPromiseAwaiter ptr
// =======================================================================================
// CoAwaitWaker

CoAwaitWaker::CoAwaitWaker(
kj::_::ExceptionOrValue& resultRef,
kj::SourceLocation location)
: Event(location),
resultRef(resultRef) {}
void FuturePoller::reject(kj::Exception&& exception) {
maybeException = kj::mv(exception);
armDepthFirst();
}

CoAwaitWaker::~CoAwaitWaker() noexcept(false) {
getCoroutine().clearPromiseNodeForTrace();
bool FuturePoller::isWaiting() {
return maybeException == kj::none;
}

void FuturePoller::throwIfRejected() {
KJ_IF_SOME(exception, maybeException) {
KJ_DEFER(maybeException = kj::none);
kj::throwFatalException(kj::mv(exception));
}
}

// =======================================================================================
// CoAwaitWaker

CoAwaitWaker::CoAwaitWaker(FuturePoller& futurePoller): futurePoller(futurePoller) {}

bool CoAwaitWaker::is_current() const {
return &kjWaker.getExecutor() == &kj::getCurrentThreadExecutor();
}
Expand Down Expand Up @@ -355,10 +366,8 @@ void CoAwaitWaker::tracePromise(kj::_::TraceBuilder& builder, bool stopAtNextEve
}
}

void CoAwaitWaker::traceEvent(kj::_::TraceBuilder& builder) {
// Just defer to our enclosing Coroutine. It will immediately call our `tracePromise` implementation.
auto& coroutine = getCoroutine();
static_cast<Event&>(coroutine).traceEvent(builder);
FuturePoller& CoAwaitWaker::getFuturePoller() {
return futurePoller;
}

bool CoAwaitWaker::wouldTrace(kj::Badge<ArcWakerAwaiter>, ArcWakerAwaiter& awaiter) {
Expand All @@ -382,19 +391,14 @@ bool CoAwaitWaker::wouldTrace(kj::Badge<RustPromiseAwaiter>, RustPromiseAwaiter&
return false;
}

void CoAwaitWaker::internalReject(kj::Badge<ArcWakerAwaiter>, kj::Exception exception) {
resultRef.addException(kj::mv(exception));
getCoroutine().armDepthFirst();
}

void CoAwaitWaker::awaitBegin() {
auto state = kjWaker.reset();

if (state.wakeCount > 0) {
// The future returned Pending, but synchronously called `wake_by_ref()` on the KjWaker,
// indicating it wants to immediately be polled again. We should arm our event right now,
// which will call `await_ready()` again on the event loop.
armDepthFirst();
futurePoller.armDepthFirst();
} else KJ_IF_SOME(promise, state.cloned) {
// The future returned Pending and cloned an ArcWaker to notify us later. We'll arrange for
// the ArcWaker's promise to arm our event once it's fulfilled.
Expand All @@ -404,33 +408,14 @@ void CoAwaitWaker::awaitBegin() {
// clone an ArcWaker. Rust is either awaiting a KJ promise, or the Rust equivalent of
// kj::NEVER_DONE.
}

// Integrate with our enclosing coroutine's tracing.
getCoroutine().setPromiseNodeForTrace(self);
}

void CoAwaitWaker::awaitEnd() {
getCoroutine().clearPromiseNodeForTrace();
}

void CoAwaitWaker::scheduleResumption() {
getCoroutine().armDepthFirst();
}

void CoAwaitWaker::setCoroutine(kj::_::CoroutineBase& coroutine) {
maybeCoroutine = coroutine;
}

kj::_::CoroutineBase& CoAwaitWaker::getCoroutine() {
return KJ_ASSERT_NONNULL(maybeCoroutine, "CoroutineBase reference should be initialized");
}

BoxFutureVoidAwaiter operator co_await(BoxFutureVoid future) {
return BoxFutureVoidAwaiter{kj::mv(future)};
BoxFutureAwaiterVoid operator co_await(BoxFutureVoid future) {
return BoxFutureAwaiterVoid{kj::mv(future)};
}

BoxFutureVoidAwaiter operator co_await(BoxFutureVoid& future) {
return BoxFutureVoidAwaiter{kj::mv(future)};
BoxFutureAwaiterVoid operator co_await(BoxFutureVoid& future) {
return BoxFutureAwaiterVoid{kj::mv(future)};
}

} // namespace workerd::rust::async
Loading

0 comments on commit 8fb92fd

Please sign in to comment.