Skip to content

Commit

Permalink
Use an incremental rule scheduler for egg
Browse files Browse the repository at this point in the history
  • Loading branch information
mwillsey committed Jun 10, 2024
1 parent d57d21b commit 4062cc5
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 72 deletions.
2 changes: 1 addition & 1 deletion rust/cubesql/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion rust/cubesql/cubesql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ nanoid = "0.3.0"
tokio-util = { version = "0.6.2", features=["compat"] }
comfy-table = "7.1.0"
bitflags = "1.3.2"
egg = { rev = "58c2586473360f0821e91ef196b55070ac1afedc", git = "https://github.com/cube-js/egg.git" }
egg = { rev = "952f8c2a1033e5da097d23c523b0d8e392eb532b", git = "https://github.com/cube-js/egg.git" }
paste = "1.0.6"
csv = "1.1.6"
tracing = { version = "0.1.40", features = ["async-await"] }
Expand Down
21 changes: 20 additions & 1 deletion rust/cubesql/cubesql/src/compile/rewrite/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ use datafusion::{
scalar::ScalarValue,
};
use egg::{Analysis, DidMerge, EGraph, Id};
use std::{fmt::Debug, ops::Index, sync::Arc};
use std::{cmp::Ordering, fmt::Debug, ops::Index, sync::Arc};

pub type MemberNameToExpr = (Option<String>, Member, Expr);

