Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement self-referential/subset key interning #633

Merged
merged 4 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions components/salsa-macro-rules/src/setup_interned_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,29 @@ macro_rules! setup_interned_struct {

/// Key to use during hash lookups. Each field is some type that implements `Lookup<T>`
/// for the owned type. This permits interning with an `&str` when a `String` is required and so forth.
struct StructKey<$db_lt, $($indexed_ty: $zalsa::interned::Lookup<$field_ty>),*>(
#[derive(Hash)]
struct StructKey<$db_lt, $($indexed_ty),*>(
$($indexed_ty,)*
std::marker::PhantomData<&$db_lt ()>,
);

impl<$db_lt, $($indexed_ty: $zalsa::interned::Lookup<$field_ty>),*> $zalsa::interned::Lookup<StructData<$db_lt>>
for StructKey<$db_lt, $($indexed_ty),*> {
impl<$db_lt, $($indexed_ty,)*> $zalsa::interned::HashEqLike<StructKey<$db_lt, $($indexed_ty),*>>
for StructData<$db_lt>
where
$($field_ty: $zalsa::interned::HashEqLike<$indexed_ty>),*
{

fn hash<H: std::hash::Hasher>(&self, h: &mut H) {
$($zalsa::interned::Lookup::hash(&self.$field_index, &mut *h);)*
$($zalsa::interned::HashEqLike::<$indexed_ty>::hash(&self.$field_index, &mut *h);)*
}

fn eq(&self, data: &StructData<$db_lt>) -> bool {
($($zalsa::interned::Lookup::eq(&self.$field_index, &data.$field_index) && )* true)
fn eq(&self, data: &StructKey<$db_lt, $($indexed_ty),*>) -> bool {
($($zalsa::interned::HashEqLike::<$indexed_ty>::eq(&self.$field_index, &data.$field_index) && )* true)
}
}

impl<$db_lt, $($indexed_ty: $zalsa::interned::Lookup<$field_ty>),*> $zalsa::interned::Lookup<StructData<$db_lt>>
for StructKey<$db_lt, $($indexed_ty),*> {

#[allow(unused_unit)]
fn into_owned(self) -> StructData<$db_lt> {
Expand Down Expand Up @@ -158,14 +166,17 @@ macro_rules! setup_interned_struct {
}

impl<$db_lt> $Struct<$db_lt> {
pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: impl $zalsa::interned::Lookup<$field_ty>),*) -> Self
pub fn $new_fn<$Db, $($indexed_ty: $zalsa::interned::Lookup<$field_ty> + std::hash::Hash,)*>(db: &$db_lt $Db, $($field_id: $indexed_ty),*) -> Self
where
// FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database`
$Db: ?Sized + salsa::Database,
$(
$field_ty: $zalsa::interned::HashEqLike<$indexed_ty>,
)*
{
let current_revision = $zalsa::current_revision(db);
$Configuration::ingredient(db).intern(db.as_dyn_database(),
StructKey::<$db_lt>($($field_id,)* std::marker::PhantomData::default()))
StructKey::<$db_lt>($($field_id,)* std::marker::PhantomData::default()), |_, data| ($($zalsa::interned::Lookup::into_owned(data.$field_index),)*))
}

$(
Expand Down
4 changes: 2 additions & 2 deletions components/salsa-macro-rules/src/setup_tracked_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ macro_rules! setup_tracked_fn {
use salsa::plumbing as $zalsa;
let key = $zalsa::macro_if! {
if $needs_interner {
$Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*))
$Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data)
} else {
$zalsa::AsId::as_id(&($($input_id),*))
}
Expand Down Expand Up @@ -285,7 +285,7 @@ macro_rules! setup_tracked_fn {
let result = $zalsa::macro_if! {
if $needs_interner {
{
let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*));
let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data);
$Configuration::fn_ingredient($db).fetch($db, key)
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ impl<C: Configuration> IngredientImpl<C> {
None
};

let id = zalsa_local.allocate(zalsa.table(), self.ingredient_index, || Value::<C> {
let id = zalsa_local.allocate(zalsa.table(), self.ingredient_index, |_| Value::<C> {
fields,
stamps,
memos: Default::default(),
Expand Down
170 changes: 96 additions & 74 deletions src/interned.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use dashmap::SharedValue;

use crate::accumulator::accumulated_map::InputAccumulatedValues;
use crate::durability::Durability;
use crate::id::AsId;
use crate::ingredient::fmt_index;
use crate::key::DependencyIndex;
use crate::plumbing::{Jar, JarAux};
Expand Down Expand Up @@ -120,20 +121,34 @@ where
unsafe { std::mem::transmute(data) }
}

pub fn intern_id<'db>(
pub fn intern<'db, Key>(
&'db self,
db: &'db dyn crate::Database,
data: impl Lookup<C::Data<'db>>,
) -> crate::Id {
C::deref_struct(self.intern(db, data)).as_id()
key: Key,
assemble: impl FnOnce(Id, Key) -> C::Data<'db>,
) -> C::Struct<'db>
where
Key: Hash,
C::Data<'db>: HashEqLike<Key>,
{
C::struct_from_id(self.intern_id(db, key, assemble))
}

/// Intern data to a unique reference.
pub fn intern<'db>(
pub fn intern_id<'db, Key>(
&'db self,
db: &'db dyn crate::Database,
data: impl Lookup<C::Data<'db>>,
) -> C::Struct<'db> {
key: Key,
assemble: impl FnOnce(Id, Key) -> C::Data<'db>,
) -> crate::Id
where
Key: Hash,
// We'd want the following predicate, but this currently implies `'static` due to a rustc
// bug
// for<'db> C::Data<'db>: HashEqLike<Key>,
// so instead we go with this and transmute the lifetime in the `eq` closure
C::Data<'db>: HashEqLike<Key>,
{
let zalsa_local = db.zalsa_local();
zalsa_local.report_tracked_read(
DependencyIndex::for_table(self.ingredient_index),
Expand All @@ -142,51 +157,53 @@ where
InputAccumulatedValues::Empty,
);

// Optimisation to only get read lock on the map if the data has already
// been interned.
// We need to use the raw API for this lookup. See the [`Lookup`][] trait definition for an explanation of why.
let data_hash = {
let mut hasher = self.key_map.hasher().build_hasher();
data.hash(&mut hasher);
hasher.finish()
// Optimization to only get read lock on the map if the data has already been interned.
let data_hash = self.key_map.hasher().hash_one(&key);
let shard = &self.key_map.shards()[self.key_map.determine_shard(data_hash as _)];
let eq = |(data, _): &_| {
// SAFETY: it's safe to go from Data<'static> to Data<'db>
// shrink lifetime here to use a single lifetime in Lookup::eq(&StructKey<'db>, &C::Data<'db>)
let data: &C::Data<'db> = unsafe { std::mem::transmute(data) };
HashEqLike::eq(data, &key)
};
let shard = self.key_map.determine_shard(data_hash as _);

{
let lock = self.key_map.shards()[shard].read();
if let Some(bucket) = lock.find(data_hash, |(a, _)| {
// SAFETY: it's safe to go from Data<'static> to Data<'db>
// shrink lifetime here to use a single lifetime in Lookup::eq(&StructKey<'db>, &C::Data<'db>)
let a: &C::Data<'db> = unsafe { std::mem::transmute(a) };
Lookup::eq(&data, a)
}) {
let lock = shard.read();
if let Some(bucket) = lock.find(data_hash, eq) {
// SAFETY: Read lock on map is held during this block
return C::struct_from_id(unsafe { *bucket.as_ref().1.get() });
return unsafe { *bucket.as_ref().1.get() };
}
};

let data = data.into_owned();

let internal_data = unsafe { self.to_internal_data(data) };
}

match self.key_map.entry(internal_data.clone()) {
let mut lock = shard.write();
match lock.find_or_find_insert_slot(data_hash, eq, |(element, _)| {
self.key_map.hasher().hash_one(element)
}) {
// Data has been interned by a racing call, use that ID instead
dashmap::mapref::entry::Entry::Occupied(entry) => {
let id = *entry.get();
drop(entry);
C::struct_from_id(id)
}

Ok(slot) => unsafe { *slot.as_ref().1.get() },
// We won any races so should intern the data
dashmap::mapref::entry::Entry::Vacant(entry) => {
Err(slot) => {
let zalsa = db.zalsa();
let table = zalsa.table();
let next_id = zalsa_local.allocate(table, self.ingredient_index, || Value::<C> {
data: internal_data,
let id = zalsa_local.allocate(table, self.ingredient_index, |id| Value::<C> {
data: unsafe { self.to_internal_data(assemble(id, key)) },
memos: Default::default(),
syncs: Default::default(),
});
entry.insert(next_id);
C::struct_from_id(next_id)
unsafe {
lock.insert_in_slot(
data_hash,
slot,
(table.get::<Value<C>>(id).data.clone(), SharedValue::new(id)),
)
};
debug_assert_eq!(
data_hash,
self.key_map
.hasher()
.hash_one(table.get::<Value<C>>(id).data.clone())
);
id
}
}
}
Expand Down Expand Up @@ -312,6 +329,10 @@ where
&self.syncs
}
}
pub trait HashEqLike<O> {
fn hash<H: Hasher>(&self, h: &mut H);
fn eq(&self, data: &O) -> bool;
}

/// The `Lookup` trait is a more flexible variant on [`std::borrow::Borrow`]
/// and [`std::borrow::ToOwned`].
Expand All @@ -328,12 +349,14 @@ where
/// requires that `&(K1...)` be convertible to `&ViewStruct` which just isn't
/// possible. `Lookup` instead offers direct `hash` and `eq` methods.
pub trait Lookup<O> {
fn hash<H: Hasher>(&self, h: &mut H);
fn eq(&self, data: &O) -> bool;
fn into_owned(self) -> O;
}

impl<T> Lookup<T> for T
impl<T> Lookup<T> for T {
fn into_owned(self) -> T {
self
}
}
impl<T> HashEqLike<T> for T
where
T: Hash + Eq,
{
Expand All @@ -344,30 +367,18 @@ where
fn eq(&self, data: &T) -> bool {
self == data
}

fn into_owned(self) -> T {
self
}
}

impl<T> Lookup<T> for &T
where
T: Clone + Eq + Hash,
T: Clone,
{
fn hash<H: Hasher>(&self, h: &mut H) {
Hash::hash(self, &mut *h);
}

fn eq(&self, data: &T) -> bool {
*self == data
}

fn into_owned(self) -> T {
Clone::clone(self)
}
}

impl<'a, T> Lookup<Box<T>> for &'a T
impl<'a, T> HashEqLike<Box<T>> for &'a T
where
T: ?Sized + Hash + Eq,
Box<T>: From<&'a T>,
Expand All @@ -378,64 +389,75 @@ where
fn eq(&self, data: &Box<T>) -> bool {
**self == **data
}
}

impl<'a, T> Lookup<Box<T>> for &'a T
where
T: ?Sized + Hash + Eq,
Box<T>: From<&'a T>,
{
fn into_owned(self) -> Box<T> {
Box::from(self)
}
}

impl Lookup<String> for &str {
fn into_owned(self) -> String {
self.to_owned()
}
}
impl HashEqLike<&str> for String {
fn hash<H: Hasher>(&self, h: &mut H) {
Hash::hash(self, &mut *h)
}

fn eq(&self, data: &String) -> bool {
self == data
}

fn into_owned(self) -> String {
self.to_owned()
fn eq(&self, data: &&str) -> bool {
self == *data
}
}

impl<A: Hash + Eq + PartialEq<T> + Clone + Lookup<T>, T> Lookup<Vec<T>> for &[A] {
impl<A, T: Hash + Eq + PartialEq<A>> HashEqLike<&[A]> for Vec<T> {
fn hash<H: Hasher>(&self, h: &mut H) {
Hash::hash(self, h);
}

fn eq(&self, data: &Vec<T>) -> bool {
fn eq(&self, data: &&[A]) -> bool {
self.len() == data.len() && data.iter().enumerate().all(|(i, a)| &self[i] == a)
}

}
impl<A: Hash + Eq + PartialEq<T> + Clone + Lookup<T>, T> Lookup<Vec<T>> for &[A] {
fn into_owned(self) -> Vec<T> {
self.iter().map(|a| Lookup::into_owned(a.clone())).collect()
}
}

impl<const N: usize, A: Hash + Eq + PartialEq<T> + Clone + Lookup<T>, T> Lookup<Vec<T>> for [A; N] {
impl<const N: usize, A, T: Hash + Eq + PartialEq<A>> HashEqLike<[A; N]> for Vec<T> {
fn hash<H: Hasher>(&self, h: &mut H) {
Hash::hash(self, h);
}

fn eq(&self, data: &Vec<T>) -> bool {
fn eq(&self, data: &[A; N]) -> bool {
self.len() == data.len() && data.iter().enumerate().all(|(i, a)| &self[i] == a)
}

}
impl<const N: usize, A: Hash + Eq + PartialEq<T> + Clone + Lookup<T>, T> Lookup<Vec<T>> for [A; N] {
fn into_owned(self) -> Vec<T> {
self.into_iter()
.map(|a| Lookup::into_owned(a.clone()))
.collect()
}
}

impl Lookup<PathBuf> for &Path {
impl HashEqLike<&Path> for PathBuf {
fn hash<H: Hasher>(&self, h: &mut H) {
Hash::hash(self, h);
}

fn eq(&self, data: &PathBuf) -> bool {
fn eq(&self, data: &&Path) -> bool {
self == data
}

}
impl Lookup<PathBuf> for &Path {
fn into_owned(self) -> PathBuf {
self.to_owned()
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ pub mod plumbing {

pub mod interned {
pub use crate::interned::Configuration;
pub use crate::interned::HashEqLike;
pub use crate::interned::IngredientImpl;
pub use crate::interned::JarImpl;
pub use crate::interned::Lookup;
Expand Down
Loading
Loading