Skip to content

Commit

Permalink
Make reachable go over e-classes instead of e-nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed Apr 14, 2024
1 parent 9bac92e commit 0bd4e41
Showing 1 changed file with 38 additions and 28 deletions.
66 changes: 38 additions & 28 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,39 +189,49 @@ pub fn extract(
termdag: &mut TermDag,
cm: impl CostModel,
) -> HashMap<ClassId, CostSet> {
// Find all reachable nodes
let mut frontier: VecDeque<NodeId> = root_nodes(&egraph.nodes).collect();
let mut reachable = HashSet::new();
while let Some(nid) = frontier.pop_front() {
let node = egraph.nodes.get(&nid).unwrap();
// Don't add unextractable nodes to the reachable set
if unextractables.contains(&node.op) {
continue;
}
// Try to add the node to the reachable set
if !reachable.insert(nid.clone()) {
// If we already reached the node, we're in a cycle
continue;
}
// Add the non-ignored children to the frontier
let ignored_children = cm.ignored_children(&node.op);
for (i, child) in node.children.iter().enumerate() {
if !ignored_children.contains(&i) {
frontier.push_back(child.clone());
let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
let parents = build_parent_index(egraph);

// Find all reachable classes
let mut frontier = UniqueQueue::default();
for nid in root_nodes(&egraph.nodes) {
frontier.insert(n2c(&nid));
}
let mut reachable: HashSet<ClassId> = HashSet::new();
while let Some(cid) = frontier.pop() {
for nid in &egraph.classes().get(cid).unwrap().nodes {
let node = egraph.nodes.get(nid).unwrap();

// Don't add unextractable nodes to the reachable set
if unextractables.contains(&node.op) {
continue;
}

// Try to add the node to the reachable set
if !reachable.insert(cid.clone()) {
// If we already reached the node, we're in a cycle
continue;
}

// Add the non-ignored children to the frontier
let ignored_children = cm.ignored_children(&node.op);
for (i, child) in node.children.iter().enumerate() {
if !ignored_children.contains(&i) {
frontier.insert(n2c(child));
}
}
}
}

let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);
let parents = build_parent_index(egraph);

// start the analysis from reachable nodes with no non-ignored children
let mut worklist = UniqueQueue::default();
for node_id in &reachable {
let node = &egraph[node_id];
for class_id in &reachable {
for node_id in &egraph.classes().get(class_id).unwrap().nodes {
let node = &egraph[node_id];

if cm.ignored_children(&node.op).len() == node.children.len() {
worklist.insert(node_id.clone());
if cm.ignored_children(&node.op).len() == node.children.len() {
worklist.insert(node_id.clone());
}
}
}

Expand All @@ -234,7 +244,7 @@ pub fn extract(
let class_id = n2c(&node_id);
let node = &egraph[&node_id];

assert!(reachable.contains(&node_id));
assert!(reachable.contains(class_id));
assert!(!unextractables.contains(&node.op));

let ignored_children = cm.ignored_children(&node.op);
Expand All @@ -260,7 +270,7 @@ pub fn extract(
worklist.extend(
parents[class_id]
.iter()
.filter(|nid| reachable.contains(nid))
.filter(|nid| reachable.contains(n2c(nid)))
.cloned(),
);
}
Expand Down

0 comments on commit 0bd4e41

Please sign in to comment.