From 3d00ae33b68e06726938870a85732212ab7f0380 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 28 Aug 2024 15:19:47 -0700 Subject: [PATCH 01/35] add example test for fixpoint iteration --- tests/cycle_fixpoint.rs | 147 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 tests/cycle_fixpoint.rs diff --git a/tests/cycle_fixpoint.rs b/tests/cycle_fixpoint.rs new file mode 100644 index 000000000..55a626753 --- /dev/null +++ b/tests/cycle_fixpoint.rs @@ -0,0 +1,147 @@ +/// Minimal example use case for fixpoint iteration cycle resolution. +use salsa::{Database as Db, Setter}; + +/// A Use of a symbol. +#[salsa::input] +struct Use { + reaching_definitions: Vec, +} + +#[salsa::input] +struct Literal { + value: LiteralValue, +} + +#[derive(Clone, Debug)] +enum LiteralValue { + Int(usize), + Str(String), +} + +/// A Definition of a symbol, either of the form `base + increment` or `0 + increment`. +#[salsa::input] +struct Definition { + base: Option, + increment: Literal, +} + +#[derive(Eq, PartialEq, Clone, Debug)] +enum Type { + Unbound, + LiteralInt(usize), + LiteralStr(String), + Int, + Str, + Union(Vec), +} + +#[salsa::tracked] +fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { + let defs = u.reaching_definitions(db); + match defs[..] { + [] => Type::Unbound, + [def] => infer_definition(db, def), + _ => Type::Union(defs.iter().map(|&def| infer_definition(db, def)).collect()), + } +} + +#[salsa::tracked] +fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { + let increment_ty = infer_literal(db, def.increment(db)); + if let Some(base) = def.base(db) { + let base_ty = infer_use(db, base); + match (base_ty, increment_ty) { + (Type::Unbound, _) => panic!("unbound use"), + (Type::LiteralInt(b), Type::LiteralInt(i)) => Type::LiteralInt(b + i), + (Type::LiteralStr(b), Type::LiteralStr(i)) => Type::LiteralStr(format!("{}{}", b, i)), + (Type::Int, Type::LiteralInt(_)) => Type::Int, + (Type::LiteralInt(_), Type::Int) => Type::Int, + (Type::Str, Type::LiteralStr(_)) => Type::Str, + (Type::LiteralStr(_), Type::Str) => Type::Str, + _ => panic!("type error"), + } + } else { + increment_ty + } +} + +#[salsa::tracked] +fn infer_literal<'db>(db: &'db dyn Db, literal: Literal) -> Type { + match literal.value(db) { + LiteralValue::Int(i) => Type::LiteralInt(i), + LiteralValue::Str(s) => Type::LiteralStr(s), + } +} + +/// x = 1 +#[test] +fn simple() { + let db = salsa::DatabaseImpl::new(); + + let def = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1))); + let u = Use::new(&db, vec![def]); + + let ty = infer_use(&db, u); + + assert_eq!(ty, Type::LiteralInt(1)); +} + +/// x = "a" if flag else "b" +#[test] +fn union() { + let db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new( + &db, + None, + Literal::new(&db, LiteralValue::Str("a".to_string())), + ); + let def2 = Definition::new( + &db, + None, + Literal::new(&db, LiteralValue::Str("b".to_string())), + ); + let u = Use::new(&db, vec![def1, def2]); + + let ty = infer_use(&db, u); + + assert_eq!( + ty, + Type::Union(vec![ + Type::LiteralStr("a".to_string()), + Type::LiteralStr("b".to_string()) + ]) + ); +} + +/// x = 1; loop: x = x + 0 +#[test] +fn cycle_converges() { + let mut db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1))); + let def2 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); + let u = Use::new(&db, vec![def1, def2]); + def2.set_base(&mut db).to(Some(u)); + + let ty = infer_use(&db, u); + + /// Loop converges on LiteralInt(1) + assert_eq!(ty, Type::LiteralInt(1)); +} + +/// x = 1; loop: x = x + 1 +#[test] +fn cycle_diverges() { + let mut db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1))); + let def2 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1))); + let u = Use::new(&db, vec![def1, def2]); + def2.set_base(&mut db).to(Some(u)); + + let ty = infer_use(&db, u); + + /// Loop diverges. Cut it off and fallback from "all LiteralInt observed" to Type::Int + assert_eq!(ty, Type::Int); +} From 31979b9cfcb223019860f8945146d485f25118eb Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 28 Aug 2024 15:35:49 -0700 Subject: [PATCH 02/35] add a multi-symbol test --- tests/cycle_fixpoint.rs | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/tests/cycle_fixpoint.rs b/tests/cycle_fixpoint.rs index 55a626753..5b0a9d3a4 100644 --- a/tests/cycle_fixpoint.rs +++ b/tests/cycle_fixpoint.rs @@ -114,7 +114,7 @@ fn union() { ); } -/// x = 1; loop: x = x + 0 +/// x = 1; loop { x = x + 0 } #[test] fn cycle_converges() { let mut db = salsa::DatabaseImpl::new(); @@ -126,11 +126,11 @@ fn cycle_converges() { let ty = infer_use(&db, u); - /// Loop converges on LiteralInt(1) + // Loop converges on LiteralInt(1) assert_eq!(ty, Type::LiteralInt(1)); } -/// x = 1; loop: x = x + 1 +/// x = 1; loop { x = x + 1 } #[test] fn cycle_diverges() { let mut db = salsa::DatabaseImpl::new(); @@ -142,6 +142,28 @@ fn cycle_diverges() { let ty = infer_use(&db, u); - /// Loop diverges. Cut it off and fallback from "all LiteralInt observed" to Type::Int + // Loop diverges. Cut it off and fallback from "all LiteralInt observed" to Type::Int assert_eq!(ty, Type::Int); } + +/// x = 0; y = 0; loop { x = y + 0; y = x + 0 } +#[test] +fn multi_symbol_cycle_converges() { + let mut db = salsa::DatabaseImpl::new(); + + let defx0 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); + let defy0 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); + let defx1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); + let defy1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); + let use_x = Use::new(&db, vec![defx0, defx1]); + let use_y = Use::new(&db, vec![defy0, defy1]); + defx1.set_base(&mut db).to(Some(use_y)); + defy1.set_base(&mut db).to(Some(use_x)); + + let x_ty = infer_use(&db, use_x); + let y_ty = infer_use(&db, use_y); + + // Both symbols converge on LiteralInt(0) + assert_eq!(x_ty, Type::LiteralInt(0)); + assert_eq!(y_ty, Type::LiteralInt(0)); +} From 35e236e3cde88c8b7887e7c3d68c63611e96656b Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 20 Sep 2024 18:32:25 -0700 Subject: [PATCH 03/35] simplify test case --- tests/cycle_fixpoint.rs | 148 +++++++++++++++++++++++----------------- 1 file changed, 84 insertions(+), 64 deletions(-) diff --git a/tests/cycle_fixpoint.rs b/tests/cycle_fixpoint.rs index 5b0a9d3a4..9d788a2a4 100644 --- a/tests/cycle_fixpoint.rs +++ b/tests/cycle_fixpoint.rs @@ -1,5 +1,7 @@ -/// Minimal example use case for fixpoint iteration cycle resolution. +/// Minimal(ish) example use case for fixpoint iteration cycle resolution. use salsa::{Database as Db, Setter}; +use std::collections::BTreeSet; +use std::iter::IntoIterator; /// A Use of a symbol. #[salsa::input] @@ -9,13 +11,7 @@ struct Use { #[salsa::input] struct Literal { - value: LiteralValue, -} - -#[derive(Clone, Debug)] -enum LiteralValue { - Int(usize), - Str(String), + value: usize, } /// A Definition of a symbol, either of the form `base + increment` or `0 + increment`. @@ -27,21 +23,39 @@ struct Definition { #[derive(Eq, PartialEq, Clone, Debug)] enum Type { - Unbound, - LiteralInt(usize), - LiteralStr(String), - Int, - Str, - Union(Vec), + Bottom, + Values(Box<[usize]>), + Top, +} + +impl Type { + fn join(tys: impl IntoIterator) -> Type { + let mut result = Type::Bottom; + for ty in tys.into_iter() { + result = match (result, ty) { + (result, Type::Bottom) => result, + (_, Type::Top) => Type::Top, + (Type::Top, _) => Type::Top, + (Type::Bottom, ty) => ty, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend(a_ints); + set.extend(b_ints); + Type::Values(set.into_iter().collect()) + } + } + } + result + } } #[salsa::tracked] fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { let defs = u.reaching_definitions(db); match defs[..] { - [] => Type::Unbound, + [] => Type::Bottom, [def] => infer_definition(db, def), - _ => Type::Union(defs.iter().map(|&def| infer_definition(db, def)).collect()), + _ => Type::join(defs.iter().map(|&def| infer_definition(db, def))), } } @@ -50,27 +64,31 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { let increment_ty = infer_literal(db, def.increment(db)); if let Some(base) = def.base(db) { let base_ty = infer_use(db, base); - match (base_ty, increment_ty) { - (Type::Unbound, _) => panic!("unbound use"), - (Type::LiteralInt(b), Type::LiteralInt(i)) => Type::LiteralInt(b + i), - (Type::LiteralStr(b), Type::LiteralStr(i)) => Type::LiteralStr(format!("{}{}", b, i)), - (Type::Int, Type::LiteralInt(_)) => Type::Int, - (Type::LiteralInt(_), Type::Int) => Type::Int, - (Type::Str, Type::LiteralStr(_)) => Type::Str, - (Type::LiteralStr(_), Type::Str) => Type::Str, - _ => panic!("type error"), - } + add(&base_ty, &increment_ty) } else { increment_ty } } +fn add(a: &Type, b: &Type) -> Type { + match (a, b) { + (Type::Bottom, _) | (_, Type::Bottom) => panic!("unbound use"), + (Type::Top, _) | (_, Type::Top) => Type::Top, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend( + a_ints + .into_iter() + .flat_map(|a| b_ints.into_iter().map(move |b| a + b)), + ); + Type::Values(set.into_iter().collect()) + } + } +} + #[salsa::tracked] fn infer_literal<'db>(db: &'db dyn Db, literal: Literal) -> Type { - match literal.value(db) { - LiteralValue::Int(i) => Type::LiteralInt(i), - LiteralValue::Str(s) => Type::LiteralStr(s), - } + Type::Values(Box::from([literal.value(db)])) } /// x = 1 @@ -78,40 +96,42 @@ fn infer_literal<'db>(db: &'db dyn Db, literal: Literal) -> Type { fn simple() { let db = salsa::DatabaseImpl::new(); - let def = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1))); + let def = Definition::new(&db, None, Literal::new(&db, 1)); let u = Use::new(&db, vec![def]); let ty = infer_use(&db, u); - assert_eq!(ty, Type::LiteralInt(1)); + assert_eq!(ty, Type::Values(Box::from([1]))); } -/// x = "a" if flag else "b" +/// x = 1 if flag else 2 #[test] fn union() { let db = salsa::DatabaseImpl::new(); - let def1 = Definition::new( - &db, - None, - Literal::new(&db, LiteralValue::Str("a".to_string())), - ); - let def2 = Definition::new( - &db, - None, - Literal::new(&db, LiteralValue::Str("b".to_string())), - ); + let def1 = Definition::new(&db, None, Literal::new(&db, 1)); + let def2 = Definition::new(&db, None, Literal::new(&db, 2)); let u = Use::new(&db, vec![def1, def2]); let ty = infer_use(&db, u); - assert_eq!( - ty, - Type::Union(vec![ - Type::LiteralStr("a".to_string()), - Type::LiteralStr("b".to_string()) - ]) - ); + assert_eq!(ty, Type::Values(Box::from([1, 2]))); +} + +/// x = 1 if flag else 2; y = x + 1 +#[test] +fn union_add() { + let db = salsa::DatabaseImpl::new(); + + let x1 = Definition::new(&db, None, Literal::new(&db, 1)); + let x2 = Definition::new(&db, None, Literal::new(&db, 2)); + let x_use = Use::new(&db, vec![x1, x2]); + let y_def = Definition::new(&db, Some(x_use), Literal::new(&db, 1)); + let y_use = Use::new(&db, vec![y_def]); + + let ty = infer_use(&db, y_use); + + assert_eq!(ty, Type::Values(Box::from([2, 3]))); } /// x = 1; loop { x = x + 0 } @@ -119,15 +139,15 @@ fn union() { fn cycle_converges() { let mut db = salsa::DatabaseImpl::new(); - let def1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1))); - let def2 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); + let def1 = Definition::new(&db, None, Literal::new(&db, 1)); + let def2 = Definition::new(&db, None, Literal::new(&db, 0)); let u = Use::new(&db, vec![def1, def2]); def2.set_base(&mut db).to(Some(u)); let ty = infer_use(&db, u); - // Loop converges on LiteralInt(1) - assert_eq!(ty, Type::LiteralInt(1)); + // Loop converges on 1 + assert_eq!(ty, Type::Values(Box::from([1]))); } /// x = 1; loop { x = x + 1 } @@ -135,15 +155,15 @@ fn cycle_converges() { fn cycle_diverges() { let mut db = salsa::DatabaseImpl::new(); - let def1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1))); - let def2 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(1))); + let def1 = Definition::new(&db, None, Literal::new(&db, 1)); + let def2 = Definition::new(&db, None, Literal::new(&db, 1)); let u = Use::new(&db, vec![def1, def2]); def2.set_base(&mut db).to(Some(u)); let ty = infer_use(&db, u); - // Loop diverges. Cut it off and fallback from "all LiteralInt observed" to Type::Int - assert_eq!(ty, Type::Int); + // Loop diverges. Cut it off and fallback to Type::Top + assert_eq!(ty, Type::Top); } /// x = 0; y = 0; loop { x = y + 0; y = x + 0 } @@ -151,10 +171,10 @@ fn cycle_diverges() { fn multi_symbol_cycle_converges() { let mut db = salsa::DatabaseImpl::new(); - let defx0 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); - let defy0 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); - let defx1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); - let defy1 = Definition::new(&db, None, Literal::new(&db, LiteralValue::Int(0))); + let defx0 = Definition::new(&db, None, Literal::new(&db, 0)); + let defy0 = Definition::new(&db, None, Literal::new(&db, 0)); + let defx1 = Definition::new(&db, None, Literal::new(&db, 0)); + let defy1 = Definition::new(&db, None, Literal::new(&db, 0)); let use_x = Use::new(&db, vec![defx0, defx1]); let use_y = Use::new(&db, vec![defy0, defy1]); defx1.set_base(&mut db).to(Some(use_y)); @@ -164,6 +184,6 @@ fn multi_symbol_cycle_converges() { let y_ty = infer_use(&db, use_y); // Both symbols converge on LiteralInt(0) - assert_eq!(x_ty, Type::LiteralInt(0)); - assert_eq!(y_ty, Type::LiteralInt(0)); + assert_eq!(x_ty, Type::Values(Box::from([0]))); + assert_eq!(y_ty, Type::Values(Box::from([0]))); } From 90ea6e97cf4c702c419af897971596256c863547 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 8 Oct 2024 16:08:09 -0700 Subject: [PATCH 04/35] WIP: remove existing cycle handling tests for now --- tests/cycles.rs | 436 ------------------ tests/parallel/parallel_cycle_all_recover.rs | 104 ----- tests/parallel/parallel_cycle_mid_recover.rs | 102 ---- tests/parallel/parallel_cycle_none_recover.rs | 78 ---- tests/parallel/parallel_cycle_one_recover.rs | 91 ---- 5 files changed, 811 deletions(-) delete mode 100644 tests/cycles.rs delete mode 100644 tests/parallel/parallel_cycle_all_recover.rs delete mode 100644 tests/parallel/parallel_cycle_mid_recover.rs delete mode 100644 tests/parallel/parallel_cycle_none_recover.rs delete mode 100644 tests/parallel/parallel_cycle_one_recover.rs diff --git a/tests/cycles.rs b/tests/cycles.rs deleted file mode 100644 index f07484188..000000000 --- a/tests/cycles.rs +++ /dev/null @@ -1,436 +0,0 @@ -#![allow(warnings)] - -use std::panic::{RefUnwindSafe, UnwindSafe}; - -use expect_test::expect; -use salsa::DatabaseImpl; -use salsa::Durability; - -// Axes: -// -// Threading -// * Intra-thread -// * Cross-thread -- part of cycle is on one thread, part on another -// -// Recovery strategies: -// * Panic -// * Fallback -// * Mixed -- multiple strategies within cycle participants -// -// Across revisions: -// * N/A -- only one revision -// * Present in new revision, not old -// * Present in old revision, not new -// * Present in both revisions -// -// Dependencies -// * Tracked -// * Untracked -- cycle participant(s) contain untracked reads -// -// Layers -// * Direct -- cycle participant is directly invoked from test -// * Indirect -- invoked a query that invokes the cycle -// -// -// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | -// | ------ | -------- | -------- | --------- | ------ | --------- | -// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | -// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | -// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | -// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | -// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | -// | Intra | Fallback | New | Tracked | direct | cycle_appears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | -// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | - -#[derive(PartialEq, Eq, Hash, Clone, Debug)] -struct Error { - cycle: Vec, -} - -use salsa::Database as Db; -use salsa::Setter; - -#[salsa::input] -struct MyInput {} - -#[salsa::tracked] -fn memoized_a(db: &dyn Db, input: MyInput) { - memoized_b(db, input) -} - -#[salsa::tracked] -fn memoized_b(db: &dyn Db, input: MyInput) { - memoized_a(db, input) -} - -#[salsa::tracked] -fn volatile_a(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_b(db, input) -} - -#[salsa::tracked] -fn volatile_b(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_a(db, input) -} - -/// The queries A, B, and C in `Database` can be configured -/// to invoke one another in arbitrary ways using this -/// enum. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum CycleQuery { - None, - A, - B, - C, - AthenC, -} - -#[salsa::input] -struct ABC { - a: CycleQuery, - b: CycleQuery, - c: CycleQuery, -} - -impl CycleQuery { - fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { - match self { - CycleQuery::A => cycle_a(db, abc), - CycleQuery::B => cycle_b(db, abc), - CycleQuery::C => cycle_c(db, abc), - CycleQuery::AthenC => { - let _ = cycle_a(db, abc); - cycle_c(db, abc) - } - CycleQuery::None => Ok(()), - } - } -} - -#[salsa::tracked(recovery_fn=recover_a)] -fn cycle_a(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.a(db).invoke(db, abc) -} - -fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked(recovery_fn=recover_b)] -fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.b(db).invoke(db, abc) -} - -fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked] -fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.c(db).invoke(db, abc) -} - -#[track_caller] -fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { - let v = std::panic::catch_unwind(f); - if let Err(d) = &v { - if let Some(cycle) = d.downcast_ref::() { - return cycle.clone(); - } - } - panic!("unexpected value: {:?}", v) -} - -#[test] -fn cycle_memoized() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| memoized_a(db, input)); - let expected = expect![[r#" - [ - memoized_a(Id(0)), - memoized_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }) -} - -#[test] -fn cycle_volatile() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| volatile_a(db, input)); - let expected = expect![[r#" - [ - volatile_a(Id(0)), - volatile_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }); -} - -#[test] -fn expect_cycle() { - // A --> B - // ^ | - // +-----+ - - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(db, abc).is_err()); - }) -} - -#[test] -fn inner_cycle() { - // A --> B <-- C - // ^ | - // +-----+ - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); - let err = cycle_c(db, abc); - assert!(err.is_err()); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&err.unwrap_err().cycle); - }) -} - -#[test] -fn cycle_revalidate() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - abc.set_b(&mut db).to(CycleQuery::A); // same value as default - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_recovery_unchanged_twice() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - abc.set_c(&mut db).to(CycleQuery::A); // force new revision - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_appears() { - let mut db = salsa::DatabaseImpl::new(); - // A --> B - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); - - // A --> B - // ^ | - // +-----+ - abc.set_b(&mut db).to(CycleQuery::A); - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_disappears() { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - // A --> B - abc.set_b(&mut db).to(CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); -} - -/// A variant on `cycle_disappears` in which the values of -/// `a` and `b` are set with durability values. -/// If we are not careful, this could cause us to overlook -/// the fact that the cycle will no longer occur. -#[test] -fn cycle_disappears_durability() { - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new( - &mut db, - CycleQuery::None, - CycleQuery::None, - CycleQuery::None, - ); - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::B); - abc.set_b(&mut db) - .with_durability(Durability::HIGH) - .to(CycleQuery::A); - - assert!(cycle_a(&db, abc).is_err()); - - // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, - // because `b` participates in the same cycle as `a`, its final durability - // should be `LOW`. - // - // Check that setting a `LOW` input causes us to re-execute `b` query, and - // observe that the cycle goes away. - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::None); - - assert!(cycle_b(&mut db, abc).is_ok()); -} - -#[test] -fn cycle_mixed_1() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> B <-- C - // | ^ - // +-----+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::B); - - let expected = expect![[r#" - [ - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_c(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_mixed_2() { - salsa::DatabaseImpl::new().attach(|db| { - // Configuration: - // - // A --> B --> C - // ^ | - // +-----------+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::A); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_a(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_deterministic_order() { - // No matter whether we start from A or B, we get the same set of participants: - let f = || { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - (db, abc) - }; - let (db, abc) = f(); - let a = cycle_a(&db, abc); - let (db, abc) = f(); - let b = cycle_b(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); -} - -#[test] -fn cycle_multiple() { - // No matter whether we start from A or B, we get the same set of participants: - let mut db = salsa::DatabaseImpl::new(); - - // Configuration: - // - // A --> B <-- C - // ^ | ^ - // +-----+ | - // | | - // +-----+ - // - // Here, conceptually, B encounters a cycle with A and then - // recovers. - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); - - let c = cycle_c(&db, abc); - let b = cycle_b(&db, abc); - let a = cycle_a(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&( - c.unwrap_err().cycle, - b.unwrap_err().cycle, - a.unwrap_err().cycle, - )); -} - -#[test] -fn cycle_recovery_set_but_not_participating() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> C -+ - // ^ | - // +--+ - let abc = ABC::new(db, CycleQuery::C, CycleQuery::None, CycleQuery::C); - - // Here we expect C to panic and A not to recover: - let r = extract_cycle(|| drop(cycle_a(db, abc))); - let expected = expect![[r#" - [ - cycle_c(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&r.all_participants(db)); - }) -} diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs deleted file mode 100644 index 9dc8c74e2..000000000 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ /dev/null @@ -1,104 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked(recovery_fn = recover_a1)] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a1"); - key.field(db) * 10 + 1 -} - -#[salsa::tracked(recovery_fn=recover_a2)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a2"); - key.field(db) * 10 + 2 -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 1 -} - -#[salsa::tracked(recovery_fn=recover_b2)] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b2"); - key.field(db) * 20 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected, recovers) -// | b2 completes, recovers -// | b1 completes, recovers -// a2 sees cycle, recovers -// a1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - assert_eq!(thread_a.join().unwrap(), 11); - assert_eq!(thread_b.join().unwrap(), 21); -} diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs deleted file mode 100644 index 593d46a66..000000000 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ /dev/null @@ -1,102 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // tell thread b we have started - db.signal(1); - - // wait for thread b to block on a1 - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // create the cycle - b1(db, input) -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // wait for thread a to have started - db.wait_for(1); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will encounter a cycle but recover - b3(db, input); - b1(db, input); // hasn't recovered yet - 0 -} - -#[salsa::tracked(recovery_fn=recover_b3)] -pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will block on thread a, signaling stage 2 - a1(db, input) -} - -fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b3"); - key.field(db) * 200 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | | -// | b2 -// | b3 -// | a1 (blocks -> stage 2) -// (unblocked) | -// a2 (cycle detected) | -// b3 recovers -// b2 resumes -// b1 recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs deleted file mode 100644 index 89f1ecfb0..000000000 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Test a cycle where no queries recover that occurs across threads. -//! See the `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; -use expect_test::expect; -use salsa::Database; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - b(db, input) -} - -#[salsa::tracked] -pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - - // Now try to execute A - a(db, input) -} - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, -1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b(&db, input) - }); - - // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). - // Right now, it panics with a string. - let err_b = thread_b.join().unwrap_err(); - db.attach(|_| { - if let Some(c) = err_b.downcast_ref::() { - let expected = expect![[r#" - [ - a(Id(0)), - b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&c.all_participants(&db)); - } else { - panic!("b failed in an unexpected way: {:?}", err_b); - } - }); - - // We expect A to propagate a panic, which causes us to use the sentinel - // type `Canceled`. - assert!(thread_a - .join() - .unwrap_err() - .downcast_ref::() - .is_some()); -} diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs deleted file mode 100644 index c03782821..000000000 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked(recovery_fn=recover)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected) -// a2 recovery fn executes | -// a1 completes normally | -// b2 completes, recovers -// b1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block.store(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} From f218e58f0f425c5729200e9e9c72b827bf42b6f7 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 8 Oct 2024 17:35:42 -0700 Subject: [PATCH 05/35] WIP: remove all existing cycle handling, add fixpoint options --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 14 +- .../src/unexpected_cycle_recovery.rs | 22 +-- components/salsa-macros/src/accumulator.rs | 3 +- components/salsa-macros/src/input.rs | 4 +- components/salsa-macros/src/interned.rs | 4 +- components/salsa-macros/src/options.rs | 44 ++++-- components/salsa-macros/src/tracked_fn.rs | 32 +++- components/salsa-macros/src/tracked_struct.rs | 4 +- src/active_query.rs | 33 +--- src/cycle.rs | 109 ++----------- src/function.rs | 17 +-- src/function/execute.rs | 29 +--- src/lib.rs | 5 +- src/runtime.rs | 143 +----------------- src/runtime/dependency_graph.rs | 130 ---------------- src/zalsa_local.rs | 36 ----- tests/cycle_fixpoint.rs | 26 +++- tests/parallel/main.rs | 4 - 18 files changed, 146 insertions(+), 513 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 79ec96d50..eed83c4ed 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -37,6 +37,9 @@ macro_rules! setup_tracked_fn { // Path to the cycle recovery function to use. cycle_recovery_fn: ($($cycle_recovery_fn:tt)*), + // Path to function to get the initial value to use for cycle recovery. + cycle_recovery_initial: ($($cycle_recovery_initial:tt)*), + // Name of cycle recovery strategy variant to use. cycle_recovery_strategy: $cycle_recovery_strategy:ident, @@ -174,12 +177,15 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } + fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db) -> Self::Output<$db_lt> { + $($cycle_recovery_initial)*(db) + } + fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - cycle: &$zalsa::Cycle, - ($($input_id),*): ($($input_ty),*) - ) -> Self::Output<$db_lt> { - $($cycle_recovery_fn)*(db, cycle, $($input_id),*) + value: Self::Output<$db_lt>, + ) -> $zalsa::CycleRecoveryAction> { + $($cycle_recovery_fn)*(db, value) } fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index a8b8122b3..e556e104e 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -1,13 +1,17 @@ // Macro that generates the body of the cycle recovery function -// for the case where no cycle recovery is possible. This has to be -// a macro because it can take a variadic number of arguments. +// for the case where no cycle recovery is possible. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $cycle:ident, $($other_inputs:ident),*) => { - { - std::mem::drop($db); - std::mem::drop(($($other_inputs),*)); - panic!("cannot recover from cycle `{:?}`", $cycle) - } - } + ($db:ident, $value:ident) => {{ + std::mem::drop($db); + panic!("cannot recover from cycle") + }}; +} + +#[macro_export] +macro_rules! unexpected_cycle_initial { + ($db:ident) => {{ + std::mem::drop($db); + panic!("no cycle initial value") + }}; } diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index 1e09cb08f..bee9a83f8 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -39,7 +39,8 @@ impl AllowedOptions for Accumulator { const SINGLETON: bool = false; const DATA: bool = false; const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + const CYCLE_INITIAL: bool = false; const LRU: bool = false; const CONSTRUCTOR_NAME: bool = false; } diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index 9ad444913..0eb443e0d 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -50,7 +50,9 @@ impl crate::options::AllowedOptions for InputStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index 8caba77e4..ca4f3e444 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -51,7 +51,9 @@ impl crate::options::AllowedOptions for InternedStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index 6f30bb3e6..c4175c705 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -44,10 +44,15 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub db_path: Option, - /// The `recovery_fn = ` option is used to indicate the recovery function. + /// The `cycle_fn = ` option is used to indicate the cycle recovery function. /// /// If this is `Some`, the value is the ``. - pub recovery_fn: Option, + pub cycle_fn: Option, + + /// The `cycle_initial = ` option is the initial value for cycle iteration. + /// + /// If this is `Some`, the value is the ``. + pub cycle_initial: Option, /// The `data = ` option is used to define the name of the data type for an interned /// struct. @@ -79,7 +84,8 @@ impl Default for Options { no_debug: Default::default(), no_clone: Default::default(), db_path: Default::default(), - recovery_fn: Default::default(), + cycle_fn: Default::default(), + cycle_initial: Default::default(), data: Default::default(), constructor_name: Default::default(), phantom: Default::default(), @@ -99,7 +105,8 @@ pub(crate) trait AllowedOptions { const SINGLETON: bool; const DATA: bool; const DB: bool; - const RECOVERY_FN: bool; + const CYCLE_FN: bool; + const CYCLE_INITIAL: bool; const LRU: bool; const CONSTRUCTOR_NAME: bool; } @@ -207,20 +214,39 @@ impl syn::parse::Parse for Options { "`db` option not allowed here", )); } - } else if ident == "recovery_fn" { - if A::RECOVERY_FN { + } else if ident == "cycle_fn" { + if A::CYCLE_FN { + let _eq = Equals::parse(input)?; + let path = syn::Path::parse(input)?; + if let Some(old) = std::mem::replace(&mut options.cycle_fn, Some(path)) { + return Err(syn::Error::new( + old.span(), + "option `cycle_fn` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`cycle_fn` option not allowed here", + )); + } + } else if ident == "cycle_initial" { + if A::CYCLE_INITIAL { + // TODO(carljm) should it be an error to give cycle_initial without cycle_fn, + // or should we just allow this to fall into potentially infinite iteration, if + // iteration never converges? let _eq = Equals::parse(input)?; let path = syn::Path::parse(input)?; - if let Some(old) = std::mem::replace(&mut options.recovery_fn, Some(path)) { + if let Some(old) = std::mem::replace(&mut options.cycle_initial, Some(path)) { return Err(syn::Error::new( old.span(), - "option `recovery_fn` provided twice", + "option `cycle_initial` provided twice", )); } } else { return Err(syn::Error::new( ident.span(), - "`recovery_fn` option not allowed here", + "`cycle_initial` option not allowed here", )); } } else if ident == "data" { diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 57023ef24..caa4c70d4 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -39,7 +39,9 @@ impl crate::options::AllowedOptions for TrackedFn { const DB: bool = false; - const RECOVERY_FN: bool = true; + const CYCLE_FN: bool = true; + + const CYCLE_INITIAL: bool = true; const LRU: bool = true; @@ -68,7 +70,8 @@ impl Macro { let input_ids = self.input_ids(&item); let input_tys = self.input_tys(&item)?; let output_ty = self.output_ty(&db_lt, &item)?; - let (cycle_recovery_fn, cycle_recovery_strategy) = self.cycle_recovery(); + let (cycle_recovery_fn, cycle_recovery_initial, cycle_recovery_strategy) = + self.cycle_recovery()?; let is_specifiable = self.args.specify.is_some(); let no_eq = self.args.no_eq.is_some(); @@ -127,6 +130,7 @@ impl Macro { output_ty: #output_ty, inner_fn: #inner_fn, cycle_recovery_fn: #cycle_recovery_fn, + cycle_recovery_initial: #cycle_recovery_initial, cycle_recovery_strategy: #cycle_recovery_strategy, is_specifiable: #is_specifiable, no_eq: #no_eq, @@ -160,14 +164,26 @@ impl Macro { Ok(ValidFn { db_ident, db_path }) } - fn cycle_recovery(&self) -> (TokenStream, TokenStream) { - if let Some(recovery_fn) = &self.args.recovery_fn { - (quote!((#recovery_fn)), quote!(Fallback)) - } else { - ( + fn cycle_recovery(&self) -> syn::Result<(TokenStream, TokenStream, TokenStream)> { + match (&self.args.cycle_fn, &self.args.cycle_initial) { + (Some(cycle_fn), Some(cycle_initial)) => Ok(( + quote!((#cycle_fn)), + quote!((#cycle_initial)), + quote!(Recover), + )), + (None, None) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), + quote!((salsa::plumbing::unexpected_cycle_initial!)), quote!(Panic), - ) + )), + (Some(_), None) => Err(syn::Error::new_spanned( + self.args.cycle_fn.as_ref().unwrap(), + "must provide `cycle_initial` along with `cycle_fn`", + )), + (None, Some(_)) => Err(syn::Error::new_spanned( + self.args.cycle_initial.as_ref().unwrap(), + "must provide `cycle_fn` along with `cycle_initial`", + )), } } diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index 1730b3404..009f6f8fc 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -45,7 +45,9 @@ impl crate::options::AllowedOptions for TrackedStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/src/active_query.rs b/src/active_query.rs index 71781a385..adf544cc2 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -9,7 +9,7 @@ use crate::{ key::{DatabaseKeyIndex, DependencyIndex}, tracked_struct::{Disambiguator, Identity}, zalsa_local::EMPTY_DEPENDENCIES, - Cycle, Id, Revision, + Id, Revision, }; #[derive(Debug)] @@ -35,9 +35,6 @@ pub(crate) struct ActiveQuery { /// True if there was an untracked read. untracked_read: bool, - /// Stores the entire cycle, if one is found and this query is part of it. - pub(crate) cycle: Option, - /// When new tracked structs are created, their data is hashed, and the resulting /// hash is added to this map. If it is not present, then the disambiguator is 0. /// Otherwise it is 1 more than the current value (which is incremented). @@ -64,7 +61,6 @@ impl ActiveQuery { changed_at: Revision::start(), input_outputs: FxIndexSet::default(), untracked_read: false, - cycle: None, disambiguator_map: Default::default(), tracked_struct_ids: Default::default(), accumulated: Default::default(), @@ -130,33 +126,6 @@ impl ActiveQuery { } } - /// Adds any dependencies from `other` into `self`. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(super) fn add_from(&mut self, other: &ActiveQuery) { - self.changed_at = self.changed_at.max(other.changed_at); - self.durability = self.durability.min(other.durability); - self.untracked_read |= other.untracked_read; - self.input_outputs - .extend(other.input_outputs.iter().copied()); - } - - /// Removes the participants in `cycle` from my dependencies. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(super) fn remove_cycle_participants(&mut self, cycle: &Cycle) { - for p in cycle.participant_keys() { - let p: DependencyIndex = p.into(); - self.input_outputs.shift_remove(&(EdgeKind::Input, p)); - } - } - - /// Copy the changed-at, durability, and dependencies from `cycle_query`. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) { - self.changed_at = cycle_query.changed_at; - self.durability = cycle_query.durability; - self.input_outputs.clone_from(&cycle_query.input_outputs); - } - pub(super) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator { let disambiguator = self .disambiguator_map diff --git a/src/cycle.rs b/src/cycle.rs index 2c67e164d..55ce28888 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,92 +1,11 @@ -use crate::{key::DatabaseKeyIndex, Database}; -use std::{panic::AssertUnwindSafe, sync::Arc}; - -/// Captures the participants of a cycle that occurred when executing a query. -/// -/// This type is meant to be used to help give meaningful error messages to the -/// user or to help salsa developers figure out why their program is resulting -/// in a computation cycle. -/// -/// It is used in a few ways: -/// -/// * During [cycle recovery](https://https://salsa-rs.github.io/salsa/cycles/fallback.html), -/// where it is given to the fallback function. -/// * As the panic value when an unexpected cycle (i.e., a cycle where one or more participants -/// lacks cycle recovery information) occurs. -/// -/// You can read more about cycle handling in -/// the [salsa book](https://https://salsa-rs.github.io/salsa/cycles.html). -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Cycle { - participants: CycleParticipants, -} - -// We want `Cycle`` to be thin -pub(crate) type CycleParticipants = Arc>; - -impl Cycle { - pub(crate) fn new(participants: CycleParticipants) -> Self { - Self { participants } - } - - /// True if two `Cycle` values represent the same cycle. - pub(crate) fn is(&self, cycle: &Cycle) -> bool { - Arc::ptr_eq(&self.participants, &cycle.participants) - } - - pub(crate) fn throw(self) -> ! { - tracing::debug!("throwing cycle {:?}", self); - std::panic::resume_unwind(Box::new(self)) - } - - pub(crate) fn catch(execute: impl FnOnce() -> T) -> Result { - match std::panic::catch_unwind(AssertUnwindSafe(execute)) { - Ok(v) => Ok(v), - Err(err) => match err.downcast::() { - Ok(cycle) => Err(*cycle), - Err(other) => std::panic::resume_unwind(other), - }, - } - } - - /// Iterate over the [`DatabaseKeyIndex`] for each query participating - /// in the cycle. The start point of this iteration within the cycle - /// is arbitrary but deterministic, but the ordering is otherwise determined - /// by the execution. - pub fn participant_keys(&self) -> impl Iterator + '_ { - self.participants.iter().copied() - } - - /// Returns a vector with the debug information for - /// all the participants in the cycle. - pub fn all_participants(&self, _db: &dyn Database) -> Vec { - self.participant_keys().collect() - } - - /// Returns a vector with the debug information for - /// those participants in the cycle that lacked recovery - /// information. - pub fn unexpected_participants(&self, db: &dyn Database) -> Vec { - self.participant_keys() - .filter(|&d| d.cycle_recovery_strategy(db) == CycleRecoveryStrategy::Panic) - .collect() - } -} - -impl std::fmt::Debug for Cycle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - crate::attach::with_attached_database(|db| { - f.debug_struct("UnexpectedCycle") - .field("all_participants", &self.all_participants(db)) - .field("unexpected_participants", &self.unexpected_participants(db)) - .finish() - }) - .unwrap_or_else(|| { - f.debug_struct("Cycle") - .field("participants", &self.participants) - .finish() - }) - } +/// Return value from a cycle recovery function. +#[derive(Debug)] +pub enum CycleRecoveryAction { + /// Iterate the cycle again to look for a fixpoint. + Iterate, + + /// Cut off iteration and use the given result value for this query. + Fallback(T), } /// Cycle recovery strategy: Is this query capable of recovering from @@ -96,14 +15,12 @@ pub enum CycleRecoveryStrategy { /// Cannot recover from cycles: panic. /// /// This is the default. - /// - /// In the case of a failure due to a cycle, the panic - /// value will be the `Cycle`. Panic, - /// Recovers from cycles by storing a sentinel value. + /// Recovers from cycles by fixpoint iterating and/or falling + /// back to a sentinel value. /// - /// This value is computed by the query's `recovery_fn` - /// function. - Fallback, + /// This choice is computed by the query's `cycle_recovery` + /// function and initial value. + Recover, } diff --git a/src/function.rs b/src/function.rs index 07f13d496..82c7384e5 100644 --- a/src/function.rs +++ b/src/function.rs @@ -2,14 +2,14 @@ use std::{any::Any, fmt, sync::Arc}; use crate::{ accumulator::accumulated_map::AccumulatedMap, - cycle::CycleRecoveryStrategy, + cycle::{CycleRecoveryAction, CycleRecoveryStrategy}, ingredient::fmt_index, key::DatabaseKeyIndex, plumbing::JarAux, salsa_struct::SalsaStructInDb, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, - Cycle, Database, Id, Revision, + Database, Id, Revision, }; use self::delete::DeletedEntries; @@ -67,15 +67,14 @@ pub trait Configuration: Any { /// This invokes the function the user wrote. fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; - /// If the cycle strategy is `Fallback`, then invoked when `key` is a participant - /// in a cycle to find out what value it should have. - /// - /// This invokes the recovery function given by the user. + /// Get the cycle recovery initial value. + fn cycle_initial(db: &Self::DbView) -> Self::Output<'_>; + + /// Decide whether to iterate a cycle again or fallback. fn recover_from_cycle<'db>( db: &'db Self::DbView, - cycle: &Cycle, - input: Self::Input<'db>, - ) -> Self::Output<'db>; + value: Self::Output<'db>, + ) -> CycleRecoveryAction>; } /// Function ingredients are the "workhorse" of salsa. diff --git a/src/function/execute.rs b/src/function/execute.rs index 4171fe6d4..11b283003 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,8 +1,6 @@ use std::sync::Arc; -use crate::{ - zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, -}; +use crate::{zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Database, Event, EventKind}; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -48,30 +46,7 @@ where // stale, or value is absent. Let's execute! let database_key_index = active_query.database_key_index; let id = database_key_index.key_index; - let value = match Cycle::catch(|| C::execute(db, C::id_to_input(db, id))) { - Ok(v) => v, - Err(cycle) => { - tracing::debug!( - "{database_key_index:?}: caught cycle {cycle:?}, have strategy {:?}", - C::CYCLE_STRATEGY - ); - match C::CYCLE_STRATEGY { - crate::cycle::CycleRecoveryStrategy::Panic => cycle.throw(), - crate::cycle::CycleRecoveryStrategy::Fallback => { - if let Some(c) = active_query.take_cycle() { - assert!(c.is(&cycle)); - C::recover_from_cycle(db, &cycle, C::id_to_input(db, id)) - } else { - // we are not a participant in this cycle - debug_assert!(!cycle - .participant_keys() - .any(|k| k == database_key_index)); - cycle.throw() - } - } - } - } - }; + let value = C::execute(db, C::id_to_input(db, id)); let mut revisions = active_query.pop(); // If the new value is equal to the old one, then it didn't diff --git a/src/lib.rs b/src/lib.rs index 8cc739fab..59de46ffc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ mod zalsa_local; pub use self::accumulator::Accumulator; pub use self::cancelled::Cancelled; -pub use self::cycle::Cycle; +pub use self::cycle::CycleRecoveryAction; pub use self::database::AsDynDatabase; pub use self::database::Database; pub use self::database_impl::DatabaseImpl; @@ -70,7 +70,7 @@ pub mod plumbing { pub use crate::array::Array; pub use crate::attach::attach; pub use crate::attach::with_attached_database; - pub use crate::cycle::Cycle; + pub use crate::cycle::CycleRecoveryAction; pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; pub use crate::database::Database; @@ -114,6 +114,7 @@ pub mod plumbing { pub use salsa_macro_rules::setup_method_body; pub use salsa_macro_rules::setup_tracked_fn; pub use salsa_macro_rules::setup_tracked_struct; + pub use salsa_macro_rules::unexpected_cycle_initial; pub use salsa_macro_rules::unexpected_cycle_recovery; pub mod accumulator { diff --git a/src/runtime.rs b/src/runtime.rs index 5fa281e9d..1f15c3b5a 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,19 +1,14 @@ use std::{ mem, - panic::panic_any, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::atomic::{AtomicBool, Ordering}, thread::ThreadId, }; use parking_lot::Mutex; use crate::{ - active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, - key::DatabaseKeyIndex, revision::AtomicRevision, table::Table, zalsa_local::ZalsaLocal, - Cancelled, Cycle, Database, Event, EventKind, Revision, + durability::Durability, key::DatabaseKeyIndex, revision::AtomicRevision, table::Table, + zalsa_local::ZalsaLocal, Cancelled, Database, Event, EventKind, Revision, }; use self::dependency_graph::DependencyGraph; @@ -49,7 +44,6 @@ pub struct Runtime { pub(crate) enum WaitResult { Completed, Panicked, - Cycle(Cycle), } #[derive(Copy, Clone, Debug)] @@ -161,17 +155,6 @@ impl Runtime { /// /// If the thread `other_id` panics, then our thread is considered /// cancelled, so this function will panic with a `Cancelled` value. - /// - /// # Cycle handling - /// - /// If the thread `other_id` already depends on the current thread, - /// and hence there is a cycle in the query graph, then this function - /// will unwind instead of returning normally. The method of unwinding - /// depends on the [`Self::mutual_cycle_recovery_strategy`] - /// of the cycle participants: - /// - /// * [`CycleRecoveryStrategy::Panic`]: panic with the [`Cycle`] as the value. - /// * [`CycleRecoveryStrategy::Fallback`]: initiate unwinding with [`CycleParticipant::unwind`]. pub(crate) fn block_on_or_unwind( &self, db: &dyn Database, @@ -184,11 +167,7 @@ impl Runtime { let thread_id = std::thread::current().id(); if dg.depends_on(other_id, thread_id) { - self.unblock_cycle_and_maybe_throw(db, local_state, &mut dg, database_key, other_id); - - // If the above fn returns, then (via cycle recovery) it has unblocked the - // cycle, so we can continue. - assert!(!dg.depends_on(other_id, thread_id)); + panic!("unexpected dependency graph cycle"); } db.salsa_event(&|| Event { @@ -219,120 +198,6 @@ impl Runtime { // cancelled. The assumption is that the panic will be detected // by the other thread and responded to appropriately. WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), - - WaitResult::Cycle(c) => c.throw(), - } - } - - /// Handles a cycle in the dependency graph that was detected when the - /// current thread tried to block on `database_key_index` which is being - /// executed by `to_id`. If this function returns, then `to_id` no longer - /// depends on the current thread, and so we should continue executing - /// as normal. Otherwise, the function will throw a `Cycle` which is expected - /// to be caught by some frame on our stack. This occurs either if there is - /// a frame on our stack with cycle recovery (possibly the top one!) or if there - /// is no cycle recovery at all. - fn unblock_cycle_and_maybe_throw( - &self, - db: &dyn Database, - local_state: &ZalsaLocal, - dg: &mut DependencyGraph, - database_key_index: DatabaseKeyIndex, - to_id: ThreadId, - ) { - tracing::debug!( - "unblock_cycle_and_maybe_throw(database_key={:?})", - database_key_index - ); - - let (me_recovered, others_recovered, cycle) = local_state.with_query_stack(|from_stack| { - let from_id = std::thread::current().id(); - - // Make a "dummy stack frame". As we iterate through the cycle, we will collect the - // inputs from each participant. Then, if we are participating in cycle recovery, we - // will propagate those results to all participants. - let mut cycle_query = ActiveQuery::new(database_key_index); - - // Identify the cycle participants: - let cycle = { - let mut v = vec![]; - dg.for_each_cycle_participant( - from_id, - from_stack, - database_key_index, - to_id, - |aqs| { - aqs.iter_mut().for_each(|aq| { - cycle_query.add_from(aq); - v.push(aq.database_key_index); - }); - }, - ); - - // We want to give the participants in a deterministic order - // (at least for this execution, not necessarily across executions), - // no matter where it started on the stack. Find the minimum - // key and rotate it to the front. - - if let Some((_, index, _)) = v - .iter() - .enumerate() - .map(|(idx, key)| (key.ingredient_index.debug_name(db), idx, key)) - .min() - { - v.rotate_left(index); - } - - Cycle::new(Arc::new(v.into_boxed_slice())) - }; - tracing::debug!("cycle {cycle:?}, cycle_query {cycle_query:#?}"); - - // We can remove the cycle participants from the list of dependencies; - // they are a strongly connected component (SCC) and we only care about - // dependencies to things outside the SCC that control whether it will - // form again. - cycle_query.remove_cycle_participants(&cycle); - - // Mark each cycle participant that has recovery set, along with - // any frames that come after them on the same thread. Those frames - // are going to be unwound so that fallback can occur. - dg.for_each_cycle_participant(from_id, from_stack, database_key_index, to_id, |aqs| { - aqs.iter_mut() - .skip_while(|aq| { - match db - .zalsa() - .lookup_ingredient(aq.database_key_index.ingredient_index) - .cycle_recovery_strategy() - { - CycleRecoveryStrategy::Panic => true, - CycleRecoveryStrategy::Fallback => false, - } - }) - .for_each(|aq| { - tracing::debug!("marking {:?} for fallback", aq.database_key_index); - aq.take_inputs_from(&cycle_query); - assert!(aq.cycle.is_none()); - aq.cycle = Some(cycle.clone()); - }); - }); - - // Unblock every thread that has cycle recovery with a `WaitResult::Cycle`. - // They will throw the cycle, which will be caught by the frame that has - // cycle recovery so that it can execute that recovery. - let (me_recovered, others_recovered) = - dg.maybe_unblock_runtimes_in_cycle(from_id, from_stack, database_key_index, to_id); - (me_recovered, others_recovered, cycle) - }); - - if me_recovered { - // If the current thread has recovery, we want to throw - // so that it can begin. - cycle.throw() - } else if others_recovered { - // If other threads have recovery but we didn't: return and we will block on them. - } else { - // if nobody has recover, then we panic - panic_any(cycle); } } diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index 84c5327fc..c90e650de 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -31,7 +31,6 @@ pub(super) struct DependencyGraph { #[derive(Debug)] struct Edge { blocked_on_id: ThreadId, - blocked_on_key: DatabaseKeyIndex, stack: QueryStack, /// Signalled whenever a query with dependents completes. @@ -55,115 +54,6 @@ impl DependencyGraph { p == to_id } - /// Invokes `closure` with a `&mut ActiveQuery` for each query that participates in the cycle. - /// The cycle runs as follows: - /// - /// 1. The runtime `from_id`, which has the stack `from_stack`, would like to invoke `database_key`... - /// 2. ...but `database_key` is already being executed by `to_id`... - /// 3. ...and `to_id` is transitively dependent on something which is present on `from_stack`. - pub(super) fn for_each_cycle_participant( - &mut self, - from_id: ThreadId, - from_stack: &mut QueryStack, - database_key: DatabaseKeyIndex, - to_id: ThreadId, - mut closure: impl FnMut(&mut [ActiveQuery]), - ) { - debug_assert!(self.depends_on(to_id, from_id)); - - // To understand this algorithm, consider this [drawing](https://is.gd/TGLI9v): - // - // database_key = QB2 - // from_id = A - // to_id = B - // from_stack = [QA1, QA2, QA3] - // - // self.edges[B] = { C, QC2, [QB1..QB3] } - // self.edges[C] = { A, QA2, [QC1..QC3] } - // - // The cyclic - // edge we have - // failed to add. - // : - // A : B C - // : - // QA1 v QB1 QC1 - // ┌► QA2 ┌──► QB2 ┌─► QC2 - // │ QA3 ───┘ QB3 ──┘ QC3 ───┐ - // │ │ - // └───────────────────────────────┘ - // - // Final output: [QB2, QB3, QC2, QC3, QA2, QA3] - - let mut id = to_id; - let mut key = database_key; - while id != from_id { - // Looking at the diagram above, the idea is to - // take the edge from `to_id` starting at `key` - // (inclusive) and down to the end. We can then - // load up the next thread (i.e., we start at B/QB2, - // and then load up the dependency on C/QC2). - let edge = self.edges.get_mut(&id).unwrap(); - closure(strip_prefix_query_stack_mut(&mut edge.stack, key)); - id = edge.blocked_on_id; - key = edge.blocked_on_key; - } - - // Finally, we copy in the results from `from_stack`. - closure(strip_prefix_query_stack_mut(from_stack, key)); - } - - /// Unblock each blocked runtime (excluding the current one) if some - /// query executing in that runtime is participating in cycle fallback. - /// - /// Returns a boolean (Current, Others) where: - /// * Current is true if the current runtime has cycle participants - /// with fallback; - /// * Others is true if other runtimes were unblocked. - pub(super) fn maybe_unblock_runtimes_in_cycle( - &mut self, - from_id: ThreadId, - from_stack: &QueryStack, - database_key: DatabaseKeyIndex, - to_id: ThreadId, - ) -> (bool, bool) { - // See diagram in `for_each_cycle_participant`. - let mut id = to_id; - let mut key = database_key; - let mut others_unblocked = false; - while id != from_id { - let edge = self.edges.get(&id).unwrap(); - let next_id = edge.blocked_on_id; - let next_key = edge.blocked_on_key; - - if let Some(cycle) = strip_prefix_query_stack(&edge.stack, key) - .iter() - .rev() - .find_map(|aq| aq.cycle.clone()) - { - // Remove `id` from the list of runtimes blocked on `next_key`: - self.query_dependents - .get_mut(&next_key) - .unwrap() - .retain(|r| *r != id); - - // Unblock runtime so that it can resume execution once lock is released: - self.unblock_runtime(id, WaitResult::Cycle(cycle)); - - others_unblocked = true; - } - - id = next_id; - key = next_key; - } - - let this_unblocked = strip_prefix_query_stack(from_stack, key) - .iter() - .any(|aq| aq.cycle.is_some()); - - (this_unblocked, others_unblocked) - } - /// Modifies the graph so that `from_id` is blocked /// on `database_key`, which is being computed by /// `to_id`. @@ -219,7 +109,6 @@ impl DependencyGraph { from_id, Edge { blocked_on_id: to_id, - blocked_on_key: database_key, stack: from_stack, condvar: condvar.clone(), }, @@ -260,22 +149,3 @@ impl DependencyGraph { edge.condvar.notify_one(); } } - -fn strip_prefix_query_stack(stack_mut: &[ActiveQuery], key: DatabaseKeyIndex) -> &[ActiveQuery] { - let prefix = stack_mut - .iter() - .take_while(|p| p.database_key_index != key) - .count(); - &stack_mut[prefix..] -} - -fn strip_prefix_query_stack_mut( - stack_mut: &mut [ActiveQuery], - key: DatabaseKeyIndex, -) -> &mut [ActiveQuery] { - let prefix = stack_mut - .iter() - .take_while(|p| p.database_key_index != key) - .count(); - &mut stack_mut[prefix..] -} diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 6988d6537..2d9bf9f12 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -14,7 +14,6 @@ use crate::tracked_struct::{Disambiguator, Identity, IdentityHash}; use crate::zalsa::IngredientIndex; use crate::Accumulator; use crate::Cancelled; -use crate::Cycle; use crate::Database; use crate::Event; use crate::EventKind; @@ -175,31 +174,6 @@ impl ZalsaLocal { self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { top_query.add_read(input, durability, changed_at, accumulated); - - // We are a cycle participant: - // - // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0 - // ^ ^ - // : | - // This edge -----+ | - // | - // | - // N0 - // - // In this case, the value we have just read from `Ci+1` - // is actually the cycle fallback value and not especially - // interesting. We unwind now with `CycleParticipant` to avoid - // executing the rest of our query function. This unwinding - // will be caught and our own fallback value will be used. - // - // Note that `Ci+1` may` have *other* callers who are not - // participants in the cycle (e.g., N0 in the graph above). - // They will not have the `cycle` marker set in their - // stack frames, so they will just read the fallback value - // from `Ci+1` and continue on their merry way. - if let Some(cycle) = &top_query.cycle { - cycle.clone().throw() - } } }) } @@ -536,18 +510,8 @@ impl ActiveQueryGuard<'_> { // Extract accumulated inputs. let popped_query = self.complete(); - // If this frame were a cycle participant, it would have unwound. - assert!(popped_query.cycle.is_none()); - popped_query.into_revisions() } - - /// If the active query is registered as a cycle participant, remove and - /// return that cycle. - pub(crate) fn take_cycle(&self) -> Option { - self.local_state - .with_query_stack(|stack| stack.last_mut()?.cycle.take()) - } } impl Drop for ActiveQueryGuard<'_> { diff --git a/tests/cycle_fixpoint.rs b/tests/cycle_fixpoint.rs index 9d788a2a4..35f42b6bf 100644 --- a/tests/cycle_fixpoint.rs +++ b/tests/cycle_fixpoint.rs @@ -1,5 +1,5 @@ -/// Minimal(ish) example use case for fixpoint iteration cycle resolution. -use salsa::{Database as Db, Setter}; +/// Test case for fixpoint iteration cycle resolution. +use salsa::{CycleRecoveryAction, Database as Db, Setter}; use std::collections::BTreeSet; use std::iter::IntoIterator; @@ -59,7 +59,7 @@ fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { } } -#[salsa::tracked] +#[salsa::tracked(cycle_fn=recover_definition_cycle, cycle_initial=initial_definition)] fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { let increment_ty = infer_literal(db, def.increment(db)); if let Some(base) = def.base(db) { @@ -70,6 +70,24 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } +fn initial_definition<'db>(_db: &'db dyn Db) -> Type { + Type::Bottom +} + +fn recover_definition_cycle<'db>(_db: &'db dyn Db, value: Type) -> CycleRecoveryAction { + match value { + Type::Bottom => CycleRecoveryAction::Iterate, + Type::Values(values) => { + if values.len() > 4 { + CycleRecoveryAction::Fallback(Type::Top) + } else { + CycleRecoveryAction::Iterate + } + } + Type::Top => CycleRecoveryAction::Iterate, + } +} + fn add(a: &Type, b: &Type) -> Type { match (a, b) { (Type::Bottom, _) | (_, Type::Bottom) => panic!("unbound use"), @@ -183,7 +201,7 @@ fn multi_symbol_cycle_converges() { let x_ty = infer_use(&db, use_x); let y_ty = infer_use(&db, use_y); - // Both symbols converge on LiteralInt(0) + // Both symbols converge on 0 assert_eq!(x_ty, Type::Values(Box::from([0]))); assert_eq!(y_ty, Type::Values(Box::from([0]))); } diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index e01e46546..ed895948a 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,9 +1,5 @@ mod setup; mod parallel_cancellation; -mod parallel_cycle_all_recover; -mod parallel_cycle_mid_recover; -mod parallel_cycle_none_recover; -mod parallel_cycle_one_recover; mod parallel_map; mod signal; From 6ce61c6d9791fb8f6059016701a47fa9a8cc460f Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 11 Oct 2024 18:53:17 -0700 Subject: [PATCH 06/35] WIP: added provisional value and cycle fields --- src/cycle.rs | 11 ++++++++ src/function/fetch.rs | 6 ++++- src/function/maybe_changed_after.rs | 1 + src/function/memo.rs | 39 +++++++++++++++++++++++++++-- src/function/specify.rs | 1 + src/table/sync.rs | 28 +++++++++++++++++++++ 6 files changed, 83 insertions(+), 3 deletions(-) diff --git a/src/cycle.rs b/src/cycle.rs index 55ce28888..83c88da2a 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,3 +1,5 @@ +use crate::DatabaseKeyIndex; + /// Return value from a cycle recovery function. #[derive(Debug)] pub enum CycleRecoveryAction { @@ -24,3 +26,12 @@ pub enum CycleRecoveryStrategy { /// function and initial value. Recover, } + +/// A query cycle. +#[derive(Clone, Copy, Debug)] +pub(crate) struct Cycle { + /// The head of the cycle. + /// + /// The query whose execution ultimately resulted in calling itself again. + head: DatabaseKeyIndex, +} diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f6d495dff..ed5d364da 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,4 +1,7 @@ -use super::{memo::Memo, Configuration, IngredientImpl}; +use super::{ + memo::{Memo, Value}, + Configuration, IngredientImpl, +}; use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id}; @@ -71,6 +74,7 @@ where zalsa_local, database_key_index, self.memo_ingredient_index, + self.initial_value(db), )?; // Push the query on the stack. diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index b1d671a36..475fbe58e 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -58,6 +58,7 @@ where zalsa_local, database_key_index, self.memo_ingredient_index, + self.initial_value(db), )?; let active_query = zalsa_local.push_query(database_key_index); diff --git a/src/function/memo.rs b/src/function/memo.rs index f982dc339..0fda1e64d 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -7,8 +7,12 @@ use crossbeam::atomic::AtomicCell; use crate::zalsa_local::QueryOrigin; use crate::{ - key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, - Revision, + cycle::{Cycle, CycleRecoveryStrategy}, + key::DatabaseKeyIndex, + table::sync::ProvisionalValue, + zalsa::Zalsa, + zalsa_local::QueryRevisions, + Event, EventKind, Id, Revision, }; use super::{Configuration, IngredientImpl}; @@ -86,6 +90,24 @@ impl IngredientImpl { } } } + + pub(super) fn initial_value<'db>(&'db self, db: &'db C::DbView) -> Option { + match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Recover => Some(self.to_provisional_value(C::cycle_initial(db))), + CycleRecoveryStrategy::Panic => None, + } + } + + fn to_provisional_value<'db>(&'db self, value: C::Output<'db>) -> ProvisionalValue { + unsafe { self.value_to_static(Value { value }) }.into() + } + + unsafe fn value_to_static<'db>( + &'db self, + value: OutputValue<'db, C>, + ) -> OutputValue<'static, C> { + unsafe { std::mem::transmute(value) } + } } #[derive(Debug)] @@ -99,6 +121,9 @@ pub(super) struct Memo { /// Revision information pub(super) revisions: QueryRevisions, + + /// Cycle, if this result was created in cycle iteration + pub(super) cycle: Option, } impl Memo { @@ -107,6 +132,7 @@ impl Memo { value, verified_at: AtomicCell::new(revision_now), revisions, + cycle: None, } } /// True if this memo is known not to have changed based on its durability. @@ -181,3 +207,12 @@ impl crate::table::memo::Memo for Memo { &self.revisions.origin } } + +type OutputValue<'lt, C> = Value<::Output<'lt>>; + +#[derive(Debug)] +pub(super) struct Value { + value: V, +} + +impl crate::table::sync::Value for Value {} diff --git a/src/function/specify.rs b/src/function/specify.rs index 9eccad65b..e56a09bfb 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -81,6 +81,7 @@ where value: Some(value), verified_at: AtomicCell::new(revision), revisions, + cycle: None, }; tracing::debug!( diff --git a/src/table/sync.rs b/src/table/sync.rs index dfe78a23a..8ef44fe4d 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -1,4 +1,6 @@ use std::{ + any::Any, + fmt::Debug, sync::atomic::{AtomicBool, Ordering}, thread::ThreadId, }; @@ -28,6 +30,29 @@ struct SyncState { /// Set to true if any other queries are blocked, /// waiting for this query to complete. anyone_waiting: AtomicBool, + + /// Provisional return value for fixpoint iteration + /// of a query, set in advance of query execution for + /// queries that anticipate possible cycles. + provisional_value: Option, +} + +pub(crate) trait Value: Any + Send + Sync + Debug {} + +/// Provisional value for a query, in case of fixpoint cycle iterator. +pub(crate) struct ProvisionalValue { + value: Box, +} + +impl From for ProvisionalValue +where + T: Value, +{ + fn from(value: T) -> Self { + Self { + value: Box::new(value), + } + } } impl SyncTable { @@ -37,6 +62,7 @@ impl SyncTable { zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, + provisional_value: Option, ) -> Option> { let mut syncs = self.syncs.write(); let zalsa = db.zalsa(); @@ -49,6 +75,7 @@ impl SyncTable { syncs[memo_ingredient_index.as_usize()] = Some(SyncState { id: thread_id, anyone_waiting: AtomicBool::new(false), + provisional_value, }); Some(ClaimGuard { database_key_index, @@ -60,6 +87,7 @@ impl SyncTable { Some(SyncState { id: other_id, anyone_waiting, + provisional_value: _, }) => { // NB: `Ordering::Relaxed` is sufficient here, // as there are no loads that are "gated" on this From 29d110ef7b7354dc64dae5cd56557d7cc4697a8d Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 16 Oct 2024 17:14:51 -0700 Subject: [PATCH 07/35] rename to CycleRecoveryStrategy::Fixpoint --- components/salsa-macros/src/tracked_fn.rs | 2 +- src/cycle.rs | 2 +- src/function/memo.rs | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index caa4c70d4..897a7c33d 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -169,7 +169,7 @@ impl Macro { (Some(cycle_fn), Some(cycle_initial)) => Ok(( quote!((#cycle_fn)), quote!((#cycle_initial)), - quote!(Recover), + quote!(Fixpoint), )), (None, None) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), diff --git a/src/cycle.rs b/src/cycle.rs index 83c88da2a..b0d512bfb 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -24,7 +24,7 @@ pub enum CycleRecoveryStrategy { /// /// This choice is computed by the query's `cycle_recovery` /// function and initial value. - Recover, + Fixpoint, } /// A query cycle. diff --git a/src/function/memo.rs b/src/function/memo.rs index 0fda1e64d..4578f240d 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -93,7 +93,9 @@ impl IngredientImpl { pub(super) fn initial_value<'db>(&'db self, db: &'db C::DbView) -> Option { match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Recover => Some(self.to_provisional_value(C::cycle_initial(db))), + CycleRecoveryStrategy::Fixpoint => { + Some(self.to_provisional_value(C::cycle_initial(db))) + } CycleRecoveryStrategy::Panic => None, } } From adb6c771acfb9450ec320c2fe8cf81922669b4ff Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 22 Oct 2024 17:40:43 -0700 Subject: [PATCH 08/35] WIP: rip out ProvisionalValue --- src/function/fetch.rs | 6 +----- src/function/maybe_changed_after.rs | 1 - src/function/memo.rs | 27 ++------------------------- src/table/sync.rs | 26 -------------------------- 4 files changed, 3 insertions(+), 57 deletions(-) diff --git a/src/function/fetch.rs b/src/function/fetch.rs index ed5d364da..f6d495dff 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,7 +1,4 @@ -use super::{ - memo::{Memo, Value}, - Configuration, IngredientImpl, -}; +use super::{memo::Memo, Configuration, IngredientImpl}; use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id}; @@ -74,7 +71,6 @@ where zalsa_local, database_key_index, self.memo_ingredient_index, - self.initial_value(db), )?; // Push the query on the stack. diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 475fbe58e..b1d671a36 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -58,7 +58,6 @@ where zalsa_local, database_key_index, self.memo_ingredient_index, - self.initial_value(db), )?; let active_query = zalsa_local.push_query(database_key_index); diff --git a/src/function/memo.rs b/src/function/memo.rs index 4578f240d..de0c7bde2 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -9,7 +9,6 @@ use crate::zalsa_local::QueryOrigin; use crate::{ cycle::{Cycle, CycleRecoveryStrategy}, key::DatabaseKeyIndex, - table::sync::ProvisionalValue, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, Revision, @@ -91,25 +90,12 @@ impl IngredientImpl { } } - pub(super) fn initial_value<'db>(&'db self, db: &'db C::DbView) -> Option { + pub(super) fn initial_value<'db>(&'db self, db: &'db C::DbView) -> Option> { match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Fixpoint => { - Some(self.to_provisional_value(C::cycle_initial(db))) - } + CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db)), CycleRecoveryStrategy::Panic => None, } } - - fn to_provisional_value<'db>(&'db self, value: C::Output<'db>) -> ProvisionalValue { - unsafe { self.value_to_static(Value { value }) }.into() - } - - unsafe fn value_to_static<'db>( - &'db self, - value: OutputValue<'db, C>, - ) -> OutputValue<'static, C> { - unsafe { std::mem::transmute(value) } - } } #[derive(Debug)] @@ -209,12 +195,3 @@ impl crate::table::memo::Memo for Memo { &self.revisions.origin } } - -type OutputValue<'lt, C> = Value<::Output<'lt>>; - -#[derive(Debug)] -pub(super) struct Value { - value: V, -} - -impl crate::table::sync::Value for Value {} diff --git a/src/table/sync.rs b/src/table/sync.rs index 8ef44fe4d..ee72353c6 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -30,29 +30,6 @@ struct SyncState { /// Set to true if any other queries are blocked, /// waiting for this query to complete. anyone_waiting: AtomicBool, - - /// Provisional return value for fixpoint iteration - /// of a query, set in advance of query execution for - /// queries that anticipate possible cycles. - provisional_value: Option, -} - -pub(crate) trait Value: Any + Send + Sync + Debug {} - -/// Provisional value for a query, in case of fixpoint cycle iterator. -pub(crate) struct ProvisionalValue { - value: Box, -} - -impl From for ProvisionalValue -where - T: Value, -{ - fn from(value: T) -> Self { - Self { - value: Box::new(value), - } - } } impl SyncTable { @@ -62,7 +39,6 @@ impl SyncTable { zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, - provisional_value: Option, ) -> Option> { let mut syncs = self.syncs.write(); let zalsa = db.zalsa(); @@ -75,7 +51,6 @@ impl SyncTable { syncs[memo_ingredient_index.as_usize()] = Some(SyncState { id: thread_id, anyone_waiting: AtomicBool::new(false), - provisional_value, }); Some(ClaimGuard { database_key_index, @@ -87,7 +62,6 @@ impl SyncTable { Some(SyncState { id: other_id, anyone_waiting, - provisional_value: _, }) => { // NB: `Ordering::Relaxed` is sufficient here, // as there are no loads that are "gated" on this From 7f82b6d1831d52979cec36da1f2ca3f624104180 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 09:58:20 -0700 Subject: [PATCH 09/35] WIP: working single-iteration with provisional memo --- src/active_query.rs | 9 +++++++- src/cycle.rs | 11 ---------- src/function/execute.rs | 23 ++++++++++++++++++--- src/function/fetch.rs | 1 + src/function/maybe_changed_after.rs | 2 +- src/function/memo.rs | 12 +++-------- src/function/specify.rs | 2 +- src/input.rs | 1 + src/interned.rs | 1 + src/tracked_struct.rs | 1 + src/zalsa_local.rs | 32 +++++++++++++++++++++++++---- tests/cycle_fixpoint.rs | 10 ++++----- 12 files changed, 70 insertions(+), 35 deletions(-) diff --git a/src/active_query.rs b/src/active_query.rs index adf544cc2..036fce6c9 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -1,4 +1,4 @@ -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use super::zalsa_local::{EdgeKind, QueryEdges, QueryOrigin, QueryRevisions}; use crate::tracked_struct::IdentityHash; @@ -51,6 +51,9 @@ pub(crate) struct ActiveQuery { /// Stores the values accumulated to the given ingredient. /// The type of accumulated value is erased but known to the ingredient. pub(crate) accumulated: AccumulatedMap, + + /// Heads of cycles that this result was generated inside. + pub(crate) cycle_heads: FxHashSet, } impl ActiveQuery { @@ -64,6 +67,7 @@ impl ActiveQuery { disambiguator_map: Default::default(), tracked_struct_ids: Default::default(), accumulated: Default::default(), + cycle_heads: Default::default(), } } @@ -73,11 +77,13 @@ impl ActiveQuery { durability: Durability, revision: Revision, accumulated: InputAccumulatedValues, + cycle_heads: &FxHashSet, ) { self.input_outputs.insert((EdgeKind::Input, input)); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(revision); self.accumulated.add_input(accumulated); + self.cycle_heads.extend(cycle_heads); } pub(super) fn add_untracked_read(&mut self, changed_at: Revision) { @@ -123,6 +129,7 @@ impl ActiveQuery { durability: self.durability, tracked_struct_ids: self.tracked_struct_ids, accumulated: self.accumulated, + cycle_heads: self.cycle_heads, } } diff --git a/src/cycle.rs b/src/cycle.rs index b0d512bfb..c90f2170b 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,5 +1,3 @@ -use crate::DatabaseKeyIndex; - /// Return value from a cycle recovery function. #[derive(Debug)] pub enum CycleRecoveryAction { @@ -26,12 +24,3 @@ pub enum CycleRecoveryStrategy { /// function and initial value. Fixpoint, } - -/// A query cycle. -#[derive(Clone, Copy, Debug)] -pub(crate) struct Cycle { - /// The head of the cycle. - /// - /// The query whose execution ultimately resulted in calling itself again. - head: DatabaseKeyIndex, -} diff --git a/src/function/execute.rs b/src/function/execute.rs index 11b283003..3b34bce90 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,6 +1,10 @@ use std::sync::Arc; -use crate::{zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Database, Event, EventKind}; +use crate::{ + zalsa::ZalsaDatabase, + zalsa_local::{ActiveQueryGuard, QueryRevisions}, + Database, Event, EventKind, +}; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -26,6 +30,7 @@ where let zalsa = db.zalsa(); let revision_now = zalsa.current_revision(); let database_key_index = active_query.database_key_index; + let id = database_key_index.key_index; tracing::info!("{:?}: executing query", database_key_index); @@ -36,6 +41,20 @@ where }, }); + // If this tracked function supports fixpoint iteration, pre-insert a provisional-value + // memo for its initial iteration value. + if let Some(initial_value) = self.initial_value(db) { + self.insert_memo( + zalsa, + id, + Memo::new( + Some(initial_value), + revision_now, + QueryRevisions::fixpoint_initial(database_key_index), + ), + ); + } + // If we already executed this query once, then use the tracked-struct ids from the // previous execution as the starting point for the new one. if let Some(old_memo) = &opt_old_memo { @@ -44,8 +63,6 @@ where // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! - let database_key_index = active_query.database_key_index; - let id = database_key_index.key_index; let value = C::execute(db, C::id_to_input(db, id)); let mut revisions = active_query.pop(); diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f6d495dff..79df58ffe 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -26,6 +26,7 @@ where durability, changed_at, InputAccumulatedValues::from_map(&memo.revisions.accumulated), + &memo.revisions.cycle_heads, ); value diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index b1d671a36..ea4ac9dc0 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -166,7 +166,7 @@ where // in rev 1 but not in rev 2. return false; } - QueryOrigin::BaseInput => { + QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { // This value was `set` by the mutator thread -- ie, it's a base input and it cannot be out of date. return true; } diff --git a/src/function/memo.rs b/src/function/memo.rs index de0c7bde2..8bddd5fa1 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -7,10 +7,7 @@ use crossbeam::atomic::AtomicCell; use crate::zalsa_local::QueryOrigin; use crate::{ - cycle::{Cycle, CycleRecoveryStrategy}, - key::DatabaseKeyIndex, - zalsa::Zalsa, - zalsa_local::QueryRevisions, + cycle::CycleRecoveryStrategy, key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, Revision, }; @@ -71,7 +68,8 @@ impl IngredientImpl { match memo.revisions.origin { QueryOrigin::Assigned(_) | QueryOrigin::DerivedUntracked(_) - | QueryOrigin::BaseInput => { + | QueryOrigin::BaseInput + | QueryOrigin::FixpointInitial => { // Careful: Cannot evict memos whose values were // assigned as output of another query // or those with untracked inputs @@ -109,9 +107,6 @@ pub(super) struct Memo { /// Revision information pub(super) revisions: QueryRevisions, - - /// Cycle, if this result was created in cycle iteration - pub(super) cycle: Option, } impl Memo { @@ -120,7 +115,6 @@ impl Memo { value, verified_at: AtomicCell::new(revision_now), revisions, - cycle: None, } } /// True if this memo is known not to have changed based on its durability. diff --git a/src/function/specify.rs b/src/function/specify.rs index e56a09bfb..fa5e04278 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -70,6 +70,7 @@ where origin: QueryOrigin::Assigned(active_query_key), tracked_struct_ids: Default::default(), accumulated: Default::default(), + cycle_heads: Default::default(), }; if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { @@ -81,7 +82,6 @@ where value: Some(value), verified_at: AtomicCell::new(revision), revisions, - cycle: None, }; tracing::debug!( diff --git a/src/input.rs b/src/input.rs index fd3726def..e55891616 100644 --- a/src/input.rs +++ b/src/input.rs @@ -190,6 +190,7 @@ impl IngredientImpl { stamp.durability, stamp.changed_at, InputAccumulatedValues::Empty, + &Default::default(), ); &value.fields } diff --git a/src/interned.rs b/src/interned.rs index ba58b512d..7b24c04e7 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -135,6 +135,7 @@ where Durability::MAX, self.reset_at, InputAccumulatedValues::Empty, + &Default::default(), ); // Optimisation to only get read lock on the map if the data has already diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 367cc36cc..95e2a71b1 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -563,6 +563,7 @@ where data.durability, field_changed_at, InputAccumulatedValues::Empty, + &Default::default(), ); unsafe { self.to_self_ref(&data.fields) } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 2d9bf9f12..25d50353c 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1,4 +1,4 @@ -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use tracing::debug; use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}; @@ -166,6 +166,7 @@ impl ZalsaLocal { durability: Durability, changed_at: Revision, accumulated: InputAccumulatedValues, + cycle_heads: &FxHashSet, ) { debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", @@ -173,7 +174,7 @@ impl ZalsaLocal { ); self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { - top_query.add_read(input, durability, changed_at, accumulated); + top_query.add_read(input, durability, changed_at, accumulated, cycle_heads); } }) } @@ -330,9 +331,25 @@ pub(crate) struct QueryRevisions { pub(super) tracked_struct_ids: FxHashMap, pub(super) accumulated: AccumulatedMap, + + /// Active cycle heads, if this result was created in cycle iteration + pub(super) cycle_heads: FxHashSet, } impl QueryRevisions { + pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex) -> Self { + let mut cycle_heads = FxHashSet::default(); + cycle_heads.insert(query); + Self { + changed_at: Revision::start(), + durability: Durability::MAX, + origin: QueryOrigin::FixpointInitial, + tracked_struct_ids: Default::default(), + accumulated: Default::default(), + cycle_heads, + } + } + pub(crate) fn stamped_value(&self, value: V) -> StampedValue { self.stamp_template().stamp(value) } @@ -381,6 +398,9 @@ pub enum QueryOrigin { /// The [`QueryEdges`] argument contains a listing of all the inputs we saw /// (but we know there were more). DerivedUntracked(QueryEdges), + + /// The value is an initial provisional value for a query that supports fixpoint iteration. + FixpointInitial, } impl QueryOrigin { @@ -388,7 +408,9 @@ impl QueryOrigin { pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), - QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, + QueryOrigin::Assigned(_) | QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { + None + } }; opt_edges.into_iter().flat_map(|edges| edges.inputs()) } @@ -397,7 +419,9 @@ impl QueryOrigin { pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), - QueryOrigin::Assigned(_) | QueryOrigin::BaseInput => None, + QueryOrigin::Assigned(_) | QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { + None + } }; opt_edges.into_iter().flat_map(|edges| edges.outputs()) } diff --git a/tests/cycle_fixpoint.rs b/tests/cycle_fixpoint.rs index 35f42b6bf..bccc01093 100644 --- a/tests/cycle_fixpoint.rs +++ b/tests/cycle_fixpoint.rs @@ -49,7 +49,7 @@ impl Type { } } -#[salsa::tracked] +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)] fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { let defs = u.reaching_definitions(db); match defs[..] { @@ -59,7 +59,7 @@ fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { } } -#[salsa::tracked(cycle_fn=recover_definition_cycle, cycle_initial=initial_definition)] +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)] fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { let increment_ty = infer_literal(db, def.increment(db)); if let Some(base) = def.base(db) { @@ -70,11 +70,11 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } -fn initial_definition<'db>(_db: &'db dyn Db) -> Type { +fn cycle_initial<'db>(_db: &'db dyn Db) -> Type { Type::Bottom } -fn recover_definition_cycle<'db>(_db: &'db dyn Db, value: Type) -> CycleRecoveryAction { +fn cycle_recover<'db>(_db: &'db dyn Db, value: Type) -> CycleRecoveryAction { match value { Type::Bottom => CycleRecoveryAction::Iterate, Type::Values(values) => { @@ -90,7 +90,7 @@ fn recover_definition_cycle<'db>(_db: &'db dyn Db, value: Type) -> CycleRecovery fn add(a: &Type, b: &Type) -> Type { match (a, b) { - (Type::Bottom, _) | (_, Type::Bottom) => panic!("unbound use"), + (Type::Bottom, _) | (_, Type::Bottom) => Type::Bottom, (Type::Top, _) | (_, Type::Top) => Type::Top, (Type::Values(a_ints), Type::Values(b_ints)) => { let mut set = BTreeSet::new(); From a7be3d9cdf78834a9e8461b90527d3c6b76e668e Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 10:06:21 -0700 Subject: [PATCH 10/35] WIP: add count arg to cycle_recovery_fn --- components/salsa-macro-rules/src/setup_tracked_fn.rs | 3 ++- .../salsa-macro-rules/src/unexpected_cycle_recovery.rs | 2 +- src/function.rs | 1 + tests/cycle_fixpoint.rs | 6 +++--- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index eed83c4ed..ef5ccb9a6 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -184,8 +184,9 @@ macro_rules! setup_tracked_fn { fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, value: Self::Output<$db_lt>, + count: u32, ) -> $zalsa::CycleRecoveryAction> { - $($cycle_recovery_fn)*(db, value) + $($cycle_recovery_fn)*(db, value, count) } fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index e556e104e..8e18a8976 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -2,7 +2,7 @@ // for the case where no cycle recovery is possible. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $value:ident) => {{ + ($db:ident, $value:ident, $count:ident) => {{ std::mem::drop($db); panic!("cannot recover from cycle") }}; diff --git a/src/function.rs b/src/function.rs index 82c7384e5..f23e44f52 100644 --- a/src/function.rs +++ b/src/function.rs @@ -74,6 +74,7 @@ pub trait Configuration: Any { fn recover_from_cycle<'db>( db: &'db Self::DbView, value: Self::Output<'db>, + count: u32, ) -> CycleRecoveryAction>; } diff --git a/tests/cycle_fixpoint.rs b/tests/cycle_fixpoint.rs index bccc01093..6d9d00ce6 100644 --- a/tests/cycle_fixpoint.rs +++ b/tests/cycle_fixpoint.rs @@ -74,11 +74,11 @@ fn cycle_initial<'db>(_db: &'db dyn Db) -> Type { Type::Bottom } -fn cycle_recover<'db>(_db: &'db dyn Db, value: Type) -> CycleRecoveryAction { +fn cycle_recover<'db>(_db: &'db dyn Db, value: Type, count: u32) -> CycleRecoveryAction { match value { Type::Bottom => CycleRecoveryAction::Iterate, - Type::Values(values) => { - if values.len() > 4 { + Type::Values(_) => { + if count > 4 { CycleRecoveryAction::Fallback(Type::Top) } else { CycleRecoveryAction::Iterate From 6c6dd55cc5afaba378861cf9e827c711d7696fc2 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 10:23:22 -0700 Subject: [PATCH 11/35] WIP: move insert-initial out into fetch_cold --- src/function/execute.rs | 20 +------------------- src/function/fetch.rs | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 3b34bce90..b1247d14d 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,10 +1,6 @@ use std::sync::Arc; -use crate::{ - zalsa::ZalsaDatabase, - zalsa_local::{ActiveQueryGuard, QueryRevisions}, - Database, Event, EventKind, -}; +use crate::{zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Database, Event, EventKind}; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -41,20 +37,6 @@ where }, }); - // If this tracked function supports fixpoint iteration, pre-insert a provisional-value - // memo for its initial iteration value. - if let Some(initial_value) = self.initial_value(db) { - self.insert_memo( - zalsa, - id, - Memo::new( - Some(initial_value), - revision_now, - QueryRevisions::fixpoint_initial(database_key_index), - ), - ); - } - // If we already executed this query once, then use the tracked-struct ids from the // previous execution as the starting point for the new one. if let Some(old_memo) = &opt_old_memo { diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 79df58ffe..9cf9b0c2c 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,6 +1,9 @@ use super::{memo::Memo, Configuration, IngredientImpl}; use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id}; +use crate::{ + runtime::StampedValue, zalsa::ZalsaDatabase, zalsa_local::QueryRevisions, AsDynDatabase as _, + Id, +}; impl IngredientImpl where @@ -89,6 +92,18 @@ where } } + if let Some(initial_value) = self.initial_value(db) { + self.insert_memo( + zalsa, + id, + Memo::new( + Some(initial_value), + zalsa.current_revision(), + QueryRevisions::fixpoint_initial(database_key_index), + ), + ); + } + Some(self.execute(db, active_query, opt_old_memo)) } } From 8974361b65ec88bdd90cbbe8ac049de3ddce333a Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 11:40:11 -0700 Subject: [PATCH 12/35] WIP: cycle-head iteration --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 6 +- src/function.rs | 17 +++-- src/function/backdate.rs | 2 +- src/function/fetch.rs | 67 +++++++++++++++---- src/function/maybe_changed_after.rs | 3 +- src/function/memo.rs | 4 ++ src/lib.rs | 2 +- src/zalsa_local.rs | 3 +- tests/cycle_fixpoint.rs | 2 +- 9 files changed, 74 insertions(+), 32 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index ef5ccb9a6..aad76935d 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -158,7 +158,7 @@ macro_rules! setup_tracked_fn { const CYCLE_STRATEGY: $zalsa::CycleRecoveryStrategy = $zalsa::CycleRecoveryStrategy::$cycle_recovery_strategy; - fn should_backdate_value( + fn values_equal( old_value: &Self::Output<'_>, new_value: &Self::Output<'_>, ) -> bool { @@ -166,7 +166,7 @@ macro_rules! setup_tracked_fn { if $no_eq { false } else { - $zalsa::should_backdate_value(old_value, new_value) + $zalsa::values_equal(old_value, new_value) } } } @@ -183,7 +183,7 @@ macro_rules! setup_tracked_fn { fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - value: Self::Output<$db_lt>, + value: &$db_lt Self::Output<$db_lt>, count: u32, ) -> $zalsa::CycleRecoveryAction> { $($cycle_recovery_fn)*(db, value, count) diff --git a/src/function.rs b/src/function.rs index f23e44f52..c8af0d32a 100644 --- a/src/function.rs +++ b/src/function.rs @@ -49,13 +49,12 @@ pub trait Configuration: Any { /// (and, if so, how). const CYCLE_STRATEGY: CycleRecoveryStrategy; - /// Invokes after a new result `new_value`` has been computed for which an older memoized - /// value existed `old_value`. Returns true if the new value is equal to the older one - /// and hence should be "backdated" (i.e., marked as having last changed in an older revision, - /// even though it was recomputed). + /// Invokes after a new result `new_value`` has been computed for which an older memoized value + /// existed `old_value`, or in fixpoint iteration. Returns true if the new value is equal to + /// the older one. /// - /// This invokes user's code in form of the `Eq` impl. - fn should_backdate_value(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool; + /// This invokes user code in form of the `Eq` impl. + fn values_equal(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool; /// Convert from the id used internally to the value that execute is expecting. /// This is a no-op if the input to the function is a salsa struct. @@ -73,7 +72,7 @@ pub trait Configuration: Any { /// Decide whether to iterate a cycle again or fallback. fn recover_from_cycle<'db>( db: &'db Self::DbView, - value: Self::Output<'db>, + value: &'db Self::Output<'db>, count: u32, ) -> CycleRecoveryAction>; } @@ -116,9 +115,9 @@ pub struct IngredientImpl { } /// True if `old_value == new_value`. Invoked by the generated -/// code for `should_backdate_value` so as to give a better +/// code for `values_equal` so as to give a better /// error message. -pub fn should_backdate_value(old_value: &V, new_value: &V) -> bool { +pub fn values_equal(old_value: &V, new_value: &V) -> bool { old_value == new_value } diff --git a/src/function/backdate.rs b/src/function/backdate.rs index bfca6f050..7eff1b3d4 100644 --- a/src/function/backdate.rs +++ b/src/function/backdate.rs @@ -21,7 +21,7 @@ where // consumers must be aware of. Becoming *more* durable // is not. See the test `constant_to_non_constant`. if revisions.durability >= old_memo.revisions.durability - && C::should_backdate_value(old_value, value) + && C::values_equal(old_value, value) { tracing::debug!( "value is equal, back-dating to {:?}", diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 9cf9b0c2c..6e3797683 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -77,33 +77,72 @@ where self.memo_ingredient_index, )?; - // Push the query on the stack. - let active_query = zalsa_local.push_query(database_key_index); - // Now that we've claimed the item, check again to see if there's a "hot" value. - let zalsa = db.zalsa(); let opt_old_memo = self.get_memo_from_table_for(zalsa, id); if let Some(old_memo) = &opt_old_memo { - if old_memo.value.is_some() && self.deep_verify_memo(db, old_memo, &active_query) { - // Unsafety invariant: memo is present in memo_map. - unsafe { - return Some(self.extend_memo_lifetime(old_memo)); + if old_memo.value.is_some() { + let active_query = zalsa_local.push_query(database_key_index); + if self.deep_verify_memo(db, old_memo, &active_query) { + // Unsafety invariant: memo is present in memo_map. + unsafe { + return Some(self.extend_memo_lifetime(old_memo)); + } } } } + let revision_now = zalsa.current_revision(); - if let Some(initial_value) = self.initial_value(db) { - self.insert_memo( + let mut opt_last_provisional = if let Some(initial_value) = self.initial_value(db) { + Some(self.insert_memo( zalsa, id, Memo::new( Some(initial_value), - zalsa.current_revision(), + revision_now, QueryRevisions::fixpoint_initial(database_key_index), ), - ); - } + )) + } else { + None + }; + let mut iteration_count = 0; - Some(self.execute(db, active_query, opt_old_memo)) + loop { + let active_query = zalsa_local.push_query(database_key_index); + let mut result = self.execute(db, active_query, opt_old_memo.clone()); + + if result.in_cycle(database_key_index) { + if let Some(last_provisional) = opt_last_provisional { + match (&result.value, &last_provisional.value) { + (Some(result_value), Some(provisional_value)) + if !C::values_equal(result_value, provisional_value) => + { + match C::recover_from_cycle(db, result_value, iteration_count) { + crate::CycleRecoveryAction::Iterate => { + iteration_count += 1; + opt_last_provisional = Some(result); + continue; + } + crate::CycleRecoveryAction::Fallback(value) => { + result = self.insert_memo( + zalsa, + id, + Memo::new( + Some(value), + revision_now, + result.revisions.clone(), + ), + ); + } + } + } + _ => {} + } + } + // This is no longer a provisional result, it's our real result, so remove ourselves + // from the cycle heads. + } + return Some(result); + } } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index ea4ac9dc0..3e67eda36 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -131,8 +131,7 @@ where /// /// Takes an [`ActiveQueryGuard`] argument because this function recursively /// walks dependencies of `old_memo` and may even execute them to see if their - /// outputs have changed. As that could lead to cycles, it is important that the - /// query is on the stack. + /// outputs have changed. pub(super) fn deep_verify_memo( &self, db: &C::DbView, diff --git a/src/function/memo.rs b/src/function/memo.rs index 8bddd5fa1..4a4dfe256 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -158,6 +158,10 @@ impl Memo { } } + pub(super) fn in_cycle(&self, database_key_index: DatabaseKeyIndex) -> bool { + self.revisions.cycle_heads.contains(&database_key_index) + } + pub(super) fn tracing_debug(&self) -> impl std::fmt::Debug + '_ { struct TracingDebug<'a, T> { memo: &'a Memo, diff --git a/src/lib.rs b/src/lib.rs index 59de46ffc..9e9bcf6dd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,7 +74,7 @@ pub mod plumbing { pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; pub use crate::database::Database; - pub use crate::function::should_backdate_value; + pub use crate::function::values_equal; pub use crate::id::AsId; pub use crate::id::FromId; pub use crate::id::Id; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 25d50353c..ccbd7d13f 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -332,7 +332,8 @@ pub(crate) struct QueryRevisions { pub(super) accumulated: AccumulatedMap, - /// Active cycle heads, if this result was created in cycle iteration + /// This result was computed based on provisional cycle-iteration + /// results from these queries. pub(super) cycle_heads: FxHashSet, } diff --git a/tests/cycle_fixpoint.rs b/tests/cycle_fixpoint.rs index 6d9d00ce6..196f762da 100644 --- a/tests/cycle_fixpoint.rs +++ b/tests/cycle_fixpoint.rs @@ -74,7 +74,7 @@ fn cycle_initial<'db>(_db: &'db dyn Db) -> Type { Type::Bottom } -fn cycle_recover<'db>(_db: &'db dyn Db, value: Type, count: u32) -> CycleRecoveryAction { +fn cycle_recover<'db>(_db: &'db dyn Db, value: &Type, count: u32) -> CycleRecoveryAction { match value { Type::Bottom => CycleRecoveryAction::Iterate, Type::Values(_) => { From 797655e1c305607e66d690a03fefc41d8c650656 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 12:36:18 -0700 Subject: [PATCH 13/35] WIP: move loop into execute --- src/function/execute.rs | 99 ++++++++++++++++++++++------- src/function/fetch.rs | 59 +---------------- src/function/maybe_changed_after.rs | 5 +- 3 files changed, 80 insertions(+), 83 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index b1247d14d..2b2179df4 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,6 +1,8 @@ use std::sync::Arc; -use crate::{zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Database, Event, EventKind}; +use crate::{ + zalsa::ZalsaDatabase, zalsa_local::QueryRevisions, Database, DatabaseKeyIndex, Event, EventKind, +}; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -20,12 +22,11 @@ where pub(super) fn execute<'db>( &'db self, db: &'db C::DbView, - active_query: ActiveQueryGuard<'_>, + database_key_index: DatabaseKeyIndex, opt_old_memo: Option>>>, ) -> &'db Memo> { - let zalsa = db.zalsa(); + let (zalsa, zalsa_local) = db.zalsas(); let revision_now = zalsa.current_revision(); - let database_key_index = active_query.database_key_index; let id = database_key_index.key_index; tracing::info!("{:?}: executing query", database_key_index); @@ -37,28 +38,80 @@ where }, }); - // If we already executed this query once, then use the tracked-struct ids from the - // previous execution as the starting point for the new one. - if let Some(old_memo) = &opt_old_memo { - active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); - } + let mut opt_last_provisional = if let Some(initial_value) = self.initial_value(db) { + Some(self.insert_memo( + zalsa, + id, + Memo::new( + Some(initial_value), + revision_now, + QueryRevisions::fixpoint_initial(database_key_index), + ), + )) + } else { + None + }; - // Query was not previously executed, or value is potentially - // stale, or value is absent. Let's execute! - let value = C::execute(db, C::id_to_input(db, id)); - let mut revisions = active_query.pop(); + let mut iteration_count = 0; - // If the new value is equal to the old one, then it didn't - // really change, even if some of its inputs have. So we can - // "backdate" its `changed_at` revision to be the same as the - // old value. - if let Some(old_memo) = &opt_old_memo { - self.backdate_if_appropriate(old_memo, &mut revisions, &value); - self.diff_outputs(db, database_key_index, old_memo, &mut revisions); - } + loop { + let active_query = zalsa_local.push_query(database_key_index); + + // If we already executed this query once, then use the tracked-struct ids from the + // previous execution as the starting point for the new one. + if let Some(old_memo) = &opt_old_memo { + active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); + } - tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}"); + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! + let value = C::execute(db, C::id_to_input(db, id)); + let mut revisions = active_query.pop(); - self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)) + // If the new value is equal to the old one, then it didn't + // really change, even if some of its inputs have. So we can + // "backdate" its `changed_at` revision to be the same as the + // old value. + if let Some(old_memo) = &opt_old_memo { + self.backdate_if_appropriate(old_memo, &mut revisions, &value); + self.diff_outputs(db, database_key_index, old_memo, &mut revisions); + } + + let mut result = + self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)); + + if result.in_cycle(database_key_index) { + if let Some(last_provisional) = opt_last_provisional { + match (&result.value, &last_provisional.value) { + (Some(result_value), Some(provisional_value)) + if !C::values_equal(result_value, provisional_value) => + { + match C::recover_from_cycle(db, result_value, iteration_count) { + crate::CycleRecoveryAction::Iterate => { + iteration_count += 1; + opt_last_provisional = Some(result); + continue; + } + crate::CycleRecoveryAction::Fallback(value) => { + result = self.insert_memo( + zalsa, + id, + Memo::new( + Some(value), + revision_now, + result.revisions.clone(), + ), + ); + } + } + } + _ => {} + } + } + // This is no longer a provisional result, it's our real result, so remove ourselves + // from the cycle heads. + } + return result; + } } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 6e3797683..38eff1768 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,9 +1,6 @@ use super::{memo::Memo, Configuration, IngredientImpl}; use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::{ - runtime::StampedValue, zalsa::ZalsaDatabase, zalsa_local::QueryRevisions, AsDynDatabase as _, - Id, -}; +use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id}; impl IngredientImpl where @@ -90,59 +87,7 @@ where } } } - let revision_now = zalsa.current_revision(); - let mut opt_last_provisional = if let Some(initial_value) = self.initial_value(db) { - Some(self.insert_memo( - zalsa, - id, - Memo::new( - Some(initial_value), - revision_now, - QueryRevisions::fixpoint_initial(database_key_index), - ), - )) - } else { - None - }; - let mut iteration_count = 0; - - loop { - let active_query = zalsa_local.push_query(database_key_index); - let mut result = self.execute(db, active_query, opt_old_memo.clone()); - - if result.in_cycle(database_key_index) { - if let Some(last_provisional) = opt_last_provisional { - match (&result.value, &last_provisional.value) { - (Some(result_value), Some(provisional_value)) - if !C::values_equal(result_value, provisional_value) => - { - match C::recover_from_cycle(db, result_value, iteration_count) { - crate::CycleRecoveryAction::Iterate => { - iteration_count += 1; - opt_last_provisional = Some(result); - continue; - } - crate::CycleRecoveryAction::Fallback(value) => { - result = self.insert_memo( - zalsa, - id, - Memo::new( - Some(value), - revision_now, - result.revisions.clone(), - ), - ); - } - } - } - _ => {} - } - } - // This is no longer a provisional result, it's our real result, so remove ourselves - // from the cycle heads. - } - return Some(result); - } + Some(self.execute(db, database_key_index, opt_old_memo)) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 3e67eda36..30b02277a 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -59,8 +59,6 @@ where database_key_index, self.memo_ingredient_index, )?; - let active_query = zalsa_local.push_query(database_key_index); - // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { return Some(true); @@ -73,6 +71,7 @@ where ); // Check if the inputs are still valid and we can just compare `changed_at`. + let active_query = zalsa_local.push_query(database_key_index); if self.deep_verify_memo(db, &old_memo, &active_query) { return Some(old_memo.revisions.changed_at > revision); } @@ -82,7 +81,7 @@ where // backdated. In that case, although we will have computed a new memo, // the value has not logically changed. if old_memo.value.is_some() { - let memo = self.execute(db, active_query, Some(old_memo)); + let memo = self.execute(db, database_key_index, Some(old_memo)); let changed_at = memo.revisions.changed_at; return Some(changed_at > revision); } From 27742a216f979ef6629c6443eca9950949fab905 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 12:51:00 -0700 Subject: [PATCH 14/35] WIP: delay storing memo --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 2 +- src/function.rs | 2 +- src/function/execute.rs | 42 ++++++++----------- 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index aad76935d..184ccb50c 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -183,7 +183,7 @@ macro_rules! setup_tracked_fn { fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - value: &$db_lt Self::Output<$db_lt>, + value: &Self::Output<$db_lt>, count: u32, ) -> $zalsa::CycleRecoveryAction> { $($cycle_recovery_fn)*(db, value, count) diff --git a/src/function.rs b/src/function.rs index c8af0d32a..225a9dead 100644 --- a/src/function.rs +++ b/src/function.rs @@ -72,7 +72,7 @@ pub trait Configuration: Any { /// Decide whether to iterate a cycle again or fallback. fn recover_from_cycle<'db>( db: &'db Self::DbView, - value: &'db Self::Output<'db>, + value: &Self::Output<'db>, count: u32, ) -> CycleRecoveryAction>; } diff --git a/src/function/execute.rs b/src/function/execute.rs index 2b2179df4..54efadfb1 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -65,7 +65,7 @@ where // Query was not previously executed, or value is potentially // stale, or value is absent. Let's execute! - let value = C::execute(db, C::id_to_input(db, id)); + let mut new_value = C::execute(db, C::id_to_input(db, id)); let mut revisions = active_query.pop(); // If the new value is equal to the old one, then it didn't @@ -73,45 +73,39 @@ where // "backdate" its `changed_at` revision to be the same as the // old value. if let Some(old_memo) = &opt_old_memo { - self.backdate_if_appropriate(old_memo, &mut revisions, &value); + self.backdate_if_appropriate(old_memo, &mut revisions, &new_value); self.diff_outputs(db, database_key_index, old_memo, &mut revisions); } - let mut result = - self.insert_memo(zalsa, id, Memo::new(Some(value), revision_now, revisions)); - - if result.in_cycle(database_key_index) { + if revisions.cycle_heads.contains(&database_key_index) { if let Some(last_provisional) = opt_last_provisional { - match (&result.value, &last_provisional.value) { - (Some(result_value), Some(provisional_value)) - if !C::values_equal(result_value, provisional_value) => - { - match C::recover_from_cycle(db, result_value, iteration_count) { + if let Some(provisional_value) = &last_provisional.value { + if !C::values_equal(&new_value, provisional_value) { + match C::recover_from_cycle(db, &new_value, iteration_count) { crate::CycleRecoveryAction::Iterate => { iteration_count += 1; - opt_last_provisional = Some(result); - continue; - } - crate::CycleRecoveryAction::Fallback(value) => { - result = self.insert_memo( + opt_last_provisional = Some(self.insert_memo( zalsa, id, - Memo::new( - Some(value), - revision_now, - result.revisions.clone(), - ), - ); + Memo::new(Some(new_value), revision_now, revisions), + )); + continue; + } + crate::CycleRecoveryAction::Fallback(fallback_value) => { + new_value = fallback_value; } } } - _ => {} } } // This is no longer a provisional result, it's our real result, so remove ourselves // from the cycle heads. } - return result; + return self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + ); } } } From 4f4df72be37784ce4d266b23487c942c6de35a0f Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 13:00:21 -0700 Subject: [PATCH 15/35] WIP: remove ourself from cycle heads when done iterating --- src/active_query.rs | 2 +- src/function/execute.rs | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/active_query.rs b/src/active_query.rs index 036fce6c9..ba5ddc35c 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -52,7 +52,7 @@ pub(crate) struct ActiveQuery { /// The type of accumulated value is erased but known to the ingredient. pub(crate) accumulated: AccumulatedMap, - /// Heads of cycles that this result was generated inside. + /// Provisional cycle results that this query depends on. pub(crate) cycle_heads: FxHashSet, } diff --git a/src/function/execute.rs b/src/function/execute.rs index 54efadfb1..611bb2389 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -77,10 +77,15 @@ where self.diff_outputs(db, database_key_index, old_memo, &mut revisions); } + // Did the new result we got depend on our own provisional value, in a cycle? if revisions.cycle_heads.contains(&database_key_index) { if let Some(last_provisional) = opt_last_provisional { if let Some(provisional_value) = &last_provisional.value { + // If the new result is equal to the last provisional result, the cycle has + // converged and we are done. if !C::values_equal(&new_value, provisional_value) { + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: match C::recover_from_cycle(db, &new_value, iteration_count) { crate::CycleRecoveryAction::Iterate => { iteration_count += 1; @@ -98,8 +103,9 @@ where } } } - // This is no longer a provisional result, it's our real result, so remove ourselves - // from the cycle heads. + // This is no longer a provisional result, it's our final result, so remove + // ourself from the cycle heads. + revisions.cycle_heads.remove(&database_key_index); } return self.insert_memo( zalsa, From 315944e3ebe6e4e35d9cec61b78e5a7b821d15a7 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 13:46:10 -0700 Subject: [PATCH 16/35] WIP: working convergence and fallback --- .../src/unexpected_cycle_recovery.rs | 3 ++- src/active_query.rs | 1 + src/function/execute.rs | 13 +++++++++++-- src/function/maybe_changed_after.rs | 6 ++++++ src/function/specify.rs | 1 + src/table/sync.rs | 2 -- src/zalsa_local.rs | 5 +++++ 7 files changed, 26 insertions(+), 5 deletions(-) diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index 8e18a8976..cf8bbce13 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -1,5 +1,6 @@ // Macro that generates the body of the cycle recovery function -// for the case where no cycle recovery is possible. +// for the case where no cycle recovery is possible. Must be a macro +// because the signature types must match the particular tracked function. #[macro_export] macro_rules! unexpected_cycle_recovery { ($db:ident, $value:ident, $count:ident) => {{ diff --git a/src/active_query.rs b/src/active_query.rs index ba5ddc35c..157bec117 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -129,6 +129,7 @@ impl ActiveQuery { durability: self.durability, tracked_struct_ids: self.tracked_struct_ids, accumulated: self.accumulated, + cycle_ignore: !self.cycle_heads.is_empty(), cycle_heads: self.cycle_heads, } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 611bb2389..f6ecaecc2 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -89,6 +89,7 @@ where match C::recover_from_cycle(db, &new_value, iteration_count) { crate::CycleRecoveryAction::Iterate => { iteration_count += 1; + revisions.cycle_ignore = false; opt_last_provisional = Some(self.insert_memo( zalsa, id, @@ -103,9 +104,17 @@ where } } } - // This is no longer a provisional result, it's our final result, so remove - // ourself from the cycle heads. + // This is no longer a provisional result, it's our final result, so remove ourself + // from the cycle heads, and iterate one last time to remove ourself from all other + // results in the cycle as well. revisions.cycle_heads.remove(&database_key_index); + revisions.cycle_ignore = false; + self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + ); + continue; } return self.insert_memo( zalsa, diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 30b02277a..0a0eaca77 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -100,6 +100,9 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo>, ) -> bool { + if memo.revisions.cycle_ignore { + return false; + } let verified_at = memo.verified_at.load(); let revision_now = zalsa.current_revision(); @@ -137,6 +140,9 @@ where old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, ) -> bool { + if old_memo.revisions.cycle_ignore { + return false; + } let zalsa = db.zalsa(); let database_key_index = active_query.database_key_index; diff --git a/src/function/specify.rs b/src/function/specify.rs index fa5e04278..f5803b3dc 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -71,6 +71,7 @@ where tracked_struct_ids: Default::default(), accumulated: Default::default(), cycle_heads: Default::default(), + cycle_ignore: false, }; if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { diff --git a/src/table/sync.rs b/src/table/sync.rs index ee72353c6..dfe78a23a 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -1,6 +1,4 @@ use std::{ - any::Any, - fmt::Debug, sync::atomic::{AtomicBool, Ordering}, thread::ThreadId, }; diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index ccbd7d13f..86e49f7db 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -335,6 +335,10 @@ pub(crate) struct QueryRevisions { /// This result was computed based on provisional cycle-iteration /// results from these queries. pub(super) cycle_heads: FxHashSet, + + /// True if this result is based on provisional results, and is not itself a cycle head; this + /// should not be used as a cached result. + pub(super) cycle_ignore: bool, } impl QueryRevisions { @@ -348,6 +352,7 @@ impl QueryRevisions { tracked_struct_ids: Default::default(), accumulated: Default::default(), cycle_heads, + cycle_ignore: false, } } From 28242a43d3f05713fcb8ee3f0260ab0467179822 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Wed, 23 Oct 2024 14:08:09 -0700 Subject: [PATCH 17/35] WIP: clippy and cleanup --- src/function/execute.rs | 10 ++++------ src/function/memo.rs | 4 ---- src/key.rs | 9 +-------- src/zalsa.rs | 10 ---------- tests/cycle_fixpoint.rs | 4 ++-- tests/parallel/setup.rs | 6 ------ 6 files changed, 7 insertions(+), 36 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index f6ecaecc2..6b7290a27 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -38,8 +38,8 @@ where }, }); - let mut opt_last_provisional = if let Some(initial_value) = self.initial_value(db) { - Some(self.insert_memo( + let mut opt_last_provisional = self.initial_value(db).map(|initial_value| { + self.insert_memo( zalsa, id, Memo::new( @@ -47,10 +47,8 @@ where revision_now, QueryRevisions::fixpoint_initial(database_key_index), ), - )) - } else { - None - }; + ) + }); let mut iteration_count = 0; diff --git a/src/function/memo.rs b/src/function/memo.rs index 4a4dfe256..8bddd5fa1 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -158,10 +158,6 @@ impl Memo { } } - pub(super) fn in_cycle(&self, database_key_index: DatabaseKeyIndex) -> bool { - self.revisions.cycle_heads.contains(&database_key_index) - } - pub(super) fn tracing_debug(&self) -> impl std::fmt::Debug + '_ { struct TracingDebug<'a, T> { memo: &'a Memo, diff --git a/src/key.rs b/src/key.rs index 92e63541d..de84f710e 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,7 +1,4 @@ -use crate::{ - accumulator::accumulated_map::AccumulatedMap, cycle::CycleRecoveryStrategy, - zalsa::IngredientIndex, Database, Id, -}; +use crate::{accumulator::accumulated_map::AccumulatedMap, zalsa::IngredientIndex, Database, Id}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and @@ -96,10 +93,6 @@ impl DatabaseKeyIndex { self.key_index } - pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - self.ingredient_index.cycle_recovery_strategy(db) - } - pub(crate) fn accumulated(self, db: &dyn Database) -> Option<&AccumulatedMap> { db.zalsa() .lookup_ingredient(self.ingredient_index) diff --git a/src/zalsa.rs b/src/zalsa.rs index e92e28919..04ee1c405 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -5,7 +5,6 @@ use std::any::{Any, TypeId}; use std::marker::PhantomData; use std::thread::ThreadId; -use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{Ingredient, Jar, JarAux}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::{Runtime, WaitResult}; @@ -86,18 +85,9 @@ impl IngredientIndex { self.0 as usize } - pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - db.zalsa().lookup_ingredient(self).cycle_recovery_strategy() - } - pub fn successor(self, index: usize) -> Self { IngredientIndex(self.0 + 1 + index as u32) } - - /// Return the "debug name" of this ingredient (e.g., the name of the tracked struct it represents) - pub(crate) fn debug_name(self, db: &dyn Database) -> &'static str { - db.zalsa().lookup_ingredient(self).debug_name() - } } /// A special secondary index *just* for ingredients that attach diff --git a/tests/cycle_fixpoint.rs b/tests/cycle_fixpoint.rs index 196f762da..b792741b9 100644 --- a/tests/cycle_fixpoint.rs +++ b/tests/cycle_fixpoint.rs @@ -70,11 +70,11 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } -fn cycle_initial<'db>(_db: &'db dyn Db) -> Type { +fn cycle_initial(_db: &dyn Db) -> Type { Type::Bottom } -fn cycle_recover<'db>(_db: &'db dyn Db, value: &Type, count: u32) -> CycleRecoveryAction { +fn cycle_recover(_db: &dyn Db, value: &Type, count: u32) -> CycleRecoveryAction { match value { Type::Bottom => CycleRecoveryAction::Iterate, Type::Values(_) => { diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index 56d204eea..c266731a0 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -9,8 +9,6 @@ use crate::signal::Signal; /// a certain behavior. #[salsa::db] pub(crate) trait KnobsDatabase: Database { - fn knobs(&self) -> &Knobs; - fn signal(&self, stage: usize); fn wait_for(&self, stage: usize); @@ -68,10 +66,6 @@ impl salsa::Database for Knobs { #[salsa::db] impl KnobsDatabase for Knobs { - fn knobs(&self) -> &Knobs { - self - } - fn signal(&self, stage: usize) { self.signal.signal(stage); } From 0be825ab20441ea1fff09007bfe6e4558ad47cbe Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 24 Oct 2024 12:39:24 -0700 Subject: [PATCH 18/35] WIP: improve comments and add a type annotation --- src/function/execute.rs | 4 +++- src/zalsa_local.rs | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 6b7290a27..00e85067f 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -38,6 +38,8 @@ where }, }); + // If this query supports fixpoint iteration, populate the memo table with our initial + // value, in case some other query we call ends up calling us back. let mut opt_last_provisional = self.initial_value(db).map(|initial_value| { self.insert_memo( zalsa, @@ -50,7 +52,7 @@ where ) }); - let mut iteration_count = 0; + let mut iteration_count: u32 = 0; loop { let active_query = zalsa_local.push_query(database_key_index); diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 86e49f7db..363e3bd66 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -332,12 +332,20 @@ pub(crate) struct QueryRevisions { pub(super) accumulated: AccumulatedMap, - /// This result was computed based on provisional cycle-iteration - /// results from these queries. + /// This result was computed based on provisional values from + /// these cycle heads. The "cycle head" is the query responsible + /// for managing a fixpoint iteration. In a cycle like + /// `--> A --> B --> C --> A`, the cycle head is query `A`: it is + /// the query whose value is requested while it is executing, + /// which must provide the initial provisional value and decide, + /// after each iteration, whether the cycle has converged or must + /// iterate again. pub(super) cycle_heads: FxHashSet, - /// True if this result is based on provisional results, and is not itself a cycle head; this - /// should not be used as a cached result. + /// True if this result is based on provisional results of other + /// queries, and is not created explicitly by the query managing + /// a fixpoint iteration (the "cycle head"); this should never be + /// treated as a valid cached result. pub(super) cycle_ignore: bool, } From c3c84c4f0b0af747f925411f78518b08d56b565e Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 24 Oct 2024 12:48:09 -0700 Subject: [PATCH 19/35] WIP: don't allow cycle_fn with no_eq --- components/salsa-macros/src/tracked_fn.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 897a7c33d..74cc3bcae 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -73,7 +73,17 @@ impl Macro { let (cycle_recovery_fn, cycle_recovery_initial, cycle_recovery_strategy) = self.cycle_recovery()?; let is_specifiable = self.args.specify.is_some(); - let no_eq = self.args.no_eq.is_some(); + let no_eq = if let Some(token) = &self.args.no_eq { + if self.args.cycle_fn.is_some() { + return Err(syn::Error::new_spanned( + token, + "the `no_eq` option cannot be used with `cycle_fn`", + )); + } + true + } else { + false + }; let mut inner_fn = item.clone(); inner_fn.vis = syn::Visibility::Inherited; From 72dff5fa218967f29fee314530f19ae0057bb3e0 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 24 Oct 2024 13:05:06 -0700 Subject: [PATCH 20/35] WIP: add tracing for cycle iteration --- src/function/execute.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 00e85067f..b8a5f1a18 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -81,6 +81,11 @@ where if revisions.cycle_heads.contains(&database_key_index) { if let Some(last_provisional) = opt_last_provisional { if let Some(provisional_value) = &last_provisional.value { + tracing::debug!( + "{database_key_index:?}: execute: \ + I am a cycle head, comparing last provisional value \ + {provisional_value:#?} with new value {new_value:#?}" + ); // If the new result is equal to the last provisional result, the cycle has // converged and we are done. if !C::values_equal(&new_value, provisional_value) { @@ -88,6 +93,9 @@ where // cycle-recovery function what to do: match C::recover_from_cycle(db, &new_value, iteration_count) { crate::CycleRecoveryAction::Iterate => { + tracing::debug!( + "{database_key_index:?}: execute: iterate again" + ); iteration_count += 1; revisions.cycle_ignore = false; opt_last_provisional = Some(self.insert_memo( @@ -98,6 +106,9 @@ where continue; } crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( + "{database_key_index:?}: execute: fall back to {fallback_value:#?}" + ); new_value = fallback_value; } } @@ -106,7 +117,7 @@ where } // This is no longer a provisional result, it's our final result, so remove ourself // from the cycle heads, and iterate one last time to remove ourself from all other - // results in the cycle as well. + // results in the cycle as well and turn them into usable cached results. revisions.cycle_heads.remove(&database_key_index); revisions.cycle_ignore = false; self.insert_memo( @@ -116,6 +127,9 @@ where ); continue; } + + tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); + return self.insert_memo( zalsa, id, From 6b44c925d58389ebac0dc5bb5ffe5045d8e52a2c Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 24 Oct 2024 13:34:10 -0700 Subject: [PATCH 21/35] WIP: fail fast if we get an evicted provisional value --- src/function/execute.rs | 53 ++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index b8a5f1a18..b607d8863 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -80,37 +80,36 @@ where // Did the new result we got depend on our own provisional value, in a cycle? if revisions.cycle_heads.contains(&database_key_index) { if let Some(last_provisional) = opt_last_provisional { - if let Some(provisional_value) = &last_provisional.value { - tracing::debug!( - "{database_key_index:?}: execute: \ + // Memo value can only be `None` if LRU evicted; TODO should we explicitly + // prevent LRU eviction of cycle-head provisional memos? + let provisional_value = last_provisional.value.as_ref().unwrap(); + tracing::debug!( + "{database_key_index:?}: execute: \ I am a cycle head, comparing last provisional value \ {provisional_value:#?} with new value {new_value:#?}" - ); - // If the new result is equal to the last provisional result, the cycle has - // converged and we are done. - if !C::values_equal(&new_value, provisional_value) { - // We are in a cycle that hasn't converged; ask the user's - // cycle-recovery function what to do: - match C::recover_from_cycle(db, &new_value, iteration_count) { - crate::CycleRecoveryAction::Iterate => { - tracing::debug!( - "{database_key_index:?}: execute: iterate again" - ); - iteration_count += 1; - revisions.cycle_ignore = false; - opt_last_provisional = Some(self.insert_memo( - zalsa, - id, - Memo::new(Some(new_value), revision_now, revisions), - )); - continue; - } - crate::CycleRecoveryAction::Fallback(fallback_value) => { - tracing::debug!( + ); + // If the new result is equal to the last provisional result, the cycle has + // converged and we are done. + if !C::values_equal(&new_value, provisional_value) { + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle(db, &new_value, iteration_count) { + crate::CycleRecoveryAction::Iterate => { + tracing::debug!("{database_key_index:?}: execute: iterate again"); + iteration_count += 1; + revisions.cycle_ignore = false; + opt_last_provisional = Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + )); + continue; + } + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( "{database_key_index:?}: execute: fall back to {fallback_value:#?}" ); - new_value = fallback_value; - } + new_value = fallback_value; } } } From a029ef414dd1a0d88b6933df54ab2123eaffe72b Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 24 Oct 2024 13:40:22 -0700 Subject: [PATCH 22/35] WIP: use FxHashSet::from_iter --- src/zalsa_local.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 363e3bd66..706015f1e 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -351,8 +351,7 @@ pub(crate) struct QueryRevisions { impl QueryRevisions { pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex) -> Self { - let mut cycle_heads = FxHashSet::default(); - cycle_heads.insert(query); + let cycle_heads = FxHashSet::from_iter([query]); Self { changed_at: Revision::start(), durability: Durability::MAX, From 492ae1bae19c9476cb391caf9f4c1a84b766b7f4 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 29 Oct 2024 17:35:09 -0700 Subject: [PATCH 23/35] add tests, fix multiple-revision, lazy provisional value --- src/function/execute.rs | 106 +-- src/function/fetch.rs | 33 +- src/function/maybe_changed_after.rs | 18 +- src/runtime.rs | 18 +- src/table/sync.rs | 18 +- src/zalsa.rs | 8 +- .../{cycle_fixpoint.rs => cycle/dataflow.rs} | 5 +- tests/cycle/main.rs | 755 ++++++++++++++++++ 8 files changed, 892 insertions(+), 69 deletions(-) rename tests/{cycle_fixpoint.rs => cycle/dataflow.rs} (96%) create mode 100644 tests/cycle/main.rs diff --git a/src/function/execute.rs b/src/function/execute.rs index b607d8863..4db620f7a 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,8 +1,6 @@ use std::sync::Arc; -use crate::{ - zalsa::ZalsaDatabase, zalsa_local::QueryRevisions, Database, DatabaseKeyIndex, Event, EventKind, -}; +use crate::{zalsa::ZalsaDatabase, Database, DatabaseKeyIndex, Event, EventKind}; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -38,22 +36,13 @@ where }, }); - // If this query supports fixpoint iteration, populate the memo table with our initial - // value, in case some other query we call ends up calling us back. - let mut opt_last_provisional = self.initial_value(db).map(|initial_value| { - self.insert_memo( - zalsa, - id, - Memo::new( - Some(initial_value), - revision_now, - QueryRevisions::fixpoint_initial(database_key_index), - ), - ) - }); - let mut iteration_count: u32 = 0; + // Our provisional value from the previous iteration, when doing fixpoint iteration. + // Initially it's set to None, because the initial provisional value is created lazily, + // only when a cycle is actually encountered. + let mut opt_last_provisional: Option<&Memo<::Output<'db>>> = None; + loop { let active_query = zalsa_local.push_query(database_key_index); @@ -79,44 +68,65 @@ where // Did the new result we got depend on our own provisional value, in a cycle? if revisions.cycle_heads.contains(&database_key_index) { - if let Some(last_provisional) = opt_last_provisional { - // Memo value can only be `None` if LRU evicted; TODO should we explicitly - // prevent LRU eviction of cycle-head provisional memos? - let provisional_value = last_provisional.value.as_ref().unwrap(); - tracing::debug!( - "{database_key_index:?}: execute: \ + let opt_owned_last_provisional; + let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { + // We have a last provisional value from our previous time around the loop. + &last_provisional + .value + .as_ref() + .expect("provisional value evicted by LRU?") + } else { + // This is our first time around the loop; a provisional value must have been + // inserted into the memo table when the cycle was hit, so let's pull our + // initial provisional value from there. + opt_owned_last_provisional = self.get_memo_from_table_for(zalsa, id); + &opt_owned_last_provisional + .as_deref() + .expect( + "{database_key_index:#?} is a cycle head, \ + but no provisional memo found", + ) + .value + .as_ref() + .expect("provisional value evicted by LRU?") + }; + tracing::debug!( + "{database_key_index:?}: execute: \ I am a cycle head, comparing last provisional value \ - {provisional_value:#?} with new value {new_value:#?}" - ); - // If the new result is equal to the last provisional result, the cycle has - // converged and we are done. - if !C::values_equal(&new_value, provisional_value) { - // We are in a cycle that hasn't converged; ask the user's - // cycle-recovery function what to do: - match C::recover_from_cycle(db, &new_value, iteration_count) { - crate::CycleRecoveryAction::Iterate => { - tracing::debug!("{database_key_index:?}: execute: iterate again"); - iteration_count += 1; - revisions.cycle_ignore = false; - opt_last_provisional = Some(self.insert_memo( - zalsa, - id, - Memo::new(Some(new_value), revision_now, revisions), - )); - continue; - } - crate::CycleRecoveryAction::Fallback(fallback_value) => { - tracing::debug!( - "{database_key_index:?}: execute: fall back to {fallback_value:#?}" - ); - new_value = fallback_value; - } + {last_provisional_value:#?} with new value {new_value:#?}" + ); + // If the new result is equal to the last provisional result, the cycle has + // converged and we are done. + if !C::values_equal(&new_value, last_provisional_value) { + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle(db, &new_value, iteration_count) { + crate::CycleRecoveryAction::Iterate => { + tracing::debug!("{database_key_index:?}: execute: iterate again"); + iteration_count += 1; + revisions.cycle_ignore = false; + opt_last_provisional = Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + )); + continue; + } + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( + "{database_key_index:?}: execute: fall back to {fallback_value:#?}" + ); + new_value = fallback_value; } } } // This is no longer a provisional result, it's our final result, so remove ourself // from the cycle heads, and iterate one last time to remove ourself from all other // results in the cycle as well and turn them into usable cached results. + // TODO Can we avoid doing this? the extra cycle is quite expensive if there is a + // nested cycle. Maybe track the relevant memos and replace them all with the cycle + // head removed? Or just let them keep the cycle head and allow cycle memos to be + // used when we are not actually iterating the cycle for that head? revisions.cycle_heads.remove(&database_key_index); revisions.cycle_ignore = false; self.insert_memo( diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 38eff1768..251c0470f 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,6 +1,9 @@ use super::{memo::Memo, Configuration, IngredientImpl}; use crate::accumulator::accumulated_map::InputAccumulatedValues; -use crate::{runtime::StampedValue, zalsa::ZalsaDatabase, AsDynDatabase as _, Id}; +use crate::{ + runtime::StampedValue, table::sync::ClaimResult, zalsa::ZalsaDatabase, + zalsa_local::QueryRevisions, AsDynDatabase as _, Id, +}; impl IngredientImpl where @@ -67,12 +70,36 @@ where let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let _claim_guard = zalsa.sync_table_for(id).claim( + let _claim_guard = match zalsa.sync_table_for(id).claim( db.as_dyn_database(), zalsa_local, database_key_index, self.memo_ingredient_index, - )?; + ) { + ClaimResult::Retry => return None, + ClaimResult::Cycle => { + return self + .initial_value(db) + .map(|initial_value| { + self.insert_memo( + zalsa, + id, + Memo::new( + Some(initial_value), + zalsa.current_revision(), + QueryRevisions::fixpoint_initial(database_key_index), + ), + ) + }) + .or_else(|| { + panic!( + "dependency graph cycle querying {database_key_index:#?}; \ + set cycle_fn/cycle_initial to fixpoint iterate" + ) + }) + } + ClaimResult::Claimed(guard) => guard, + }; // Now that we've claimed the item, check again to see if there's a "hot" value. let opt_old_memo = self.get_memo_from_table_for(zalsa, id); diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 0a0eaca77..e0b33eb8c 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,5 +1,7 @@ use crate::{ + cycle::CycleRecoveryStrategy, key::DatabaseKeyIndex, + table::sync::ClaimResult, zalsa::{Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin}, AsDynDatabase as _, Id, Revision, @@ -53,12 +55,24 @@ where let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(key_index); - let _claim_guard = zalsa.sync_table_for(key_index).claim( + let _claim_guard = match zalsa.sync_table_for(key_index).claim( db.as_dyn_database(), zalsa_local, database_key_index, self.memo_ingredient_index, - )?; + ) { + ClaimResult::Retry => return None, + ClaimResult::Cycle => match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Panic => panic!( + "dependency graph cycle validating {database_key_index:#?}; \ + set cycle_fn/cycle_initial to fixpoint iterate" + ), + // If we hit a cycle in memo validation, but we support fixpoint iteration, just + // consider the memo changed so we'll re-run the iteration in this revision. + CycleRecoveryStrategy::Fixpoint => return Some(true), + }, + ClaimResult::Claimed(guard) => guard, + }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { return Some(true); diff --git a/src/runtime.rs b/src/runtime.rs index 1f15c3b5a..6567917ec 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -46,6 +46,12 @@ pub(crate) enum WaitResult { Panicked, } +#[derive(Clone, Debug)] +pub(crate) enum BlockResult { + Completed, + Cycle, +} + #[derive(Copy, Clone, Debug)] pub struct StampedValue { pub value: V, @@ -144,8 +150,8 @@ impl Runtime { r_new } - /// Block until `other_id` completes executing `database_key`; - /// panic or unwind in the case of a cycle. + /// Block until `other_id` completes executing `database_key`, or return `BlockResult::Cycle` + /// immediately in case of a cycle. /// /// `query_mutex_guard` is the guard for the current query's state; /// it will be dropped after we have successfully registered the @@ -155,19 +161,19 @@ impl Runtime { /// /// If the thread `other_id` panics, then our thread is considered /// cancelled, so this function will panic with a `Cancelled` value. - pub(crate) fn block_on_or_unwind( + pub(crate) fn block_on( &self, db: &dyn Database, local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> BlockResult { let mut dg = self.dependency_graph.lock(); let thread_id = std::thread::current().id(); if dg.depends_on(other_id, thread_id) { - panic!("unexpected dependency graph cycle"); + return BlockResult::Cycle; } db.salsa_event(&|| Event { @@ -192,7 +198,7 @@ impl Runtime { }); match result { - WaitResult::Completed => (), + WaitResult::Completed => BlockResult::Completed, // If the other thread panicked, then we consider this thread // cancelled. The assumption is that the panic will be detected diff --git a/src/table/sync.rs b/src/table/sync.rs index dfe78a23a..750779277 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -7,7 +7,7 @@ use parking_lot::RwLock; use crate::{ key::DatabaseKeyIndex, - runtime::WaitResult, + runtime::{BlockResult, WaitResult}, zalsa::{MemoIngredientIndex, Zalsa}, zalsa_local::ZalsaLocal, Database, @@ -30,6 +30,12 @@ struct SyncState { anyone_waiting: AtomicBool, } +pub(crate) enum ClaimResult<'a> { + Retry, + Cycle, + Claimed(ClaimGuard<'a>), +} + impl SyncTable { pub(crate) fn claim<'me>( &'me self, @@ -37,7 +43,7 @@ impl SyncTable { zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, - ) -> Option> { + ) -> ClaimResult<'me> { let mut syncs = self.syncs.write(); let zalsa = db.zalsa(); let thread_id = std::thread::current().id(); @@ -50,7 +56,7 @@ impl SyncTable { id: thread_id, anyone_waiting: AtomicBool::new(false), }); - Some(ClaimGuard { + ClaimResult::Claimed(ClaimGuard { database_key_index, memo_ingredient_index, zalsa, @@ -68,8 +74,10 @@ impl SyncTable { // boolean is to decide *whether* to acquire the lock, // not to gate future atomic reads. anyone_waiting.store(true, Ordering::Relaxed); - zalsa.block_on_or_unwind(db, zalsa_local, database_key_index, *other_id, syncs); - None + match zalsa.block_on(db, zalsa_local, database_key_index, *other_id, syncs) { + BlockResult::Completed => ClaimResult::Retry, + BlockResult::Cycle => ClaimResult::Cycle, + } } } } diff --git a/src/zalsa.rs b/src/zalsa.rs index 04ee1c405..34b9d0213 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -7,7 +7,7 @@ use std::thread::ThreadId; use crate::ingredient::{Ingredient, Jar, JarAux}; use crate::nonce::{Nonce, NonceGenerator}; -use crate::runtime::{Runtime, WaitResult}; +use crate::runtime::{BlockResult, Runtime, WaitResult}; use crate::table::memo::MemoTable; use crate::table::sync::SyncTable; use crate::table::Table; @@ -260,16 +260,16 @@ impl Zalsa { } /// See [`Runtime::block_on_or_unwind`][] - pub(crate) fn block_on_or_unwind( + pub(crate) fn block_on( &self, db: &dyn Database, local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> BlockResult { self.runtime - .block_on_or_unwind(db, local_state, database_key, other_id, query_mutex_guard) + .block_on(db, local_state, database_key, other_id, query_mutex_guard) } /// See [`Runtime::unblock_queries_blocked_on`][] diff --git a/tests/cycle_fixpoint.rs b/tests/cycle/dataflow.rs similarity index 96% rename from tests/cycle_fixpoint.rs rename to tests/cycle/dataflow.rs index b792741b9..e8ac8296f 100644 --- a/tests/cycle_fixpoint.rs +++ b/tests/cycle/dataflow.rs @@ -1,4 +1,7 @@ -/// Test case for fixpoint iteration cycle resolution. +//! Test case for fixpoint iteration cycle resolution. +//! +//! This test case is intended to simulate a (very simplified) version of a real dataflow analysis +//! using fixpoint iteration. use salsa::{CycleRecoveryAction, Database as Db, Setter}; use std::collections::BTreeSet; use std::iter::IntoIterator; diff --git a/tests/cycle/main.rs b/tests/cycle/main.rs new file mode 100644 index 000000000..66a5f1053 --- /dev/null +++ b/tests/cycle/main.rs @@ -0,0 +1,755 @@ +//! Test cases for fixpoint iteration cycle resolution. +//! +//! These test cases use a generic query setup that allows constructing arbitrary dependency +//! graphs, and attempts to achieve good coverage of various cases. +mod dataflow; + +use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; + +/// A vector of inputs a query can evaluate to get an iterator of u8 values to operate on. +/// +/// This allows creating arbitrary query graphs between the four queries below (`min_iterate`, +/// `max_iterate`, `min_panic`, `max_panic`) for testing cycle behaviors. +#[salsa::input] +struct Inputs { + inputs: Vec, +} + +impl Inputs { + fn values(self, db: &dyn Db) -> impl Iterator + '_ { + self.inputs(db).into_iter().map(|input| input.eval(db)) + } +} + +/// A single input, evaluating to a single u8 value. +#[derive(Clone, Debug)] +enum Input { + /// a simple value + Value(u8), + + /// a simple value, reported as an untracked read + UntrackedRead(u8), + + /// minimum of the given inputs, with fixpoint iteration on cycles + MinIterate(Inputs), + + /// maximum of the given inputs, with fixpoint iteration on cycles + MaxIterate(Inputs), + + /// minimum of the given inputs, panicking on cycles + MinPanic(Inputs), + + /// maximum of the given inputs, panicking on cycles + MaxPanic(Inputs), + + /// value of the given input, plus one + Successor(Box), +} + +impl Input { + fn eval(self, db: &dyn Db) -> u8 { + match self { + Self::Value(value) => value, + Self::UntrackedRead(value) => { + db.report_untracked_read(); + value + } + Self::MinIterate(inputs) => min_iterate(db, inputs), + Self::MaxIterate(inputs) => max_iterate(db, inputs), + Self::MinPanic(inputs) => min_panic(db, inputs), + Self::MaxPanic(inputs) => max_panic(db, inputs), + Self::Successor(input) => input.eval(db) + 1, + } + } + + fn assert(self, db: &dyn Db, expected: u8) { + assert_eq!(self.eval(db), expected) + } +} + +#[salsa::tracked(cycle_fn=min_recover, cycle_initial=min_initial)] +fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { + inputs.values(db).min().expect("empty inputs!") +} + +const MIN_COUNT_FALLBACK: u8 = 100; +const MIN_VALUE_FALLBACK: u8 = 5; +const MIN_VALUE: u8 = 10; + +fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction { + if *value < MIN_VALUE { + CycleRecoveryAction::Fallback(MIN_VALUE_FALLBACK) + } else if count > 10 { + CycleRecoveryAction::Fallback(MIN_COUNT_FALLBACK) + } else { + CycleRecoveryAction::Iterate + } +} + +fn min_initial(_db: &dyn Db) -> u8 { + 255 +} + +#[salsa::tracked(cycle_fn=max_recover, cycle_initial=max_initial)] +fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { + inputs.values(db).max().expect("empty inputs!") +} + +const MAX_COUNT_FALLBACK: u8 = 200; +const MAX_VALUE_FALLBACK: u8 = 250; +const MAX_VALUE: u8 = 245; + +fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction { + if *value > MAX_VALUE { + CycleRecoveryAction::Fallback(MAX_VALUE_FALLBACK) + } else if count > 10 { + CycleRecoveryAction::Fallback(MAX_COUNT_FALLBACK) + } else { + CycleRecoveryAction::Iterate + } +} + +fn max_initial(_db: &dyn Db) -> u8 { + 0 +} + +#[salsa::tracked] +fn min_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { + inputs.values(db).min().expect("empty inputs!") +} + +#[salsa::tracked] +fn max_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { + inputs.values(db).max().expect("empty inputs!") +} + +// Diagram nomenclature for nodes: Each node is represented as a:xx(ii), where `a` is a sequential +// identifier from `a`, `b`, `c`..., xx is one of the four query kinds: +// - `Ni` for `min_iterate` +// - `Xi` for `max_iterate` +// - `Np` for `min_panic` +// - `Xp` for `max_panic` +// +// and `ii` is the inputs for that query, represented as a comma-separated list, with each +// component representing an input: +// - `a`, `b`, `c`... where the input is another node, +// - `uXX` for `UntrackedRead(XX)` +// - `vXX` for `Value(XX)` +// - `sY` for `Successor(Y)` +// +// We always enter from the top left node in the diagram. + +/// a:Np(a) -+ +/// ^ | +/// +--------+ +/// +/// Simple self-cycle, no iteration, should panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn self_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + a_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.eval(&db); +} + +/// a:Np(u10, a) -+ +/// ^ | +/// +-------------+ +/// +/// Simple self-cycle with untracked read, no iteration, should panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn self_untracked_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + a_in.set_inputs(&mut db) + .to(vec![Input::UntrackedRead(10), a.clone()]); + + a.eval(&db); +} + +/// a:Ni(a) -+ +/// ^ | +/// +--------+ +/// +/// Simple self-cycle, iteration converges on initial value. +#[test] +fn self_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + a_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 255); +} + +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, one with iteration and one without. +/// If we enter from the one with iteration, we converge on its initial value. +#[test] +fn two_mixed_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 255); +} + +/// a:Np(b) --> b:Ni(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, one with iteration and one without. +/// If we enter from the one with no iteration, we panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn two_mixed_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(b_in); + let b = Input::MinIterate(a_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.eval(&db); +} + +/// a:Ni(b) --> b:Xi(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, both with iteration. +/// We converge on the initial value of whichever we first enter from. +#[test] +fn two_iterate_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MaxIterate(b_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 255); + b.assert(&db, 255); +} + +/// a:Xi(b) --> b:Ni(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, both with iteration. +/// We converge on the initial value of whichever we enter from. +/// (Same setup as above test, different query order.) +#[test] +fn two_iterate_converge_initial_value_2() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MinIterate(b_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 0); + b.assert(&db, 0); +} + +/// a:Np(b) --> b:Ni(c) --> c:Xp(b) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, enter indirectly at node with iteration, converge on its initial value. +#[test] +fn two_indirect_iterate_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.assert(&db, 255); +} + +/// a:Xp(b) --> b:Np(c) --> c:Xi(b) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, enter indirectly at node without iteration, panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn two_indirect_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MaxIterate(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.eval(&db); +} + +/// a:Np(b) -> b:Ni(v250,c) -> c:Xp(b) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, converges to non-initial value. +#[test] +fn two_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![Input::Value(250), c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.assert(&db, 250); +} + +/// a:Xp(b) -> b:Xi(v10,c) -> c:Xp(sb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back due to >10 iterations. +#[test] +fn two_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![Input::Value(10), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Successor(Box::new(b))]); + + a.assert(&db, MAX_COUNT_FALLBACK + 1); +} + +/// a:Xp(b) -> b:Xi(v241,c) -> c:Xp(sb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back due to value reaching >MAX_VALUE (we start at 241 and each +/// iteration increments until we reach >245). +#[test] +fn two_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![Input::Value(241), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Successor(Box::new(b))]); + + a.assert(&db, MAX_VALUE_FALLBACK + 1); +} + +/// a:Ni(b) -> b:Np(a, c) -> c:Np(v25, a) +/// ^ | | +/// +----------+------------------------+ +/// +/// Three-query cycle, (b) and (c) both depend on (a). We converge on 25. +#[test] +fn three_fork_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), a.clone()]); + + a.assert(&db, 25); +} + +/// a:Ni(b) -> b:Ni(a, c) -> c:Np(v25, b) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We converge on 25. +#[test] +fn layered_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db).to(vec![Input::Value(25), b]); + + a.assert(&db, 25); +} + +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v25, sb) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We hit max iterations and fall back. +#[test] +fn layered_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), Input::Successor(Box::new(b))]); + + a.assert(&db, MAX_COUNT_FALLBACK + 1); +} + +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v240, sb) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We hit max value and fall back. +#[test] +fn layered_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(240), Input::Successor(Box::new(b))]); + + a.assert(&db, MAX_VALUE_FALLBACK + 1); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Np(v25, a, b) +/// ^ ^ | +/// +----------+------------------------+ +/// +/// Nested cycles. We converge on 25. +#[test] +fn nested_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), a.clone(), b]); + + a.assert(&db, 25); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Np(v25, b, a) +/// ^ ^ | +/// +----------+------------------------+ +/// +/// Nested cycles, inner first. We converge on 25. +#[test] +fn nested_inner_first_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), b, a.clone()]); + + a.assert(&db, 25); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, a, sb) +/// ^ ^ | +/// +----------+-------------------------+ +/// +/// Nested cycles. We hit max iterations and fall back. +#[test] +fn nested_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(25), + a.clone(), + Input::Successor(Box::new(b)), + ]); + + a.assert(&db, MAX_COUNT_FALLBACK + 1); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, b, sa) +/// ^ ^ | +/// +----------+-------------------------+ +/// +/// Nested cycles, inner first. We hit max iterations and fall back. +#[test] +fn nested_inner_first_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(25), + b, + Input::Successor(Box::new(a.clone())), + ]); + + a.assert(&db, MAX_COUNT_FALLBACK + 1); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v240, a, sb) +/// ^ ^ | +/// +----------+--------------------------+ +/// +/// Nested cycles. We hit max value and fall back. +#[test] +fn nested_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(240), + a.clone(), + Input::Successor(Box::new(b)), + ]); + + a.assert(&db, MAX_VALUE_FALLBACK + 1); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v240, b, sa) +/// ^ ^ | +/// +----------+--------------------------+ +/// +/// Nested cycles, inner first. We hit max value and fall back. +#[test] +fn nested_inner_first_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(240), + b, + Input::Successor(Box::new(a.clone())), + ]); + + a.assert(&db, MAX_VALUE_FALLBACK + 1); +} + +/// a:Ni(b) -> b:Ni(c, a) -> c:Np(v25, a, b) +/// ^ ^ | | +/// +----------+--------|------------------+ +/// | | +/// +-------------------+ +/// +/// Nested cycles, double head. We converge on 25. +#[test_log::test] +fn nested_double_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c, a.clone()]); + c_in.set_inputs(&mut db) + .to(vec![Input::Value(25), a.clone(), b]); + + a.assert(&db, 25); +} + +// Multiple-revision cycles + +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// a:Ni(b) --> b:Np(v30) +/// +/// Cycle becomes not-a-cycle in next revision. +#[test] +fn cycle_becomes_non_cycle() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.clone().assert(&db, 255); + + b_in.set_inputs(&mut db).to(vec![Input::Value(30)]); + + a.assert(&db, 30); +} + +/// a:Ni(b) --> b:Np(v30) +/// +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// Non-cycle becomes a cycle in next revision. +#[test] +fn non_cycle_becomes_cycle() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![Input::Value(30)]); + + a.clone().assert(&db, 30); + + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert(&db, 255); +} + +/// a:Xi(b) -> b:Xi(c, a) -> c:Xp(v25, a, sb) +/// ^ ^ | | +/// +----------+--------|-------------------+ +/// | | +/// +-------------------+ +/// +/// Nested cycles, double head. We hit max iterations and fall back, then max value on the next +/// revision, then converge on the next. +#[test] +fn nested_double_multiple_revisions() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c, a.clone()]); + c_in.set_inputs(&mut db).to(vec![ + Input::Value(25), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + + a.clone().assert(&db, MAX_COUNT_FALLBACK + 1); + + // next revision, we hit max value instead + c_in.set_inputs(&mut db).to(vec![ + Input::Value(240), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + + a.clone().assert(&db, MAX_VALUE_FALLBACK + 1); + + // and next revision, we converge + c_in.set_inputs(&mut db) + .to(vec![Input::Value(240), a.clone(), b]); + + a.assert(&db, 240); +} + +/// a:Ni(b) -> b:Np(a) +/// ^ | +/// +----------------+ +/// +/// In a cycle with some LOW durability and some HIGH durability inputs, changing a LOW durability +/// input still re-executes the full cycle in the next revision. +#[test] +fn cycle_durability() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + a_in.set_inputs(&mut db) + .with_durability(Durability::LOW) + .to(vec![b]); + b_in.set_inputs(&mut db) + .with_durability(Durability::HIGH) + .to(vec![a.clone()]); + + a.clone().assert(&db, 255); + + // next revision, we hit max value instead + b_in.set_inputs(&mut db) + .with_durability(Durability::LOW) + .to(vec![Input::Value(45), a.clone()]); + + a.assert(&db, 45); +} From 0f7d940b0022f1f08f616ae772008be4cbfc0d1b Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 31 Oct 2024 14:52:39 -0700 Subject: [PATCH 24/35] review feedback, more tracing --- src/function/execute.rs | 22 ++++++++++++++-------- src/function/fetch.rs | 4 ++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 4db620f7a..89d1a93a4 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -71,16 +71,16 @@ where let opt_owned_last_provisional; let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { // We have a last provisional value from our previous time around the loop. - &last_provisional + last_provisional .value .as_ref() - .expect("provisional value evicted by LRU?") + .expect("provisional value should not be evicted by LRU") } else { // This is our first time around the loop; a provisional value must have been // inserted into the memo table when the cycle was hit, so let's pull our // initial provisional value from there. opt_owned_last_provisional = self.get_memo_from_table_for(zalsa, id); - &opt_owned_last_provisional + opt_owned_last_provisional .as_deref() .expect( "{database_key_index:#?} is a cycle head, \ @@ -88,12 +88,11 @@ where ) .value .as_ref() - .expect("provisional value evicted by LRU?") + .expect("provisional value should not be evicted by LRU") }; tracing::debug!( "{database_key_index:?}: execute: \ - I am a cycle head, comparing last provisional value \ - {last_provisional_value:#?} with new value {new_value:#?}" + I am a cycle head, comparing last provisional value with new value" ); // If the new result is equal to the last provisional result, the cycle has // converged and we are done. @@ -103,7 +102,10 @@ where match C::recover_from_cycle(db, &new_value, iteration_count) { crate::CycleRecoveryAction::Iterate => { tracing::debug!("{database_key_index:?}: execute: iterate again"); - iteration_count += 1; + iteration_count = iteration_count.checked_add(1).expect( + "fixpoint iteration of {database_key_index:#?} should \ + converge before u32::MAX iterations", + ); revisions.cycle_ignore = false; opt_last_provisional = Some(self.insert_memo( zalsa, @@ -114,7 +116,7 @@ where } crate::CycleRecoveryAction::Fallback(fallback_value) => { tracing::debug!( - "{database_key_index:?}: execute: fall back to {fallback_value:#?}" + "{database_key_index:?}: execute: user cycle_fn says to fall back" ); new_value = fallback_value; } @@ -127,6 +129,10 @@ where // nested cycle. Maybe track the relevant memos and replace them all with the cycle // head removed? Or just let them keep the cycle head and allow cycle memos to be // used when we are not actually iterating the cycle for that head? + tracing::debug!( + "{database_key_index:?}: execute: fixpoint iteration has a final value, \ + one more iteration to remove cycle heads from memos" + ); revisions.cycle_heads.remove(&database_key_index); revisions.cycle_ignore = false; self.insert_memo( diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 251c0470f..5014a6ae3 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -81,6 +81,10 @@ where return self .initial_value(db) .map(|initial_value| { + tracing::debug!( + "hit cycle at {database_key_index:#?}, \ + inserting and returning fixpoint initial value" + ); self.insert_memo( zalsa, id, From 5ef7a5feb96f24383975363ee0876f4cd09e01eb Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 31 Oct 2024 15:36:21 -0700 Subject: [PATCH 25/35] fix multi-revision bug --- src/function/maybe_changed_after.rs | 8 ++- tests/cycle/dataflow.rs | 77 ++++++++++++++++------------- 2 files changed, 48 insertions(+), 37 deletions(-) diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index e0b33eb8c..ad55bcf26 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -96,8 +96,12 @@ where // the value has not logically changed. if old_memo.value.is_some() { let memo = self.execute(db, database_key_index, Some(old_memo)); - let changed_at = memo.revisions.changed_at; - return Some(changed_at > revision); + // If we get back a memo that's part of a cycle and requires further iteration to + // resolve, we can't backdate that. + if !memo.revisions.cycle_ignore { + let changed_at = memo.revisions.changed_at; + return Some(changed_at > revision); + } } // Otherwise, nothing for it: have to consider the value to have changed. diff --git a/tests/cycle/dataflow.rs b/tests/cycle/dataflow.rs index e8ac8296f..82a36d12e 100644 --- a/tests/cycle/dataflow.rs +++ b/tests/cycle/dataflow.rs @@ -12,16 +12,11 @@ struct Use { reaching_definitions: Vec, } -#[salsa::input] -struct Literal { - value: usize, -} - /// A Definition of a symbol, either of the form `base + increment` or `0 + increment`. #[salsa::input] struct Definition { base: Option, - increment: Literal, + increment: usize, } #[derive(Eq, PartialEq, Clone, Debug)] @@ -64,7 +59,7 @@ fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { #[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)] fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { - let increment_ty = infer_literal(db, def.increment(db)); + let increment_ty = Type::Values(Box::from([def.increment(db)])); if let Some(base) = def.base(db) { let base_ty = infer_use(db, base); add(&base_ty, &increment_ty) @@ -107,17 +102,12 @@ fn add(a: &Type, b: &Type) -> Type { } } -#[salsa::tracked] -fn infer_literal<'db>(db: &'db dyn Db, literal: Literal) -> Type { - Type::Values(Box::from([literal.value(db)])) -} - /// x = 1 #[test] fn simple() { let db = salsa::DatabaseImpl::new(); - let def = Definition::new(&db, None, Literal::new(&db, 1)); + let def = Definition::new(&db, None, 1); let u = Use::new(&db, vec![def]); let ty = infer_use(&db, u); @@ -130,8 +120,8 @@ fn simple() { fn union() { let db = salsa::DatabaseImpl::new(); - let def1 = Definition::new(&db, None, Literal::new(&db, 1)); - let def2 = Definition::new(&db, None, Literal::new(&db, 2)); + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 2); let u = Use::new(&db, vec![def1, def2]); let ty = infer_use(&db, u); @@ -144,10 +134,10 @@ fn union() { fn union_add() { let db = salsa::DatabaseImpl::new(); - let x1 = Definition::new(&db, None, Literal::new(&db, 1)); - let x2 = Definition::new(&db, None, Literal::new(&db, 2)); + let x1 = Definition::new(&db, None, 1); + let x2 = Definition::new(&db, None, 2); let x_use = Use::new(&db, vec![x1, x2]); - let y_def = Definition::new(&db, Some(x_use), Literal::new(&db, 1)); + let y_def = Definition::new(&db, Some(x_use), 1); let y_use = Use::new(&db, vec![y_def]); let ty = infer_use(&db, y_use); @@ -157,11 +147,11 @@ fn union_add() { /// x = 1; loop { x = x + 0 } #[test] -fn cycle_converges() { +fn cycle_converges_then_diverges() { let mut db = salsa::DatabaseImpl::new(); - let def1 = Definition::new(&db, None, Literal::new(&db, 1)); - let def2 = Definition::new(&db, None, Literal::new(&db, 0)); + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 0); let u = Use::new(&db, vec![def1, def2]); def2.set_base(&mut db).to(Some(u)); @@ -169,15 +159,22 @@ fn cycle_converges() { // Loop converges on 1 assert_eq!(ty, Type::Values(Box::from([1]))); + + // Set the increment on x from 0 to 1 + let new_increment = 1; + def2.set_increment(&mut db).to(new_increment); + + // Now the loop diverges and we fall back to Top + assert_eq!(infer_use(&db, u), Type::Top); } /// x = 1; loop { x = x + 1 } #[test] -fn cycle_diverges() { +fn cycle_diverges_then_converges() { let mut db = salsa::DatabaseImpl::new(); - let def1 = Definition::new(&db, None, Literal::new(&db, 1)); - let def2 = Definition::new(&db, None, Literal::new(&db, 1)); + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 1); let u = Use::new(&db, vec![def1, def2]); def2.set_base(&mut db).to(Some(u)); @@ -185,26 +182,36 @@ fn cycle_diverges() { // Loop diverges. Cut it off and fallback to Type::Top assert_eq!(ty, Type::Top); + + // Set the increment from 1 to 0. + def2.set_increment(&mut db).to(0); + + // Now the loop converges on 1. + assert_eq!(infer_use(&db, u), Type::Values(Box::from([1]))); } /// x = 0; y = 0; loop { x = y + 0; y = x + 0 } -#[test] -fn multi_symbol_cycle_converges() { +#[test_log::test] +fn multi_symbol_cycle_converges_then_diverges() { let mut db = salsa::DatabaseImpl::new(); - let defx0 = Definition::new(&db, None, Literal::new(&db, 0)); - let defy0 = Definition::new(&db, None, Literal::new(&db, 0)); - let defx1 = Definition::new(&db, None, Literal::new(&db, 0)); - let defy1 = Definition::new(&db, None, Literal::new(&db, 0)); + let defx0 = Definition::new(&db, None, 0); + let defy0 = Definition::new(&db, None, 0); + let defx1 = Definition::new(&db, None, 0); + let defy1 = Definition::new(&db, None, 0); let use_x = Use::new(&db, vec![defx0, defx1]); let use_y = Use::new(&db, vec![defy0, defy1]); defx1.set_base(&mut db).to(Some(use_y)); defy1.set_base(&mut db).to(Some(use_x)); - let x_ty = infer_use(&db, use_x); - let y_ty = infer_use(&db, use_y); - // Both symbols converge on 0 - assert_eq!(x_ty, Type::Values(Box::from([0]))); - assert_eq!(y_ty, Type::Values(Box::from([0]))); + assert_eq!(infer_use(&db, use_x), Type::Values(Box::from([0]))); + assert_eq!(infer_use(&db, use_y), Type::Values(Box::from([0]))); + + // Set the increment on x from 0 to 1. + defx1.set_increment(&mut db).to(1); + + // Now the loop diverges and we fall back to Top. + assert_eq!(infer_use(&db, use_x), Type::Top); + assert_eq!(infer_use(&db, use_y), Type::Top); } From 30934607689a51929aa869043e73db6b5fed9e5f Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 31 Oct 2024 15:56:07 -0700 Subject: [PATCH 26/35] better fix for multi-revision bug --- src/function/execute.rs | 18 +++++++++--------- src/function/fetch.rs | 5 ++++- src/function/maybe_changed_after.rs | 8 ++------ src/zalsa_local.rs | 4 ++-- 4 files changed, 17 insertions(+), 18 deletions(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 89d1a93a4..158f94a02 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -57,15 +57,6 @@ where let mut new_value = C::execute(db, C::id_to_input(db, id)); let mut revisions = active_query.pop(); - // If the new value is equal to the old one, then it didn't - // really change, even if some of its inputs have. So we can - // "backdate" its `changed_at` revision to be the same as the - // old value. - if let Some(old_memo) = &opt_old_memo { - self.backdate_if_appropriate(old_memo, &mut revisions, &new_value); - self.diff_outputs(db, database_key_index, old_memo, &mut revisions); - } - // Did the new result we got depend on our own provisional value, in a cycle? if revisions.cycle_heads.contains(&database_key_index) { let opt_owned_last_provisional; @@ -145,6 +136,15 @@ where tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); + // If the new value is equal to the old one, then it didn't + // really change, even if some of its inputs have. So we can + // "backdate" its `changed_at` revision to be the same as the + // old value. + if let Some(old_memo) = &opt_old_memo { + self.backdate_if_appropriate(old_memo, &mut revisions, &new_value); + self.diff_outputs(db, database_key_index, old_memo, &mut revisions); + } + return self.insert_memo( zalsa, id, diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 5014a6ae3..c63b1b43e 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -91,7 +91,10 @@ where Memo::new( Some(initial_value), zalsa.current_revision(), - QueryRevisions::fixpoint_initial(database_key_index), + QueryRevisions::fixpoint_initial( + database_key_index, + zalsa.current_revision(), + ), ), ) }) diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index ad55bcf26..e0b33eb8c 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -96,12 +96,8 @@ where // the value has not logically changed. if old_memo.value.is_some() { let memo = self.execute(db, database_key_index, Some(old_memo)); - // If we get back a memo that's part of a cycle and requires further iteration to - // resolve, we can't backdate that. - if !memo.revisions.cycle_ignore { - let changed_at = memo.revisions.changed_at; - return Some(changed_at > revision); - } + let changed_at = memo.revisions.changed_at; + return Some(changed_at > revision); } // Otherwise, nothing for it: have to consider the value to have changed. diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 706015f1e..ca566f496 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -350,10 +350,10 @@ pub(crate) struct QueryRevisions { } impl QueryRevisions { - pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex) -> Self { + pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, revision: Revision) -> Self { let cycle_heads = FxHashSet::from_iter([query]); Self { - changed_at: Revision::start(), + changed_at: revision, durability: Durability::MAX, origin: QueryOrigin::FixpointInitial, tracked_struct_ids: Default::default(), From 7d9ec1c218adc6d2a14a6e72704702ca5c05cfe2 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 31 Oct 2024 16:19:25 -0700 Subject: [PATCH 27/35] test fixes --- tests/cycle/dataflow.rs | 2 +- tests/cycle/main.rs | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/cycle/dataflow.rs b/tests/cycle/dataflow.rs index 82a36d12e..53dc301b7 100644 --- a/tests/cycle/dataflow.rs +++ b/tests/cycle/dataflow.rs @@ -191,7 +191,7 @@ fn cycle_diverges_then_converges() { } /// x = 0; y = 0; loop { x = y + 0; y = x + 0 } -#[test_log::test] +#[test] fn multi_symbol_cycle_converges_then_diverges() { let mut db = salsa::DatabaseImpl::new(); diff --git a/tests/cycle/main.rs b/tests/cycle/main.rs index 66a5f1053..05100f650 100644 --- a/tests/cycle/main.rs +++ b/tests/cycle/main.rs @@ -724,9 +724,9 @@ fn nested_double_multiple_revisions() { a.assert(&db, 240); } -/// a:Ni(b) -> b:Np(a) -/// ^ | -/// +----------------+ +/// a:Ni(b) -> b:Ni(c) -> c:Ni(a) +/// ^ | +/// +---------------------------+ /// /// In a cycle with some LOW durability and some HIGH durability inputs, changing a LOW durability /// input still re-executes the full cycle in the next revision. @@ -735,21 +735,26 @@ fn cycle_durability() { let mut db = DbImpl::new(); let a_in = Inputs::new(&db, vec![]); let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); let a = Input::MinIterate(a_in); let b = Input::MinIterate(b_in); + let c = Input::MinIterate(c_in); a_in.set_inputs(&mut db) .with_durability(Durability::LOW) - .to(vec![b]); + .to(vec![b.clone()]); b_in.set_inputs(&mut db) + .with_durability(Durability::HIGH) + .to(vec![c]); + c_in.set_inputs(&mut db) .with_durability(Durability::HIGH) .to(vec![a.clone()]); a.clone().assert(&db, 255); - // next revision, we hit max value instead - b_in.set_inputs(&mut db) + // next revision, we converge instead + a_in.set_inputs(&mut db) .with_durability(Durability::LOW) - .to(vec![Input::Value(45), a.clone()]); + .to(vec![Input::Value(45), b]); a.assert(&db, 45); } From aa4a73151b1cb13e58b60b41482eeb1b6a2b23ee Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 31 Oct 2024 17:14:01 -0700 Subject: [PATCH 28/35] pass inputs to cycle recovery functions --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 7 +++-- .../src/unexpected_cycle_recovery.rs | 10 ++++--- src/function.rs | 3 +- src/function/execute.rs | 7 ++++- src/function/fetch.rs | 2 +- src/function/memo.rs | 8 +++-- tests/cycle/dataflow.rs | 30 ++++++++++++++++--- tests/cycle/main.rs | 8 ++--- 8 files changed, 55 insertions(+), 20 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 184ccb50c..72a38039c 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -177,16 +177,17 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } - fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db) -> Self::Output<$db_lt> { - $($cycle_recovery_initial)*(db) + fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> { + $($cycle_recovery_initial)*(db, $($input_id),*) } fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, value: &Self::Output<$db_lt>, count: u32, + ($($input_id),*): ($($input_ty),*) ) -> $zalsa::CycleRecoveryAction> { - $($cycle_recovery_fn)*(db, value, count) + $($cycle_recovery_fn)*(db, value, count, $($input_id),*) } fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index cf8bbce13..a1cd1e73f 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -1,18 +1,20 @@ // Macro that generates the body of the cycle recovery function -// for the case where no cycle recovery is possible. Must be a macro -// because the signature types must match the particular tracked function. +// for the case where no cycle recovery is possible. This has to be +// a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $value:ident, $count:ident) => {{ + ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); panic!("cannot recover from cycle") }}; } #[macro_export] macro_rules! unexpected_cycle_initial { - ($db:ident) => {{ + ($db:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); panic!("no cycle initial value") }}; } diff --git a/src/function.rs b/src/function.rs index 225a9dead..97511d223 100644 --- a/src/function.rs +++ b/src/function.rs @@ -67,13 +67,14 @@ pub trait Configuration: Any { fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; /// Get the cycle recovery initial value. - fn cycle_initial(db: &Self::DbView) -> Self::Output<'_>; + fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; /// Decide whether to iterate a cycle again or fallback. fn recover_from_cycle<'db>( db: &'db Self::DbView, value: &Self::Output<'db>, count: u32, + input: Self::Input<'db>, ) -> CycleRecoveryAction>; } diff --git a/src/function/execute.rs b/src/function/execute.rs index 158f94a02..ef43fc2c8 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -90,7 +90,12 @@ where if !C::values_equal(&new_value, last_provisional_value) { // We are in a cycle that hasn't converged; ask the user's // cycle-recovery function what to do: - match C::recover_from_cycle(db, &new_value, iteration_count) { + match C::recover_from_cycle( + db, + &new_value, + iteration_count, + C::id_to_input(db, id), + ) { crate::CycleRecoveryAction::Iterate => { tracing::debug!("{database_key_index:?}: execute: iterate again"); iteration_count = iteration_count.checked_add(1).expect( diff --git a/src/function/fetch.rs b/src/function/fetch.rs index c63b1b43e..c00523390 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -79,7 +79,7 @@ where ClaimResult::Retry => return None, ClaimResult::Cycle => { return self - .initial_value(db) + .initial_value(db, database_key_index.key_index) .map(|initial_value| { tracing::debug!( "hit cycle at {database_key_index:#?}, \ diff --git a/src/function/memo.rs b/src/function/memo.rs index 8bddd5fa1..8e167ccbd 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -88,9 +88,13 @@ impl IngredientImpl { } } - pub(super) fn initial_value<'db>(&'db self, db: &'db C::DbView) -> Option> { + pub(super) fn initial_value<'db>( + &'db self, + db: &'db C::DbView, + key: Id, + ) -> Option> { match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db)), + CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))), CycleRecoveryStrategy::Panic => None, } } diff --git a/tests/cycle/dataflow.rs b/tests/cycle/dataflow.rs index 53dc301b7..d8ef4cf3a 100644 --- a/tests/cycle/dataflow.rs +++ b/tests/cycle/dataflow.rs @@ -47,7 +47,7 @@ impl Type { } } -#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)] fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { let defs = u.reaching_definitions(db); match defs[..] { @@ -57,7 +57,7 @@ fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { } } -#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)] fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { let increment_ty = Type::Values(Box::from([def.increment(db)])); if let Some(base) = def.base(db) { @@ -68,11 +68,33 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } -fn cycle_initial(_db: &dyn Db) -> Type { +fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { Type::Bottom } -fn cycle_recover(_db: &dyn Db, value: &Type, count: u32) -> CycleRecoveryAction { +fn def_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _def: Definition, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { + Type::Bottom +} + +fn use_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _use: Use, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { match value { Type::Bottom => CycleRecoveryAction::Iterate, Type::Values(_) => { diff --git a/tests/cycle/main.rs b/tests/cycle/main.rs index 05100f650..09cb6e830 100644 --- a/tests/cycle/main.rs +++ b/tests/cycle/main.rs @@ -76,7 +76,7 @@ const MIN_COUNT_FALLBACK: u8 = 100; const MIN_VALUE_FALLBACK: u8 = 5; const MIN_VALUE: u8 = 10; -fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction { +fn min_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { if *value < MIN_VALUE { CycleRecoveryAction::Fallback(MIN_VALUE_FALLBACK) } else if count > 10 { @@ -86,7 +86,7 @@ fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction } } -fn min_initial(_db: &dyn Db) -> u8 { +fn min_initial(_db: &dyn Db, _inputs: Inputs) -> u8 { 255 } @@ -99,7 +99,7 @@ const MAX_COUNT_FALLBACK: u8 = 200; const MAX_VALUE_FALLBACK: u8 = 250; const MAX_VALUE: u8 = 245; -fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction { +fn max_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { if *value > MAX_VALUE { CycleRecoveryAction::Fallback(MAX_VALUE_FALLBACK) } else if count > 10 { @@ -109,7 +109,7 @@ fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction } } -fn max_initial(_db: &dyn Db) -> u8 { +fn max_initial(_db: &dyn Db, _inputs: Inputs) -> u8 { 0 } From 67376f1a2b92aec021f4a2957a33d2c5b67167c3 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 14 Nov 2024 17:52:39 -0800 Subject: [PATCH 29/35] fixed cycle-unchanged test --- src/accumulator.rs | 3 +- src/function.rs | 4 +- src/function/fetch.rs | 14 +- src/function/maybe_changed_after.rs | 217 +++++++++++++++++----------- src/ingredient.rs | 3 +- src/input.rs | 5 +- src/input/input_field.rs | 5 +- src/interned.rs | 5 +- src/key.rs | 7 +- src/tracked_struct.rs | 5 +- src/tracked_struct/tracked_field.rs | 6 +- tests/{cycle/main.rs => cycle.rs} | 46 +++++- tests/{cycle => }/dataflow.rs | 2 +- 13 files changed, 215 insertions(+), 107 deletions(-) rename tests/{cycle/main.rs => cycle.rs} (93%) rename tests/{cycle => }/dataflow.rs (99%) diff --git a/src/accumulator.rs b/src/accumulator.rs index b9355419a..aedd0072b 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -12,6 +12,7 @@ use accumulated_map::AccumulatedMap; use crate::{ cycle::CycleRecoveryStrategy, + function::VerifyResult, ingredient::{fmt_index, Ingredient, Jar}, plumbing::JarAux, zalsa::IngredientIndex, @@ -102,7 +103,7 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { + ) -> VerifyResult { panic!("nothing should ever depend on an accumulator directly") } diff --git a/src/function.rs b/src/function.rs index 97511d223..b825c4921 100644 --- a/src/function.rs +++ b/src/function.rs @@ -16,6 +16,8 @@ use self::delete::DeletedEntries; use super::ingredient::Ingredient; +pub(crate) use maybe_changed_after::VerifyResult; + mod accumulated; mod backdate; mod delete; @@ -194,7 +196,7 @@ where db: &dyn Database, input: Option, revision: Revision, - ) -> bool { + ) -> VerifyResult { let key = input.unwrap(); let db = db.as_view::(); self.maybe_changed_after(db, key, revision) diff --git a/src/function/fetch.rs b/src/function/fetch.rs index c00523390..8c999ee9b 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,4 +1,4 @@ -use super::{memo::Memo, Configuration, IngredientImpl}; +use super::{memo::Memo, Configuration, IngredientImpl, VerifyResult}; use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::{ runtime::StampedValue, table::sync::ClaimResult, zalsa::ZalsaDatabase, @@ -113,10 +113,14 @@ where if let Some(old_memo) = &opt_old_memo { if old_memo.value.is_some() { let active_query = zalsa_local.push_query(database_key_index); - if self.deep_verify_memo(db, old_memo, &active_query) { - // Unsafety invariant: memo is present in memo_map. - unsafe { - return Some(self.extend_memo_lifetime(old_memo)); + if let VerifyResult::Unchanged(cycle_heads) = + self.deep_verify_memo(db, old_memo, &active_query) + { + if cycle_heads.is_empty() { + // Unsafety invariant: memo is present in memo_map. + unsafe { + return Some(self.extend_memo_lifetime(old_memo)); + } } } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index e0b33eb8c..7d86b27bb 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -6,9 +6,36 @@ use crate::{ zalsa_local::{ActiveQueryGuard, EdgeKind, QueryOrigin}, AsDynDatabase as _, Id, Revision, }; +use rustc_hash::FxHashSet; use super::{memo::Memo, Configuration, IngredientImpl}; +/// Result of memo validation. +pub enum VerifyResult { + /// Memo has changed and needs to be recomputed. + Changed, + + /// Memo remains valid. + /// + /// Database keys in the hashset represent cycle heads encountered in validation; don't mark + /// memos verified until we've iterated the full cycle to ensure no inputs changed. + Unchanged(FxHashSet), +} + +impl VerifyResult { + pub(crate) fn changed_if(condition: bool) -> Self { + if condition { + Self::Changed + } else { + Self::unchanged() + } + } + + pub(crate) fn unchanged() -> Self { + Self::Unchanged(FxHashSet::default()) + } +} + impl IngredientImpl where C: Configuration, @@ -18,7 +45,7 @@ where db: &'db C::DbView, id: Id, revision: Revision, - ) -> bool { + ) -> VerifyResult { let (zalsa, zalsa_local) = db.zalsas(); zalsa_local.unwind_if_revision_cancelled(db.as_dyn_database()); @@ -31,7 +58,7 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { - return memo.revisions.changed_at > revision; + return VerifyResult::changed_if(memo.revisions.changed_at > revision); } drop(memo_guard); // release the arc-swap guard before cold path if let Some(mcs) = self.maybe_changed_after_cold(db, id, revision) { @@ -41,7 +68,7 @@ where } } else { // No memo? Assume has changed. - return true; + return VerifyResult::Changed; } } } @@ -51,7 +78,7 @@ where db: &'db C::DbView, key_index: Id, revision: Revision, - ) -> Option { + ) -> Option { let (zalsa, zalsa_local) = db.zalsas(); let database_key_index = self.database_key_index(key_index); @@ -67,15 +94,17 @@ where "dependency graph cycle validating {database_key_index:#?}; \ set cycle_fn/cycle_initial to fixpoint iterate" ), - // If we hit a cycle in memo validation, but we support fixpoint iteration, just - // consider the memo changed so we'll re-run the iteration in this revision. - CycleRecoveryStrategy::Fixpoint => return Some(true), + CycleRecoveryStrategy::Fixpoint => { + return Some(VerifyResult::Unchanged(FxHashSet::from_iter([ + database_key_index, + ]))) + } }, ClaimResult::Claimed(guard) => guard, }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index) else { - return Some(true); + return Some(VerifyResult::Changed); }; tracing::debug!( @@ -86,8 +115,14 @@ where // Check if the inputs are still valid and we can just compare `changed_at`. let active_query = zalsa_local.push_query(database_key_index); - if self.deep_verify_memo(db, &old_memo, &active_query) { - return Some(old_memo.revisions.changed_at > revision); + if let VerifyResult::Unchanged(cycle_heads) = + self.deep_verify_memo(db, &old_memo, &active_query) + { + return Some(if old_memo.revisions.changed_at > revision { + VerifyResult::Changed + } else { + VerifyResult::Unchanged(cycle_heads) + }); } // If inputs have changed, but we have an old value, we can re-execute. @@ -97,11 +132,11 @@ where if old_memo.value.is_some() { let memo = self.execute(db, database_key_index, Some(old_memo)); let changed_at = memo.revisions.changed_at; - return Some(changed_at > revision); + return Some(VerifyResult::changed_if(changed_at > revision)); } // Otherwise, nothing for it: have to consider the value to have changed. - Some(true) + Some(VerifyResult::Changed) } /// True if the memo's value and `changed_at` time is still valid in this revision. @@ -141,9 +176,9 @@ where false } - /// True if the memo's value and `changed_at` time is up to date in the current - /// revision. When this returns true, it also updates the memo's `verified_at` - /// field if needed to make future calls cheaper. + /// VerifyResult::Unchanged if the memo's value and `changed_at` time is up to date in the + /// current revision. When this returns Unchanged with no cycle heads, it also updates the + /// memo's `verified_at` field if needed to make future calls cheaper. /// /// Takes an [`ActiveQueryGuard`] argument because this function recursively /// walks dependencies of `old_memo` and may even execute them to see if their @@ -153,9 +188,9 @@ where db: &C::DbView, old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, - ) -> bool { + ) -> VerifyResult { if old_memo.revisions.cycle_ignore { - return false; + return VerifyResult::Changed; } let zalsa = db.zalsa(); let database_key_index = active_query.database_key_index; @@ -166,79 +201,95 @@ where ); if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo) { - return true; + return VerifyResult::Unchanged(Default::default()); } - match &old_memo.revisions.origin { - QueryOrigin::Assigned(_) => { - // If the value was assigneed by another query, - // and that query were up-to-date, - // then we would have updated the `verified_at` field already. - // So the fact that we are here means that it was not specified - // during this revision or is otherwise stale. - // - // Example of how this can happen: - // - // Conditionally specified queries - // where the value is specified - // in rev 1 but not in rev 2. - return false; - } - QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { - // This value was `set` by the mutator thread -- ie, it's a base input and it cannot be out of date. - return true; - } - QueryOrigin::DerivedUntracked(_) => { - // Untracked inputs? Have to assume that it changed. - return false; - } - QueryOrigin::Derived(edges) => { - // Fully tracked inputs? Iterate over the inputs and check them, one by one. - // - // NB: It's important here that we are iterating the inputs in the order that - // they executed. It's possible that if the value of some input I0 is no longer - // valid, then some later input I1 might never have executed at all, so verifying - // it is still up to date is meaningless. - let last_verified_at = old_memo.verified_at.load(); - for &(edge_kind, dependency_index) in edges.input_outputs.iter() { - match edge_kind { - EdgeKind::Input => { - if dependency_index - .maybe_changed_after(db.as_dyn_database(), last_verified_at) - { - return false; + loop { + let mut cycle_heads = FxHashSet::default(); + + match &old_memo.revisions.origin { + QueryOrigin::Assigned(_) => { + // If the value was assigneed by another query, + // and that query were up-to-date, + // then we would have updated the `verified_at` field already. + // So the fact that we are here means that it was not specified + // during this revision or is otherwise stale. + // + // Example of how this can happen: + // + // Conditionally specified queries + // where the value is specified + // in rev 1 but not in rev 2. + return VerifyResult::Changed; + } + QueryOrigin::BaseInput | QueryOrigin::FixpointInitial => { + // This value was `set` by the mutator thread -- ie, it's a base input and it cannot be out of date. + return VerifyResult::unchanged(); + } + QueryOrigin::DerivedUntracked(_) => { + // Untracked inputs? Have to assume that it changed. + return VerifyResult::Changed; + } + QueryOrigin::Derived(edges) => { + // Fully tracked inputs? Iterate over the inputs and check them, one by one. + // + // NB: It's important here that we are iterating the inputs in the order that + // they executed. It's possible that if the value of some input I0 is no longer + // valid, then some later input I1 might never have executed at all, so verifying + // it is still up to date is meaningless. + let last_verified_at = old_memo.verified_at.load(); + for &(edge_kind, dependency_index) in edges.input_outputs.iter() { + match edge_kind { + EdgeKind::Input => { + match dependency_index + .maybe_changed_after(db.as_dyn_database(), last_verified_at) + { + VerifyResult::Changed => return VerifyResult::Changed, + VerifyResult::Unchanged(cycles) => cycle_heads.extend(cycles), + } + } + EdgeKind::Output => { + // Subtle: Mark outputs as validated now, even though we may + // later find an input that requires us to re-execute the function. + // Even if it re-execute, the function will wind up writing the same value, + // since all prior inputs were green. It's important to do this during + // this loop, because it's possible that one of our input queries will + // re-execute and may read one of our earlier outputs + // (e.g., in a scenario where we do something like + // `e = Entity::new(..); query(e);` and `query` reads a field of `e`). + // + // NB. Accumulators are also outputs, but the above logic doesn't + // quite apply to them. Since multiple values are pushed, the first value + // may be unchanged, but later values could be different. + // In that case, however, the data accumulated + // by this function cannot be read until this function is marked green, + // so even if we mark them as valid here, the function will re-execute + // and overwrite the contents. + // + // TODO not if we found a cycle head other than ourself? + dependency_index.mark_validated_output( + db.as_dyn_database(), + database_key_index, + ); } - } - EdgeKind::Output => { - // Subtle: Mark outputs as validated now, even though we may - // later find an input that requires us to re-execute the function. - // Even if it re-execute, the function will wind up writing the same value, - // since all prior inputs were green. It's important to do this during - // this loop, because it's possible that one of our input queries will - // re-execute and may read one of our earlier outputs - // (e.g., in a scenario where we do something like - // `e = Entity::new(..); query(e);` and `query` reads a field of `e`). - // - // NB. Accumulators are also outputs, but the above logic doesn't - // quite apply to them. Since multiple values are pushed, the first value - // may be unchanged, but later values could be different. - // In that case, however, the data accumulated - // by this function cannot be read until this function is marked green, - // so even if we mark them as valid here, the function will re-execute - // and overwrite the contents. - dependency_index - .mark_validated_output(db.as_dyn_database(), database_key_index); } } } } - } - old_memo.mark_as_verified( - db.as_dyn_database(), - zalsa.current_revision(), - database_key_index, - ); - true + let in_heads = cycle_heads.remove(&database_key_index); + + if cycle_heads.is_empty() { + old_memo.mark_as_verified( + db.as_dyn_database(), + zalsa.current_revision(), + database_key_index, + ); + } + if in_heads { + continue; + } + return VerifyResult::Unchanged(cycle_heads); + } } } diff --git a/src/ingredient.rs b/src/ingredient.rs index 383fdc6b7..56445016c 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,6 +6,7 @@ use std::{ use crate::{ accumulator::accumulated_map::AccumulatedMap, cycle::CycleRecoveryStrategy, + function::VerifyResult, zalsa::{IngredientIndex, MemoIngredientIndex}, zalsa_local::QueryOrigin, Database, DatabaseKeyIndex, Id, @@ -38,7 +39,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: &'db dyn Database, input: Option, revision: Revision, - ) -> bool; + ) -> VerifyResult; /// What were the inputs (if any) that were used to create the value at `key_index`. fn origin(&self, db: &dyn Database, key_index: Id) -> Option; diff --git a/src/input.rs b/src/input.rs index e55891616..88a25c279 100644 --- a/src/input.rs +++ b/src/input.rs @@ -10,6 +10,7 @@ use parking_lot::Mutex; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, + function::VerifyResult, id::{AsId, FromId}, ingredient::{fmt_index, Ingredient}, key::{DatabaseKeyIndex, DependencyIndex}, @@ -215,10 +216,10 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { + ) -> VerifyResult { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. - false + VerifyResult::unchanged() } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { diff --git a/src/input/input_field.rs b/src/input/input_field.rs index fd3082256..66213aef7 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,4 +1,5 @@ use crate::cycle::CycleRecoveryStrategy; +use crate::function::VerifyResult; use crate::ingredient::{fmt_index, Ingredient}; use crate::input::Configuration; use crate::zalsa::IngredientIndex; @@ -54,11 +55,11 @@ where db: &dyn Database, input: Option, revision: Revision, - ) -> bool { + ) -> VerifyResult { let zalsa = db.zalsa(); let input = input.unwrap(); let value = >::data(zalsa, input); - value.stamps[self.field_index].changed_at > revision + VerifyResult::changed_if(value.stamps[self.field_index].changed_at > revision) } fn origin(&self, _db: &dyn Database, _key_index: Id) -> Option { diff --git a/src/interned.rs b/src/interned.rs index 7b24c04e7..774eadbe0 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -1,5 +1,6 @@ use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::durability::Durability; +use crate::function::VerifyResult; use crate::id::AsId; use crate::ingredient::fmt_index; use crate::key::DependencyIndex; @@ -221,8 +222,8 @@ where _db: &dyn Database, _input: Option, revision: Revision, - ) -> bool { - revision < self.reset_at + ) -> VerifyResult { + VerifyResult::changed_if(revision < self.reset_at) } fn cycle_recovery_strategy(&self) -> crate::cycle::CycleRecoveryStrategy { diff --git a/src/key.rs b/src/key.rs index de84f710e..7d8d4d269 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,4 +1,7 @@ -use crate::{accumulator::accumulated_map::AccumulatedMap, zalsa::IngredientIndex, Database, Id}; +use crate::{ + accumulator::accumulated_map::AccumulatedMap, function::VerifyResult, zalsa::IngredientIndex, + Database, Id, +}; /// An integer that uniquely identifies a particular query instance within the /// database. Used to track dependencies between queries. Fully ordered and @@ -51,7 +54,7 @@ impl DependencyIndex { &self, db: &dyn Database, last_verified_at: crate::Revision, - ) -> bool { + ) -> VerifyResult { db.zalsa() .lookup_ingredient(self.ingredient_index) .maybe_changed_after(db, self.key_index, last_verified_at) diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 95e2a71b1..2d85524f9 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -6,6 +6,7 @@ use tracked_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, cycle::CycleRecoveryStrategy, + function::VerifyResult, ingredient::{fmt_index, Ingredient, Jar, JarAux}, key::{DatabaseKeyIndex, DependencyIndex}, plumbing::ZalsaLocal, @@ -583,8 +584,8 @@ where _db: &dyn Database, _input: Option, _revision: Revision, - ) -> bool { - false + ) -> VerifyResult { + VerifyResult::unchanged() } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index ff1909397..0745c6d83 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,6 +1,6 @@ use std::marker::PhantomData; -use crate::{ingredient::Ingredient, zalsa::IngredientIndex, Database, Id}; +use crate::{function::VerifyResult, ingredient::Ingredient, zalsa::IngredientIndex, Database, Id}; use super::{Configuration, Value}; @@ -53,12 +53,12 @@ where db: &'db dyn Database, input: Option, revision: crate::Revision, - ) -> bool { + ) -> VerifyResult { let zalsa = db.zalsa(); let id = input.unwrap(); let data = >::data(zalsa.table(), id); let field_changed_at = data.revisions[self.field_index]; - field_changed_at > revision + VerifyResult::changed_if(field_changed_at > revision) } fn origin( diff --git a/tests/cycle/main.rs b/tests/cycle.rs similarity index 93% rename from tests/cycle/main.rs rename to tests/cycle.rs index 09cb6e830..c2ce2288d 100644 --- a/tests/cycle/main.rs +++ b/tests/cycle.rs @@ -2,8 +2,9 @@ //! //! These test cases use a generic query setup that allows constructing arbitrary dependency //! graphs, and attempts to achieve good coverage of various cases. -mod dataflow; - +mod common; +use common::{ExecuteValidateLoggerDatabase, LogDatabase}; +use expect_test::expect; use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; /// A vector of inputs a query can evaluate to get an iterator of u8 values to operate on. @@ -758,3 +759,44 @@ fn cycle_durability() { a.assert(&db, 45); } + +/// a:Np(v59, b) -> b:Ni(v60, c) -> c:Np(b) +/// ^ | +/// +---------------------+ +/// +/// If nothing in a cycle changed in the new revision, no part of the cycle should re-execute. +#[test] +fn cycle_unchanged() { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db) + .to(vec![Input::Value(59), b.clone()]); + b_in.set_inputs(&mut db).to(vec![Input::Value(60), c]); + c_in.set_inputs(&mut db).to(vec![b.clone()]); + + a.clone().assert(&db, 59); + b.clone().assert(&db, 60); + + db.assert_logs_len(5); + + // next revision, we change only A, which is not part of the cycle and the cycle does not + // depend on. + a_in.set_inputs(&mut db) + .to(vec![Input::Value(45), b.clone()]); + + b.assert(&db, 60); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + ]"#]]); + + a.assert(&db, 45); +} diff --git a/tests/cycle/dataflow.rs b/tests/dataflow.rs similarity index 99% rename from tests/cycle/dataflow.rs rename to tests/dataflow.rs index d8ef4cf3a..6cd6e91b1 100644 --- a/tests/cycle/dataflow.rs +++ b/tests/dataflow.rs @@ -213,7 +213,7 @@ fn cycle_diverges_then_converges() { } /// x = 0; y = 0; loop { x = y + 0; y = x + 0 } -#[test] +#[test_log::test] fn multi_symbol_cycle_converges_then_diverges() { let mut db = salsa::DatabaseImpl::new(); From 286b5fb9b0f8422adb6a84b96de396a43dc4f9f5 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 14 Nov 2024 18:12:37 -0800 Subject: [PATCH 30/35] add TODO comments for some outstanding questions --- components/salsa-macros/src/tracked_fn.rs | 2 ++ src/function/execute.rs | 3 +++ 2 files changed, 5 insertions(+) diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 74cc3bcae..45beea738 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -175,6 +175,8 @@ impl Macro { Ok(ValidFn { db_ident, db_path }) } fn cycle_recovery(&self) -> syn::Result<(TokenStream, TokenStream, TokenStream)> { + // TODO should we ask the user to specify a struct that impls a trait with two methods, + // rather than asking for two methods separately? match (&self.args.cycle_fn, &self.args.cycle_initial) { (Some(cycle_fn), Some(cycle_initial)) => Ok(( quote!((#cycle_fn)), diff --git a/src/function/execute.rs b/src/function/execute.rs index ef43fc2c8..41b95ee80 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -90,6 +90,9 @@ where if !C::values_equal(&new_value, last_provisional_value) { // We are in a cycle that hasn't converged; ask the user's // cycle-recovery function what to do: + // TODO do we need explicit prevention of people calling queries inside + // cycle-recovery functions (some no-queries-allowed state on Runtime?) + // or is this just an "if it hurts, don't do it" scenario? match C::recover_from_cycle( db, &new_value, From b2d4d9203f53661a3f5777db00311cf78b4f392d Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Fri, 15 Nov 2024 16:37:12 +0000 Subject: [PATCH 31/35] add a test for the "AB peeping C" scenario --- tests/cycle.rs | 2 +- tests/parallel/cycle_ab_peeping_c.rs | 97 ++++++++++++++++++++++++++++ tests/parallel/main.rs | 1 + 3 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 tests/parallel/cycle_ab_peeping_c.rs diff --git a/tests/cycle.rs b/tests/cycle.rs index c2ce2288d..8012902cb 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -130,7 +130,7 @@ fn max_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { // - `Xi` for `max_iterate` // - `Np` for `min_panic` // - `Xp` for `max_panic` -// +//\ // and `ii` is the inputs for that query, represented as a comma-separated list, with each // component representing an input: // - `a`, `b`, `c`... where the input is another node, diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs new file mode 100644 index 000000000..f8ad20b25 --- /dev/null +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -0,0 +1,97 @@ +//! Test a specific cycle scenario: +//! +//! Thread T1 calls A which calls B which calls A. +//! +//! Thread T2 calls C which calls B. +//! +//! The trick is that the call from Thread T2 comes before B has reached a fixed point. +//! We want to be sure that C sees the final value (and blocks until it is complete). + +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MID: CycleValue = CycleValue(11); +const MAX: CycleValue = CycleValue(22); + +#[salsa::tracked(cycle_fn=query_a_cycle_fn, cycle_initial=query_a_initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + eprintln!("query_a()"); + let b_value = query_b(db); + + eprintln!("query_a: {:?}", b_value); + + // When we reach the mid point, signal stage 1 (unblocking T2) + // and then wait for T2 to signal stage 2. + if b_value == MID { + eprintln!("query_a: signal"); + db.signal(1); + db.wait_for(2); + } + + b_value +} + +fn query_a_cycle_fn( + _db: &dyn KnobsDatabase, + value: &CycleValue, + count: u32, +) -> CycleRecoveryAction { + eprintln!("query_a_cycle_fn({:?}, {:?})", value, count); + CycleRecoveryAction::Iterate +} + +fn query_a_initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[salsa::tracked] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + eprintln!("query_b()"); + + let a_value = query_a(db); + + eprintln!("query_b: {:?}", a_value); + + CycleValue(a_value.0 + 1).min(MAX) +} + +#[salsa::tracked] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + eprintln!("query_c()"); + + // Wait until T1 has reached MID then execute `query_b`. + // This shoul block and (due to the configuration on our database) signal stage 2. + db.wait_for(1); + + eprintln!("query_c: signaled"); + + query_b(db) +} + +#[test] +fn the_test() { + eprintln!("hi"); + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + + let db_t2 = db_t1.clone(); + db_t2.signal_on_will_block.store(2); + + // Thread 1: + scope.spawn(move || { + let r = query_a(&db_t1); + assert_eq!(r, MAX); + }); + + // Thread 2: + scope.spawn(move || { + let r = query_c(&db_t2); + assert_eq!(r, MAX); + }); + }); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index ed895948a..6b34dc06f 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,5 +1,6 @@ mod setup; +mod cycle_ab_peeping_c; mod parallel_cancellation; mod parallel_map; mod signal; From 670f88b530fb6ed98b053a3ecd72a102cb40ee4c Mon Sep 17 00:00:00 2001 From: Niko Matsakis Date: Fri, 15 Nov 2024 16:58:58 +0000 Subject: [PATCH 32/35] another parallel test scenario --- tests/parallel/cycle_a_t1_b_t2.rs | 79 +++++++++++++++++++++++++++++++ tests/parallel/main.rs | 1 + tests/parallel/setup.rs | 2 + 3 files changed, 82 insertions(+) create mode 100644 tests/parallel/cycle_a_t1_b_t2.rs diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs new file mode 100644 index 000000000..5ea5c447d --- /dev/null +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -0,0 +1,79 @@ +//! Test a specific cycle scenario: +//! +//! ```text +//! Thread T1 Thread T2 +//! --------- --------- +//! | | +//! v | +//! query_a() | +//! ^ | v +//! | +------------> query_b() +//! | | +//! +--------------------+ +//! `````` + +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(22); + +// Signal 1: T1 has entered `query_a` +// Signal 2: T2 has entered `query_b` + +#[salsa::tracked(cycle_fn=query_a_cycle_fn, cycle_initial=query_a_initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + db.signal(1); + + // Wait for Thread T2 to enter `query_b` before we continue. + db.wait_for(2); + + query_b(db) +} + +fn query_a_cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn query_a_initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[salsa::tracked] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + // Wait for Thread T1 to enter `query_a` before we continue. + db.wait_for(1); + + db.signal(2); + + let a_value = query_a(db); + CycleValue(a_value.0 + 1).min(MAX) +} + +#[test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + + // Thread 1: + scope.spawn(move || { + let r = query_a(&db_t1); + assert_eq!(r, MAX); + }); + + // Thread 2: + scope.spawn(move || { + let r = query_b(&db_t2); + assert_eq!(r, MAX); + }); + }); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 6b34dc06f..e4e423937 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,5 +1,6 @@ mod setup; +mod cycle_a_t1_b_t2; mod cycle_ab_peeping_c; mod parallel_cancellation; mod parallel_map; diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index c266731a0..b67d4f667 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -9,8 +9,10 @@ use crate::signal::Signal; /// a certain behavior. #[salsa::db] pub(crate) trait KnobsDatabase: Database { + /// Signal that we are entering stage 1. fn signal(&self, stage: usize); + /// Wait until we reach stage `stage` (no-op if we have already reached that stage). fn wait_for(&self, stage: usize); } From 00acc5616fc7a2ff4e3b5446d82d8d9546e4c07b Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 13 Dec 2024 16:10:57 -0800 Subject: [PATCH 33/35] WIP: removed cycle_ignore; nested cycles broken --- src/active_query.rs | 1 - src/function/execute.rs | 21 ++++++++++---- src/function/fetch.rs | 22 +++++++++++++- src/function/maybe_changed_after.rs | 9 +++--- src/function/memo.rs | 6 ++++ src/function/specify.rs | 1 - src/zalsa_local.rs | 7 ----- tests/cycle.rs | 43 ++++++++++++++++------------ tests/parallel/cycle_a_t1_b_t2.rs | 30 +++++++++---------- tests/parallel/cycle_ab_peeping_c.rs | 2 +- 10 files changed, 87 insertions(+), 55 deletions(-) diff --git a/src/active_query.rs b/src/active_query.rs index 157bec117..ba5ddc35c 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -129,7 +129,6 @@ impl ActiveQuery { durability: self.durability, tracked_struct_ids: self.tracked_struct_ids, accumulated: self.accumulated, - cycle_ignore: !self.cycle_heads.is_empty(), cycle_heads: self.cycle_heads, } } diff --git a/src/function/execute.rs b/src/function/execute.rs index 41b95ee80..6afde184c 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -85,6 +85,8 @@ where "{database_key_index:?}: execute: \ I am a cycle head, comparing last provisional value with new value" ); + dbg!(&new_value); + dbg!(last_provisional_value); // If the new result is equal to the last provisional result, the cycle has // converged and we are done. if !C::values_equal(&new_value, last_provisional_value) { @@ -105,7 +107,6 @@ where "fixpoint iteration of {database_key_index:#?} should \ converge before u32::MAX iterations", ); - revisions.cycle_ignore = false; opt_last_provisional = Some(self.insert_memo( zalsa, id, @@ -121,19 +122,26 @@ where } } } + iteration_count = iteration_count.checked_add(1).expect( + "fixpoint iteration of {database_key_index:#?} should \ + converge before u32::MAX iterations", + ); + if iteration_count > 10 { + panic!("too much iteration"); + } // This is no longer a provisional result, it's our final result, so remove ourself // from the cycle heads, and iterate one last time to remove ourself from all other // results in the cycle as well and turn them into usable cached results. - // TODO Can we avoid doing this? the extra cycle is quite expensive if there is a - // nested cycle. Maybe track the relevant memos and replace them all with the cycle - // head removed? Or just let them keep the cycle head and allow cycle memos to be - // used when we are not actually iterating the cycle for that head? + // TODO Can we avoid doing this? the extra iteration is quite expensive if there is + // a nested cycle. Maybe track the relevant memos and replace them all with the + // cycle head removed? Or just let them keep the cycle head and allow cycle memos + // to be used when we are not actually iterating the cycle for that head? tracing::debug!( "{database_key_index:?}: execute: fixpoint iteration has a final value, \ one more iteration to remove cycle heads from memos" ); revisions.cycle_heads.remove(&database_key_index); - revisions.cycle_ignore = false; + dbg!(&revisions.cycle_heads); self.insert_memo( zalsa, id, @@ -143,6 +151,7 @@ where } tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); + dbg!(&new_value); // If the new value is equal to the old one, then it didn't // really change, even if some of its inputs have. So we can diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 8c999ee9b..75e1c0188 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -54,6 +54,7 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { if memo.value.is_some() + && !memo.is_provisional() && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo) { // Unsafety invariant: memo is present in memo_map @@ -78,6 +79,25 @@ where ) { ClaimResult::Retry => return None, ClaimResult::Cycle => { + dbg!("hit cycle for", database_key_index); + // check if there's a provisional value for this query + let memo_guard = self.get_memo_from_table_for(zalsa, id); + if let Some(memo) = &memo_guard { + dbg!("found provisional value, shallow verifying it"); + if memo.value.is_some() + && memo.revisions.cycle_heads.contains(&database_key_index) + && self.shallow_verify_memo(db, zalsa, database_key_index, memo) + { + dbg!("verified provisional value, returning it"); + dbg!(&memo.value); + // Unsafety invariant: memo is present in memo_map. + unsafe { + return Some(self.extend_memo_lifetime(memo)); + } + } + } + // no provisional value; create/insert/return initial provisional value + dbg!("no provisional value found, checking for initial value"); return self .initial_value(db, database_key_index.key_index) .map(|initial_value| { @@ -103,7 +123,7 @@ where "dependency graph cycle querying {database_key_index:#?}; \ set cycle_fn/cycle_initial to fixpoint iterate" ) - }) + }); } ClaimResult::Claimed(guard) => guard, }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 7d86b27bb..713c09ccf 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -57,7 +57,9 @@ where // Check if we have a verified version: this is the hot path. let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { - if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { + if !memo.is_provisional() + && self.shallow_verify_memo(db, zalsa, database_key_index, memo) + { return VerifyResult::changed_if(memo.revisions.changed_at > revision); } drop(memo_guard); // release the arc-swap guard before cold path @@ -149,9 +151,6 @@ where database_key_index: DatabaseKeyIndex, memo: &Memo>, ) -> bool { - if memo.revisions.cycle_ignore { - return false; - } let verified_at = memo.verified_at.load(); let revision_now = zalsa.current_revision(); @@ -189,7 +188,7 @@ where old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, ) -> VerifyResult { - if old_memo.revisions.cycle_ignore { + if old_memo.is_provisional() { return VerifyResult::Changed; } let zalsa = db.zalsa(); diff --git a/src/function/memo.rs b/src/function/memo.rs index 8e167ccbd..dd9559794 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -121,6 +121,12 @@ impl Memo { revisions, } } + + /// True if this is a provisional cycle-iteration result. + pub(super) fn is_provisional(&self) -> bool { + !self.revisions.cycle_heads.is_empty() + } + /// True if this memo is known not to have changed based on its durability. pub(super) fn check_durability(&self, zalsa: &Zalsa) -> bool { let last_changed = zalsa.last_changed_revision(self.revisions.durability); diff --git a/src/function/specify.rs b/src/function/specify.rs index f5803b3dc..fa5e04278 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -71,7 +71,6 @@ where tracked_struct_ids: Default::default(), accumulated: Default::default(), cycle_heads: Default::default(), - cycle_ignore: false, }; if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key) { diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index ca566f496..aa9c3e203 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -341,12 +341,6 @@ pub(crate) struct QueryRevisions { /// after each iteration, whether the cycle has converged or must /// iterate again. pub(super) cycle_heads: FxHashSet, - - /// True if this result is based on provisional results of other - /// queries, and is not created explicitly by the query managing - /// a fixpoint iteration (the "cycle head"); this should never be - /// treated as a valid cached result. - pub(super) cycle_ignore: bool, } impl QueryRevisions { @@ -359,7 +353,6 @@ impl QueryRevisions { tracked_struct_ids: Default::default(), accumulated: Default::default(), cycle_heads, - cycle_ignore: false, } } diff --git a/tests/cycle.rs b/tests/cycle.rs index 8012902cb..4e7f5b41d 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -6,6 +6,7 @@ mod common; use common::{ExecuteValidateLoggerDatabase, LogDatabase}; use expect_test::expect; use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; +use test_log::test; /// A vector of inputs a query can evaluate to get an iterator of u8 values to operate on. /// @@ -59,11 +60,19 @@ impl Input { Self::MaxIterate(inputs) => max_iterate(db, inputs), Self::MinPanic(inputs) => min_panic(db, inputs), Self::MaxPanic(inputs) => max_panic(db, inputs), - Self::Successor(input) => input.eval(db) + 1, + Self::Successor(input) => { + let inval = input.eval(db); + if inval >= MIN_VALUE && inval <= MAX_VALUE { + inval + 1 + } else { + inval + } + } } } fn assert(self, db: &dyn Db, expected: u8) { + dbg!("ASSERT"); assert_eq!(self.eval(db), expected) } } @@ -73,14 +82,14 @@ fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { inputs.values(db).min().expect("empty inputs!") } -const MIN_COUNT_FALLBACK: u8 = 100; -const MIN_VALUE_FALLBACK: u8 = 5; +const MIN_COUNT_FALLBACK: u8 = 4; +const MIN_VALUE_FALLBACK: u8 = 7; const MIN_VALUE: u8 = 10; fn min_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { if *value < MIN_VALUE { CycleRecoveryAction::Fallback(MIN_VALUE_FALLBACK) - } else if count > 10 { + } else if count > 3 { CycleRecoveryAction::Fallback(MIN_COUNT_FALLBACK) } else { CycleRecoveryAction::Iterate @@ -96,14 +105,14 @@ fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { inputs.values(db).max().expect("empty inputs!") } -const MAX_COUNT_FALLBACK: u8 = 200; -const MAX_VALUE_FALLBACK: u8 = 250; +const MAX_COUNT_FALLBACK: u8 = 251; +const MAX_VALUE_FALLBACK: u8 = 248; const MAX_VALUE: u8 = 245; fn max_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { if *value > MAX_VALUE { CycleRecoveryAction::Fallback(MAX_VALUE_FALLBACK) - } else if count > 10 { + } else if count > 3 { CycleRecoveryAction::Fallback(MAX_COUNT_FALLBACK) } else { CycleRecoveryAction::Iterate @@ -332,11 +341,11 @@ fn two_converge() { a.assert(&db, 250); } -/// a:Xp(b) -> b:Xi(v10,c) -> c:Xp(sb) +/// a:Xp(b) -> b:Xi(v20,c) -> c:Xp(sb) /// ^ | /// +---------------------+ /// -/// Two-query cycle, falls back due to >10 iterations. +/// Two-query cycle, falls back due to >3 iterations. #[test] fn two_fallback_count() { let mut db = DbImpl::new(); @@ -347,18 +356,18 @@ fn two_fallback_count() { let b = Input::MaxIterate(b_in); let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); - b_in.set_inputs(&mut db).to(vec![Input::Value(10), c]); + b_in.set_inputs(&mut db).to(vec![Input::Value(20), c]); c_in.set_inputs(&mut db) .to(vec![Input::Successor(Box::new(b))]); - a.assert(&db, MAX_COUNT_FALLBACK + 1); + a.assert(&db, MAX_COUNT_FALLBACK); } -/// a:Xp(b) -> b:Xi(v241,c) -> c:Xp(sb) +/// a:Xp(b) -> b:Xi(v244,c) -> c:Xp(sb) /// ^ | /// +---------------------+ /// -/// Two-query cycle, falls back due to value reaching >MAX_VALUE (we start at 241 and each +/// Two-query cycle, falls back due to value reaching >MAX_VALUE (we start at 244 and each /// iteration increments until we reach >245). #[test] fn two_fallback_value() { @@ -370,11 +379,11 @@ fn two_fallback_value() { let b = Input::MaxIterate(b_in); let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); - b_in.set_inputs(&mut db).to(vec![Input::Value(241), c]); + b_in.set_inputs(&mut db).to(vec![Input::Value(244), c]); c_in.set_inputs(&mut db) .to(vec![Input::Successor(Box::new(b))]); - a.assert(&db, MAX_VALUE_FALLBACK + 1); + a.assert(&db, MAX_VALUE_FALLBACK); } /// a:Ni(b) -> b:Np(a, c) -> c:Np(v25, a) @@ -438,7 +447,6 @@ fn layered_fallback_count() { b_in.set_inputs(&mut db).to(vec![a.clone(), c]); c_in.set_inputs(&mut db) .to(vec![Input::Value(25), Input::Successor(Box::new(b))]); - a.assert(&db, MAX_COUNT_FALLBACK + 1); } @@ -579,7 +587,6 @@ fn nested_fallback_value() { a.clone(), Input::Successor(Box::new(b)), ]); - a.assert(&db, MAX_VALUE_FALLBACK + 1); } @@ -615,7 +622,7 @@ fn nested_inner_first_fallback_value() { /// +-------------------+ /// /// Nested cycles, double head. We converge on 25. -#[test_log::test] +#[test] fn nested_double_converge() { let mut db = DbImpl::new(); let a_in = Inputs::new(&db, vec![]); diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs index 5ea5c447d..b289d2cdd 100644 --- a/tests/parallel/cycle_a_t1_b_t2.rs +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -10,7 +10,7 @@ //! | +------------> query_b() //! | | //! +--------------------+ -//! `````` +//! ``` use salsa::CycleRecoveryAction; @@ -25,7 +25,7 @@ const MAX: CycleValue = CycleValue(22); // Signal 1: T1 has entered `query_a` // Signal 2: T2 has entered `query_b` -#[salsa::tracked(cycle_fn=query_a_cycle_fn, cycle_initial=query_a_initial)] +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { db.signal(1); @@ -35,19 +35,7 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { query_b(db) } -fn query_a_cycle_fn( - _db: &dyn KnobsDatabase, - _value: &CycleValue, - _count: u32, -) -> CycleRecoveryAction { - CycleRecoveryAction::Iterate -} - -fn query_a_initial(_db: &dyn KnobsDatabase) -> CycleValue { - MIN -} - -#[salsa::tracked] +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { // Wait for Thread T1 to enter `query_a` before we continue. db.wait_for(1); @@ -58,6 +46,18 @@ fn query_b(db: &dyn KnobsDatabase) -> CycleValue { CycleValue(a_value.0 + 1).min(MAX) } +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + #[test] fn the_test() { std::thread::scope(|scope| { diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs index f8ad20b25..0a6e5100b 100644 --- a/tests/parallel/cycle_ab_peeping_c.rs +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -65,7 +65,7 @@ fn query_c(db: &dyn KnobsDatabase) -> CycleValue { eprintln!("query_c()"); // Wait until T1 has reached MID then execute `query_b`. - // This shoul block and (due to the configuration on our database) signal stage 2. + // This should block and (due to the configuration on our database) signal stage 2. db.wait_for(1); eprintln!("query_c: signaled"); From 52025798e7aca17dccfe917b0c75460272cab64f Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 17 Dec 2024 18:44:01 -0800 Subject: [PATCH 34/35] fixed all single-thread cycles; multi-thread still not working --- src/accumulator.rs | 4 + src/active_query.rs | 6 +- src/cycle.rs | 5 + src/function.rs | 5 + src/function/execute.rs | 51 ++--- src/function/fetch.rs | 12 +- src/function/maybe_changed_after.rs | 49 ++++- src/function/memo.rs | 21 +- src/function/specify.rs | 1 + src/ingredient.rs | 3 + src/input.rs | 6 +- src/input/input_field.rs | 4 + src/interned.rs | 6 +- src/runtime.rs | 3 + src/table/sync.rs | 4 + src/tracked_struct.rs | 6 +- src/tracked_struct/tracked_field.rs | 4 + src/zalsa_local.rs | 2 +- tests/cycle.rs | 313 +++++++++++++++------------ tests/parallel/cycle_ab_peeping_c.rs | 8 +- 20 files changed, 317 insertions(+), 196 deletions(-) diff --git a/src/accumulator.rs b/src/accumulator.rs index aedd0072b..5fd9f6cc6 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -107,6 +107,10 @@ impl Ingredient for IngredientImpl { panic!("nothing should ever depend on an accumulator directly") } + fn is_verified_final<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { CycleRecoveryStrategy::Panic } diff --git a/src/active_query.rs b/src/active_query.rs index ba5ddc35c..1178620d0 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -77,13 +77,15 @@ impl ActiveQuery { durability: Durability, revision: Revision, accumulated: InputAccumulatedValues, - cycle_heads: &FxHashSet, + cycle_heads: Option<&FxHashSet>, ) { self.input_outputs.insert((EdgeKind::Input, input)); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(revision); self.accumulated.add_input(accumulated); - self.cycle_heads.extend(cycle_heads); + if let Some(cycle_heads) = cycle_heads { + self.cycle_heads.extend(cycle_heads); + } } pub(super) fn add_untracked_read(&mut self, changed_at: Revision) { diff --git a/src/cycle.rs b/src/cycle.rs index c90f2170b..46fdc4464 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,3 +1,8 @@ +/// The maximum number of times we'll fixpoint-iterate before panicking. +/// +/// Should only be relevant in case of a badly configured cycle recovery. +pub const MAX_ITERATIONS: u32 = 200; + /// Return value from a cycle recovery function. #[derive(Debug)] pub enum CycleRecoveryAction { diff --git a/src/function.rs b/src/function.rs index b825c4921..013024ce7 100644 --- a/src/function.rs +++ b/src/function.rs @@ -202,6 +202,11 @@ where self.maybe_changed_after(db, key, revision) } + fn is_verified_final<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { + self.get_memo_from_table_for(db.zalsa(), input) + .is_some_and(|memo| !memo.may_be_provisional()) + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { C::CYCLE_STRATEGY } diff --git a/src/function/execute.rs b/src/function/execute.rs index 6afde184c..adbd74318 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,6 +1,8 @@ use std::sync::Arc; -use crate::{zalsa::ZalsaDatabase, Database, DatabaseKeyIndex, Event, EventKind}; +use crate::{ + cycle::MAX_ITERATIONS, zalsa::ZalsaDatabase, Database, DatabaseKeyIndex, Event, EventKind, +}; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -103,51 +105,36 @@ where ) { crate::CycleRecoveryAction::Iterate => { tracing::debug!("{database_key_index:?}: execute: iterate again"); - iteration_count = iteration_count.checked_add(1).expect( - "fixpoint iteration of {database_key_index:#?} should \ - converge before u32::MAX iterations", - ); - opt_last_provisional = Some(self.insert_memo( - zalsa, - id, - Memo::new(Some(new_value), revision_now, revisions), - )); - continue; } crate::CycleRecoveryAction::Fallback(fallback_value) => { tracing::debug!( "{database_key_index:?}: execute: user cycle_fn says to fall back" ); new_value = fallback_value; + // We have to insert the fallback value for this query and then iterate + // one more time to fill in correct values for everything else in the + // cycle based on it; then we'll re-insert it as final value. } } - } - iteration_count = iteration_count.checked_add(1).expect( - "fixpoint iteration of {database_key_index:#?} should \ + iteration_count = iteration_count.checked_add(1).expect( + "fixpoint iteration of {database_key_index:#?} should \ converge before u32::MAX iterations", - ); - if iteration_count > 10 { - panic!("too much iteration"); + ); + if iteration_count > MAX_ITERATIONS { + panic!("{database_key_index:?}: execute: too many cycle iterations"); + } + opt_last_provisional = Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + )); + continue; } - // This is no longer a provisional result, it's our final result, so remove ourself - // from the cycle heads, and iterate one last time to remove ourself from all other - // results in the cycle as well and turn them into usable cached results. - // TODO Can we avoid doing this? the extra iteration is quite expensive if there is - // a nested cycle. Maybe track the relevant memos and replace them all with the - // cycle head removed? Or just let them keep the cycle head and allow cycle memos - // to be used when we are not actually iterating the cycle for that head? tracing::debug!( - "{database_key_index:?}: execute: fixpoint iteration has a final value, \ - one more iteration to remove cycle heads from memos" + "{database_key_index:?}: execute: fixpoint iteration has a final value" ); revisions.cycle_heads.remove(&database_key_index); dbg!(&revisions.cycle_heads); - self.insert_memo( - zalsa, - id, - Memo::new(Some(new_value), revision_now, revisions), - ); - continue; } tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 75e1c0188..bc4295a65 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -29,7 +29,7 @@ where durability, changed_at, InputAccumulatedValues::from_map(&memo.revisions.accumulated), - &memo.revisions.cycle_heads, + memo.cycle_heads(), ); value @@ -54,8 +54,7 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { if memo.value.is_some() - && !memo.is_provisional() - && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo) + && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo, false) { // Unsafety invariant: memo is present in memo_map unsafe { @@ -84,9 +83,10 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { dbg!("found provisional value, shallow verifying it"); + dbg!(memo.tracing_debug()); if memo.value.is_some() && memo.revisions.cycle_heads.contains(&database_key_index) - && self.shallow_verify_memo(db, zalsa, database_key_index, memo) + && self.shallow_verify_memo(db, zalsa, database_key_index, memo, true) { dbg!("verified provisional value, returning it"); dbg!(&memo.value); @@ -146,6 +146,8 @@ where } } - Some(self.execute(db, database_key_index, opt_old_memo)) + let memo = self.execute(db, database_key_index, opt_old_memo); + + Some(memo) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 713c09ccf..d26a25819 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -57,9 +57,7 @@ where // Check if we have a verified version: this is the hot path. let memo_guard = self.get_memo_from_table_for(zalsa, id); if let Some(memo) = &memo_guard { - if !memo.is_provisional() - && self.shallow_verify_memo(db, zalsa, database_key_index, memo) - { + if self.shallow_verify_memo(db, zalsa, database_key_index, memo, false) { return VerifyResult::changed_if(memo.revisions.changed_at > revision); } drop(memo_guard); // release the arc-swap guard before cold path @@ -150,14 +148,25 @@ where zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, + allow_provisional: bool, ) -> bool { - let verified_at = memo.verified_at.load(); - let revision_now = zalsa.current_revision(); - tracing::debug!( "{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})", memo = memo.tracing_debug() ); + if !allow_provisional { + if memo.may_be_provisional() { + tracing::debug!( + "{database_key_index:?}: validate_provisional(memo = {memo:#?})", + memo = memo.tracing_debug() + ); + if !self.validate_provisional(db, zalsa, memo) { + return false; + } + } + } + let verified_at = memo.verified_at.load(); + let revision_now = zalsa.current_revision(); if verified_at == revision_now { // Already verified. @@ -175,6 +184,26 @@ where false } + /// Check if this memo's cycle heads have all been finalized. If so, mark it verified final and + /// return true, if not return false. + fn validate_provisional( + &self, + db: &C::DbView, + zalsa: &Zalsa, + memo: &Memo>, + ) -> bool { + for cycle_head in &memo.revisions.cycle_heads { + if !zalsa + .lookup_ingredient(cycle_head.ingredient_index) + .is_verified_final(db.as_dyn_database(), cycle_head.key_index) + { + return false; + } + } + memo.verified_final.store(true); + true + } + /// VerifyResult::Unchanged if the memo's value and `changed_at` time is up to date in the /// current revision. When this returns Unchanged with no cycle heads, it also updates the /// memo's `verified_at` field if needed to make future calls cheaper. @@ -188,9 +217,6 @@ where old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, ) -> VerifyResult { - if old_memo.is_provisional() { - return VerifyResult::Changed; - } let zalsa = db.zalsa(); let database_key_index = active_query.database_key_index; @@ -199,9 +225,12 @@ where old_memo = old_memo.tracing_debug() ); - if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo) { + if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo, false) { return VerifyResult::Unchanged(Default::default()); } + if old_memo.may_be_provisional() { + return VerifyResult::Changed; + } loop { let mut cycle_heads = FxHashSet::default(); diff --git a/src/function/memo.rs b/src/function/memo.rs index dd9559794..b9c2b3ef5 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -1,3 +1,4 @@ +use rustc_hash::FxHashSet; use std::any::Any; use std::fmt::Debug; use std::fmt::Formatter; @@ -109,6 +110,9 @@ pub(super) struct Memo { /// as the current revision. pub(super) verified_at: AtomicCell, + /// Is this memo verified to not be a provisional cycle result? + pub(super) verified_final: AtomicCell, + /// Revision information pub(super) revisions: QueryRevisions, } @@ -118,13 +122,23 @@ impl Memo { Memo { value, verified_at: AtomicCell::new(revision_now), + verified_final: AtomicCell::new(revisions.cycle_heads.is_empty()), revisions, } } - /// True if this is a provisional cycle-iteration result. - pub(super) fn is_provisional(&self) -> bool { - !self.revisions.cycle_heads.is_empty() + /// True if this is may be a provisional cycle-iteration result. + pub(super) fn may_be_provisional(&self) -> bool { + !self.verified_final.load() + } + + /// Cycle heads that should be propagated to dependent queries. + pub(super) fn cycle_heads(&self) -> Option<&FxHashSet> { + if self.may_be_provisional() { + Some(&self.revisions.cycle_heads) + } else { + None + } } /// True if this memo is known not to have changed based on its durability. @@ -185,6 +199,7 @@ impl Memo { }, ) .field("verified_at", &self.memo.verified_at) + .field("verified_final", &self.memo.verified_final) .field("revisions", &self.memo.revisions) .finish() } diff --git a/src/function/specify.rs b/src/function/specify.rs index fa5e04278..26f10b3e6 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -81,6 +81,7 @@ where let memo = Memo { value: Some(value), verified_at: AtomicCell::new(revision), + verified_final: AtomicCell::new(true), revisions, }; diff --git a/src/ingredient.rs b/src/ingredient.rs index 56445016c..d705df8b8 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -41,6 +41,9 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { revision: Revision, ) -> VerifyResult; + /// Is the value for `input` in this ingredient marked as possibly a provisional cycle value? + fn is_verified_final<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool; + /// What were the inputs (if any) that were used to create the value at `key_index`. fn origin(&self, db: &dyn Database, key_index: Id) -> Option; diff --git a/src/input.rs b/src/input.rs index 88a25c279..bf6d5e32e 100644 --- a/src/input.rs +++ b/src/input.rs @@ -191,7 +191,7 @@ impl IngredientImpl { stamp.durability, stamp.changed_at, InputAccumulatedValues::Empty, - &Default::default(), + None, ); &value.fields } @@ -222,6 +222,10 @@ impl Ingredient for IngredientImpl { VerifyResult::unchanged() } + fn is_verified_final<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { CycleRecoveryStrategy::Panic } diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 66213aef7..374b2f89e 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -62,6 +62,10 @@ where VerifyResult::changed_if(value.stamps[self.field_index].changed_at > revision) } + fn is_verified_final<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + fn origin(&self, _db: &dyn Database, _key_index: Id) -> Option { None } diff --git a/src/interned.rs b/src/interned.rs index 774eadbe0..3a7cc240f 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -136,7 +136,7 @@ where Durability::MAX, self.reset_at, InputAccumulatedValues::Empty, - &Default::default(), + None, ); // Optimisation to only get read lock on the map if the data has already @@ -226,6 +226,10 @@ where VerifyResult::changed_if(revision < self.reset_at) } + fn is_verified_final<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + fn cycle_recovery_strategy(&self) -> crate::cycle::CycleRecoveryStrategy { crate::cycle::CycleRecoveryStrategy::Panic } diff --git a/src/runtime.rs b/src/runtime.rs index 6567917ec..db6b5bf09 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -172,7 +172,10 @@ impl Runtime { let mut dg = self.dependency_graph.lock(); let thread_id = std::thread::current().id(); + eprintln!("Runtime::block_on {database_key:?}, I am {thread_id:?}, other id {other_id:?}"); + if dg.depends_on(other_id, thread_id) { + eprintln!("thread dependency cycle"); return BlockResult::Cycle; } diff --git a/src/table/sync.rs b/src/table/sync.rs index 750779277..ae0b00275 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -50,8 +50,11 @@ impl SyncTable { util::ensure_vec_len(&mut syncs, memo_ingredient_index.as_usize() + 1); + eprintln!("SyncTable::claim {database_key_index:?}, thread {thread_id:?}"); + match &syncs[memo_ingredient_index.as_usize()] { None => { + eprintln!("not claimed, claiming"); syncs[memo_ingredient_index.as_usize()] = Some(SyncState { id: thread_id, anyone_waiting: AtomicBool::new(false), @@ -67,6 +70,7 @@ impl SyncTable { id: other_id, anyone_waiting, }) => { + eprintln!("already claimed"); // NB: `Ordering::Relaxed` is sufficient here, // as there are no loads that are "gated" on this // value. Everything that is written is also protected diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 2d85524f9..e2b842b4d 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -564,7 +564,7 @@ where data.durability, field_changed_at, InputAccumulatedValues::Empty, - &Default::default(), + None, ); unsafe { self.to_self_ref(&data.fields) } @@ -588,6 +588,10 @@ where VerifyResult::unchanged() } + fn is_verified_final<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { crate::cycle::CycleRecoveryStrategy::Panic } diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index 0745c6d83..f030e3cf1 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -61,6 +61,10 @@ where VerifyResult::changed_if(field_changed_at > revision) } + fn is_verified_final<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + fn origin( &self, _db: &dyn Database, diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index aa9c3e203..4cb5c1eea 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -166,7 +166,7 @@ impl ZalsaLocal { durability: Durability, changed_at: Revision, accumulated: InputAccumulatedValues, - cycle_heads: &FxHashSet, + cycle_heads: Option<&FxHashSet>, ) { debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", diff --git a/tests/cycle.rs b/tests/cycle.rs index 4e7f5b41d..10a144cb2 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -8,7 +8,24 @@ use expect_test::expect; use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; use test_log::test; -/// A vector of inputs a query can evaluate to get an iterator of u8 values to operate on. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Value { + N(u8), + OutOfBounds, + TooManyIterations, +} + +impl Value { + fn into_value(&self) -> Option { + if let Self::N(val) = self { + Some(*val) + } else { + None + } + } +} + +/// A vector of inputs a query can evaluate to get an iterator of values to operate on. /// /// This allows creating arbitrary query graphs between the four queries below (`min_iterate`, /// `max_iterate`, `min_panic`, `max_panic`) for testing cycle behaviors. @@ -18,19 +35,19 @@ struct Inputs { } impl Inputs { - fn values(self, db: &dyn Db) -> impl Iterator + '_ { + fn values(self, db: &dyn Db) -> impl Iterator + '_ { self.inputs(db).into_iter().map(|input| input.eval(db)) } } -/// A single input, evaluating to a single u8 value. +/// A single input, evaluating to a single [`Value`]. #[derive(Clone, Debug)] enum Input { /// a simple value - Value(u8), + Value(Value), /// a simple value, reported as an untracked read - UntrackedRead(u8), + UntrackedRead(Value), /// minimum of the given inputs, with fixpoint iteration on cycles MinIterate(Inputs), @@ -49,7 +66,7 @@ enum Input { } impl Input { - fn eval(self, db: &dyn Db) -> u8 { + fn eval(self, db: &dyn Db) -> Value { match self { Self::Value(value) => value, Self::UntrackedRead(value) => { @@ -60,77 +77,117 @@ impl Input { Self::MaxIterate(inputs) => max_iterate(db, inputs), Self::MinPanic(inputs) => min_panic(db, inputs), Self::MaxPanic(inputs) => max_panic(db, inputs), - Self::Successor(input) => { - let inval = input.eval(db); - if inval >= MIN_VALUE && inval <= MAX_VALUE { - inval + 1 - } else { - inval - } - } + Self::Successor(input) => match input.eval(db) { + Value::N(num) => Value::N(num + 1), + other => other, + }, } } - fn assert(self, db: &dyn Db, expected: u8) { - dbg!("ASSERT"); + fn assert(self, db: &dyn Db, expected: Value) { assert_eq!(self.eval(db), expected) } -} -#[salsa::tracked(cycle_fn=min_recover, cycle_initial=min_initial)] -fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { - inputs.values(db).min().expect("empty inputs!") + fn assert_value(self, db: &dyn Db, expected: u8) { + self.assert(db, Value::N(expected)) + } + + fn assert_bounds(self, db: &dyn Db) { + self.assert(db, Value::OutOfBounds) + } + + fn assert_count(self, db: &dyn Db) { + self.assert(db, Value::TooManyIterations) + } } -const MIN_COUNT_FALLBACK: u8 = 4; -const MIN_VALUE_FALLBACK: u8 = 7; const MIN_VALUE: u8 = 10; - -fn min_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { - if *value < MIN_VALUE { - CycleRecoveryAction::Fallback(MIN_VALUE_FALLBACK) - } else if count > 3 { - CycleRecoveryAction::Fallback(MIN_COUNT_FALLBACK) +const MAX_VALUE: u8 = 245; +const MAX_ITERATIONS: u32 = 3; + +/// Recover from a cycle by falling back to `Value::OutOfBounds` if the value is out of bounds, +/// `Value::TooManyIterations` if we've iterated more than `MAX_ITERATIONS` times, or else +/// iterating again. +fn cycle_recover( + _db: &dyn Db, + value: &Value, + count: u32, + _inputs: Inputs, +) -> CycleRecoveryAction { + if value + .into_value() + .is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE) + { + CycleRecoveryAction::Fallback(Value::OutOfBounds) + } else if count > MAX_ITERATIONS { + CycleRecoveryAction::Fallback(Value::TooManyIterations) } else { CycleRecoveryAction::Iterate } } -fn min_initial(_db: &dyn Db, _inputs: Inputs) -> u8 { - 255 +/// Fold an iterator of `Value` into a `Value`, given some binary operator to apply to two `u8`. +/// `Value::TooManyIterations` and `Value::OutOfBounds` will always propagate, with +/// `Value::TooManyIterations` taking precedence. +fn fold_values(values: impl IntoIterator, op: F) -> Value +where + F: Fn(u8, u8) -> u8, +{ + values + .into_iter() + .fold(None, |accum, elem| { + let Some(accum) = accum else { + return Some(elem); + }; + match (accum, elem) { + (Value::TooManyIterations, _) | (_, Value::TooManyIterations) => { + Some(Value::TooManyIterations) + } + (Value::OutOfBounds, _) | (_, Value::OutOfBounds) => Some(Value::OutOfBounds), + (Value::N(val1), Value::N(val2)) => Some(Value::N(op(val1, val2))), + } + }) + .expect("inputs should not be empty") } -#[salsa::tracked(cycle_fn=max_recover, cycle_initial=max_initial)] -fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { - inputs.values(db).max().expect("empty inputs!") +/// Query minimum value of inputs, with cycle recovery. +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=min_initial)] +fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + dbg!(fold_values(inputs.values(db), u8::min)) } -const MAX_COUNT_FALLBACK: u8 = 251; -const MAX_VALUE_FALLBACK: u8 = 248; -const MAX_VALUE: u8 = 245; +fn min_initial(_db: &dyn Db, _inputs: Inputs) -> Value { + Value::N(255) +} -fn max_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { - if *value > MAX_VALUE { - CycleRecoveryAction::Fallback(MAX_VALUE_FALLBACK) - } else if count > 3 { - CycleRecoveryAction::Fallback(MAX_COUNT_FALLBACK) - } else { - CycleRecoveryAction::Iterate - } +/// Query maximum value of inputs, with cycle recovery. +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=max_initial)] +fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + dbg!(fold_values(inputs.values(db), u8::max)) } -fn max_initial(_db: &dyn Db, _inputs: Inputs) -> u8 { - 0 +fn max_initial(_db: &dyn Db, _inputs: Inputs) -> Value { + Value::N(0) } +/// Query minimum value of inputs, without cycle recovery. #[salsa::tracked] -fn min_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { - inputs.values(db).min().expect("empty inputs!") +fn min_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + dbg!(fold_values(inputs.values(db), u8::min)) } +/// Query maximum value of inputs, without cycle recovery. #[salsa::tracked] -fn max_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> u8 { - inputs.values(db).max().expect("empty inputs!") +fn max_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + dbg!(fold_values(inputs.values(db), u8::max)) +} + +fn untracked(num: u8) -> Input { + Input::UntrackedRead(Value::N(num)) +} + +fn value(num: u8) -> Input { + Input::Value(Value::N(num)) } // Diagram nomenclature for nodes: Each node is represented as a:xx(ii), where `a` is a sequential @@ -176,8 +233,7 @@ fn self_untracked_panic() { let mut db = DbImpl::new(); let a_in = Inputs::new(&db, vec![]); let a = Input::MinPanic(a_in); - a_in.set_inputs(&mut db) - .to(vec![Input::UntrackedRead(10), a.clone()]); + a_in.set_inputs(&mut db).to(vec![untracked(10), a.clone()]); a.eval(&db); } @@ -194,7 +250,7 @@ fn self_converge_initial_value() { let a = Input::MinIterate(a_in); a_in.set_inputs(&mut db).to(vec![a.clone()]); - a.assert(&db, 255); + a.assert_value(&db, 255); } /// a:Ni(b) --> b:Np(a) @@ -213,7 +269,7 @@ fn two_mixed_converge_initial_value() { a_in.set_inputs(&mut db).to(vec![b]); b_in.set_inputs(&mut db).to(vec![a.clone()]); - a.assert(&db, 255); + a.assert_value(&db, 255); } /// a:Np(b) --> b:Ni(a) @@ -252,8 +308,8 @@ fn two_iterate_converge_initial_value() { a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![a.clone()]); - a.assert(&db, 255); - b.assert(&db, 255); + a.assert_value(&db, 255); + b.assert_value(&db, 255); } /// a:Xi(b) --> b:Ni(a) @@ -273,8 +329,8 @@ fn two_iterate_converge_initial_value_2() { a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![a.clone()]); - a.assert(&db, 0); - b.assert(&db, 0); + a.assert_value(&db, 0); + b.assert_value(&db, 0); } /// a:Np(b) --> b:Ni(c) --> c:Xp(b) @@ -295,7 +351,7 @@ fn two_indirect_iterate_converge_initial_value() { b_in.set_inputs(&mut db).to(vec![c]); c_in.set_inputs(&mut db).to(vec![b]); - a.assert(&db, 255); + a.assert_value(&db, 255); } /// a:Xp(b) --> b:Np(c) --> c:Xi(b) @@ -320,7 +376,7 @@ fn two_indirect_panic() { a.eval(&db); } -/// a:Np(b) -> b:Ni(v250,c) -> c:Xp(b) +/// a:Np(b) -> b:Ni(v200,c) -> c:Xp(b) /// ^ | /// +---------------------+ /// @@ -335,10 +391,10 @@ fn two_converge() { let b = Input::MinIterate(b_in); let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); - b_in.set_inputs(&mut db).to(vec![Input::Value(250), c]); + b_in.set_inputs(&mut db).to(vec![value(200), c]); c_in.set_inputs(&mut db).to(vec![b]); - a.assert(&db, 250); + a.assert_value(&db, 200); } /// a:Xp(b) -> b:Xi(v20,c) -> c:Xp(sb) @@ -356,11 +412,11 @@ fn two_fallback_count() { let b = Input::MaxIterate(b_in); let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); - b_in.set_inputs(&mut db).to(vec![Input::Value(20), c]); + b_in.set_inputs(&mut db).to(vec![value(20), c]); c_in.set_inputs(&mut db) .to(vec![Input::Successor(Box::new(b))]); - a.assert(&db, MAX_COUNT_FALLBACK); + a.assert_count(&db); } /// a:Xp(b) -> b:Xi(v244,c) -> c:Xp(sb) @@ -379,11 +435,11 @@ fn two_fallback_value() { let b = Input::MaxIterate(b_in); let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); - b_in.set_inputs(&mut db).to(vec![Input::Value(244), c]); + b_in.set_inputs(&mut db).to(vec![value(244), c]); c_in.set_inputs(&mut db) .to(vec![Input::Successor(Box::new(b))]); - a.assert(&db, MAX_VALUE_FALLBACK); + a.assert_bounds(&db); } /// a:Ni(b) -> b:Np(a, c) -> c:Np(v25, a) @@ -402,10 +458,9 @@ fn three_fork_converge() { let c = Input::MinPanic(c_in); a_in.set_inputs(&mut db).to(vec![b]); b_in.set_inputs(&mut db).to(vec![a.clone(), c]); - c_in.set_inputs(&mut db) - .to(vec![Input::Value(25), a.clone()]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone()]); - a.assert(&db, 25); + a.assert_value(&db, 25); } /// a:Ni(b) -> b:Ni(a, c) -> c:Np(v25, b) @@ -424,9 +479,9 @@ fn layered_converge() { let c = Input::MinPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![a.clone(), c]); - c_in.set_inputs(&mut db).to(vec![Input::Value(25), b]); + c_in.set_inputs(&mut db).to(vec![value(25), b]); - a.assert(&db, 25); + a.assert_value(&db, 25); } /// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v25, sb) @@ -446,11 +501,11 @@ fn layered_fallback_count() { a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![a.clone(), c]); c_in.set_inputs(&mut db) - .to(vec![Input::Value(25), Input::Successor(Box::new(b))]); - a.assert(&db, MAX_COUNT_FALLBACK + 1); + .to(vec![value(25), Input::Successor(Box::new(b))]); + a.assert_count(&db); } -/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v240, sb) +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v243, sb) /// ^ | ^ | /// +----------+ +----------+ /// @@ -467,9 +522,9 @@ fn layered_fallback_value() { a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![a.clone(), c]); c_in.set_inputs(&mut db) - .to(vec![Input::Value(240), Input::Successor(Box::new(b))]); + .to(vec![value(243), Input::Successor(Box::new(b))]); - a.assert(&db, MAX_VALUE_FALLBACK + 1); + a.assert_bounds(&db); } /// a:Ni(b) -> b:Ni(c) -> c:Np(v25, a, b) @@ -488,10 +543,9 @@ fn nested_converge() { let c = Input::MinPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![c]); - c_in.set_inputs(&mut db) - .to(vec![Input::Value(25), a.clone(), b]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone(), b]); - a.assert(&db, 25); + a.assert_value(&db, 25); } /// a:Ni(b) -> b:Ni(c) -> c:Np(v25, b, a) @@ -510,10 +564,9 @@ fn nested_inner_first_converge() { let c = Input::MinPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![c]); - c_in.set_inputs(&mut db) - .to(vec![Input::Value(25), b, a.clone()]); + c_in.set_inputs(&mut db).to(vec![value(25), b, a.clone()]); - a.assert(&db, 25); + a.assert_value(&db, 25); } /// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, a, sb) @@ -532,13 +585,10 @@ fn nested_fallback_count() { let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![c]); - c_in.set_inputs(&mut db).to(vec![ - Input::Value(25), - a.clone(), - Input::Successor(Box::new(b)), - ]); + c_in.set_inputs(&mut db) + .to(vec![value(25), a.clone(), Input::Successor(Box::new(b))]); - a.assert(&db, MAX_COUNT_FALLBACK + 1); + a.assert_count(&db); } /// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, b, sa) @@ -557,16 +607,13 @@ fn nested_inner_first_fallback_count() { let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![c]); - c_in.set_inputs(&mut db).to(vec![ - Input::Value(25), - b, - Input::Successor(Box::new(a.clone())), - ]); + c_in.set_inputs(&mut db) + .to(vec![value(25), b, Input::Successor(Box::new(a.clone()))]); - a.assert(&db, MAX_COUNT_FALLBACK + 1); + a.assert_count(&db); } -/// a:Xi(b) -> b:Xi(c) -> c:Xp(v240, a, sb) +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v243, a, sb) /// ^ ^ | /// +----------+--------------------------+ /// @@ -581,16 +628,18 @@ fn nested_fallback_value() { let b = Input::MaxIterate(b_in); let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); - b_in.set_inputs(&mut db).to(vec![c]); + b_in.set_inputs(&mut db).to(vec![c.clone()]); c_in.set_inputs(&mut db).to(vec![ - Input::Value(240), + value(243), a.clone(), - Input::Successor(Box::new(b)), + Input::Successor(Box::new(b.clone())), ]); - a.assert(&db, MAX_VALUE_FALLBACK + 1); + a.assert_bounds(&db); + b.assert_bounds(&db); + c.assert_bounds(&db); } -/// a:Xi(b) -> b:Xi(c) -> c:Xp(v240, b, sa) +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v243, b, sa) /// ^ ^ | /// +----------+--------------------------+ /// @@ -606,13 +655,10 @@ fn nested_inner_first_fallback_value() { let c = Input::MaxPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![c]); - c_in.set_inputs(&mut db).to(vec![ - Input::Value(240), - b, - Input::Successor(Box::new(a.clone())), - ]); + c_in.set_inputs(&mut db) + .to(vec![value(243), b, Input::Successor(Box::new(a.clone()))]); - a.assert(&db, MAX_VALUE_FALLBACK + 1); + a.assert_bounds(&db); } /// a:Ni(b) -> b:Ni(c, a) -> c:Np(v25, a, b) @@ -633,10 +679,9 @@ fn nested_double_converge() { let c = Input::MinPanic(c_in); a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![c, a.clone()]); - c_in.set_inputs(&mut db) - .to(vec![Input::Value(25), a.clone(), b]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone(), b]); - a.assert(&db, 25); + a.assert_value(&db, 25); } // Multiple-revision cycles @@ -658,11 +703,11 @@ fn cycle_becomes_non_cycle() { a_in.set_inputs(&mut db).to(vec![b]); b_in.set_inputs(&mut db).to(vec![a.clone()]); - a.clone().assert(&db, 255); + a.clone().assert_value(&db, 255); - b_in.set_inputs(&mut db).to(vec![Input::Value(30)]); + b_in.set_inputs(&mut db).to(vec![value(30)]); - a.assert(&db, 30); + a.assert_value(&db, 30); } /// a:Ni(b) --> b:Np(v30) @@ -680,13 +725,13 @@ fn non_cycle_becomes_cycle() { let a = Input::MinIterate(a_in); let b = Input::MinPanic(b_in); a_in.set_inputs(&mut db).to(vec![b]); - b_in.set_inputs(&mut db).to(vec![Input::Value(30)]); + b_in.set_inputs(&mut db).to(vec![value(30)]); - a.clone().assert(&db, 30); + a.clone().assert_value(&db, 30); b_in.set_inputs(&mut db).to(vec![a.clone()]); - a.assert(&db, 255); + a.assert_value(&db, 255); } /// a:Xi(b) -> b:Xi(c, a) -> c:Xp(v25, a, sb) @@ -709,27 +754,26 @@ fn nested_double_multiple_revisions() { a_in.set_inputs(&mut db).to(vec![b.clone()]); b_in.set_inputs(&mut db).to(vec![c, a.clone()]); c_in.set_inputs(&mut db).to(vec![ - Input::Value(25), + value(25), a.clone(), Input::Successor(Box::new(b.clone())), ]); - a.clone().assert(&db, MAX_COUNT_FALLBACK + 1); + a.clone().assert_count(&db); // next revision, we hit max value instead c_in.set_inputs(&mut db).to(vec![ - Input::Value(240), + value(243), a.clone(), Input::Successor(Box::new(b.clone())), ]); - a.clone().assert(&db, MAX_VALUE_FALLBACK + 1); + a.clone().assert_bounds(&db); // and next revision, we converge - c_in.set_inputs(&mut db) - .to(vec![Input::Value(240), a.clone(), b]); + c_in.set_inputs(&mut db).to(vec![value(240), a.clone(), b]); - a.assert(&db, 240); + a.assert_value(&db, 240); } /// a:Ni(b) -> b:Ni(c) -> c:Ni(a) @@ -757,14 +801,14 @@ fn cycle_durability() { .with_durability(Durability::HIGH) .to(vec![a.clone()]); - a.clone().assert(&db, 255); + a.clone().assert_value(&db, 255); // next revision, we converge instead a_in.set_inputs(&mut db) .with_durability(Durability::LOW) - .to(vec![Input::Value(45), b]); + .to(vec![value(45), b]); - a.assert(&db, 45); + a.assert_value(&db, 45); } /// a:Np(v59, b) -> b:Ni(v60, c) -> c:Np(b) @@ -781,22 +825,19 @@ fn cycle_unchanged() { let a = Input::MinPanic(a_in); let b = Input::MinIterate(b_in); let c = Input::MinPanic(c_in); - a_in.set_inputs(&mut db) - .to(vec![Input::Value(59), b.clone()]); - b_in.set_inputs(&mut db).to(vec![Input::Value(60), c]); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(60), c]); c_in.set_inputs(&mut db).to(vec![b.clone()]); - a.clone().assert(&db, 59); - b.clone().assert(&db, 60); + a.clone().assert_value(&db, 59); + b.clone().assert_value(&db, 60); - db.assert_logs_len(5); + db.assert_logs_len(4); // next revision, we change only A, which is not part of the cycle and the cycle does not // depend on. - a_in.set_inputs(&mut db) - .to(vec![Input::Value(45), b.clone()]); - - b.assert(&db, 60); + a_in.set_inputs(&mut db).to(vec![value(45), b.clone()]); + b.assert_value(&db, 60); db.assert_logs(expect![[r#" [ @@ -805,5 +846,5 @@ fn cycle_unchanged() { "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", ]"#]]); - a.assert(&db, 45); + a.assert_value(&db, 45); } diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs index 0a6e5100b..7b3e462e8 100644 --- a/tests/parallel/cycle_ab_peeping_c.rs +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -18,7 +18,7 @@ const MIN: CycleValue = CycleValue(0); const MID: CycleValue = CycleValue(11); const MAX: CycleValue = CycleValue(22); -#[salsa::tracked(cycle_fn=query_a_cycle_fn, cycle_initial=query_a_initial)] +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] fn query_a(db: &dyn KnobsDatabase) -> CycleValue { eprintln!("query_a()"); let b_value = query_b(db); @@ -36,7 +36,7 @@ fn query_a(db: &dyn KnobsDatabase) -> CycleValue { b_value } -fn query_a_cycle_fn( +fn cycle_fn( _db: &dyn KnobsDatabase, value: &CycleValue, count: u32, @@ -45,11 +45,11 @@ fn query_a_cycle_fn( CycleRecoveryAction::Iterate } -fn query_a_initial(_db: &dyn KnobsDatabase) -> CycleValue { +fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { MIN } -#[salsa::tracked] +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] fn query_b(db: &dyn KnobsDatabase) -> CycleValue { eprintln!("query_b()"); From 3bf79307dbf01a9d63647e74bbe330a1ceaa4967 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 14 Jan 2025 16:20:30 -0800 Subject: [PATCH 35/35] panic if fallback fails to converge --- src/function/execute.rs | 12 ++++++++++++ tests/cycle.rs | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/src/function/execute.rs b/src/function/execute.rs index 37fa2fcdd..3b63208b6 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -38,6 +38,7 @@ where }); let mut iteration_count: u32 = 0; + let mut fell_back = false; // Our provisional value from the previous iteration, when doing fixpoint iteration. // Initially it's set to None, because the initial provisional value is created lazily, @@ -91,6 +92,16 @@ where // If the new result is equal to the last provisional result, the cycle has // converged and we are done. if !C::values_equal(&new_value, last_provisional_value) { + if fell_back { + // We fell back to a value last iteration, but the fallback didn't result + // in convergence. We only have bad options here: continue iterating + // (ignoring the request to fall back), or forcibly use the fallback and + // leave the cycle in an inconsistent state (we'll be using a value for + // this query that it doesn't evaluate to, given its inputs). Maybe we'll + // have to go with the latter, but for now let's panic and see if real use + // cases need non-converging fallbacks. + panic!("{database_key_index:?}: execute: fallback did not converge"); + } // We are in a cycle that hasn't converged; ask the user's // cycle-recovery function what to do: // TODO do we need explicit prevention of people calling queries inside @@ -113,6 +124,7 @@ where // We have to insert the fallback value for this query and then iterate // one more time to fill in correct values for everything else in the // cycle based on it; then we'll re-insert it as final value. + fell_back = true; } } iteration_count = iteration_count.checked_add(1).expect( diff --git a/tests/cycle.rs b/tests/cycle.rs index 10a144cb2..101023248 100644 --- a/tests/cycle.rs +++ b/tests/cycle.rs @@ -61,8 +61,11 @@ enum Input { /// maximum of the given inputs, panicking on cycles MaxPanic(Inputs), - /// value of the given input, plus one + /// value of the given input, plus one; propagates error values Successor(Box), + + /// successor, converts error values to zero + SuccessorOrZero(Box), } impl Input { @@ -81,6 +84,10 @@ impl Input { Value::N(num) => Value::N(num + 1), other => other, }, + Self::SuccessorOrZero(input) => match input.eval(db) { + Value::N(num) => Value::N(num + 1), + _ => Value::N(0), + }, } } @@ -203,6 +210,7 @@ fn value(num: u8) -> Input { // - `uXX` for `UntrackedRead(XX)` // - `vXX` for `Value(XX)` // - `sY` for `Successor(Y)` +// - `zY` for `SuccessorOrZero(Y)` // // We always enter from the top left node in the diagram. @@ -419,6 +427,29 @@ fn two_fallback_count() { a.assert_count(&db); } +/// a:Xp(b) -> b:Xi(v20,c) -> c:Xp(zb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back but fallback does not converge. +#[test] +#[should_panic(expected = "fallback did not converge")] +fn two_fallback_diverge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(20), c.clone()]); + c_in.set_inputs(&mut db) + .to(vec![Input::SuccessorOrZero(Box::new(b))]); + + a.assert_count(&db); +} + /// a:Xp(b) -> b:Xi(v244,c) -> c:Xp(sb) /// ^ | /// +---------------------+