Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
s-arash committed Jan 19, 2025
1 parent 3d7330a commit 4aff71c
Show file tree
Hide file tree
Showing 19 changed files with 125 additions and 71 deletions.
2 changes: 1 addition & 1 deletion ascent/src/c_lat_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq> RelIndexRead<'a>
Some(res)
}

fn len(&self) -> usize { self.unwrap_frozen().len() }
fn len_estimate(&self) -> usize { self.unwrap_frozen().len() }
}

impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Clone + Hash + Eq + Sync> CRelIndexRead<'a> for CLatIndex<K, V> {
Expand Down
2 changes: 1 addition & 1 deletion ascent/src/c_rel_full_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelFullIndex<K,
Some(res)
}

fn len(&self) -> usize { self.unwrap_frozen().len() }
fn len_estimate(&self) -> usize { self.unwrap_frozen().len() }
}

impl<'a, K: 'a + Clone + Hash + Eq, V: 'a + Sync> CRelIndexRead<'a> for CRelFullIndex<K, V> {
Expand Down
2 changes: 1 addition & 1 deletion ascent/src/c_rel_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl<'a, K: 'a + Clone + Hash + Eq, V: 'a> RelIndexRead<'a> for CRelIndex<K, V>
Some(res)
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
// approximate len
let sample_size = 4;
let shards = self.unwrap_frozen().shards();
Expand Down
4 changes: 2 additions & 2 deletions ascent/src/c_rel_no_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl<'a, V: 'a> RelIndexRead<'a> for CRelNoIndex<V> {
}

#[inline(always)]
fn len(&self) -> usize { 1 }
fn len_estimate(&self) -> usize { 1 }
}

