Skip to content

Commit

Permalink
Promise<T> can return T to Rust when awaited
Browse files Browse the repository at this point in the history
  • Loading branch information
harrishancock committed Jan 30, 2025
1 parent 044a80b commit 32a634f
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 24 deletions.
9 changes: 4 additions & 5 deletions src/rust/async/await_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ use crate::CxxResult;
// =======================================================================================
// GuardedRustPromiseAwaiter

use crate::ffi::guarded_rust_promise_awaiter_drop_in_place;
use crate::ffi::guarded_rust_promise_awaiter_new_in_place;

#[path = "await.h.rs"]
mod await_h;
pub use await_h::GuardedRustPromiseAwaiter;
Expand Down Expand Up @@ -48,7 +45,9 @@ impl Drop for GuardedRustPromiseAwaiter {
//
// https://doc.rust-lang.org/std/ptr/index.html#safety
unsafe {
guarded_rust_promise_awaiter_drop_in_place(PtrGuardedRustPromiseAwaiter(self));
crate::ffi::guarded_rust_promise_awaiter_drop_in_place(PtrGuardedRustPromiseAwaiter(
self,
));
}
}
}
Expand Down Expand Up @@ -170,7 +169,7 @@ impl PromiseAwaiter {
//
// https://doc.rust-lang.org/std/ptr/index.html#safety
awaiter.get_or_init(move |ptr: *mut GuardedRustPromiseAwaiter| unsafe {
guarded_rust_promise_awaiter_new_in_place(
crate::ffi::guarded_rust_promise_awaiter_new_in_place(
PtrGuardedRustPromiseAwaiter(ptr),
rust_waker_ptr,
node.expect("node should be Some in call to init()"),
Expand Down
10 changes: 10 additions & 0 deletions src/rust/async/cxx-bridge-test.c++
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,16 @@ KJ_TEST(".awaiting a Promise<T> from Rust can produce an Err Result") {
}().wait(waitScope);
}

KJ_TEST("Rust can await Promise<int32_t>") {
kj::EventLoop loop;
kj::WaitScope waitScope(loop);

[]() -> kj::Promise<void> {
kj::Maybe<kj::Exception> maybeException;
co_await new_awaiting_future_i32();
}().wait(waitScope);
}

// TODO(now): More test cases.
// - Standalone ArcWaker tests. Ensure Rust calls ArcWaker destructor when we expect.
// - Ensure Rust calls PromiseNode destructor from LazyRustPromiseAwaiter.
Expand Down
10 changes: 10 additions & 0 deletions src/rust/async/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use promise::PtrOwnPromiseNode;
use promise::PtrPromise;

mod test_futures;
use test_futures::new_awaiting_future_i32;
use test_futures::new_error_handling_future_void;
use test_futures::new_errored_future_fallible_void;
use test_futures::new_layered_ready_future_void;
Expand Down Expand Up @@ -146,6 +147,12 @@ mod ffi {
fn own_promise_node_unwrap_void(node: OwnPromiseNode) -> Result<()>;
unsafe fn promise_drop_in_place_void(promise: PtrPromiseVoid);
fn promise_into_own_promise_node_void(promise: PromiseVoid) -> OwnPromiseNode;

type PromiseI32 = crate::Promise<i32>;
type PtrPromiseI32 = crate::PtrPromise<i32>;
fn own_promise_node_unwrap_i32(node: OwnPromiseNode) -> Result<i32>;
unsafe fn promise_drop_in_place_i32(promise: PtrPromiseI32);
fn promise_into_own_promise_node_i32(promise: PromiseI32) -> OwnPromiseNode;
}
// -----------------------------------------------------
// Test functions
Expand All @@ -159,6 +166,7 @@ mod ffi {
fn new_coroutine_promise_void() -> PromiseVoid;

fn new_errored_promise_void() -> PromiseVoid;
fn new_ready_promise_i32(value: i32) -> PromiseI32;
}

enum CloningAction {
Expand Down Expand Up @@ -192,5 +200,7 @@ mod ffi {

fn new_errored_future_fallible_void() -> BoxFutureFallibleVoid;
fn new_error_handling_future_void() -> BoxFutureVoid;

fn new_awaiting_future_i32() -> BoxFutureVoid;
}
}
12 changes: 11 additions & 1 deletion src/rust/async/promise-boilerplate.c++
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace workerd::rust::async {
namespace {

template <typename T>
void unwrapNode(OwnPromiseNode node) {
T unwrapNode(OwnPromiseNode node) {
kj::_::ExceptionOr<kj::_::FixVoid<T>> result;

node->get(result);
Expand All @@ -31,4 +31,14 @@ void own_promise_node_unwrap_void(OwnPromiseNode node) {
return unwrapNode<void>(kj::mv(node));
}

OwnPromiseNode promise_into_own_promise_node_i32(PromiseI32 promise) {
return kj::_::PromiseNode::from(kj::mv(promise));
};
void promise_drop_in_place_i32(PtrPromiseI32 promise) {
kj::dtor(*promise);
}
int32_t own_promise_node_unwrap_i32(OwnPromiseNode node) {
return unwrapNode<int32_t>(kj::mv(node));
}

} // namespace workerd::rust::async
7 changes: 6 additions & 1 deletion src/rust/async/promise-boilerplate.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@ namespace workerd::rust::async {
// TODO(now): Generate boilerplate with a macro.
using PromiseVoid = Promise<void>;
using PtrPromiseVoid = PromiseVoid*;

void own_promise_node_unwrap_void(OwnPromiseNode);
void promise_drop_in_place_void(PtrPromiseVoid);
OwnPromiseNode promise_into_own_promise_node_void(PromiseVoid);

using PromiseI32 = Promise<int32_t>;
using PtrPromiseI32 = PromiseI32*;
int32_t own_promise_node_unwrap_i32(OwnPromiseNode);
void promise_drop_in_place_i32(PtrPromiseI32);
OwnPromiseNode promise_into_own_promise_node_i32(PromiseI32);

} // namespace workerd::rust::async
53 changes: 37 additions & 16 deletions src/rust/async/promise.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
use cxx::ExternType;

use crate::ffi::own_promise_node_drop_in_place;

use crate::CxxResult;

// TODO(now): Generate boilerplate with a macro.
use crate::ffi::own_promise_node_unwrap_void;
use crate::ffi::promise_drop_in_place_void;
use crate::ffi::promise_into_own_promise_node_void;

// The inner pointer is never read on Rust's side, so Rust thinks it's dead code.
#[allow(dead_code)]
pub struct OwnPromiseNode(*const ());
Expand All @@ -31,7 +24,7 @@ impl Drop for OwnPromiseNode {
//
// https://doc.rust-lang.org/std/ptr/index.html#safety
unsafe {
own_promise_node_drop_in_place(PtrOwnPromiseNode(self));
crate::ffi::own_promise_node_drop_in_place(PtrOwnPromiseNode(self));
}
}
}
Expand Down Expand Up @@ -80,32 +73,60 @@ impl<T: PromiseTarget> Drop for Promise<T> {
}
}

#[repr(transparent)]
pub struct PtrPromise<T: PromiseTarget>(*mut Promise<T>);

// =======================================================================================
// Boilerplate follows
//
// TODO(now): Generate boilerplate with a macro.

// TODO(now): Safety comment.
unsafe impl ExternType for Promise<()> {
type Id = cxx::type_id!("workerd::rust::async::PromiseVoid");
type Kind = cxx::kind::Trivial;
}

#[repr(transparent)]
pub struct PtrPromise<T: PromiseTarget>(*mut Promise<T>);

// TODO(now): Generate boilerplate with a macro.
// Safety: Raw pointers are the same size in both languages.
unsafe impl ExternType for PtrPromise<()> {
type Id = cxx::type_id!("workerd::rust::async::PtrPromiseVoid");
type Kind = cxx::kind::Trivial;
}

// TODO(now): Generate boilerplate with a macro.
impl PromiseTarget for () {
fn into_own_promise_node(promise: Promise<Self>) -> OwnPromiseNode {
promise_into_own_promise_node_void(promise)
crate::ffi::promise_into_own_promise_node_void(promise)
}
unsafe fn drop_in_place(ptr: PtrPromise<Self>) {
crate::ffi::promise_drop_in_place_void(ptr);
}
fn unwrap(node: OwnPromiseNode) -> std::result::Result<Self, cxx::Exception> {
crate::ffi::own_promise_node_unwrap_void(node)
}
}

// ---------------------------------------------------------

// TODO(now): Safety comment.
unsafe impl ExternType for Promise<i32> {
type Id = cxx::type_id!("workerd::rust::async::PromiseI32");
type Kind = cxx::kind::Trivial;
}

// Safety: Raw pointers are the same size in both languages.
unsafe impl ExternType for PtrPromise<i32> {
type Id = cxx::type_id!("workerd::rust::async::PtrPromiseI32");
type Kind = cxx::kind::Trivial;
}

impl PromiseTarget for i32 {
fn into_own_promise_node(promise: Promise<Self>) -> OwnPromiseNode {
crate::ffi::promise_into_own_promise_node_i32(promise)
}
unsafe fn drop_in_place(ptr: PtrPromise<Self>) {
promise_drop_in_place_void(ptr);
crate::ffi::promise_drop_in_place_i32(ptr);
}
fn unwrap(node: OwnPromiseNode) -> std::result::Result<Self, cxx::Exception> {
own_promise_node_unwrap_void(node)
crate::ffi::own_promise_node_unwrap_i32(node)
}
}
4 changes: 4 additions & 0 deletions src/rust/async/test-promises.c++
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ kj::Promise<void> new_ready_promise_void() {
return kj::Promise<void>(kj::READY_NOW);
}

kj::Promise<int32_t> new_ready_promise_i32(int32_t value) {
return kj::Promise<int32_t>(value);
}

kj::Promise<void> new_pending_promise_void() {
return kj::Promise<void>(kj::NEVER_DONE);
}
Expand Down
1 change: 1 addition & 0 deletions src/rust/async/test-promises.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ kj::Promise<void> new_pending_promise_void();
kj::Promise<void> new_coroutine_promise_void();

kj::Promise<void> new_errored_promise_void();
kj::Promise<int32_t> new_ready_promise_i32(int32_t);

} // namespace workerd::rust::async
9 changes: 8 additions & 1 deletion src/rust/async/test_futures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,14 @@ pub fn new_errored_future_fallible_void() -> BoxFuture<Result<()>> {

pub fn new_error_handling_future_void() -> BoxFuture<()> {
Box::pin(async {
let err = crate::ffi::new_errored_promise_void().await.expect_err("should see error");
let err = crate::ffi::new_errored_promise_void().await.expect_err("should throw");
assert!(err.what().contains("test error"));
}).into()
}

pub fn new_awaiting_future_i32() -> BoxFuture<()> {
Box::pin(async {
let value = crate::ffi::new_ready_promise_i32(123).await.expect("should not throw");
assert_eq!(value, 123);
}).into()
}

0 comments on commit 32a634f

Please sign in to comment.