Skip to content

Commit

Permalink
Move reachability to extract(), add ignored_children
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed Apr 13, 2024
1 parent 2c137f0 commit 2756d25
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 72 deletions.
129 changes: 59 additions & 70 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use egglog::*;
use egraph_serialize::{ClassId, NodeId};
use egraph_serialize::{ClassId, Node, NodeId};
use indexmap::*;
use ordered_float::NotNan;
use rustc_hash::FxHashMap;
Expand All @@ -10,14 +10,9 @@ pub fn serialized_egraph(
) -> (egraph_serialize::EGraph, HashSet<String>) {
let config = SerializeConfig::default();
let mut egraph = egglog_egraph.serialize(config);
let root_nodes: Vec<NodeId> = egraph
.nodes
.iter()
.filter(|(_nid, node)| node.op == "Program")
.map(|(nid, _node)| nid.clone())
.collect();
for nid in &root_nodes {
egraph.root_eclasses.push(egraph.nid_to_cid(nid).clone());
let root_nodes = root_nodes(&egraph.nodes);
for nid in root_nodes {
egraph.root_eclasses.push(egraph.nid_to_cid(&nid).clone());
}
let unextractables: HashSet<String> = egglog_egraph
.functions
Expand All @@ -31,27 +26,16 @@ pub fn serialized_egraph(
})
.collect();

// Find all reachable nodes
let mut frontier = std::collections::VecDeque::from(root_nodes);
let mut reachable = std::collections::HashSet::new();
while let Some(nid) = frontier.pop_front() {
// Try to add the node to the reachable set
if !reachable.insert(nid.clone()) {
// If we already reached the node, we're in a cycle
continue;
}
// Add the children to the frontier
for child in &egraph.nodes.get(&nid).unwrap().children {
frontier.push_back(child.clone());
}
}

// Remove unextractable nodes from the egraph
egraph.nodes.retain(|nid, _node| reachable.contains(nid));

(egraph, unextractables)
}

fn root_nodes(nodes: &IndexMap<NodeId, Node>) -> impl Iterator<Item = NodeId> + '_ {
nodes
.iter()
.filter(|(_nid, node)| node.op == "Program")
.map(|(nid, _node)| nid.clone())
}

type Cost = NotNan<f64>;

pub struct CostSet {
Expand Down Expand Up @@ -127,47 +111,39 @@ fn get_node_cost(
cm: &impl CostModel,
termdag: &mut TermDag,
) -> CostSet {
// cost is 0 unless otherwise specified
let op_cost = cm.get_op_cost(op);
let term = get_term(op, child_cost_sets, termdag);
if cm.is_ignored(op) {
return CostSet {
total: op_cost,
costs: [(cid.clone(), op_cost)].into(),
term,
};
}

let mut resulting_set = HashMap::<ClassId, Cost>::new();
let mut resulting_total = NotNan::new(0.).unwrap();
let mut total = cm.get_op_cost(op);
let mut costs: HashMap<ClassId, Cost> = [(cid.clone(), total)].into();

let ignored_children = cm.ignored_children(op);
let unshared_children = cm.unshared_children(op);
if !cm.is_ignored(op) {
for (i, child_set) in child_cost_sets.iter().enumerate() {
if unshared_children.contains(&i) {
// don't add to the cost set, but do add to the total
resulting_total += child_set.total;
} else {
for (child_cid, child_cost) in &child_set.costs {
// it was already present in the set
if let Some(existing) = resulting_set.insert(child_cid.clone(), *child_cost) {
assert_eq!(
existing, *child_cost,
"Two different costs found for the same child enode!"
);
} else {
resulting_total += child_cost;
}

for (i, child_set) in child_cost_sets.iter().enumerate() {
let is_ignored = ignored_children.contains(&i);
let is_unshared = unshared_children.contains(&i);

if is_ignored {
// don't add to the cost set or the total
assert!(!is_unshared);
} else if is_unshared {
// don't add to the cost set, but do add to the total
total += child_set.total;
} else {
for (child_cid, child_cost) in &child_set.costs {
// it was already present in the set
if let Some(existing) = costs.insert(child_cid.clone(), *child_cost) {
assert_eq!(
existing, *child_cost,
"Two different costs found for the same child enode!"
);
} else {
total += child_cost;
}
}
}
}

CostSet {
total: resulting_total,
costs: resulting_set,
term,
}
let term = get_term(op, child_cost_sets, termdag);
CostSet { total, costs, term }
}

fn calculate_cost_set(
Expand Down Expand Up @@ -220,14 +196,31 @@ fn calculate_cost_set(
}

pub fn extract(
egraph: &egraph_serialize::EGraph,
egraph: &mut egraph_serialize::EGraph,
// TODO: once our egglog program uses `subsume` actions,
// unextractables will be more complex, as right now
// it only checks unextractable at the function level.
unextractables: HashSet<String>,
termdag: &mut TermDag,
cm: impl CostModel,
) -> HashMap<ClassId, CostSet> {
// Find all reachable nodes
let mut frontier: std::collections::VecDeque<NodeId> = root_nodes(&egraph.nodes).collect();
let mut reachable = std::collections::HashSet::new();
while let Some(nid) = frontier.pop_front() {
// Try to add the node to the reachable set
if !reachable.insert(nid.clone()) {
// If we already reached the node, we're in a cycle
continue;
}
// Add the children to the frontier
for child in &egraph.nodes.get(&nid).unwrap().children {
frontier.push_back(child.clone());
}
}
// Remove unreachable nodes from the egraph
egraph.nodes.retain(|nid, _node| reachable.contains(nid));

let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
let parents = build_parent_index(egraph);
let mut worklist = initialize_worklist(egraph);
Expand Down Expand Up @@ -285,11 +278,7 @@ impl CostModel for DefaultCostModel {
let cost = match op {
// Constants
"Const" => 1.,
_ if op.parse::<i64>().is_ok() => 0.,
"Int" | "IntT" => 0.,
"true" | "false" | "Bool" | "BoolT" => 0.,
"PointerT" | "StateT" => 0.,
"Nil" | "Cons" | "TNil" | "TCons" => 0.,
// "Nil" | "Cons" | "TNil" | "TCons" => 0.,
// "()" => 0.,
// "Arg" => 0.,
// "Base" | "TupleT" => 0.,
Expand All @@ -316,9 +305,9 @@ impl CostModel for DefaultCostModel {

fn ignored_children(&self, op: &str) -> &[usize] {
match op {
"Arg" => &[0], // arg type
"Const" => &[0, 1], // constant, arg type
"InContext" => &[0], // assumption
"Arg" => &[0], // arg type
"Const" => &[0, 1], // constant, arg type
"InContext" => &[0], // assumption
"Function" => &[0, 1, 2], // name, input type, output type
_ => &[],
}
Expand Down
9 changes: 7 additions & 2 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ pub fn optimize(program: &TreeProgram) -> std::result::Result<TreeProgram, egglo
let mut egraph = egglog::EGraph::default();
egraph.parse_and_run_program(&program)?;

let (serialized, unextractables) = serialized_egraph(egraph);
let (mut serialized, unextractables) = serialized_egraph(egraph);
let mut termdag = egglog::TermDag::default();
let results = extract(&serialized, unextractables, &mut termdag, DefaultCostModel);
let results = extract(
&mut serialized,
unextractables,
&mut termdag,
DefaultCostModel,
);
assert_eq!(results.len(), 1);
let (_cid, costset) = results.into_iter().next().unwrap();
let mut from_egglog = FromEgglog {
Expand Down

0 comments on commit 2756d25

Please sign in to comment.