Skip to content

Commit

Permalink
Merge pull request #455 from egraphs-good/fix-get-node-cost-bug
Browse files Browse the repository at this point in the history
Fix get node cost bug
  • Loading branch information
oflatt authored Apr 15, 2024
2 parents d5ec258 + 43c49fd commit 3a99f6b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 35 deletions.
53 changes: 22 additions & 31 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use egraph_serialize::{ClassId, NodeId};
use indexmap::*;
use ordered_float::NotNan;
use rustc_hash::FxHashMap;
use std::collections::{HashMap, HashSet};
use std::collections::{HashMap, HashSet, VecDeque};

pub fn serialized_egraph(
egglog_egraph: egglog::EGraph,
Expand Down Expand Up @@ -96,7 +96,8 @@ fn get_term(op: &str, cost_sets: &[&CostSet], termdag: &mut TermDag) -> Term {

/// Given an operator, eclass, and cost sets for children eclasses,
/// calculate the new cost set for this operator.
/// This is done by unioning the child costs sets and summing them up, except for special cases like regions.
/// This is done by unioning the child costs sets and summing them up,
/// except for special cases like regions.
fn get_node_cost(
op: &str,
cid: &ClassId,
Expand All @@ -105,48 +106,34 @@ fn get_node_cost(
cm: &CostModel,
termdag: &mut TermDag,
) -> CostSet {
// cost is 0 unless otherwise specified
let op_cost = cm.ops.get(op).copied().unwrap_or(NotNan::new(0.).unwrap());
let mut total = cm.ops.get(op).copied().unwrap_or(NotNan::new(1.).unwrap());
let mut costs = HashMap::from([(cid.clone(), total)]);
let term = get_term(op, child_cost_sets, termdag);
if cm.ignored.contains(op) {
return CostSet {
total: op_cost,
costs: [(cid.clone(), op_cost)].into(),
term,
};
}

let mut resulting_set = HashMap::<ClassId, Cost>::new();
let mut resulting_total = NotNan::new(0.).unwrap();

let unshared_default = vec![];
let unshared_default: &[usize] = &[];
let unshared_children = cm.regions.get(op).unwrap_or(&unshared_default);
if !cm.ignored.contains(op) {
for (i, child_set) in child_cost_sets.iter().enumerate() {
if unshared_children.contains(&i) {
// don't add to the cost set, but do add to the total
resulting_total += child_set.total;
total += child_set.total;
} else {
for (child_cid, child_cost) in &child_set.costs {
// it was already present in the set
if let Some(existing) = resulting_set.insert(child_cid.clone(), *child_cost) {
if let Some(existing) = costs.insert(child_cid.clone(), *child_cost) {
assert_eq!(
existing, *child_cost,
"Two different costs found for the same child enode!"
);
} else {
resulting_total += child_cost;
total += child_cost;
}
}
}
}
}

CostSet {
total: resulting_total,
costs: resulting_set,
term,
}
CostSet { total, costs, term }
}

fn calculate_cost_set(
Expand Down Expand Up @@ -252,7 +239,7 @@ pub struct CostModel {
ignored: HashSet<&'static str>,
// for each regon nodes, regions[region] is a list of
// children that should not be shared.
regions: HashMap<&'static str, Vec<usize>>,
regions: HashMap<&'static str, &'static [usize]>,
}

impl CostModel {
Expand Down Expand Up @@ -287,18 +274,22 @@ impl CostModel {
// Call
("Call", 10.),
];
let ignored = HashSet::from(["InLoop", "InFunc", "InSwitch", "InIf"]);
let ops: HashMap<_, _> = ops
.into_iter()
.map(|(op, cost)| (op, NotNan::new(cost).unwrap()))
.collect();

let ignored = HashSet::from(["InLoop", "InFunc", "InSwitch", "InIf"]);

let do_while_regions: &[usize] = &[1]; // needed for type inference
let regions = HashMap::from([
("DoWhile", vec![1]),
("Function", vec![3]),
("If", vec![2, 3]),
("DoWhile", do_while_regions),
("Function", &[3]),
("If", &[2, 3]),
// TODO this doesn't support Switch properly- branches share nodes
("Switch", vec![2]),
("Switch", &[2]),
]);

CostModel {
ops,
ignored,
Expand All @@ -319,7 +310,7 @@ where
T: Eq + std::hash::Hash + Clone,
{
set: HashSet<T>,
queue: std::collections::VecDeque<T>,
queue: VecDeque<T>,
}

impl<T> Default for UniqueQueue<T>
Expand All @@ -329,7 +320,7 @@ where
fn default() -> Self {
UniqueQueue {
set: Default::default(),
queue: std::collections::VecDeque::new(),
queue: Default::default(),
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions dag_in_context/src/schema.egg
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,9 @@
; =================================

; Switch on a list of lazily-evaluated branches.
; Does not create a region.
; pred must be an integer
; pred Expr branches chosen
(function Switch (Expr Expr ListExpr) Expr)
; pred inputs branches chosen
(function Switch (Expr Expr ListExpr) Expr)
; If is like switch, but with a boolean predicate
; pred inputs then else
(function If (Expr Expr Expr Expr) Expr)
Expand Down
2 changes: 2 additions & 0 deletions dag_in_context/src/utility/in_context.egg
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
(Full)
;; Don't make new contexts for sub-regions
(Region))

;; Add these to the egraph so we can match on them
(Full)
(Region)

Expand Down
2 changes: 1 addition & 1 deletion dag_in_context/src/utility/subst.egg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

(let inf-fuel 1000000)

;; (Subst fuel assumption to in) substitutes to for `(Arg ty)` in `in`.
;; (Subst fuel assumption to in) substitutes `to` for `(Arg ty)` in `in`.
;; It also replaces any contexts found by updating them to `assumption`.
;; `assumption` *justifies* this substitution, as the context that the result is used in.
;; In other words, it must refine the equivalence relation of `in` with `to` as the argument.
Expand Down

0 comments on commit 3a99f6b

Please sign in to comment.