Skip to content

Commit

Permalink
Add unextractable to reachability analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed Apr 14, 2024
1 parent e704f43 commit 9bac92e
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,13 +193,17 @@ pub fn extract(
let mut frontier: VecDeque<NodeId> = root_nodes(&egraph.nodes).collect();
let mut reachable = HashSet::new();
while let Some(nid) = frontier.pop_front() {
let node = egraph.nodes.get(&nid).unwrap();
// Don't add unextractable nodes to the reachable set
if unextractables.contains(&node.op) {
continue;
}
// Try to add the node to the reachable set
if !reachable.insert(nid.clone()) {
// If we already reached the node, we're in a cycle
continue;
}
// Add the non-ignored children to the frontier
let node = egraph.nodes.get(&nid).unwrap();
let ignored_children = cm.ignored_children(&node.op);
for (i, child) in node.children.iter().enumerate() {
if !ignored_children.contains(&i) {
Expand Down Expand Up @@ -230,9 +234,8 @@ pub fn extract(
let class_id = n2c(&node_id);
let node = &egraph[&node_id];

if unextractables.contains(&node.op) || !reachable.contains(&node_id) {
continue;
}
assert!(reachable.contains(&node_id));
assert!(!unextractables.contains(&node.op));

let ignored_children = cm.ignored_children(&node.op);
let all_non_ignored_children_have_costs = node
Expand All @@ -254,7 +257,12 @@ pub fn extract(
costs.insert(class_id.clone(), cost_set);
}

worklist.extend(parents[class_id].iter().cloned());
worklist.extend(
parents[class_id]
.iter()
.filter(|nid| reachable.contains(nid))
.cloned(),
);
}
}

Expand Down

0 comments on commit 9bac92e

Please sign in to comment.