Skip to content

Commit

Permalink
Inline root_nodes()
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed Apr 14, 2024
1 parent 0bd4e41 commit e940fa7
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use egglog::*;
use egraph_serialize::{ClassId, Node, NodeId};
use egraph_serialize::{ClassId, NodeId};
use indexmap::*;
use ordered_float::NotNan;
use rustc_hash::FxHashMap;
Expand All @@ -10,7 +10,12 @@ pub fn serialized_egraph(
) -> (egraph_serialize::EGraph, HashSet<String>) {
let config = SerializeConfig::default();
let mut egraph = egglog_egraph.serialize(config);
for nid in root_nodes(&egraph.nodes) {
let root_nodes = egraph
.nodes
.iter()
.filter(|(_nid, node)| node.op == "Program")
.map(|(nid, _node)| nid.clone());
for nid in root_nodes {
egraph.root_eclasses.push(egraph.nid_to_cid(&nid).clone());
}
let unextractables = egglog_egraph
Expand All @@ -27,13 +32,6 @@ pub fn serialized_egraph(
(egraph, unextractables)
}

fn root_nodes(nodes: &IndexMap<NodeId, Node>) -> impl Iterator<Item = NodeId> + '_ {
nodes
.iter()
.filter(|(_nid, node)| node.op == "Program")
.map(|(nid, _node)| nid.clone())
}

type Cost = NotNan<f64>;

pub struct CostSet {
Expand Down Expand Up @@ -194,9 +192,8 @@ pub fn extract(

// Find all reachable classes
let mut frontier = UniqueQueue::default();
for nid in root_nodes(&egraph.nodes) {
frontier.insert(n2c(&nid));
}
frontier.extend(&egraph.root_eclasses);

let mut reachable: HashSet<ClassId> = HashSet::new();
while let Some(cid) = frontier.pop() {
for nid in &egraph.classes().get(cid).unwrap().nodes {
Expand Down

0 comments on commit e940fa7

Please sign in to comment.