diff --git a/dag_in_context/src/greedy_dag_extractor.rs b/dag_in_context/src/greedy_dag_extractor.rs index c1eb2f45b..98e96789d 100644 --- a/dag_in_context/src/greedy_dag_extractor.rs +++ b/dag_in_context/src/greedy_dag_extractor.rs @@ -3,7 +3,7 @@ use egraph_serialize::{ClassId, NodeId}; use indexmap::*; use ordered_float::NotNan; use rustc_hash::FxHashMap; -use std::collections::{HashMap, HashSet}; +use std::collections::{HashMap, HashSet, VecDeque}; pub fn serialized_egraph( egglog_egraph: egglog::EGraph, @@ -96,7 +96,8 @@ fn get_term(op: &str, cost_sets: &[&CostSet], termdag: &mut TermDag) -> Term { /// Given an operator, eclass, and cost sets for children eclasses, /// calculate the new cost set for this operator. -/// This is done by unioning the child costs sets and summing them up, except for special cases like regions. +/// This is done by unioning the child costs sets and summing them up, +/// except for special cases like regions. fn get_node_cost( op: &str, cid: &ClassId, @@ -105,48 +106,34 @@ fn get_node_cost( cm: &CostModel, termdag: &mut TermDag, ) -> CostSet { - // cost is 0 unless otherwise specified - let op_cost = cm.ops.get(op).copied().unwrap_or(NotNan::new(0.).unwrap()); + let mut total = cm.ops.get(op).copied().unwrap_or(NotNan::new(1.).unwrap()); + let mut costs = HashMap::from([(cid.clone(), total)]); let term = get_term(op, child_cost_sets, termdag); - if cm.ignored.contains(op) { - return CostSet { - total: op_cost, - costs: [(cid.clone(), op_cost)].into(), - term, - }; - } - let mut resulting_set = HashMap::::new(); - let mut resulting_total = NotNan::new(0.).unwrap(); - - let unshared_default = vec![]; + let unshared_default: &[usize] = &[]; let unshared_children = cm.regions.get(op).unwrap_or(&unshared_default); if !cm.ignored.contains(op) { for (i, child_set) in child_cost_sets.iter().enumerate() { if unshared_children.contains(&i) { // don't add to the cost set, but do add to the total - resulting_total += child_set.total; + total += child_set.total; } else { for (child_cid, child_cost) in &child_set.costs { // it was already present in the set - if let Some(existing) = resulting_set.insert(child_cid.clone(), *child_cost) { + if let Some(existing) = costs.insert(child_cid.clone(), *child_cost) { assert_eq!( existing, *child_cost, "Two different costs found for the same child enode!" ); } else { - resulting_total += child_cost; + total += child_cost; } } } } } - CostSet { - total: resulting_total, - costs: resulting_set, - term, - } + CostSet { total, costs, term } } fn calculate_cost_set( @@ -252,7 +239,7 @@ pub struct CostModel { ignored: HashSet<&'static str>, // for each regon nodes, regions[region] is a list of // children that should not be shared. - regions: HashMap<&'static str, Vec>, + regions: HashMap<&'static str, &'static [usize]>, } impl CostModel { @@ -287,18 +274,22 @@ impl CostModel { // Call ("Call", 10.), ]; - let ignored = HashSet::from(["InLoop", "InFunc", "InSwitch", "InIf"]); let ops: HashMap<_, _> = ops .into_iter() .map(|(op, cost)| (op, NotNan::new(cost).unwrap())) .collect(); + + let ignored = HashSet::from(["InLoop", "InFunc", "InSwitch", "InIf"]); + + let do_while_regions: &[usize] = &[1]; // needed for type inference let regions = HashMap::from([ - ("DoWhile", vec![1]), - ("Function", vec![3]), - ("If", vec![2, 3]), + ("DoWhile", do_while_regions), + ("Function", &[3]), + ("If", &[2, 3]), // TODO this doesn't support Switch properly- branches share nodes - ("Switch", vec![2]), + ("Switch", &[2]), ]); + CostModel { ops, ignored, @@ -319,7 +310,7 @@ where T: Eq + std::hash::Hash + Clone, { set: HashSet, - queue: std::collections::VecDeque, + queue: VecDeque, } impl Default for UniqueQueue @@ -329,7 +320,7 @@ where fn default() -> Self { UniqueQueue { set: Default::default(), - queue: std::collections::VecDeque::new(), + queue: Default::default(), } } } diff --git a/dag_in_context/src/schema.egg b/dag_in_context/src/schema.egg index 2e380b1b8..0c23b324a 100644 --- a/dag_in_context/src/schema.egg +++ b/dag_in_context/src/schema.egg @@ -130,10 +130,9 @@ ; ================================= ; Switch on a list of lazily-evaluated branches. -; Does not create a region. ; pred must be an integer -; pred Expr branches chosen -(function Switch (Expr Expr ListExpr) Expr) +; pred inputs branches chosen +(function Switch (Expr Expr ListExpr) Expr) ; If is like switch, but with a boolean predicate ; pred inputs then else (function If (Expr Expr Expr Expr) Expr) diff --git a/dag_in_context/src/utility/in_context.egg b/dag_in_context/src/utility/in_context.egg index 7bbef64f9..5267de4f9 100644 --- a/dag_in_context/src/utility/in_context.egg +++ b/dag_in_context/src/utility/in_context.egg @@ -41,6 +41,8 @@ (Full) ;; Don't make new contexts for sub-regions (Region)) + +;; Add these to the egraph so we can match on them (Full) (Region) diff --git a/dag_in_context/src/utility/subst.egg b/dag_in_context/src/utility/subst.egg index 844093272..411920799 100644 --- a/dag_in_context/src/utility/subst.egg +++ b/dag_in_context/src/utility/subst.egg @@ -8,7 +8,7 @@ (let inf-fuel 1000000) -;; (Subst fuel assumption to in) substitutes to for `(Arg ty)` in `in`. +;; (Subst fuel assumption to in) substitutes `to` for `(Arg ty)` in `in`. ;; It also replaces any contexts found by updating them to `assumption`. ;; `assumption` *justifies* this substitution, as the context that the result is used in. ;; In other words, it must refine the equivalence relation of `in` with `to` as the argument.