diff --git a/Cargo.toml b/Cargo.toml index 69393bd8f..1e788cfb0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ rustc-hash = "2" salsa-macro-rules = { version = "0.1.0", path = "components/salsa-macro-rules" } salsa-macros = { path = "components/salsa-macros" } smallvec = "1" -lazy_static = "1" rayon = "1.10.0" [dev-dependencies] diff --git a/benches/compare.rs b/benches/compare.rs index ac15ac978..c885b54d9 100644 --- a/benches/compare.rs +++ b/benches/compare.rs @@ -1,4 +1,8 @@ -use codspeed_criterion_compat::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use std::mem::transmute; + +use codspeed_criterion_compat::{ + criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion, +}; use salsa::Setter; #[salsa::input] @@ -26,25 +30,30 @@ fn mutating_inputs(c: &mut Criterion) { codspeed_criterion_compat::measurement::WallTime, > = c.benchmark_group("Mutating Inputs"); - let mut db = salsa::DatabaseImpl::default(); - for n in &[10, 20, 30] { - let base_string = "hello, world!".to_owned(); - let base_len = base_string.len(); - - let string = base_string.clone().repeat(*n); - let new_len = string.len(); - group.bench_function(BenchmarkId::new("mutating", n), |b| { - b.iter(|| { - let input = Input::new(&db, base_string.clone()); - let actual_len = length(&db, input); - assert_eq!(base_len, actual_len); - - input.set_text(&mut db).to(string.clone()); - let actual_len = length(&db, input); - assert_eq!(new_len, actual_len); - }) + b.iter_batched_ref( + || { + let db = salsa::DatabaseImpl::default(); + let base_string = "hello, world!".to_owned(); + let base_len = base_string.len(); + + let string = base_string.clone().repeat(*n); + let new_len = string.len(); + + let input = Input::new(&db, base_string.clone()); + let actual_len = length(&db, input); + assert_eq!(base_len, actual_len); + + (db, input, string, new_len) + }, + |&mut (ref mut db, input, ref string, new_len)| { + input.set_text(db).to(string.clone()); + let actual_len = length(db, input); + assert_eq!(new_len, actual_len); + }, + BatchSize::SmallInput, + ) }); } @@ -56,34 +65,58 @@ fn inputs(c: &mut Criterion) { codspeed_criterion_compat::measurement::WallTime, > = c.benchmark_group("Mutating Inputs"); - let db = salsa::DatabaseImpl::default(); - group.bench_function(BenchmarkId::new("new", "InternedInput"), |b| { - b.iter(|| { - let input: InternedInput = InternedInput::new(&db, "hello, world!".to_owned()); - interned_length(&db, input); - }) + b.iter_batched_ref( + salsa::DatabaseImpl::default, + |db| { + let input: InternedInput = InternedInput::new(db, "hello, world!".to_owned()); + interned_length(db, input); + }, + BatchSize::SmallInput, + ) }); group.bench_function(BenchmarkId::new("amortized", "InternedInput"), |b| { - let input = InternedInput::new(&db, "hello, world!".to_owned()); - let _ = interned_length(&db, input); - - b.iter(|| interned_length(&db, input)); + b.iter_batched_ref( + || { + let db = salsa::DatabaseImpl::default(); + // we can't pass this along otherwise, and the lifetime is generally informational + let input: InternedInput<'static> = + unsafe { transmute(InternedInput::new(&db, "hello, world!".to_owned())) }; + let _ = interned_length(&db, input); + (db, input) + }, + |&mut (ref db, input)| { + interned_length(db, input); + }, + BatchSize::SmallInput, + ) }); group.bench_function(BenchmarkId::new("new", "Input"), |b| { - b.iter(|| { - let input = Input::new(&db, "hello, world!".to_owned()); - length(&db, input); - }) + b.iter_batched_ref( + salsa::DatabaseImpl::default, + |db| { + let input = Input::new(db, "hello, world!".to_owned()); + length(db, input); + }, + BatchSize::SmallInput, + ) }); group.bench_function(BenchmarkId::new("amortized", "Input"), |b| { - let input = Input::new(&db, "hello, world!".to_owned()); - let _ = length(&db, input); - - b.iter(|| length(&db, input)); + b.iter_batched_ref( + || { + let db = salsa::DatabaseImpl::default(); + let input = Input::new(&db, "hello, world!".to_owned()); + let _ = length(&db, input); + (db, input) + }, + |&mut (ref db, input)| { + length(db, input); + }, + BatchSize::SmallInput, + ) }); group.finish(); diff --git a/components/salsa-macro-rules/src/setup_input_struct.rs b/components/salsa-macro-rules/src/setup_input_struct.rs index 51cd482bc..b506b28b0 100644 --- a/components/salsa-macro-rules/src/setup_input_struct.rs +++ b/components/salsa-macro-rules/src/setup_input_struct.rs @@ -124,6 +124,9 @@ macro_rules! setup_input_struct { } impl $zalsa::SalsaStructInDb for $Struct { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } impl $Struct { diff --git a/components/salsa-macro-rules/src/setup_interned_struct.rs b/components/salsa-macro-rules/src/setup_interned_struct.rs index bf9d98f53..f3eeb83c3 100644 --- a/components/salsa-macro-rules/src/setup_interned_struct.rs +++ b/components/salsa-macro-rules/src/setup_interned_struct.rs @@ -69,21 +69,29 @@ macro_rules! setup_interned_struct { /// Key to use during hash lookups. Each field is some type that implements `Lookup` /// for the owned type. This permits interning with an `&str` when a `String` is required and so forth. - struct StructKey<$db_lt, $($indexed_ty: $zalsa::interned::Lookup<$field_ty>),*>( + #[derive(Hash)] + struct StructKey<$db_lt, $($indexed_ty),*>( $($indexed_ty,)* std::marker::PhantomData<&$db_lt ()>, ); - impl<$db_lt, $($indexed_ty: $zalsa::interned::Lookup<$field_ty>),*> $zalsa::interned::Lookup> - for StructKey<$db_lt, $($indexed_ty),*> { + impl<$db_lt, $($indexed_ty,)*> $zalsa::interned::HashEqLike> + for StructData<$db_lt> + where + $($field_ty: $zalsa::interned::HashEqLike<$indexed_ty>),* + { fn hash(&self, h: &mut H) { - $($zalsa::interned::Lookup::hash(&self.$field_index, &mut *h);)* + $($zalsa::interned::HashEqLike::<$indexed_ty>::hash(&self.$field_index, &mut *h);)* } - fn eq(&self, data: &StructData<$db_lt>) -> bool { - ($($zalsa::interned::Lookup::eq(&self.$field_index, &data.$field_index) && )* true) + fn eq(&self, data: &StructKey<$db_lt, $($indexed_ty),*>) -> bool { + ($($zalsa::interned::HashEqLike::<$indexed_ty>::eq(&self.$field_index, &data.$field_index) && )* true) } + } + + impl<$db_lt, $($indexed_ty: $zalsa::interned::Lookup<$field_ty>),*> $zalsa::interned::Lookup> + for StructKey<$db_lt, $($indexed_ty),*> { #[allow(unused_unit)] fn into_owned(self) -> StructData<$db_lt> { @@ -141,6 +149,9 @@ macro_rules! setup_interned_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } unsafe impl $zalsa::Update for $Struct<'_> { @@ -155,14 +166,17 @@ macro_rules! setup_interned_struct { } impl<$db_lt> $Struct<$db_lt> { - pub fn $new_fn<$Db>(db: &$db_lt $Db, $($field_id: impl $zalsa::interned::Lookup<$field_ty>),*) -> Self + pub fn $new_fn<$Db, $($indexed_ty: $zalsa::interned::Lookup<$field_ty> + std::hash::Hash,)*>(db: &$db_lt $Db, $($field_id: $indexed_ty),*) -> Self where // FIXME(rust-lang/rust#65991): The `db` argument *should* have the type `dyn Database` $Db: ?Sized + salsa::Database, + $( + $field_ty: $zalsa::interned::HashEqLike<$indexed_ty>, + )* { let current_revision = $zalsa::current_revision(db); $Configuration::ingredient(db).intern(db.as_dyn_database(), - StructKey::<$db_lt>($($field_id,)* std::marker::PhantomData::default())) + StructKey::<$db_lt>($($field_id,)* std::marker::PhantomData::default()), |_, data| ($($zalsa::interned::Lookup::into_owned(data.$field_index),)*)) } $( diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 72a38039c..6bb1a698a 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -102,6 +102,9 @@ macro_rules! setup_tracked_fn { $zalsa::IngredientCache::new(); impl $zalsa::SalsaStructInDb for $InternedData<'_> { + fn lookup_ingredient_index(_aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + None + } } impl $zalsa::interned::Configuration for $Configuration { @@ -207,7 +210,19 @@ macro_rules! setup_tracked_fn { aux: &dyn $zalsa::JarAux, first_index: $zalsa::IngredientIndex, ) -> Vec> { + let struct_index = $zalsa::macro_if! { + if $needs_interner { + first_index.successor(0) + } else { + <$InternedData as $zalsa::SalsaStructInDb>::lookup_ingredient_index(aux) + .expect( + "Salsa struct is passed as an argument of a tracked function, but its ingredient hasn't been added!" + ) + } + }; + let fn_ingredient = <$zalsa::function::IngredientImpl<$Configuration>>::new( + struct_index, first_index, aux, ); @@ -227,6 +242,10 @@ macro_rules! setup_tracked_fn { } } } + + fn salsa_struct_type_id(&self) -> Option { + None + } } #[allow(non_local_definitions)] @@ -238,7 +257,7 @@ macro_rules! setup_tracked_fn { use salsa::plumbing as $zalsa; let key = $zalsa::macro_if! { if $needs_interner { - $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)) + $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data) } else { $zalsa::AsId::as_id(&($($input_id),*)) } @@ -274,7 +293,7 @@ macro_rules! setup_tracked_fn { let result = $zalsa::macro_if! { if $needs_interner { { - let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*)); + let key = $Configuration::intern_ingredient($db).intern_id($db.as_dyn_database(), ($($input_id),*), |_, data| data); $Configuration::fn_ingredient($db).fetch($db, key) } } else { diff --git a/components/salsa-macro-rules/src/setup_tracked_struct.rs b/components/salsa-macro-rules/src/setup_tracked_struct.rs index d0d42c6da..a783e3762 100644 --- a/components/salsa-macro-rules/src/setup_tracked_struct.rs +++ b/components/salsa-macro-rules/src/setup_tracked_struct.rs @@ -152,6 +152,9 @@ macro_rules! setup_tracked_struct { } impl $zalsa::SalsaStructInDb for $Struct<'_> { + fn lookup_ingredient_index(aux: &dyn $zalsa::JarAux) -> core::option::Option<$zalsa::IngredientIndex> { + aux.lookup_jar_by_type(&<$zalsa_struct::JarImpl<$Configuration>>::default()) + } } impl $zalsa::TrackedStructInDb for $Struct<'_> { diff --git a/src/accumulator.rs b/src/accumulator.rs index 5fd9f6cc6..c6072341d 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -54,6 +54,10 @@ impl Jar for JarImpl { ) -> Vec> { vec![Box::new(>::new(first_index))] } + + fn salsa_struct_type_id(&self) -> Option { + None + } } pub struct IngredientImpl { @@ -62,7 +66,7 @@ pub struct IngredientImpl { } impl IngredientImpl { - /// Find the accumulator ingrediate for `A` in the database, if any. + /// Find the accumulator ingredient for `A` in the database, if any. pub fn from_db(db: &Db) -> Option<&Self> where Db: ?Sized + Database, @@ -101,7 +105,7 @@ impl Ingredient for IngredientImpl { fn maybe_changed_after( &self, _db: &dyn Database, - _input: Option, + _input: Id, _revision: Revision, ) -> VerifyResult { panic!("nothing should ever depend on an accumulator directly") @@ -123,7 +127,7 @@ impl Ingredient for IngredientImpl { &self, _db: &dyn Database, _executor: DatabaseKeyIndex, - _output_key: Option, + _output_key: crate::Id, ) { } @@ -131,7 +135,7 @@ impl Ingredient for IngredientImpl { &self, _db: &dyn Database, _executor: DatabaseKeyIndex, - _stale_output_key: Option, + _stale_output_key: crate::Id, ) { } diff --git a/src/active_query.rs b/src/active_query.rs index 1178620d0..8dcf7bd1f 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -1,14 +1,15 @@ use rustc_hash::{FxHashMap, FxHashSet}; -use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions}; +use super::zalsa_local::{QueryEdges, QueryOrigin, QueryRevisions}; +use crate::key::OutputDependencyIndex; use crate::tracked_struct::IdentityHash; +use crate::zalsa_local::QueryEdge; use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, durability::Durability, hash::FxIndexSet, - key::{DatabaseKeyIndex, DependencyIndex}, + key::{DatabaseKeyIndex, InputDependencyIndex}, tracked_struct::{Disambiguator, Identity}, - zalsa_local::EMPTY_DEPENDENCIES, Id, Revision, }; @@ -30,7 +31,7 @@ pub(crate) struct ActiveQuery { /// * tracked structs created /// * invocations of `specify` /// * accumulators pushed to - input_outputs: FxIndexSet<(EdgeKind, DependencyIndex)>, + input_outputs: FxIndexSet, /// True if there was an untracked read. untracked_read: bool, @@ -73,13 +74,13 @@ impl ActiveQuery { pub(super) fn add_read( &mut self, - input: DependencyIndex, + input: InputDependencyIndex, durability: Durability, revision: Revision, accumulated: InputAccumulatedValues, cycle_heads: Option<&FxHashSet>, ) { - self.input_outputs.insert((EdgeKind::Input, input)); + self.input_outputs.insert(QueryEdge::Input(input)); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(revision); self.accumulated.add_input(accumulated); @@ -101,24 +102,17 @@ impl ActiveQuery { } /// Adds a key to our list of outputs. - pub(super) fn add_output(&mut self, key: DependencyIndex) { - self.input_outputs.insert((EdgeKind::Output, key)); + pub(super) fn add_output(&mut self, key: OutputDependencyIndex) { + self.input_outputs.insert(QueryEdge::Output(key)); } /// True if the given key was output by this query. - pub(super) fn is_output(&self, key: DependencyIndex) -> bool { - self.input_outputs.contains(&(EdgeKind::Output, key)) + pub(super) fn is_output(&self, key: OutputDependencyIndex) -> bool { + self.input_outputs.contains(&QueryEdge::Output(key)) } pub(crate) fn into_revisions(self) -> QueryRevisions { - let input_outputs = if self.input_outputs.is_empty() { - EMPTY_DEPENDENCIES.clone() - } else { - self.input_outputs.into_iter().collect() - }; - - let edges = QueryEdges::new(input_outputs); - + let edges = QueryEdges::new(self.input_outputs); let origin = if self.untracked_read { QueryOrigin::DerivedUntracked(edges) } else { diff --git a/src/database.rs b/src/database.rs index a978df0e9..0c6867f98 100644 --- a/src/database.rs +++ b/src/database.rs @@ -54,6 +54,24 @@ pub trait Database: Send + AsDynDatabase + Any + ZalsaDatabase { ) } + /// Starts unwinding the stack if the current revision is cancelled. + /// + /// This method can be called by query implementations that perform + /// potentially expensive computations, in order to speed up propagation of + /// cancellation. + /// + /// Cancellation will automatically be triggered by salsa on any query + /// invocation. + /// + /// This method should not be overridden by `Database` implementors. A + /// `salsa_event` is emitted when this method is called, so that should be + /// used instead. + fn unwind_if_revision_cancelled(&self) { + let db = self.as_dyn_database(); + let zalsa_local = db.zalsa_local(); + zalsa_local.unwind_if_revision_cancelled(db); + } + /// Execute `op` with the database in thread-local storage for debug print-outs. fn attach(&self, op: impl FnOnce(&Self) -> R) -> R where diff --git a/src/event.rs b/src/event.rs index 8eb1ee605..ac413647e 100644 --- a/src/event.rs +++ b/src/event.rs @@ -1,6 +1,9 @@ use std::thread::ThreadId; -use crate::{key::DatabaseKeyIndex, key::DependencyIndex}; +use crate::{ + key::DatabaseKeyIndex, + key::{InputDependencyIndex, OutputDependencyIndex}, +}; /// The `Event` struct identifies various notable things that can /// occur during salsa execution. Instances of this struct are given @@ -14,6 +17,15 @@ pub struct Event { pub kind: EventKind, } +impl Event { + pub fn new(kind: EventKind) -> Self { + Self { + thread_id: std::thread::current().id(), + kind, + } + } +} + /// An enum identifying the various kinds of events that can occur. #[derive(Debug)] pub enum EventKind { @@ -64,7 +76,7 @@ pub enum EventKind { execute_key: DatabaseKeyIndex, /// Key for the query that is no longer output - output_key: DependencyIndex, + output_key: OutputDependencyIndex, }, /// Tracked structs or memoized data were discarded (freed). @@ -79,6 +91,6 @@ pub enum EventKind { executor_key: DatabaseKeyIndex, /// Accumulator that was accumulated into - accumulator: DependencyIndex, + accumulator: InputDependencyIndex, }, } diff --git a/src/function.rs b/src/function.rs index 013024ce7..11ef3779f 100644 --- a/src/function.rs +++ b/src/function.rs @@ -128,10 +128,10 @@ impl IngredientImpl where C: Configuration, { - pub fn new(index: IngredientIndex, aux: &dyn JarAux) -> Self { + pub fn new(struct_index: IngredientIndex, index: IngredientIndex, aux: &dyn JarAux) -> Self { Self { index, - memo_ingredient_index: aux.next_memo_ingredient_index(index), + memo_ingredient_index: aux.next_memo_ingredient_index(struct_index, index), lru: Default::default(), deleted_entries: Default::default(), } @@ -194,12 +194,11 @@ where fn maybe_changed_after( &self, db: &dyn Database, - input: Option, + input: Id, revision: Revision, ) -> VerifyResult { - let key = input.unwrap(); let db = db.as_view::(); - self.maybe_changed_after(db, key, revision) + self.maybe_changed_after(db, input, revision) } fn is_verified_final<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { @@ -219,9 +218,8 @@ where &self, db: &dyn Database, executor: DatabaseKeyIndex, - output_key: Option, + output_key: crate::Id, ) { - let output_key = output_key.unwrap(); self.validate_specified_value(db, executor, output_key); } @@ -229,7 +227,7 @@ where &self, _db: &dyn Database, _executor: DatabaseKeyIndex, - _stale_output_key: Option, + _stale_output_key: crate::Id, ) { // This function is invoked when a query Q specifies the value for `stale_output_key` in rev 1, // but not in rev 2. We don't do anything in this case, we just leave the (now stale) memo. diff --git a/src/function/accumulated.rs b/src/function/accumulated.rs index d8b0fcfc9..588e42d1e 100644 --- a/src/function/accumulated.rs +++ b/src/function/accumulated.rs @@ -53,8 +53,9 @@ where continue; } + let ingredient = zalsa.lookup_ingredient(k.ingredient_index); // Extend `output` with any values accumulated by `k`. - if let Some(accumulated_map) = k.accumulated(db) { + if let Some(accumulated_map) = ingredient.accumulated(db, k.key_index) { accumulated_map.extend_with_accumulated(accumulator.index(), &mut output); // Skip over the inputs because we know that the entire sub-graph has no accumulated values @@ -69,10 +70,7 @@ where // output vector, we want to push in execution order, so reverse order to // ensure the first child that was executed will be the first child popped // from the stack. - let Some(origin) = zalsa - .lookup_ingredient(k.ingredient_index) - .origin(db, k.key_index) - else { + let Some(origin) = ingredient.origin(db, k.key_index) else { continue; }; diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index d89f0d52f..a727d7720 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -1,6 +1,6 @@ use super::{memo::Memo, Configuration, IngredientImpl}; use crate::{ - hash::FxHashSet, key::DependencyIndex, zalsa_local::QueryRevisions, AsDynDatabase as _, + hash::FxHashSet, key::OutputDependencyIndex, zalsa_local::QueryRevisions, AsDynDatabase as _, DatabaseKeyIndex, Event, EventKind, }; @@ -32,11 +32,8 @@ where if !old_outputs.is_empty() { // Remove the outputs that are no longer present in the current revision // to prevent that the next revision is seeded with a id mapping that no longer exists. - revisions.tracked_struct_ids.retain(|k, value| { - !old_outputs.contains(&DependencyIndex { - ingredient_index: k.ingredient_index(), - key_index: Some(*value), - }) + revisions.tracked_struct_ids.retain(|&k, &mut value| { + !old_outputs.contains(&OutputDependencyIndex::new(k.ingredient_index(), value)) }); } @@ -45,15 +42,14 @@ where } } - fn report_stale_output(db: &C::DbView, key: DatabaseKeyIndex, output: DependencyIndex) { + fn report_stale_output(db: &C::DbView, key: DatabaseKeyIndex, output: OutputDependencyIndex) { let db = db.as_dyn_database(); - db.salsa_event(&|| Event { - thread_id: std::thread::current().id(), - kind: EventKind::WillDiscardStaleOutput { + db.salsa_event(&|| { + Event::new(EventKind::WillDiscardStaleOutput { execute_key: key, output_key: output, - }, + }) }); output.remove_stale_output(db, key); diff --git a/src/function/execute.rs b/src/function/execute.rs index adbd74318..37fa2fcdd 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -31,11 +31,10 @@ where tracing::info!("{:?}: executing query", database_key_index); - db.salsa_event(&|| Event { - thread_id: std::thread::current().id(), - kind: EventKind::WillExecute { + db.salsa_event(&|| { + Event::new(EventKind::WillExecute { database_key: database_key_index, - }, + }) }); let mut iteration_count: u32 = 0; diff --git a/src/function/fetch.rs b/src/function/fetch.rs index bc4295a65..5b322051c 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -56,10 +56,9 @@ where if memo.value.is_some() && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo, false) { - // Unsafety invariant: memo is present in memo_map - unsafe { - return Some(self.extend_memo_lifetime(memo)); - } + // Unsafety invariant: memo is present in memo_map and we have verified that it is + // still valid for the current revision. + return unsafe { Some(self.extend_memo_lifetime(memo)) }; } } None @@ -137,10 +136,9 @@ where self.deep_verify_memo(db, old_memo, &active_query) { if cycle_heads.is_empty() { - // Unsafety invariant: memo is present in memo_map. - unsafe { - return Some(self.extend_memo_lifetime(old_memo)); - } + // Unsafety invariant: memo is present in memo_map and we have verified that it is + // still valid for the current revision. + return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; } } } diff --git a/src/function/input_outputs.rs b/src/function/input_outputs.rs deleted file mode 100644 index dab66fdf0..000000000 --- a/src/function/input_outputs.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::{ - accumulator::accumulated_map::AccumulatedMap, zalsa::Zalsa, zalsa_local::QueryOrigin, Id, -}; - -use super::{Configuration, IngredientImpl}; - -impl IngredientImpl -where - C: Configuration, -{ - pub(super) fn origin(&self, zalsa: &Zalsa, key: Id) -> Option { - self.get_memo_from_table_for(zalsa, key) - .map(|m| m.revisions.origin.clone()) - } - - pub(super) fn accumulated(&self, zalsa: &Zalsa, key: Id) -> Option<&AccumulatedMap> { - // NEXT STEP: stash and refactor `fetch` to return an `&Memo` so we can make this work - self.get_memo_from_table_for(zalsa, key) - .map(|m| &m.revisions.accumulated) - } -} diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index d26a25819..c93492eb1 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -3,7 +3,7 @@ use crate::{ key::DatabaseKeyIndex, table::sync::ClaimResult, zalsa::{Zalsa, ZalsaDatabase}, - zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin}, + zalsa_local::{ActiveQueryGuard, QueryEdge, QueryOrigin}, AsDynDatabase as _, Id, Revision, }; use rustc_hash::FxHashSet; @@ -266,9 +266,9 @@ where // valid, then some later input I1 might never have executed at all, so verifying // it is still up to date is meaningless. let last_verified_at = old_memo.verified_at.load(); - for &(edge_kind, dependency_index) in edges.input_outputs.iter() { - match edge_kind { - EdgeKind::Input => { + for &edge in edges.input_outputs.iter() { + match edge { + QueryEdge::Input(dependency_index) => { match dependency_index .maybe_changed_after(db.as_dyn_database(), last_verified_at) { @@ -276,7 +276,7 @@ where VerifyResult::Unchanged(cycles) => cycle_heads.extend(cycles), } } - EdgeKind::Output => { + QueryEdge::Output(dependency_index) => { // Subtle: Mark outputs as validated now, even though we may // later find an input that requires us to re-execute the function. // Even if it re-execute, the function will wind up writing the same value, diff --git a/src/function/memo.rs b/src/function/memo.rs index b9c2b3ef5..b4bfba901 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -62,31 +62,47 @@ impl IngredientImpl { /// with an equivalent memo that has no value. If the memo is untracked, BaseInput, /// or has values assigned as output of another query, this has no effect. pub(super) fn evict_value_from_memo_for<'db>(&'db self, zalsa: &'db Zalsa, id: Id) { - let Some(memo) = self.get_memo_from_table_for(zalsa, id) else { - return; - }; - - match memo.revisions.origin { - QueryOrigin::Assigned(_) - | QueryOrigin::DerivedUntracked(_) - | QueryOrigin::BaseInput - | QueryOrigin::FixpointInitial => { - // Careful: Cannot evict memos whose values were - // assigned as output of another query - // or those with untracked inputs - // as their values cannot be reconstructed. - } - - QueryOrigin::Derived(_) => { - let memo_evicted = Arc::new(Memo::new( - None::>, - memo.verified_at.load(), - memo.revisions.clone(), - )); - - self.insert_memo_into_table_for(zalsa, id, memo_evicted); - } - } + zalsa.memo_table_for(id).map_memo::>>( + self.memo_ingredient_index, + |memo| { + match memo.revisions.origin { + QueryOrigin::Assigned(_) + | QueryOrigin::DerivedUntracked(_) + | QueryOrigin::BaseInput + | QueryOrigin::FixpointInitial => { + // Careful: Cannot evict memos whose values were + // assigned as output of another query + // or those with untracked inputs + // as their values cannot be reconstructed. + memo + } + QueryOrigin::Derived(_) => { + // QueryRevisions: !Clone to discourage cloning, we need it here though + let QueryRevisions { + changed_at, + durability, + origin, + tracked_struct_ids, + accumulated, + cycle_heads, + } = &memo.revisions; + // Re-assemble the memo but with the value set to `None` + Arc::new(Memo::new( + None, + memo.verified_at.load(), + QueryRevisions { + changed_at: *changed_at, + durability: *durability, + origin: origin.clone(), + tracked_struct_ids: tracked_struct_ids.clone(), + accumulated: accumulated.clone(), + cycle_heads: cycle_heads.clone(), + }, + )) + } + } + }, + ); } pub(super) fn initial_value<'db>( @@ -162,11 +178,10 @@ impl Memo { revision_now: Revision, database_key_index: DatabaseKeyIndex, ) { - db.salsa_event(&|| Event { - thread_id: std::thread::current().id(), - kind: EventKind::DidValidateMemoizedValue { + db.salsa_event(&|| { + Event::new(EventKind::DidValidateMemoizedValue { database_key: database_key_index, - }, + }) }); self.verified_at.store(revision_now); diff --git a/src/ingredient.rs b/src/ingredient.rs index d705df8b8..bf122282e 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -24,10 +24,33 @@ pub trait Jar: Any { aux: &dyn JarAux, first_index: IngredientIndex, ) -> Vec>; + + /// If this jar's first ingredient is a salsa struct, return its `TypeId` + fn salsa_struct_type_id(&self) -> Option; } +/// Methods on the Salsa database available to jars while they are creating their ingredients. pub trait JarAux { - fn next_memo_ingredient_index(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex; + /// Return index of first ingredient from `jar` (based on the dynamic type of `jar`). + /// Returns `None` if the jar has not yet been added. + /// Used by tracked functions to lookup the ingredient index for the salsa struct they take as argument. + fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option; + + /// Returns the memo ingredient index that should be used to attach data from the given tracked function + /// to the given salsa struct (which the fn accepts as argument). + /// + /// The memo ingredient indices for a given function must be distinct from the memo indices + /// of all other functions that take the same salsa struct. + /// + /// # Parameters + /// + /// * `struct_ingredient_index`, the index of the salsa struct the memo will be attached to + /// * `ingredient_index`, the index of the tracked function whose data is stored in the memo + fn next_memo_ingredient_index( + &self, + struct_ingredient_index: IngredientIndex, + ingredient_index: IngredientIndex, + ) -> MemoIngredientIndex; } pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { @@ -37,7 +60,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { fn maybe_changed_after<'db>( &'db self, db: &'db dyn Database, - input: Option, + input: Id, revision: Revision, ) -> VerifyResult; @@ -64,7 +87,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { &'db self, db: &'db dyn Database, executor: DatabaseKeyIndex, - output_key: Option, + output_key: crate::Id, ); /// Invoked when the value `stale_output` was output by `executor` in a previous @@ -75,7 +98,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { &self, db: &dyn Database, executor: DatabaseKeyIndex, - stale_output_key: Option, + stale_output_key: Id, ); /// Returns the [`IngredientIndex`] of this ingredient. diff --git a/src/input.rs b/src/input.rs index bf6d5e32e..9f5751aa9 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,4 +1,8 @@ -use std::{any::Any, fmt, ops::DerefMut}; +use std::{ + any::{Any, TypeId}, + fmt, + ops::DerefMut, +}; pub mod input_field; pub mod setter; @@ -13,7 +17,7 @@ use crate::{ function::VerifyResult, id::{AsId, FromId}, ingredient::{fmt_index, Ingredient}, - key::{DatabaseKeyIndex, DependencyIndex}, + key::{DatabaseKeyIndex, InputDependencyIndex}, plumbing::{Jar, JarAux, Stamp}, table::{memo::MemoTable, sync::SyncTable, Slot, Table}, zalsa::{IngredientIndex, Zalsa}, @@ -62,6 +66,10 @@ impl Jar for JarImpl { })) .collect() } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct>()) + } } pub struct IngredientImpl { @@ -110,7 +118,7 @@ impl IngredientImpl { None }; - let id = zalsa_local.allocate(zalsa.table(), self.ingredient_index, || Value:: { + let id = zalsa_local.allocate(zalsa.table(), self.ingredient_index, |_| Value:: { fields, stamps, memos: Default::default(), @@ -184,10 +192,7 @@ impl IngredientImpl { let value = Self::data(zalsa, id); let stamp = &value.stamps[field_index]; zalsa_local.report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(id), - }, + InputDependencyIndex::new(field_ingredient_index, id), stamp.durability, stamp.changed_at, InputAccumulatedValues::Empty, @@ -214,7 +219,7 @@ impl Ingredient for IngredientImpl { fn maybe_changed_after( &self, _db: &dyn Database, - _input: Option, + _input: Id, _revision: Revision, ) -> VerifyResult { // Input ingredients are just a counter, they store no data, they are immortal. @@ -238,7 +243,7 @@ impl Ingredient for IngredientImpl { &self, _db: &dyn Database, executor: DatabaseKeyIndex, - output_key: Option, + output_key: Id, ) { unreachable!( "mark_validated_output({:?}, {:?}): input cannot be the output of a tracked function", @@ -250,7 +255,7 @@ impl Ingredient for IngredientImpl { &self, _db: &dyn Database, executor: DatabaseKeyIndex, - stale_output_key: Option, + stale_output_key: Id, ) { unreachable!( "remove_stale_output({:?}, {:?}): input cannot be the output of a tracked function", diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 374b2f89e..461f7c3e1 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -53,11 +53,10 @@ where fn maybe_changed_after( &self, db: &dyn Database, - input: Option, + input: Id, revision: Revision, ) -> VerifyResult { let zalsa = db.zalsa(); - let input = input.unwrap(); let value = >::data(zalsa, input); VerifyResult::changed_if(value.stamps[self.field_index].changed_at > revision) } @@ -74,7 +73,7 @@ where &self, _db: &dyn Database, _executor: DatabaseKeyIndex, - _output_key: Option, + _output_key: Id, ) { } @@ -82,7 +81,7 @@ where &self, _db: &dyn Database, _executor: DatabaseKeyIndex, - _stale_output_key: Option, + _stale_output_key: Id, ) { } diff --git a/src/interned.rs b/src/interned.rs index 3a7cc240f..178b2e4fa 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -1,9 +1,10 @@ +use dashmap::SharedValue; + use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::durability::Durability; use crate::function::VerifyResult; -use crate::id::AsId; use crate::ingredient::fmt_index; -use crate::key::DependencyIndex; +use crate::key::InputDependencyIndex; use crate::plumbing::{Jar, JarAux}; use crate::table::memo::MemoTable; use crate::table::sync::SyncTable; @@ -11,6 +12,7 @@ use crate::table::Slot; use crate::zalsa::IngredientIndex; use crate::zalsa_local::QueryOrigin; use crate::{Database, DatabaseKeyIndex, Id}; +use std::any::TypeId; use std::fmt; use std::hash::{BuildHasher, Hash, Hasher}; use std::marker::PhantomData; @@ -94,6 +96,10 @@ impl Jar for JarImpl { ) -> Vec> { vec![Box::new(IngredientImpl::::new(first_index)) as _] } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct<'static>>()) + } } impl IngredientImpl @@ -116,74 +122,107 @@ where unsafe { std::mem::transmute(data) } } - pub fn intern_id<'db>( + /// Intern data to a unique reference. + /// + /// If `key` is already interned, returns the existing [`Id`] for the interned data without + /// invoking `assemble`. + /// Otherwise, invokes `assemble` with the given `key` and the [`Id`] to be allocated for this + /// interned value. The resulting [`C::Data`] will then be interned. + /// + /// Note: Using the database within the `assemble` function may result in a deadlock if + /// the database ends up trying to intern or allocate a new value. + pub fn intern<'db, Key>( &'db self, db: &'db dyn crate::Database, - data: impl Lookup>, - ) -> crate::Id { - C::deref_struct(self.intern(db, data)).as_id() + key: Key, + assemble: impl FnOnce(Id, Key) -> C::Data<'db>, + ) -> C::Struct<'db> + where + Key: Hash, + C::Data<'db>: HashEqLike, + { + C::struct_from_id(self.intern_id(db, key, assemble)) } /// Intern data to a unique reference. - pub fn intern<'db>( + /// + /// If `key` is already interned, returns the existing [`Id`] for the interned data without + /// invoking `assemble`. + /// Otherwise, invokes `assemble` with the given `key` and the [`Id`] to be allocated for this + /// interned value. The resulting [`C::Data`] will then be interned. + /// + /// Note: Using the database within the `assemble` function may result in a deadlock if + /// the database ends up trying to intern or allocate a new value. + pub fn intern_id<'db, Key>( &'db self, db: &'db dyn crate::Database, - data: impl Lookup>, - ) -> C::Struct<'db> { + key: Key, + assemble: impl FnOnce(Id, Key) -> C::Data<'db>, + ) -> crate::Id + where + Key: Hash, + // We'd want the following predicate, but this currently implies `'static` due to a rustc + // bug + // for<'db> C::Data<'db>: HashEqLike, + // so instead we go with this and transmute the lifetime in the `eq` closure + C::Data<'db>: HashEqLike, + { let zalsa_local = db.zalsa_local(); zalsa_local.report_tracked_read( - DependencyIndex::for_table(self.ingredient_index), + InputDependencyIndex::for_table(self.ingredient_index), Durability::MAX, self.reset_at, InputAccumulatedValues::Empty, None, ); - // Optimisation to only get read lock on the map if the data has already - // been interned. - // We need to use the raw API for this lookup. See the [`Lookup`][] trait definition for an explanation of why. - let data_hash = { - let mut hasher = self.key_map.hasher().build_hasher(); - data.hash(&mut hasher); - hasher.finish() + // Optimization to only get read lock on the map if the data has already been interned. + let data_hash = self.key_map.hasher().hash_one(&key); + let shard = &self.key_map.shards()[self.key_map.determine_shard(data_hash as _)]; + let eq = |(data, _): &_| { + // SAFETY: it's safe to go from Data<'static> to Data<'db> + // shrink lifetime here to use a single lifetime in Lookup::eq(&StructKey<'db>, &C::Data<'db>) + let data: &C::Data<'db> = unsafe { std::mem::transmute(data) }; + HashEqLike::eq(data, &key) }; - let shard = self.key_map.determine_shard(data_hash as _); + { - let lock = self.key_map.shards()[shard].read(); - if let Some(bucket) = lock.find(data_hash, |(a, _)| { - // SAFETY: it's safe to go from Data<'static> to Data<'db> - // shrink lifetime here to use a single lifetime in Lookup::eq(&StructKey<'db>, &C::Data<'db>) - let a: &C::Data<'db> = unsafe { std::mem::transmute(a) }; - Lookup::eq(&data, a) - }) { + let lock = shard.read(); + if let Some(bucket) = lock.find(data_hash, eq) { // SAFETY: Read lock on map is held during this block - return C::struct_from_id(unsafe { *bucket.as_ref().1.get() }); + return unsafe { *bucket.as_ref().1.get() }; } - }; - - let data = data.into_owned(); - - let internal_data = unsafe { self.to_internal_data(data) }; + } - match self.key_map.entry(internal_data.clone()) { + let mut lock = shard.write(); + match lock.find_or_find_insert_slot(data_hash, eq, |(element, _)| { + self.key_map.hasher().hash_one(element) + }) { // Data has been interned by a racing call, use that ID instead - dashmap::mapref::entry::Entry::Occupied(entry) => { - let id = *entry.get(); - drop(entry); - C::struct_from_id(id) - } - + Ok(slot) => unsafe { *slot.as_ref().1.get() }, // We won any races so should intern the data - dashmap::mapref::entry::Entry::Vacant(entry) => { + Err(slot) => { let zalsa = db.zalsa(); let table = zalsa.table(); - let next_id = zalsa_local.allocate(table, self.ingredient_index, || Value:: { - data: internal_data, + let id = zalsa_local.allocate(table, self.ingredient_index, |id| Value:: { + data: unsafe { self.to_internal_data(assemble(id, key)) }, memos: Default::default(), syncs: Default::default(), }); - entry.insert(next_id); - C::struct_from_id(next_id) + unsafe { + lock.insert_in_slot( + data_hash, + slot, + (table.get::>(id).data.clone(), SharedValue::new(id)), + ) + }; + debug_assert_eq!( + data_hash, + self.key_map + .hasher() + .hash_one(table.get::>(id).data.clone()) + ); + id } } } @@ -220,7 +259,7 @@ where fn maybe_changed_after( &self, _db: &dyn Database, - _input: Option, + _input: Id, revision: Revision, ) -> VerifyResult { VerifyResult::changed_if(revision < self.reset_at) @@ -242,7 +281,7 @@ where &self, _db: &dyn Database, executor: DatabaseKeyIndex, - output_key: Option, + output_key: crate::Id, ) { unreachable!( "mark_validated_output({:?}, {:?}): input cannot be the output of a tracked function", @@ -254,7 +293,7 @@ where &self, _db: &dyn Database, executor: DatabaseKeyIndex, - stale_output_key: Option, + stale_output_key: crate::Id, ) { unreachable!( "remove_stale_output({:?}, {:?}): interned ids are not outputs", @@ -314,6 +353,12 @@ where } } +/// A trait for types that hash and compare like `O`. +pub trait HashEqLike { + fn hash(&self, h: &mut H); + fn eq(&self, data: &O) -> bool; +} + /// The `Lookup` trait is a more flexible variant on [`std::borrow::Borrow`] /// and [`std::borrow::ToOwned`]. /// @@ -329,12 +374,14 @@ where /// requires that `&(K1...)` be convertible to `&ViewStruct` which just isn't /// possible. `Lookup` instead offers direct `hash` and `eq` methods. pub trait Lookup { - fn hash(&self, h: &mut H); - fn eq(&self, data: &O) -> bool; fn into_owned(self) -> O; } - -impl Lookup for T +impl Lookup for T { + fn into_owned(self) -> T { + self + } +} +impl HashEqLike for T where T: Hash + Eq, { @@ -345,30 +392,18 @@ where fn eq(&self, data: &T) -> bool { self == data } - - fn into_owned(self) -> T { - self - } } impl Lookup for &T where - T: Clone + Eq + Hash, + T: Clone, { - fn hash(&self, h: &mut H) { - Hash::hash(self, &mut *h); - } - - fn eq(&self, data: &T) -> bool { - *self == data - } - fn into_owned(self) -> T { Clone::clone(self) } } -impl<'a, T> Lookup> for &'a T +impl<'a, T> HashEqLike<&'a T> for Box where T: ?Sized + Hash + Eq, Box: From<&'a T>, @@ -376,55 +411,61 @@ where fn hash(&self, h: &mut H) { Hash::hash(self, &mut *h) } - fn eq(&self, data: &Box) -> bool { + fn eq(&self, data: &&T) -> bool { **self == **data } +} + +impl<'a, T> Lookup> for &'a T +where + T: ?Sized + Hash + Eq, + Box: From<&'a T>, +{ fn into_owned(self) -> Box { Box::from(self) } } impl Lookup for &str { + fn into_owned(self) -> String { + self.to_owned() + } +} +impl HashEqLike<&str> for String { fn hash(&self, h: &mut H) { Hash::hash(self, &mut *h) } - fn eq(&self, data: &String) -> bool { - self == data - } - - fn into_owned(self) -> String { - self.to_owned() + fn eq(&self, data: &&str) -> bool { + self == *data } } -impl + Clone + Lookup, T> Lookup> for &[A] { +impl> HashEqLike<&[A]> for Vec { fn hash(&self, h: &mut H) { - for a in *self { - Hash::hash(a, h); - } + Hash::hash(self, h); } - fn eq(&self, data: &Vec) -> bool { + fn eq(&self, data: &&[A]) -> bool { self.len() == data.len() && data.iter().enumerate().all(|(i, a)| &self[i] == a) } - +} +impl + Clone + Lookup, T> Lookup> for &[A] { fn into_owned(self) -> Vec { self.iter().map(|a| Lookup::into_owned(a.clone())).collect() } } -impl + Clone + Lookup, T> Lookup> for [A; N] { +impl> HashEqLike<[A; N]> for Vec { fn hash(&self, h: &mut H) { - for a in self { - Hash::hash(a, h); - } + Hash::hash(self, h); } - fn eq(&self, data: &Vec) -> bool { + fn eq(&self, data: &[A; N]) -> bool { self.len() == data.len() && data.iter().enumerate().all(|(i, a)| &self[i] == a) } - +} +impl + Clone + Lookup, T> Lookup> for [A; N] { fn into_owned(self) -> Vec { self.into_iter() .map(|a| Lookup::into_owned(a.clone())) @@ -432,15 +473,16 @@ impl + Clone + Lookup, T> Lookup< } } -impl Lookup for &Path { +impl HashEqLike<&Path> for PathBuf { fn hash(&self, h: &mut H) { Hash::hash(self, h); } - fn eq(&self, data: &PathBuf) -> bool { + fn eq(&self, data: &&Path) -> bool { self == data } - +} +impl Lookup for &Path { fn into_owned(self) -> PathBuf { self.to_owned() } diff --git a/src/key.rs b/src/key.rs index 7d8d4d269..b15077329 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,39 +1,35 @@ -use crate::{ - accumulator::accumulated_map::AccumulatedMap, function::VerifyResult, zalsa::IngredientIndex, - Database, Id, -}; +use core::fmt; + +use crate::{function::VerifyResult, zalsa::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the -/// database. Used to track dependencies between queries. Fully ordered and +/// database. Used to track output dependencies between queries. Fully ordered and /// equatable but those orderings are arbitrary, and meant to be used only for /// inserting into maps and the like. -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct DependencyIndex { - pub(crate) ingredient_index: IngredientIndex, - pub(crate) key_index: Option, +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct OutputDependencyIndex { + ingredient_index: IngredientIndex, + key_index: Id, } -impl DependencyIndex { - /// Create a database-key-index for an interning or entity table. - /// The `key_index` here is always zero, which deliberately corresponds to - /// no particular id or entry. This is because the data in such tables - /// remains valid until the table as a whole is reset. Using a single id avoids - /// creating tons of dependencies in the dependency listings. - pub(crate) fn for_table(ingredient_index: IngredientIndex) -> Self { +/// An integer that uniquely identifies a particular query instance within the +/// database. Used to track input dependencies between queries. Fully ordered and +/// equatable but those orderings are arbitrary, and meant to be used only for +/// inserting into maps and the like. +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct InputDependencyIndex { + ingredient_index: IngredientIndex, + key_index: Option, +} + +impl OutputDependencyIndex { + pub(crate) fn new(ingredient_index: IngredientIndex, key_index: Id) -> Self { Self { ingredient_index, - key_index: None, + key_index, } } - pub fn ingredient_index(self) -> IngredientIndex { - self.ingredient_index - } - - pub fn key_index(self) -> Option { - self.key_index - } - pub(crate) fn remove_stale_output(&self, db: &dyn Database, executor: DatabaseKeyIndex) { db.zalsa() .lookup_ingredient(self.ingredient_index) @@ -49,26 +45,70 @@ impl DependencyIndex { .lookup_ingredient(self.ingredient_index) .mark_validated_output(db, database_key_index, self.key_index) } +} + +impl fmt::Debug for OutputDependencyIndex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + crate::attach::with_attached_database(|db| { + let ingredient = db.zalsa().lookup_ingredient(self.ingredient_index); + ingredient.fmt_index(Some(self.key_index), f) + }) + .unwrap_or_else(|| { + f.debug_tuple("OutputDependencyIndex") + .field(&self.ingredient_index) + .field(&self.key_index) + .finish() + }) + } +} + +impl InputDependencyIndex { + /// Create a database-key-index for an interning or entity table. + /// The `key_index` here is always `None`, which deliberately corresponds to + /// no particular id or entry. This is because the data in such tables + /// remains valid until the table as a whole is reset. Using a single id avoids + /// creating tons of dependencies in the dependency listings. + pub(crate) fn for_table(ingredient_index: IngredientIndex) -> Self { + Self { + ingredient_index, + key_index: None, + } + } + + pub(crate) fn new(ingredient_index: IngredientIndex, key_index: Id) -> Self { + Self { + ingredient_index, + key_index: Some(key_index), + } + } pub(crate) fn maybe_changed_after( &self, db: &dyn Database, last_verified_at: crate::Revision, ) -> VerifyResult { - db.zalsa() - .lookup_ingredient(self.ingredient_index) - .maybe_changed_after(db, self.key_index, last_verified_at) + match self.key_index { + Some(key_index) => db + .zalsa() + .lookup_ingredient(self.ingredient_index) + .maybe_changed_after(db, key_index, last_verified_at), + None => VerifyResult::unchanged(), + } + } + + pub fn set_key_index(&mut self, key_index: Id) { + self.key_index = Some(key_index); } } -impl std::fmt::Debug for DependencyIndex { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl fmt::Debug for InputDependencyIndex { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { crate::attach::with_attached_database(|db| { let ingredient = db.zalsa().lookup_ingredient(self.ingredient_index); ingredient.fmt_index(self.key_index, f) }) .unwrap_or_else(|| { - f.debug_tuple("DependencyIndex") + f.debug_tuple("InputDependencyIndex") .field(&self.ingredient_index) .field(&self.key_index) .finish() @@ -80,7 +120,7 @@ impl std::fmt::Debug for DependencyIndex { /// An "active" database key index represents a database key index /// that is actively executing. In that case, the `key_index` cannot be /// None. -#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, Hash)] pub struct DatabaseKeyIndex { pub(crate) ingredient_index: IngredientIndex, pub(crate) key_index: Id, @@ -95,22 +135,24 @@ impl DatabaseKeyIndex { pub fn key_index(self) -> Id { self.key_index } - - pub(crate) fn accumulated(self, db: &dyn Database) -> Option<&AccumulatedMap> { - db.zalsa() - .lookup_ingredient(self.ingredient_index) - .accumulated(db, self.key_index) - } } impl std::fmt::Debug for DatabaseKeyIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let i: DependencyIndex = (*self).into(); - std::fmt::Debug::fmt(&i, f) + crate::attach::with_attached_database(|db| { + let ingredient = db.zalsa().lookup_ingredient(self.ingredient_index); + ingredient.fmt_index(Some(self.key_index), f) + }) + .unwrap_or_else(|| { + f.debug_tuple("DatabaseKeyIndex") + .field(&self.ingredient_index) + .field(&self.key_index) + .finish() + }) } } -impl From for DependencyIndex { +impl From for InputDependencyIndex { fn from(value: DatabaseKeyIndex) -> Self { Self { ingredient_index: value.ingredient_index, @@ -119,10 +161,19 @@ impl From for DependencyIndex { } } -impl TryFrom for DatabaseKeyIndex { +impl From for OutputDependencyIndex { + fn from(value: DatabaseKeyIndex) -> Self { + Self { + ingredient_index: value.ingredient_index, + key_index: value.key_index, + } + } +} + +impl TryFrom for DatabaseKeyIndex { type Error = (); - fn try_from(value: DependencyIndex) -> Result { + fn try_from(value: InputDependencyIndex) -> Result { let key_index = value.key_index.ok_or(())?; Ok(Self { ingredient_index: value.ingredient_index, diff --git a/src/lib.rs b/src/lib.rs index 9e9bcf6dd..63783b102 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -133,6 +133,7 @@ pub mod plumbing { pub mod interned { pub use crate::interned::Configuration; + pub use crate::interned::HashEqLike; pub use crate::interned::IngredientImpl; pub use crate::interned::JarImpl; pub use crate::interned::Lookup; diff --git a/src/runtime.rs b/src/runtime.rs index db6b5bf09..8b9b4b12f 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -179,12 +179,11 @@ impl Runtime { return BlockResult::Cycle; } - db.salsa_event(&|| Event { - thread_id, - kind: EventKind::WillBlockOn { + db.salsa_event(&|| { + Event::new(EventKind::WillBlockOn { other_thread_id: other_id, database_key, - }, + }) }); let result = local_state.with_query_stack(|stack| { diff --git a/src/salsa_struct.rs b/src/salsa_struct.rs index fcf7920a7..8674dc125 100644 --- a/src/salsa_struct.rs +++ b/src/salsa_struct.rs @@ -1 +1,5 @@ -pub trait SalsaStructInDb {} +use crate::{plumbing::JarAux, IngredientIndex}; + +pub trait SalsaStructInDb { + fn lookup_ingredient_index(aux: &dyn JarAux) -> Option; +} diff --git a/src/storage.rs b/src/storage.rs index 409862918..10341e0b8 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -23,12 +23,14 @@ pub unsafe trait HasStorage: Database + Clone + Sized { /// Concrete implementation of the [`Database`][] trait. /// Takes an optional type parameter `U` that allows you to thread your own data. pub struct Storage { - /// Reference to the database. This is always `Some` except during destruction. - zalsa_impl: Option>, + // Note: Drop order is important, zalsa_impl needs to drop before coordinate + /// Reference to the database. + zalsa_impl: Arc, + // Note: Drop order is important, coordinate needs to drop after zalsa_impl /// Coordination data for cancellation of other handles when `zalsa_mut` is called. /// This could be stored in Zalsa but it makes things marginally cleaner to keep it separate. - coordinate: Arc, + coordinate: CoordinateDrop, /// Per-thread state zalsa_local: zalsa_local::ZalsaLocal, @@ -46,11 +48,11 @@ struct Coordinate { impl Default for Storage { fn default() -> Self { Self { - zalsa_impl: Some(Arc::new(Zalsa::new::())), - coordinate: Arc::new(Coordinate { + zalsa_impl: Arc::new(Zalsa::new::()), + coordinate: CoordinateDrop(Arc::new(Coordinate { clones: Mutex::new(1), cvar: Default::default(), - }), + })), zalsa_local: ZalsaLocal::new(), phantom: PhantomData, } @@ -58,13 +60,6 @@ impl Default for Storage { } impl Storage { - /// Access the `Arc`. This should always be - /// possible as `zalsa_impl` only becomes - /// `None` once we are in the `Drop` impl. - fn zalsa_impl(&self) -> &Arc { - self.zalsa_impl.as_ref().unwrap() - } - // ANCHOR: cancel_other_workers /// Sets cancellation flag and blocks until all other workers with access /// to this storage have completed. @@ -72,14 +67,9 @@ impl Storage { /// This could deadlock if there is a single worker with two handles to the /// same database! fn cancel_others(&self, db: &Db) { - let zalsa = self.zalsa_impl(); - zalsa.set_cancellation_flag(); - - db.salsa_event(&|| Event { - thread_id: std::thread::current().id(), + self.zalsa_impl.set_cancellation_flag(); - kind: EventKind::DidSetCancellationFlag, - }); + db.salsa_event(&|| Event::new(EventKind::DidSetCancellationFlag)); let mut clones = self.coordinate.clones.lock(); while *clones != 1 { @@ -91,16 +81,15 @@ impl Storage { unsafe impl ZalsaDatabase for T { fn zalsa(&self) -> &Zalsa { - self.storage().zalsa_impl.as_ref().unwrap() + &self.storage().zalsa_impl } fn zalsa_mut(&mut self) -> &mut Zalsa { self.storage().cancel_others(self); - // The ref count on the `Arc` should now be 1 let storage = self.storage_mut(); - let arc_zalsa_mut = storage.zalsa_impl.as_mut().unwrap(); - let zalsa_mut = Arc::get_mut(arc_zalsa_mut).unwrap(); + // The ref count on the `Arc` should now be 1 + let zalsa_mut = Arc::get_mut(&mut storage.zalsa_impl).unwrap(); zalsa_mut.new_revision(); zalsa_mut } @@ -122,20 +111,26 @@ impl Clone for Storage { Self { zalsa_impl: self.zalsa_impl.clone(), - coordinate: Arc::clone(&self.coordinate), + coordinate: CoordinateDrop(Arc::clone(&self.coordinate)), zalsa_local: ZalsaLocal::new(), phantom: PhantomData, } } } -impl Drop for Storage { - fn drop(&mut self) { - // Drop the database handle *first* - self.zalsa_impl.take(); +struct CoordinateDrop(Arc); - // *Now* decrement the number of clones and notify once we have completed - *self.coordinate.clones.lock() -= 1; - self.coordinate.cvar.notify_all(); +impl std::ops::Deref for CoordinateDrop { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Drop for CoordinateDrop { + fn drop(&mut self) { + *self.0.clones.lock() -= 1; + self.0.cvar.notify_all(); } } diff --git a/src/table.rs b/src/table.rs index af6ace70b..233314116 100644 --- a/src/table.rs +++ b/src/table.rs @@ -228,7 +228,7 @@ impl Page { pub(crate) fn allocate(&self, page: PageIndex, value: V) -> Result where - V: FnOnce() -> T, + V: FnOnce(Id) -> T, { let guard = self.allocation_lock.lock(); let index = self.allocated.load(Ordering::Acquire); @@ -237,14 +237,15 @@ impl Page { } // Initialize entry `index` + let id = make_id(page, SlotIndex::new(index)); let data = &self.data[index]; - unsafe { (*data.get()).write(value()) }; + unsafe { (*data.get()).write(value(id)) }; // Update the length (this must be done after initialization!) self.allocated.store(index + 1, Ordering::Release); drop(guard); - Ok(make_id(page, SlotIndex::new(index))) + Ok(id) } } @@ -293,7 +294,7 @@ impl dyn TablePage { fn make_id(page: PageIndex, slot: SlotIndex) -> Id { let page = page.0 as u32; let slot = slot.0 as u32; - Id::from_u32(page << PAGE_LEN_BITS | slot) + Id::from_u32((page << PAGE_LEN_BITS) | slot) } fn split_id(id: Id) -> (PageIndex, SlotIndex) { diff --git a/src/table/memo.rs b/src/table/memo.rs index 73bd52bde..fe2fc9583 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -163,11 +163,40 @@ impl MemoTable { unsafe { Some(Self::from_dummy(arc_swap.load_full())) } } - pub(crate) fn into_memos( - mut self, - ) -> impl Iterator)> { - let memos = std::mem::take(self.memos.get_mut()); - memos + /// Calls `f` on the memo at `memo_ingredient_index` and replaces the memo with the result of `f`. + /// If the memo is not present, `f` is not called. + pub(crate) fn map_memo( + &self, + memo_ingredient_index: MemoIngredientIndex, + f: impl FnOnce(Arc) -> Arc, + ) { + // If the memo slot is already occupied, it must already have the + // right type info etc, and we only need the read-lock. + let memos = self.memos.read(); + let Some(MemoEntry { + data: + Some(MemoEntryData { + type_id, + to_dyn_fn: _, + arc_swap, + }), + }) = memos.get(memo_ingredient_index.as_usize()) + else { + return; + }; + assert_eq!( + *type_id, + TypeId::of::(), + "inconsistent type-id for `{memo_ingredient_index:?}`" + ); + // SAFETY: type_id check asserted above + let memo = f(unsafe { Self::from_dummy(arc_swap.load_full()) }); + unsafe { Self::from_dummy::(arc_swap.swap(Self::to_dummy(memo))) }; + } + + pub(crate) fn into_memos(self) -> impl Iterator)> { + self.memos + .into_inner() .into_iter() .zip(0..) .filter_map(|(mut memo, index)| memo.data.take().map(|d| (d, index))) diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index e2b842b4d..329421c9d 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -1,4 +1,4 @@ -use std::{fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; +use std::{any::TypeId, fmt, hash::Hash, marker::PhantomData, ops::DerefMut}; use crossbeam::{atomic::AtomicCell, queue::SegQueue}; use tracked_field::FieldIngredientImpl; @@ -8,7 +8,7 @@ use crate::{ cycle::CycleRecoveryStrategy, function::VerifyResult, ingredient::{fmt_index, Ingredient, Jar, JarAux}, - key::{DatabaseKeyIndex, DependencyIndex}, + key::{DatabaseKeyIndex, InputDependencyIndex}, plumbing::ZalsaLocal, runtime::StampedValue, salsa_struct::SalsaStructInDb, @@ -114,6 +114,10 @@ impl Jar for JarImpl { })) .collect() } + + fn salsa_struct_type_id(&self) -> Option { + Some(TypeId::of::<::Struct<'static>>()) + } } pub trait TrackedStructInDb: SalsaStructInDb { @@ -317,7 +321,7 @@ where current_deps: &StampedValue<()>, fields: C::Fields<'db>, ) -> Id { - let value = || Value { + let value = |_| Value { updated_at: AtomicCell::new(Some(current_revision)), durability: current_deps.durability, fields: unsafe { self.to_static(fields) }, @@ -336,7 +340,7 @@ where // Overwrite the free-list entry. Use `*foo = ` because the entry // has been previously initialized and we want to free the old contents. unsafe { - *data_raw = value(); + *data_raw = value(id); } id @@ -466,11 +470,10 @@ where /// unspecified results (but not UB). See [`InternedIngredient::delete_index`] for more /// discussion and important considerations. pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: Id) { - db.salsa_event(&|| Event { - thread_id: std::thread::current().id(), - kind: crate::EventKind::DidDiscard { + db.salsa_event(&|| { + Event::new(crate::EventKind::DidDiscard { key: self.database_key_index(id), - }, + }) }); let zalsa = db.zalsa(); @@ -503,22 +506,18 @@ where // and the code that references the memo-table has a read-lock. let memo_table = unsafe { (*data).take_memo_table() }; for (memo_ingredient_index, memo) in memo_table.into_memos() { - let ingredient_index = zalsa.ingredient_index_for_memo(memo_ingredient_index); + let ingredient_index = + zalsa.ingredient_index_for_memo(self.ingredient_index, memo_ingredient_index); let executor = DatabaseKeyIndex { ingredient_index, key_index: id, }; - db.salsa_event(&|| Event { - thread_id: std::thread::current().id(), - kind: EventKind::DidDiscard { key: executor }, - }); + db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor })); for stale_output in memo.origin().outputs() { - zalsa - .lookup_ingredient(stale_output.ingredient_index) - .remove_stale_output(db, executor, stale_output.key_index); + stale_output.remove_stale_output(db, executor); } } @@ -557,10 +556,7 @@ where let field_changed_at = data.revisions[field_index]; zalsa_local.report_tracked_read( - DependencyIndex { - ingredient_index: field_ingredient_index, - key_index: Some(id), - }, + InputDependencyIndex::new(field_ingredient_index, id), data.durability, field_changed_at, InputAccumulatedValues::Empty, @@ -582,7 +578,7 @@ where fn maybe_changed_after( &self, _db: &dyn Database, - _input: Option, + _input: Id, _revision: Revision, ) -> VerifyResult { VerifyResult::unchanged() @@ -604,7 +600,7 @@ where &'db self, _db: &'db dyn Database, _executor: DatabaseKeyIndex, - _output_key: Option, + _output_key: crate::Id, ) { // we used to update `update_at` field but now we do it lazilly when data is accessed // @@ -615,13 +611,13 @@ where &self, db: &dyn Database, _executor: DatabaseKeyIndex, - stale_output_key: Option, + stale_output_key: crate::Id, ) { // This method is called when, in prior revisions, // `executor` creates a tracked struct `salsa_output_key`, // but it did not in the current revision. // In that case, we can delete `stale_output_key` and any data associated with it. - self.delete_entity(db.as_dyn_database(), stale_output_key.unwrap()); + self.delete_entity(db.as_dyn_database(), stale_output_key); } fn fmt_index(&self, index: Option, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index f030e3cf1..98ee5cab3 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -51,12 +51,11 @@ where fn maybe_changed_after<'db>( &'db self, db: &'db dyn Database, - input: Option, + input: Id, revision: crate::Revision, ) -> VerifyResult { let zalsa = db.zalsa(); - let id = input.unwrap(); - let data = >::data(zalsa.table(), id); + let data = >::data(zalsa.table(), input); let field_changed_at = data.revisions[self.field_index]; VerifyResult::changed_if(field_changed_at > revision) } @@ -77,7 +76,7 @@ where &self, _db: &dyn Database, _executor: crate::DatabaseKeyIndex, - _output_key: Option, + _output_key: crate::Id, ) { panic!("tracked field ingredients have no outputs") } @@ -86,7 +85,7 @@ where &self, _db: &dyn Database, _executor: crate::DatabaseKeyIndex, - _stale_output_key: Option, + _stale_output_key: crate::Id, ) { panic!("tracked field ingredients have no outputs") } diff --git a/src/zalsa.rs b/src/zalsa.rs index 34b9d0213..0eb1a5bde 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -1,5 +1,5 @@ use append_only_vec::AppendOnlyVec; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use rustc_hash::FxHashMap; use std::any::{Any, TypeId}; use std::marker::PhantomData; @@ -113,8 +113,10 @@ pub struct Zalsa { nonce: Nonce, - /// Number of memo ingredient indices created by calls to [`next_memo_ingredient_index`](`Self::next_memo_ingredient_index`) - memo_ingredients: Mutex>, + /// Map from the [`IngredientIndex::as_usize`][] of a salsa struct to a list of + /// [ingredient-indices](`IngredientIndex`) for tracked functions that have this salsa struct + /// as input. + memo_ingredient_indices: RwLock>>, /// Map from the type-id of an `impl Jar` to the index of its first ingredient. /// This is using a `Mutex` (versus, say, a `FxDashMap`) @@ -146,7 +148,7 @@ impl Zalsa { ingredients_vec: AppendOnlyVec::new(), ingredients_requiring_reset: AppendOnlyVec::new(), runtime: Runtime::default(), - memo_ingredients: Default::default(), + memo_ingredient_indices: Default::default(), } } @@ -180,11 +182,20 @@ impl Zalsa { { let jar_type_id = jar.type_id(); let mut jar_map = self.jar_map.lock(); - *jar_map - .entry(jar_type_id) - .or_insert_with(|| { - let index = IngredientIndex::from(self.ingredients_vec.len()); - let ingredients = jar.create_ingredients(self, index); + let mut should_create = false; + // First record the index we will use into the map and then go and create the ingredients. + // Those ingredients may invoke methods on the `JarAux` trait that read from this map + // to lookup ingredient indices for already created jars. + // + // Note that we still hold the lock above so only one jar is being created at a time and hence + // ingredient indices cannot overlap. + let index = *jar_map.entry(jar_type_id).or_insert_with(|| { + should_create = true; + IngredientIndex::from(self.ingredients_vec.len()) + }); + if should_create { + let aux = JarAuxImpl(self, &jar_map); + let ingredients = jar.create_ingredients(&aux, index); for ingredient in ingredients { let expected_index = ingredient.ingredient_index(); @@ -192,9 +203,7 @@ impl Zalsa { self.ingredients_requiring_reset.push(expected_index); } - let actual_index = self - .ingredients_vec - .push(ingredient); + let actual_index = self.ingredients_vec.push(ingredient); assert_eq!( expected_index.as_usize(), actual_index, @@ -203,10 +212,10 @@ impl Zalsa { expected_index, actual_index, ); - } - index - }) + } + + index } } @@ -284,15 +293,34 @@ impl Zalsa { pub(crate) fn ingredient_index_for_memo( &self, + struct_ingredient_index: IngredientIndex, memo_ingredient_index: MemoIngredientIndex, ) -> IngredientIndex { - self.memo_ingredients.lock()[memo_ingredient_index.as_usize()] + self.memo_ingredient_indices.read()[struct_ingredient_index.as_usize()] + [memo_ingredient_index.as_usize()] } } -impl JarAux for Zalsa { - fn next_memo_ingredient_index(&self, ingredient_index: IngredientIndex) -> MemoIngredientIndex { - let mut memo_ingredients = self.memo_ingredients.lock(); +struct JarAuxImpl<'a>(&'a Zalsa, &'a FxHashMap); + +impl JarAux for JarAuxImpl<'_> { + fn lookup_jar_by_type(&self, jar: &dyn Jar) -> Option { + self.1.get(&jar.type_id()).map(ToOwned::to_owned) + } + + fn next_memo_ingredient_index( + &self, + struct_ingredient_index: IngredientIndex, + ingredient_index: IngredientIndex, + ) -> MemoIngredientIndex { + let mut memo_ingredients = self.0.memo_ingredient_indices.write(); + let idx = struct_ingredient_index.as_usize(); + let memo_ingredients = if let Some(memo_ingredients) = memo_ingredients.get_mut(idx) { + memo_ingredients + } else { + memo_ingredients.resize_with(idx + 1, Vec::new); + &mut memo_ingredients[idx] + }; let mi = MemoIngredientIndex(u32::try_from(memo_ingredients.len()).unwrap()); memo_ingredients.push(ingredient_index); mi diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 4cb5c1eea..30087c5ae 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -4,8 +4,7 @@ use tracing::debug; use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; use crate::active_query::ActiveQuery; use crate::durability::Durability; -use crate::key::DatabaseKeyIndex; -use crate::key::DependencyIndex; +use crate::key::{DatabaseKeyIndex, InputDependencyIndex, OutputDependencyIndex}; use crate::runtime::StampedValue; use crate::table::PageIndex; use crate::table::Slot; @@ -20,7 +19,6 @@ use crate::EventKind; use crate::Id; use crate::Revision; use std::cell::RefCell; -use std::sync::Arc; /// State that is specific to a single execution thread. /// @@ -58,7 +56,7 @@ impl ZalsaLocal { &self, table: &Table, ingredient: IngredientIndex, - mut value: impl FnOnce() -> T, + mut value: impl FnOnce(Id) -> T, ) -> Id { // Find the most recent page, pushing a page if needed let mut page = *self @@ -71,7 +69,7 @@ impl ZalsaLocal { // Try to allocate an entry on that page let page_ref = table.page::(page); match page_ref.allocate(page, value) { - // If succesful, return + // If succesfull, return Ok(id) => return id, // Otherwise, create a new page and try again @@ -100,10 +98,6 @@ impl ZalsaLocal { c(self.query_stack.borrow_mut().as_mut()) } - fn query_in_progress(&self) -> bool { - self.with_query_stack(|stack| !stack.is_empty()) - } - /// Returns the index of the active query along with its *current* durability/changed-at /// information. As the query continues to execute, naturally, that information may change. pub(crate) fn active_query(&self) -> Option<(DatabaseKeyIndex, StampedValue<()>)> { @@ -140,7 +134,7 @@ impl ZalsaLocal { } /// Add an output to the current query's list of dependencies - pub(crate) fn add_output(&self, entity: DependencyIndex) { + pub(crate) fn add_output(&self, entity: OutputDependencyIndex) { self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { top_query.add_output(entity) @@ -149,7 +143,7 @@ impl ZalsaLocal { } /// Check whether `entity` is an output of the currently active query (if any) - pub(crate) fn is_output_of_active_query(&self, entity: DependencyIndex) -> bool { + pub(crate) fn is_output_of_active_query(&self, entity: OutputDependencyIndex) -> bool { self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { top_query.is_output(entity) @@ -162,7 +156,7 @@ impl ZalsaLocal { /// Register that currently active query reads the given input pub(crate) fn report_tracked_read( &self, - input: DependencyIndex, + input: InputDependencyIndex, durability: Durability, changed_at: Revision, accumulated: InputAccumulatedValues, @@ -216,13 +210,10 @@ impl ZalsaLocal { /// * the disambiguator index #[track_caller] pub(crate) fn disambiguate(&self, key: IdentityHash) -> (StampedValue<()>, Disambiguator) { - assert!( - self.query_in_progress(), - "cannot create a tracked struct disambiguator outside of a tracked function" - ); - self.with_query_stack(|stack| { - let top_query = stack.last_mut().unwrap(); + let top_query = stack.last_mut().expect( + "cannot create a tracked struct disambiguator outside of a tracked function", + ); let disambiguator = top_query.disambiguate(key); ( StampedValue { @@ -237,25 +228,20 @@ impl ZalsaLocal { #[track_caller] pub(crate) fn tracked_struct_id(&self, identity: &Identity) -> Option { - debug_assert!( - self.query_in_progress(), - "cannot create a tracked struct disambiguator outside of a tracked function" - ); - self.with_query_stack(|stack| { - let top_query = stack.last().unwrap(); + let top_query = stack.last().expect( + "cannot create a tracked struct disambiguator outside of a tracked function", + ); top_query.tracked_struct_ids.get(identity).copied() }) } #[track_caller] pub(crate) fn store_tracked_struct_id(&self, identity: Identity, id: Id) { - debug_assert!( - self.query_in_progress(), - "cannot create a tracked struct disambiguator outside of a tracked function" - ); self.with_query_stack(|stack| { - let top_query = stack.last_mut().unwrap(); + let top_query = stack.last_mut().expect( + "cannot create a tracked struct disambiguator outside of a tracked function", + ); let old_id = top_query.tracked_struct_ids.insert(identity, id); assert!( old_id.is_none(), @@ -277,12 +263,7 @@ impl ZalsaLocal { /// `salsa_event` is emitted when this method is called, so that should be /// used instead. pub(crate) fn unwind_if_revision_cancelled(&self, db: &dyn Database) { - let thread_id = std::thread::current().id(); - db.salsa_event(&|| Event { - thread_id, - - kind: EventKind::WillCheckCancellation, - }); + db.salsa_event(&|| Event::new(EventKind::WillCheckCancellation)); let zalsa = db.zalsa(); if zalsa.load_cancellation_flag() { self.unwind_cancelled(zalsa.current_revision()); @@ -300,7 +281,8 @@ impl std::panic::RefUnwindSafe for ZalsaLocal {} /// Summarizes "all the inputs that a query used" /// and "all the outputs its wrote to" -#[derive(Debug, Clone)] +#[derive(Debug)] +// #[derive(Clone)] cloning this is expensive, so we don't derive pub(crate) struct QueryRevisions { /// The most revision in which some input changed. pub(crate) changed_at: Revision, @@ -411,7 +393,7 @@ pub enum QueryOrigin { impl QueryOrigin { /// Indices for queries *read* by this query - pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { + pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), QueryOrigin::Assigned(_) | QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { @@ -422,7 +404,7 @@ impl QueryOrigin { } /// Indices for queries *written* by this query (if any) - pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { + pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), QueryOrigin::Assigned(_) | QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { @@ -433,16 +415,6 @@ impl QueryOrigin { } } -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub enum EdgeKind { - Input, - Output, -} - -lazy_static::lazy_static! { - pub(crate) static ref EMPTY_DEPENDENCIES: Arc<[(EdgeKind, DependencyIndex)]> = Arc::new([]); -} - /// The edges between a memoized value and other queries in the dependency graph. /// These edges include both dependency edges /// e.g., when creating the memoized value for Q0 executed another function Q1) @@ -462,33 +434,42 @@ pub struct QueryEdges { /// Important: /// /// * The inputs must be in **execution order** for the red-green algorithm to work. - pub input_outputs: Arc<[(EdgeKind, DependencyIndex)]>, + // pub input_outputs: ThinBox<[DependencyEdge]>, once that is a thing + pub input_outputs: Box<[QueryEdge]>, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum QueryEdge { + Input(InputDependencyIndex), + Output(OutputDependencyIndex), } impl QueryEdges { /// Returns the (tracked) inputs that were executed in computing this memoized value. /// /// These will always be in execution order. - pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { - self.input_outputs - .iter() - .filter(|(edge_kind, _)| *edge_kind == EdgeKind::Input) - .map(|(_, dependency_index)| *dependency_index) + pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { + self.input_outputs.iter().filter_map(|&edge| match edge { + QueryEdge::Input(dependency_index) => Some(dependency_index), + QueryEdge::Output(_) => None, + }) } /// Returns the (tracked) outputs that were executed in computing this memoized value. /// /// These will always be in execution order. - pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { - self.input_outputs - .iter() - .filter(|(edge_kind, _)| *edge_kind == EdgeKind::Output) - .map(|(_, dependency_index)| *dependency_index) + pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { + self.input_outputs.iter().filter_map(|&edge| match edge { + QueryEdge::Output(dependency_index) => Some(dependency_index), + QueryEdge::Input(_) => None, + }) } /// Creates a new `QueryEdges`; the values given for each field must meet struct invariants. - pub(crate) fn new(input_outputs: Arc<[(EdgeKind, DependencyIndex)]>) -> Self { - Self { input_outputs } + pub(crate) fn new(input_outputs: impl IntoIterator) -> Self { + Self { + input_outputs: input_outputs.into_iter().collect(), + } } } diff --git a/tests/compile-fail/get-set-on-private-input-field.rs b/tests/compile-fail/get-set-on-private-input-field.rs index 5ecec5836..345590b75 100644 --- a/tests/compile-fail/get-set-on-private-input-field.rs +++ b/tests/compile-fail/get-set-on-private-input-field.rs @@ -1,5 +1,3 @@ -use salsa::prelude::*; - mod a { #[salsa::input] pub struct MyInput { diff --git a/tests/compile-fail/get-set-on-private-input-field.stderr b/tests/compile-fail/get-set-on-private-input-field.stderr index 887ab00d9..40acd8c2d 100644 --- a/tests/compile-fail/get-set-on-private-input-field.stderr +++ b/tests/compile-fail/get-set-on-private-input-field.stderr @@ -1,25 +1,17 @@ error[E0624]: method `field` is private - --> tests/compile-fail/get-set-on-private-input-field.rs:14:11 + --> tests/compile-fail/get-set-on-private-input-field.rs:12:11 | -4 | #[salsa::input] +2 | #[salsa::input] | --------------- private method defined here ... -14 | input.field(&db); +12 | input.field(&db); | ^^^^^ private method error[E0624]: method `set_field` is private - --> tests/compile-fail/get-set-on-private-input-field.rs:15:11 + --> tests/compile-fail/get-set-on-private-input-field.rs:13:11 | -4 | #[salsa::input] +2 | #[salsa::input] | --------------- private method defined here ... -15 | input.set_field(&mut db).to(23); +13 | input.set_field(&mut db).to(23); | ^^^^^^^^^ private method - -warning: unused import: `salsa::prelude` - --> tests/compile-fail/get-set-on-private-input-field.rs:1:5 - | -1 | use salsa::prelude::*; - | ^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default diff --git a/tests/interned-structs.rs b/tests/interned-structs.rs index 70dbc2632..a8263ce3a 100644 --- a/tests/interned-structs.rs +++ b/tests/interned-structs.rs @@ -5,6 +5,11 @@ use expect_test::expect; use std::path::{Path, PathBuf}; use test_log::test; +#[salsa::interned] +struct InternedBoxed<'db> { + data: Box, +} + #[salsa::interned] struct InternedString<'db> { data: String, @@ -73,6 +78,16 @@ fn interning_returns_equal_keys_for_equal_data_multi_field() { assert_ne!(s1, new); } +#[test] +fn interning_boxed() { + let db = salsa::DatabaseImpl::new(); + + assert_eq!( + InternedBoxed::new(&db, "Hello"), + InternedBoxed::new(&db, Box::from("Hello")) + ); +} + #[test] fn interning_vec() { let db = salsa::DatabaseImpl::new(); diff --git a/tests/interned-structs_self_ref.rs b/tests/interned-structs_self_ref.rs new file mode 100644 index 000000000..269724b8a --- /dev/null +++ b/tests/interned-structs_self_ref.rs @@ -0,0 +1,182 @@ +//! Test that a `tracked` fn on a `salsa::input` +//! compiles and executes successfully. + +use std::convert::identity; + +use test_log::test; + +#[test] +fn interning_returns_equal_keys_for_equal_data() { + let db = salsa::DatabaseImpl::new(); + let s1 = InternedString::new(&db, "Hello, ".to_string(), identity); + let s2 = InternedString::new(&db, "World, ".to_string(), |_| s1); + let s1_2 = InternedString::new(&db, "Hello, ", identity); + let s2_2 = InternedString::new(&db, "World, ", |_| s2); + assert_eq!(s1, s1_2); + assert_eq!(s2, s2_2); +} +// Recursive expansion of interned macro +// #[salsa::interned] +// struct InternedString<'db> { +// data: String, +// other: InternedString<'db>, +// } +// ====================================== + +#[derive(Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +struct InternedString<'db>( + salsa::Id, + std::marker::PhantomData<&'db salsa::plumbing::interned::Value>>, +); + +#[allow(warnings)] +const _: () = { + use salsa::plumbing as zalsa_; + use zalsa_::interned as zalsa_struct_; + type Configuration_ = InternedString<'static>; + #[derive(Clone)] + struct StructData<'db>(String, InternedString<'db>); + + impl<'db> Eq for StructData<'db> {} + impl<'db> PartialEq for StructData<'db> { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } + } + + impl<'db> std::hash::Hash for StructData<'db> { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } + } + + #[doc = r" Key to use during hash lookups. Each field is some type that implements `Lookup`"] + #[doc = r" for the owned type. This permits interning with an `&str` when a `String` is required and so forth."] + #[derive(Hash)] + struct StructKey<'db, T0>(T0, std::marker::PhantomData<&'db ()>); + + impl<'db, T0> zalsa_::interned::HashEqLike> for StructData<'db> + where + String: zalsa_::interned::HashEqLike, + { + fn hash(&self, h: &mut H) { + zalsa_::interned::HashEqLike::::hash(&self.0, &mut *h); + } + fn eq(&self, data: &StructKey<'db, T0>) -> bool { + (zalsa_::interned::HashEqLike::::eq(&self.0, &data.0) && true) + } + } + impl zalsa_struct_::Configuration for Configuration_ { + const DEBUG_NAME: &'static str = "InternedString"; + type Data<'a> = StructData<'a>; + type Struct<'a> = InternedString<'a>; + fn struct_from_id<'db>(id: salsa::Id) -> Self::Struct<'db> { + InternedString(id, std::marker::PhantomData) + } + fn deref_struct(s: Self::Struct<'_>) -> salsa::Id { + s.0 + } + } + impl Configuration_ { + pub fn ingredient(db: &Db) -> &zalsa_struct_::IngredientImpl + where + Db: ?Sized + zalsa_::Database, + { + static CACHE: zalsa_::IngredientCache> = + zalsa_::IngredientCache::new(); + CACHE.get_or_create(db.as_dyn_database(), || { + db.zalsa() + .add_or_lookup_jar_by_type(&>::default()) + }) + } + } + impl zalsa_::AsId for InternedString<'_> { + fn as_id(&self) -> salsa::Id { + self.0 + } + } + impl zalsa_::FromId for InternedString<'_> { + fn from_id(id: salsa::Id) -> Self { + Self(id, std::marker::PhantomData) + } + } + unsafe impl Send for InternedString<'_> {} + + unsafe impl Sync for InternedString<'_> {} + + impl std::fmt::Debug for InternedString<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Self::default_debug_fmt(*self, f) + } + } + impl zalsa_::SalsaStructInDb for InternedString<'_> { + fn lookup_ingredient_index( + aux: &dyn zalsa_::JarAux, + ) -> core::option::Option { + aux.lookup_jar_by_type(&>::default()) + } + } + + unsafe impl zalsa_::Update for InternedString<'_> { + unsafe fn maybe_update(old_pointer: *mut Self, new_value: Self) -> bool { + if unsafe { *old_pointer } != new_value { + unsafe { *old_pointer = new_value }; + true + } else { + false + } + } + } + impl<'db> InternedString<'db> { + pub fn new + std::hash::Hash>( + db: &'db Db_, + data: T0, + other: impl FnOnce(InternedString<'db>) -> InternedString<'db>, + ) -> Self + where + Db_: ?Sized + salsa::Database, + String: zalsa_::interned::HashEqLike, + { + let current_revision = zalsa_::current_revision(db); + Configuration_::ingredient(db).intern( + db.as_dyn_database(), + StructKey::<'db>(data, std::marker::PhantomData::default()), + |id, data| { + StructData( + zalsa_::interned::Lookup::into_owned(data.0), + other(zalsa_::FromId::from_id(id)), + ) + }, + ) + } + fn data(self, db: &'db Db_) -> String + where + Db_: ?Sized + zalsa_::Database, + { + let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), self); + std::clone::Clone::clone((&fields.0)) + } + fn other(self, db: &'db Db_) -> InternedString<'db> + where + Db_: ?Sized + zalsa_::Database, + { + let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), self); + std::clone::Clone::clone((&fields.1)) + } + #[doc = r" Default debug formatting for this struct (may be useful if you define your own `Debug` impl)"] + pub fn default_debug_fmt(this: Self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + zalsa_::with_attached_database(|db| { + let fields = Configuration_::ingredient(db).fields(db.as_dyn_database(), this); + let mut f = f.debug_struct("InternedString"); + let f = f.field("data", &fields.0); + let f = f.field("other", &fields.1); + f.finish() + }) + .unwrap_or_else(|| { + f.debug_tuple("InternedString") + .field(&zalsa_::AsId::as_id(&this)) + .finish() + }) + } + } +}; diff --git a/tests/parallel/parallel_map.rs b/tests/parallel/parallel_map.rs index 80e4ebeaf..dee46fd8f 100644 --- a/tests/parallel/parallel_map.rs +++ b/tests/parallel/parallel_map.rs @@ -15,17 +15,18 @@ struct ParallelInput { fn tracked_fn(db: &dyn salsa::Database, input: ParallelInput) -> Vec { salsa::par_map(db, input.field(db), |_db, field| field + 1) } + #[salsa::tracked] fn a1(db: &dyn KnobsDatabase, input: ParallelInput) -> Vec { db.signal(1); salsa::par_map(db, input.field(db), |db, field| { db.wait_for(2); - field + 1 + field + dummy(db) }) } #[salsa::tracked] -fn dummy(_db: &dyn KnobsDatabase, _input: ParallelInput) -> ParallelInput { +fn dummy(_db: &dyn KnobsDatabase) -> u32 { panic!("should never get here!") } diff --git a/tests/parallel/signal.rs b/tests/parallel/signal.rs index f09aecc83..e93cb469a 100644 --- a/tests/parallel/signal.rs +++ b/tests/parallel/signal.rs @@ -8,8 +8,6 @@ pub(crate) struct Signal { impl Signal { pub(crate) fn signal(&self, stage: usize) { - dbg!(format!("signal({})", stage)); - // This check avoids acquiring the lock for things that will // clearly be a no-op. Not *necessary* but helps to ensure we // are more likely to encounter weird race conditions; @@ -27,8 +25,6 @@ impl Signal { /// Waits until the given condition is true; the fn is invoked /// with the current stage. pub(crate) fn wait_for(&self, stage: usize) { - dbg!(format!("wait_for({})", stage)); - // As above, avoid lock if clearly a no-op. if stage > 0 { let mut v = self.value.lock(); diff --git a/tests/tracked_fn_multiple_args.rs b/tests/tracked_fn_multiple_args.rs new file mode 100644 index 000000000..7c014356c --- /dev/null +++ b/tests/tracked_fn_multiple_args.rs @@ -0,0 +1,25 @@ +//! Test that a `tracked` fn on multiple salsa struct args +//! compiles and executes successfully. + +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::interned] +struct MyInterned<'db> { + field: u32, +} + +#[salsa::tracked] +fn tracked_fn<'db>(db: &'db dyn salsa::Database, input: MyInput, interned: MyInterned<'db>) -> u32 { + input.field(db) + interned.field(db) +} + +#[test] +fn execute() { + let db = salsa::DatabaseImpl::new(); + let input = MyInput::new(&db, 22); + let interned = MyInterned::new(&db, 33); + assert_eq!(tracked_fn(&db, input, interned), 55); +}