diff --git a/capnp-rpc/src/rpc.rs b/capnp-rpc/src/rpc.rs index 2a8a35882..772a961da 100644 --- a/capnp-rpc/src/rpc.rs +++ b/capnp-rpc/src/rpc.rs @@ -302,8 +302,8 @@ pub struct Import where VatId: 'static, { - // Becomes null when the import is destroyed. - import_client: Option<(Weak>>, usize)>, + // The usize is the numeric value of a raw pointer to the ImportClient. + import_client: (Weak>>, usize), // Either a copy of importClient, or, in the case of promises, the wrapping PromiseClient. // Becomes null when it is discarded *or* when the import is destroyed (e.g. the promise is @@ -315,9 +315,12 @@ where } impl Import { - fn new() -> Self { + fn new(import_client: &Rc>>) -> Self { Self { - import_client: None, + import_client: ( + Rc::downgrade(import_client), + &*import_client.borrow() as *const _ as _, + ), app_client: None, promise_client_to_resolve: None, } @@ -1463,22 +1466,18 @@ impl ConnectionState { let connection_state = state.clone(); let import_client = { - let slots = &mut state.imports.borrow_mut().slots; - let v = slots.entry(import_id).or_insert_with(Import::new); - if v.import_client.is_some() { - v.import_client - .as_ref() - .unwrap() + match state.imports.borrow_mut().slots.entry(import_id) { + std::collections::hash_map::Entry::Occupied(occ) => occ + .get() + .import_client .0 .upgrade() - .expect("dangling ref to import client?") - } else { - let import_client = ImportClient::new(&connection_state, import_id); - v.import_client = Some(( - Rc::downgrade(&import_client), - (&*import_client.borrow()) as *const _ as usize, - )); - import_client + .expect("dangling ref to import client?"), + std::collections::hash_map::Entry::Vacant(v) => { + let import_client = ImportClient::new(&connection_state, import_id); + v.insert(Import::new(&import_client)); + import_client + } } }; @@ -2605,10 +2604,9 @@ impl Drop for ImportClient { // Remove self from the import table, if the table is still pointing at us. let mut remove = false; if let Some(import) = connection_state.imports.borrow().slots.get(&self.import_id) { - if let Some((_, ptr)) = import.import_client { - if ptr == ((&*self) as *const _ as usize) { - remove = true; - } + let (_, ptr) = import.import_client; + if ptr == ((&*self) as *const _ as usize) { + remove = true; } }