#[derive(Clone, Debug)]
pub struct LogicalPlanData {
pub time: usize,
pub original_expr: Option<OriginalExpr>,
pub member_name_to_expr: Option<Vec<MemberNameToExpr>>,
pub trivial_push_down: Option<usize>,
Expand Down Expand Up @@ -221,6 +222,7 @@ impl Member {

#[derive(Clone)]
pub struct LogicalPlanAnalysis {
pub time: usize,
cube_context: Arc<CubeContext>,
planner: Arc<DefaultPhysicalPlanner>,
}
Expand Down Expand Up @@ -252,6 +254,7 @@ impl<'a> Index<Id> for SingleNodeIndex<'a> {
impl LogicalPlanAnalysis {
pub fn new(cube_context: Arc<CubeContext>, planner: Arc<DefaultPhysicalPlanner>) -> Self {
Self {
time: 0,
cube_context,
planner,
}
Expand Down Expand Up @@ -1221,6 +1224,17 @@ impl LogicalPlanAnalysis {

res
}

fn merge_max_field<T: Ord>(&mut self, a: &mut T, mut b: T) -> DidMerge {
match Ord::cmp(a, &mut b) {
Ordering::Less => {
*a = b;
DidMerge(true, false)
}
Ordering::Equal => DidMerge(false, false),
Ordering::Greater => DidMerge(false, true),
}
}
}

impl Analysis<LogicalPlanLanguage> for LogicalPlanAnalysis {
Expand All @@ -1231,6 +1245,7 @@ impl Analysis<LogicalPlanLanguage> for LogicalPlanAnalysis {
enode: &LogicalPlanLanguage,
) -> Self::Data {
LogicalPlanData {
time: egraph.analysis.time,
original_expr: Self::make_original_expr(egraph, enode),
member_name_to_expr: Self::make_member_name_to_expr(egraph, enode),
trivial_push_down: Self::make_trivial_push_down(egraph, enode),
Expand Down Expand Up @@ -1270,6 +1285,7 @@ impl Analysis<LogicalPlanLanguage> for LogicalPlanAnalysis {
| column_name
| filter_operators
| is_empty_list
| self.merge_max_field(&mut a.time, b.time)
}

fn modify(egraph: &mut EGraph<LogicalPlanLanguage, Self>, id: Id) {
Expand Down Expand Up @@ -1299,6 +1315,9 @@ impl Analysis<LogicalPlanLanguage> for LogicalPlanAnalysis {
)));
let alias_expr = egraph.add(LogicalPlanLanguage::AliasExpr([literal_expr, alias]));
egraph.union(id, alias_expr);
// egraph[id]
// .nodes
// .retain(|n| matches!(n, LogicalPlanLanguage::AliasExpr(_)));
}
}
}
Expand Down
104 changes: 36 additions & 68 deletions rust/cubesql/cubesql/src/compile/rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use std::{
borrow::Cow,
fmt::{self, Display, Formatter},
ops::Index,
slice::Iter,
str::FromStr,
sync::Arc,
};
Expand Down Expand Up @@ -943,14 +942,16 @@ impl Searcher<LogicalPlanLanguage, LogicalPlanAnalysis> for ListNodeSearcher {
(!matches.substs.is_empty()).then(|| matches)
}

fn search_with_limit(
fn search_eclasses_with_limit(
&self,
egraph: &EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
eclasses: &mut dyn Iterator<Item = Id>,
limit: usize,
) -> Vec<SearchMatches<LogicalPlanLanguage>> {
let mut result: Vec<SearchMatches<_>> = vec![];

self.list_pattern
.search_with_fn(egraph, |id, list_subst| {
.search_eclasses_with_fn(egraph, eclasses, |id, list_subst| {
let last = match result.last_mut() {
Some(top) if top.eclass == id => top,
_ => {
Expand Down Expand Up @@ -2056,78 +2057,24 @@ pub fn original_expr_name(
})
}

fn search_match_chained<'a>(
egraph: &EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
cur_match: SearchMatches<'a, LogicalPlanLanguage>,
chain: Iter<(Var, Pattern<LogicalPlanLanguage>)>,
) -> Option<SearchMatches<'a, LogicalPlanLanguage>> {
let mut chain = chain.clone();
let mut matches_to_merge = Vec::new();
if let Some((var, pattern)) = chain.next() {
for subst in cur_match.substs.iter() {
if let Some(id) = subst.get(var.clone()) {
if let Some(next_match) = pattern.search_eclass(egraph, id.clone()) {
let chain_matches = search_match_chained(
egraph,
SearchMatches {
eclass: cur_match.eclass.clone(),
substs: next_match
.substs
.iter()
.map(|next_subst| {
let mut new_subst = subst.clone();
for pattern_var in pattern.vars().into_iter() {
if let Some(pattern_var_value) = next_subst.get(pattern_var)
{
new_subst
.insert(pattern_var, pattern_var_value.clone());
}
}
new_subst
})
.collect::<Vec<_>>(),
// TODO merge
ast: cur_match.ast.clone(),
},
chain.clone(),
);
matches_to_merge.extend(chain_matches);
}
}
}
if !matches_to_merge.is_empty() {
let mut substs = Vec::new();
for m in matches_to_merge {
substs.extend(m.substs.clone());
}
Some(SearchMatches {
eclass: cur_match.eclass.clone(),
substs,
// TODO merge
ast: cur_match.ast.clone(),
})
} else {
None
}
} else {
Some(cur_match)
}
}

pub struct ChainSearcher {
main: Pattern<LogicalPlanLanguage>,
chain: Vec<(Var, Pattern<LogicalPlanLanguage>)>,
}

impl Searcher<LogicalPlanLanguage, LogicalPlanAnalysis> for ChainSearcher {
fn search(
fn search_eclasses_with_limit(
&self,
egraph: &EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
eclasses: &mut dyn Iterator<Item = Id>,
limit: usize,
) -> Vec<SearchMatches<LogicalPlanLanguage>> {
let matches = self.main.search(egraph);
let matches = self
.main
.search_eclasses_with_limit(egraph, eclasses, limit);
let mut result = Vec::new();
for m in matches {
if let Some(m) = self.search_match_chained(egraph, m, self.chain.iter()) {
if let Some(m) = self.search_match_chained(egraph, m) {
result.push(m);
}
}
Expand All @@ -2141,7 +2088,7 @@ impl Searcher<LogicalPlanLanguage, LogicalPlanAnalysis> for ChainSearcher {
limit: usize,
) -> Option<SearchMatches<LogicalPlanLanguage>> {
if let Some(m) = self.main.search_eclass_with_limit(egraph, eclass, limit) {
self.search_match_chained(egraph, m, self.chain.iter())
self.search_match_chained(egraph, m)
} else {
None
}
Expand All @@ -2160,10 +2107,31 @@ impl ChainSearcher {
fn search_match_chained<'a>(
&self,
egraph: &EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
cur_match: SearchMatches<'a, LogicalPlanLanguage>,
chain: Iter<(Var, Pattern<LogicalPlanLanguage>)>,
mut cur_match: SearchMatches<'a, LogicalPlanLanguage>,
) -> Option<SearchMatches<'a, LogicalPlanLanguage>> {
search_match_chained(egraph, cur_match, chain)
let mut new_substs = vec![];
for (var, pattern) in &self.chain {
assert!(new_substs.is_empty());
for subst in &cur_match.substs {
let eclass = subst[*var];
pattern
.search_eclass_with_fn(egraph, eclass, |chain_subst| {
let mut subst = subst.clone();
subst.extend(chain_subst.iter());
new_substs.push(subst);
Ok(())
})
.unwrap_or_default();
}
std::mem::swap(&mut new_substs, &mut cur_match.substs);
new_substs.clear();
}

if cur_match.substs.is_empty() {
None
} else {
Some(cur_match)
}
}
}

Expand Down
45 changes: 44 additions & 1 deletion rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ impl Rewriter {
.map(|v| v.parse::<u64>().unwrap())
.unwrap_or(30),
))
.with_scheduler(egg::SimpleScheduler)
.with_scheduler(IncrementalScheduler::default())
.with_hook(|runner| {
runner.egraph.analysis.time = runner.iterations.len();
Ok(())
})
.with_egraph(egraph)
}

Expand Down Expand Up @@ -555,3 +559,42 @@ impl Rewriter {
pub trait RewriteRules {
fn rewrite_rules(&self) -> Vec<Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>>;
}

struct IncrementalScheduler {
current_iter: usize,
current_eclasses: Vec<Id>,
}

impl Default for IncrementalScheduler {
fn default() -> Self {
Self {
current_iter: usize::MAX, // force an update on the first iteration
current_eclasses: Default::default(),
}
}
}

impl egg::RewriteScheduler<LogicalPlanLanguage, LogicalPlanAnalysis> for IncrementalScheduler {
fn search_rewrite<'a>(
&mut self,
iteration: usize,
egraph: &EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>,
rewrite: &'a Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>,
) -> Vec<egg::SearchMatches<'a, LogicalPlanLanguage>> {
if iteration != self.current_iter {
self.current_iter = iteration;
self.current_eclasses.clear();
self.current_eclasses.extend(
egraph
.classes()
.filter_map(|class| (class.data.time + 1 >= iteration).then(|| class.id)),
);
};
assert_eq!(iteration, self.current_iter);
rewrite.searcher.search_eclasses_with_limit(
egraph,
&mut self.current_eclasses.iter().copied(),
usize::MAX,
)
}
}

0 comments on commit 4062cc5

Please sign in to comment.