Skip to content

Commit

Permalink
Implement cost model
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed Apr 14, 2024
1 parent d5ec258 commit d3516a6
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 73 deletions.
137 changes: 73 additions & 64 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ fn get_node_cost(
cid: &ClassId,
// non-empty cost sets for children eclasses
child_cost_sets: &[&CostSet],
cm: &CostModel,
cm: &impl CostModel,
termdag: &mut TermDag,
) -> CostSet {
// cost is 0 unless otherwise specified
let op_cost = cm.ops.get(op).copied().unwrap_or(NotNan::new(0.).unwrap());
let op_cost = cm.get_op_cost(op);
let term = get_term(op, child_cost_sets, termdag);
if cm.ignored.contains(op) {
if cm.ignore_children(op) {
return CostSet {
total: op_cost,
costs: [(cid.clone(), op_cost)].into(),
Expand All @@ -119,9 +119,8 @@ fn get_node_cost(
let mut resulting_set = HashMap::<ClassId, Cost>::new();
let mut resulting_total = NotNan::new(0.).unwrap();

let unshared_default = vec![];
let unshared_children = cm.regions.get(op).unwrap_or(&unshared_default);
if !cm.ignored.contains(op) {
let unshared_children = cm.unshared_children(op);
if !cm.ignore_children(op) {
for (i, child_set) in child_cost_sets.iter().enumerate() {
if unshared_children.contains(&i) {
// don't add to the cost set, but do add to the total
Expand Down Expand Up @@ -154,7 +153,7 @@ fn calculate_cost_set(
node_id: NodeId,
costs: &FxHashMap<ClassId, CostSet>,
termdag: &mut TermDag,
cm: &CostModel,
cm: &impl CostModel,
) -> CostSet {
let node = &egraph[&node_id];
let cid = egraph.nid_to_cid(&node_id);
Expand Down Expand Up @@ -205,7 +204,7 @@ pub fn extract(
// it only checks unextractable at the function level.
unextractables: HashSet<String>,
termdag: &mut TermDag,
cm: &CostModel,
cm: impl CostModel,
) -> HashMap<ClassId, CostSet> {
let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
let parents = build_parent_index(egraph);
Expand All @@ -228,7 +227,7 @@ pub fn extract(
prev_cost = lookup.unwrap().total;
}

let cost_set = calculate_cost_set(egraph, node_id.clone(), &costs, termdag, cm);
let cost_set = calculate_cost_set(egraph, node_id.clone(), &costs, termdag, &cm);
if cost_set.total < prev_cost {
costs.insert(class_id.clone(), cost_set);
worklist.extend(parents[class_id].iter().cloned());
Expand All @@ -246,63 +245,73 @@ pub fn extract(
.collect()
}

pub struct CostModel {
ops: HashMap<&'static str, Cost>,
// Children of these constructors are ignored
ignored: HashSet<&'static str>,
// for each regon nodes, regions[region] is a list of
// children that should not be shared.
regions: HashMap<&'static str, Vec<usize>>,
pub trait CostModel {
// TODO: we could do better with type info
fn get_op_cost(&self, op: &str) -> Cost;

// if true, the op's children are ignored
fn ignore_children(&self, op: &str) -> bool;

// returns a slice of indices into the children vec
fn unshared_children(&self, op: &str) -> &'static [usize];
}

impl CostModel {
pub fn simple_cost_model() -> CostModel {
let ops = vec![
// ========== Leaf operators ==========
// Bop
// TODO: actually we also need type info
// to figure out the cost
("Add", 1.),
("Sub", 1.),
("Mul", 1.),
("Div", 1.),
("Eq", 1.),
("LessThan", 1.),
("GreaterThan", 1.),
("LessEq", 1.),
("GreaterEq", 1.),
("And", 1.),
("Or", 1.),
("PtrAdd", 1.),
("Print", 1.),
("Load", 1.),
("Free", 1.),
// Uop
("Not", 1.),
// Top
("Write", 1.),
// ========== Non-leaf operators ==========
("Alloc", 100.),
// TODO: The cost of Call is more complicated than that.
// Call
("Call", 10.),
];
let ignored = HashSet::from(["InLoop", "InFunc", "InSwitch", "InIf"]);
let ops: HashMap<_, _> = ops
.into_iter()
.map(|(op, cost)| (op, NotNan::new(cost).unwrap()))
.collect();
let regions = HashMap::from([
("DoWhile", vec![1]),
("Function", vec![3]),
("If", vec![2, 3]),
// TODO this doesn't support Switch properly- branches share nodes
("Switch", vec![2]),
]);
CostModel {
ops,
ignored,
regions,
pub struct DefaultCostModel;

impl CostModel for DefaultCostModel {
fn get_op_cost(&self, op: &str) -> Cost {
match op {
// Leaves
"Const" => 1.,
"Arg" => 0.,
_ if op.parse::<i64>().is_ok() || op.starts_with('"') => 0.,
"true" | "false" | "()" => 0.,
// Lists
"Empty" | "Single" | "Concat" | "Get" | "Nil" | "Cons" => 0.,
// Types
"IntT" | "BoolT" | "PointerT" | "StateT" => 0.,
"Base" | "TupleT" | "TNil" | "TCons" => 0.,
"Int" | "Bool" => 0.,
// Algebra
"Add" | "PtrAdd" | "Sub" | "And" | "Or" | "Not" => 10.,
"Mul" => 30.,
"Div" => 50.,
// Comparisons
"Eq" | "LessThan" | "GreaterThan" | "LessEq" | "GreaterEq" => 10.,
// Effects
"Print" | "Write" | "Load" => 50.,
"Alloc" | "Free" => 100.,
"Call" => 1000., // TODO: we could make this more accurate
// Control
"Program" | "Function" => 1.,
"DoWhile" => 100., // TODO: we could make this more accurate
"If" | "Switch" => 50.,
// Unreachable
"HasType" | "HasArgType" | "ContextOf" | "NoContext" | "ExpectType" => 0.,
"ExprIsPure" | "ListExprIsPure" | "BinaryOpIsPure" | "UnaryOpIsPure" => 0.,
"IsLeaf" | "BodyContainsExpr" | "ScopeContext" => 0.,
"Region" | "Full" | "IntI" | "BoolI" => 0.,
// Schema
"Bop" | "Uop" | "Top" => 0.,
"InContext" => 0.,
_ if self.ignore_children(op) => 0.,
_ => panic!("no cost for {op}"),
}
.try_into()
.unwrap()
}

fn ignore_children(&self, op: &str) -> bool {
matches!(op, "InLoop" | "NoContext" | "InSwitch" | "InIf")
}

fn unshared_children(&self, op: &str) -> &'static [usize] {
match op {
"DoWhile" => &[1],
"Function" => &[3],
"If" => &[2, 3],
"Switch" => &[2], // TODO: Switch branches can share nodes
_ => &[],
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion dag_in_context/src/interval_analysis.egg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

; Interval Table
(function ival (Expr) Interval
:merge (interval-intersect old new))
:unextractable
:merge (interval-intersect old new))

; =================================
; Constants
Expand Down
9 changes: 2 additions & 7 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::HashMap;

use egglog::{Term, TermDag};
use from_egglog::FromEgglog;
use greedy_dag_extractor::{extract, serialized_egraph, CostModel};
use greedy_dag_extractor::{extract, serialized_egraph, DefaultCostModel};
use interpreter::Value;
use schema::TreeProgram;
use std::fmt::Write;
Expand Down Expand Up @@ -108,12 +108,7 @@ pub fn optimize(program: &TreeProgram) -> std::result::Result<TreeProgram, egglo

let (serialized, unextractables) = serialized_egraph(egraph);
let mut termdag = egglog::TermDag::default();
let results = extract(
&serialized,
unextractables,
&mut termdag,
&CostModel::simple_cost_model(),
);
let results = extract(&serialized, unextractables, &mut termdag, DefaultCostModel);
assert_eq!(results.len(), 1);
let (_cid, costset) = results.into_iter().next().unwrap();
let mut from_egglog = FromEgglog {
Expand Down
2 changes: 1 addition & 1 deletion dag_in_context/src/type_analysis.egg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
(TCons hd (TLConcat tl r))
:ruleset type-helpers)

(function TypeList-length (TypeList) i64)
(function TypeList-length (TypeList) i64 :unextractable)
(function TypeList-ith (TypeList i64) BaseType :unextractable)
(function TypeList-suffix (TypeList i64) TypeList :unextractable)

Expand Down

0 comments on commit d3516a6

Please sign in to comment.