impl<'a, V: 'a + Sync + Send> CRelIndexRead<'a> for CRelNoIndex<V> {
Expand Down Expand Up @@ -92,7 +92,7 @@ impl<'a, V: 'a> RelIndexWrite for CRelNoIndex<V> {
impl<'a, V: 'a> RelIndexMerge for CRelNoIndex<V> {
fn move_index_contents(from: &mut Self, to: &mut Self) {
let before = Instant::now();
assert_eq!(from.len(), to.len());
assert_eq!(from.len_estimate(), to.len_estimate());
// not necessary because we have a mut reference
// assert!(!from.frozen);
// assert!(!to.frozen);
Expand Down
2 changes: 1 addition & 1 deletion ascent/src/rel_index_boilerplate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ where T: RelIndexRead<'a>
fn index_get(&'a self, key: &Self::Key) -> Option<Self::IteratorType> { (**self).index_get(key) }

#[inline(always)]
fn len(&self) -> usize { (**self).len() }
fn len_estimate(&self) -> usize { (**self).len_estimate() }
}

impl<'a, T> RelIndexReadAll<'a> for &'a T
Expand Down
16 changes: 11 additions & 5 deletions ascent/src/rel_index_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@ pub trait RelIndexRead<'a> {
type Value;
type IteratorType: Iterator<Item = Self::Value> + Clone + 'a;
fn index_get(&'a self, key: &Self::Key) -> Option<Self::IteratorType>;
fn len(&'a self) -> usize;
fn len_estimate(&'a self) -> usize;
/// Is the relation **definitely** empty?
/// It is OK for implementations to return `false` even if the relation may be empty,
/// as this is used to enable certain optimizations.
fn is_empty(&'a self) -> bool {
false
}
}

pub trait RelIndexReadAll<'a> {
Expand All @@ -36,7 +42,7 @@ impl<'a, K: Eq + std::hash::Hash + 'a, V: Clone + 'a> RelIndexRead<'a> for RelIn
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for RelIndexType1<K, V> {
Expand Down Expand Up @@ -67,7 +73,7 @@ impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for HashBrownR
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for HashBrownRelFullIndexType<K, V> {
Expand Down Expand Up @@ -98,7 +104,7 @@ impl<'a, K: Eq + std::hash::Hash, V: 'a + Clone> RelIndexRead<'a> for LatticeInd
}

#[inline(always)]
fn len(&self) -> usize { Self::len(self) }
fn len_estimate(&self) -> usize { Self::len(self) }
}

impl<'a, K: Eq + std::hash::Hash + 'a, V: 'a + Clone> RelIndexReadAll<'a> for LatticeIndexType<K, V> {
Expand Down Expand Up @@ -154,7 +160,7 @@ where
}

#[inline(always)]
fn len(&self) -> usize { self.ind1.len() + self.ind2.len() }
fn len_estimate(&self) -> usize { self.ind1.len_estimate() + self.ind2.len_estimate() }
}

// impl <'a, Ind> RelIndexRead<'a> for RelIndexCombined<'a, Ind, Ind>
Expand Down
2 changes: 1 addition & 1 deletion ascent_base/src/lattice/ord_lattice.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::fmt::{Debug, Display, Formatter};
use std::fmt::{Debug, Formatter};

use crate::Lattice;

Expand Down
49 changes: 35 additions & 14 deletions ascent_macro/src/ascent_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,29 @@ fn compile_mir_rule(rule: &MirRule, scc: &MirScc, mir: &AscentMir) -> proc_macro
} else {
0
};
compile_mir_rule_inner(rule, scc, mir, par_iter_to_ind, head_update_code, 0)
let rule_body_clauses = rule.body_items.iter().filter_map(|bi| bi.clause()).collect_vec();
let check_any_empty_rel_can_help = rule_body_clauses.len() > 1 && !(rule.simple_join_start_index.is_some() && rule_body_clauses.len() == 2);
let any_empty_rel_code = check_any_empty_rel_can_help.then(|| {
rule_body_clauses
.iter()
.map(|bclause| {
let rel_expr = expr_for_rel(&bclause.rel, mir);
quote_spanned! { bclause.rel_args_span=> #rel_expr.is_empty() }
})
.reduce(|l, r| quote! { #l || #r})
}).flatten();

let rule_compiled = compile_mir_rule_inner(rule, scc, mir, par_iter_to_ind, head_update_code, 0);
if let Some(check_any_empty_code) = any_empty_rel_code {
quote! {
let any_rel_empty = #check_any_empty_code;
if !any_rel_empty {
#rule_compiled
}
}
} else {
rule_compiled
}
}

fn compile_mir_rule_inner(
Expand All @@ -844,20 +866,19 @@ fn compile_mir_rule_inner(
let rule_cp2_compiled =
compile_mir_rule_inner(&rule_cp2, _scc, mir, par_iter_to_ind, head_update_code, clause_ind);

if let [MirBodyItem::Clause(bcl1), MirBodyItem::Clause(bcl2)] = &rule.body_items[clause_ind..clause_ind + 2] {
let rel1_var_name = expr_for_rel(&bcl1.rel, mir);
let rel2_var_name = expr_for_rel(&bcl2.rel, mir);

return quote_spanned! {bcl1.rel_args_span=>
if #rel1_var_name.len() <= #rel2_var_name.len() {
#rule_cp1_compiled
} else {
#rule_cp2_compiled
}
};
} else {
let [MirBodyItem::Clause(bcl1), MirBodyItem::Clause(bcl2)] = &rule.body_items[clause_ind..clause_ind + 2] else {
panic!("unexpected body items in reorderable rule")
}
};
let rel1_var_name = expr_for_rel(&bcl1.rel, mir);
let rel2_var_name = expr_for_rel(&bcl2.rel, mir);

return quote_spanned! {bcl1.rel_args_span=>
if #rel1_var_name.len_estimate() <= #rel2_var_name.len_estimate() {
#rule_cp1_compiled
} else {
#rule_cp2_compiled
}
};
}
if clause_ind < rule.body_items.len() {
let bitem = &rule.body_items[clause_ind];
Expand Down
7 changes: 7 additions & 0 deletions ascent_macro/src/ascent_mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ impl MirBodyItem {
}
}

pub fn clause(&self) -> Option<&MirBodyClause> {
match self {
MirBodyItem::Clause(mir_body_clause) => Some(mir_body_clause),
_ => None,
}
}

pub fn bound_vars(&self) -> Vec<Ident> {
match self {
MirBodyItem::Clause(cl) => {
Expand Down
20 changes: 20 additions & 0 deletions ascent_tests/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -972,3 +972,23 @@ fn test_ds_attr() {

assert_rels_eq!(res.bar, [(0, 1)]);
}

#[test]
fn test_rel_empty_check() {
let res = ascent_run_m_par! {
relation edge(i32, i32);
relation path(i32, i32);
relation legit(i32);

path(x, z) <-- edge(x, y), path(y, z), legit(x);
path(x, y) <-- edge(x, y), legit(*&x);

legit(0);
legit(y) <-- legit(x), path(x, y);

edge(x, x + 1) <-- for x in 0..9;
};

println!("{:?}", res.path);
assert_eq!(res.path.len(), 9 * 10 / 2);
}
8 changes: 4 additions & 4 deletions byods/ascent-byods-rels/src/adaptor/bin_rel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0<'a, TBinRel>
Some(res)
}

fn len(&'a self) -> usize { self.0.ind0_len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.ind0_len_estimate() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0<'a, TBinRel> {
Expand Down Expand Up @@ -99,7 +99,7 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd1<'a, TBinRel>
Some(res)
}

fn len(&'a self) -> usize { self.0.ind1_len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.ind1_len_estimate() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd1<'a, TBinRel> {
Expand Down Expand Up @@ -136,7 +136,7 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelInd0_1<'a, TBinRe
if self.0.contains(&key.0, &key.1) { Some(once(())) } else { None }
}

fn len(&'a self) -> usize { self.0.len_estimate() }
fn len_estimate(&'a self) -> usize { self.0.len_estimate() }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelInd0_1<'a, TBinRel> {
Expand Down Expand Up @@ -200,7 +200,7 @@ impl<'a, TBinRel: ByodsBinRel> RelIndexRead<'a> for ByodsBinRelIndNone<'a, TBinR
Some(IteratorFromDyn::new(res))
}

fn len(&'a self) -> usize { 1 }
fn len_estimate(&'a self) -> usize { 1 }
}

impl<'a, TBinRel: ByodsBinRel> RelIndexReadAll<'a> for ByodsBinRelIndNone<'a, TBinRel> {
Expand Down
16 changes: 8 additions & 8 deletions byods/ascent-byods-rels/src/adaptor/bin_rel_to_ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ where
Some(IteratorFromDyn::new(|| trrel.iter_all()))
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 4;
let sum = self.0.map.values().map(|x| x.len_estimate()).sum::<usize>();
sum * self.0.map.len() / sample_size.min(self.0.map.len()).max(1)
Expand Down Expand Up @@ -189,7 +189,7 @@ where
Some(res)
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind0_len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
Expand Down Expand Up @@ -248,7 +248,7 @@ where
Some(res)
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|trrel| trrel.ind1_len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
Expand Down Expand Up @@ -314,7 +314,7 @@ where

fn index_get(&'a self, (x1,): &Self::Key) -> Option<Self::IteratorType> { self.get(x1) }

fn len(&self) -> usize { self.0.reverse_map1.as_ref().unwrap().len() }
fn len_estimate(&self) -> usize { self.0.reverse_map1.as_ref().unwrap().len() }
}

pub struct BinRelToTernaryInd2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -374,7 +374,7 @@ where

fn index_get(&'a self, (x2,): &Self::Key) -> Option<Self::IteratorType> { self.get(x2) }

fn len(&self) -> usize { self.0.reverse_map2.as_ref().unwrap().len() }
fn len_estimate(&self) -> usize { self.0.reverse_map2.as_ref().unwrap().len() }
}

pub struct BinRelToTernaryInd1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -433,7 +433,7 @@ where
Some(IteratorFromDyn::new(res))
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
// TODO random estimate, could be very wrong
self.0.reverse_map1.as_ref().unwrap().len() * self.0.reverse_map2.as_ref().unwrap().len()
/ ((self.0.map.len() as f32).sqrt() as usize)
Expand Down Expand Up @@ -480,7 +480,7 @@ where
Some(IteratorFromDyn::new(res))
}

fn len(&self) -> usize { 1 }
fn len_estimate(&self) -> usize { 1 }
}

pub struct BinRelToTernaryInd0_1_2<'a, T0, T1, T2, TBinRel>(&'a BinRelToTernary<T0, T1, T2, TBinRel>)
Expand Down Expand Up @@ -544,7 +544,7 @@ where
if self.0.map.get(x0)?.contains(x1, x2) { Some(once(())) } else { None }
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let sample_size = 3;
let sum = self.0.map.values().take(sample_size).map(|rel| rel.len_estimate()).sum::<usize>();
let map_len = self.0.map.len();
Expand Down
4 changes: 2 additions & 2 deletions byods/ascent-byods-rels/src/binary_rel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for MapRelIndexAdaptor<'a, T> {
Some(res)
}

fn len(&'a self) -> usize { self.0.len() }
fn len_estimate(&'a self) -> usize { self.0.len() }
}

pub struct RelIndexValTransformer<T, F> {
Expand All @@ -143,7 +143,7 @@ where
Some(res)
}

fn len(&'a self) -> usize { self.rel.len() }
fn len_estimate(&'a self) -> usize { self.rel.len_estimate() }
}

impl<'a, T: 'a, F: 'a, V: 'a, U: 'a> RelIndexReadAll<'a> for RelIndexValTransformer<T, F>
Expand Down
8 changes: 4 additions & 4 deletions byods/ascent-byods-rels/src/ceqrel_ind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0_1<'a, T> {

fn index_get(&'a self, key: &Self::Key) -> Option<Self::IteratorType> { self.0.index_get(key) }

fn len(&self) -> usize { self.0.len() }
fn len_estimate(&self) -> usize { self.0.len_estimate() }
}

impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelInd0_1<'a, T> {
Expand Down Expand Up @@ -199,7 +199,7 @@ impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelInd0<'a, T> {
Some(IteratorFromDyn::new(producer))
}

fn len(&self) -> usize { self.0.unwrap_frozen().combined.elem_ids.len() }
fn len_estimate(&self) -> usize { self.0.unwrap_frozen().combined.elem_ids.len() }
}

impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelInd0<'a, T> {
Expand Down Expand Up @@ -430,7 +430,7 @@ impl<'a, T: Clone + Hash + Eq + 'a> RelIndexRead<'a> for CEqRelIndCommon<T> {
if self_.combined.contains(x, y) && !self_.old.contains(x, y) { Some(std::iter::once(())) } else { None }
}

fn len(&self) -> usize {
fn len_estimate(&self) -> usize {
let self_ = self.unwrap_frozen();
let sample_size = 3;
let sum: usize = self_.combined.sets.iter().take(sample_size).map(|s| s.len().pow(2)).sum();
Expand Down Expand Up @@ -540,7 +540,7 @@ impl<'a, T: Clone + Hash + Eq> RelIndexRead<'a> for EqRelIndNone<'a, T> {
Some(IteratorFromDyn::new(|| self.0.iter_all_added()))
}

fn len(&self) -> usize { 1 }
fn len_estimate(&self) -> usize { 1 }
}

impl<'a, T: Clone + Hash + Eq + Sync> CRelIndexRead<'a> for EqRelIndNone<'a, T> {
Expand Down
Loading

0 comments on commit 4aff71c

Please sign in to comment.