From d3516a64e0dde4846b022310e5979198e4cc2aea Mon Sep 17 00:00:00 2001 From: Alex Fischman Date: Sun, 14 Apr 2024 14:37:35 -0700 Subject: [PATCH] Implement cost model --- dag_in_context/src/greedy_dag_extractor.rs | 137 +++++++++++---------- dag_in_context/src/interval_analysis.egg | 3 +- dag_in_context/src/lib.rs | 9 +- dag_in_context/src/type_analysis.egg | 2 +- 4 files changed, 78 insertions(+), 73 deletions(-) diff --git a/dag_in_context/src/greedy_dag_extractor.rs b/dag_in_context/src/greedy_dag_extractor.rs index c1eb2f45b..6b393005a 100644 --- a/dag_in_context/src/greedy_dag_extractor.rs +++ b/dag_in_context/src/greedy_dag_extractor.rs @@ -102,13 +102,13 @@ fn get_node_cost( cid: &ClassId, // non-empty cost sets for children eclasses child_cost_sets: &[&CostSet], - cm: &CostModel, + cm: &impl CostModel, termdag: &mut TermDag, ) -> CostSet { // cost is 0 unless otherwise specified - let op_cost = cm.ops.get(op).copied().unwrap_or(NotNan::new(0.).unwrap()); + let op_cost = cm.get_op_cost(op); let term = get_term(op, child_cost_sets, termdag); - if cm.ignored.contains(op) { + if cm.ignore_children(op) { return CostSet { total: op_cost, costs: [(cid.clone(), op_cost)].into(), @@ -119,9 +119,8 @@ fn get_node_cost( let mut resulting_set = HashMap::::new(); let mut resulting_total = NotNan::new(0.).unwrap(); - let unshared_default = vec![]; - let unshared_children = cm.regions.get(op).unwrap_or(&unshared_default); - if !cm.ignored.contains(op) { + let unshared_children = cm.unshared_children(op); + if !cm.ignore_children(op) { for (i, child_set) in child_cost_sets.iter().enumerate() { if unshared_children.contains(&i) { // don't add to the cost set, but do add to the total @@ -154,7 +153,7 @@ fn calculate_cost_set( node_id: NodeId, costs: &FxHashMap, termdag: &mut TermDag, - cm: &CostModel, + cm: &impl CostModel, ) -> CostSet { let node = &egraph[&node_id]; let cid = egraph.nid_to_cid(&node_id); @@ -205,7 +204,7 @@ pub fn extract( // it only checks unextractable at the function level. unextractables: HashSet, termdag: &mut TermDag, - cm: &CostModel, + cm: impl CostModel, ) -> HashMap { let n2c = |nid: &NodeId| egraph.nid_to_cid(nid); let parents = build_parent_index(egraph); @@ -228,7 +227,7 @@ pub fn extract( prev_cost = lookup.unwrap().total; } - let cost_set = calculate_cost_set(egraph, node_id.clone(), &costs, termdag, cm); + let cost_set = calculate_cost_set(egraph, node_id.clone(), &costs, termdag, &cm); if cost_set.total < prev_cost { costs.insert(class_id.clone(), cost_set); worklist.extend(parents[class_id].iter().cloned()); @@ -246,63 +245,73 @@ pub fn extract( .collect() } -pub struct CostModel { - ops: HashMap<&'static str, Cost>, - // Children of these constructors are ignored - ignored: HashSet<&'static str>, - // for each regon nodes, regions[region] is a list of - // children that should not be shared. - regions: HashMap<&'static str, Vec>, +pub trait CostModel { + // TODO: we could do better with type info + fn get_op_cost(&self, op: &str) -> Cost; + + // if true, the op's children are ignored + fn ignore_children(&self, op: &str) -> bool; + + // returns a slice of indices into the children vec + fn unshared_children(&self, op: &str) -> &'static [usize]; } -impl CostModel { - pub fn simple_cost_model() -> CostModel { - let ops = vec![ - // ========== Leaf operators ========== - // Bop - // TODO: actually we also need type info - // to figure out the cost - ("Add", 1.), - ("Sub", 1.), - ("Mul", 1.), - ("Div", 1.), - ("Eq", 1.), - ("LessThan", 1.), - ("GreaterThan", 1.), - ("LessEq", 1.), - ("GreaterEq", 1.), - ("And", 1.), - ("Or", 1.), - ("PtrAdd", 1.), - ("Print", 1.), - ("Load", 1.), - ("Free", 1.), - // Uop - ("Not", 1.), - // Top - ("Write", 1.), - // ========== Non-leaf operators ========== - ("Alloc", 100.), - // TODO: The cost of Call is more complicated than that. - // Call - ("Call", 10.), - ]; - let ignored = HashSet::from(["InLoop", "InFunc", "InSwitch", "InIf"]); - let ops: HashMap<_, _> = ops - .into_iter() - .map(|(op, cost)| (op, NotNan::new(cost).unwrap())) - .collect(); - let regions = HashMap::from([ - ("DoWhile", vec![1]), - ("Function", vec![3]), - ("If", vec![2, 3]), - // TODO this doesn't support Switch properly- branches share nodes - ("Switch", vec![2]), - ]); - CostModel { - ops, - ignored, - regions, +pub struct DefaultCostModel; + +impl CostModel for DefaultCostModel { + fn get_op_cost(&self, op: &str) -> Cost { + match op { + // Leaves + "Const" => 1., + "Arg" => 0., + _ if op.parse::().is_ok() || op.starts_with('"') => 0., + "true" | "false" | "()" => 0., + // Lists + "Empty" | "Single" | "Concat" | "Get" | "Nil" | "Cons" => 0., + // Types + "IntT" | "BoolT" | "PointerT" | "StateT" => 0., + "Base" | "TupleT" | "TNil" | "TCons" => 0., + "Int" | "Bool" => 0., + // Algebra + "Add" | "PtrAdd" | "Sub" | "And" | "Or" | "Not" => 10., + "Mul" => 30., + "Div" => 50., + // Comparisons + "Eq" | "LessThan" | "GreaterThan" | "LessEq" | "GreaterEq" => 10., + // Effects + "Print" | "Write" | "Load" => 50., + "Alloc" | "Free" => 100., + "Call" => 1000., // TODO: we could make this more accurate + // Control + "Program" | "Function" => 1., + "DoWhile" => 100., // TODO: we could make this more accurate + "If" | "Switch" => 50., + // Unreachable + "HasType" | "HasArgType" | "ContextOf" | "NoContext" | "ExpectType" => 0., + "ExprIsPure" | "ListExprIsPure" | "BinaryOpIsPure" | "UnaryOpIsPure" => 0., + "IsLeaf" | "BodyContainsExpr" | "ScopeContext" => 0., + "Region" | "Full" | "IntI" | "BoolI" => 0., + // Schema + "Bop" | "Uop" | "Top" => 0., + "InContext" => 0., + _ if self.ignore_children(op) => 0., + _ => panic!("no cost for {op}"), + } + .try_into() + .unwrap() + } + + fn ignore_children(&self, op: &str) -> bool { + matches!(op, "InLoop" | "NoContext" | "InSwitch" | "InIf") + } + + fn unshared_children(&self, op: &str) -> &'static [usize] { + match op { + "DoWhile" => &[1], + "Function" => &[3], + "If" => &[2, 3], + "Switch" => &[2], // TODO: Switch branches can share nodes + _ => &[], } } } diff --git a/dag_in_context/src/interval_analysis.egg b/dag_in_context/src/interval_analysis.egg index 8192a70ad..284781b83 100644 --- a/dag_in_context/src/interval_analysis.egg +++ b/dag_in_context/src/interval_analysis.egg @@ -24,7 +24,8 @@ ; Interval Table (function ival (Expr) Interval - :merge (interval-intersect old new)) + :unextractable + :merge (interval-intersect old new)) ; ================================= ; Constants diff --git a/dag_in_context/src/lib.rs b/dag_in_context/src/lib.rs index 4377258f1..19258c052 100644 --- a/dag_in_context/src/lib.rs +++ b/dag_in_context/src/lib.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use egglog::{Term, TermDag}; use from_egglog::FromEgglog; -use greedy_dag_extractor::{extract, serialized_egraph, CostModel}; +use greedy_dag_extractor::{extract, serialized_egraph, DefaultCostModel}; use interpreter::Value; use schema::TreeProgram; use std::fmt::Write; @@ -108,12 +108,7 @@ pub fn optimize(program: &TreeProgram) -> std::result::Result