diff --git a/src/common/scan-info/src/test/mod.rs b/src/common/scan-info/src/test/mod.rs index 0da27600e9..a9d248c5d6 100644 --- a/src/common/scan-info/src/test/mod.rs +++ b/src/common/scan-info/src/test/mod.rs @@ -17,12 +17,14 @@ use crate::{PartitionField, Pushdowns, ScanOperator, ScanTaskLike, ScanTaskLikeR struct DummyScanTask { pub schema: SchemaRef, pub pushdowns: Pushdowns, + pub in_memory_size: Option, } #[derive(Debug)] pub struct DummyScanOperator { pub schema: SchemaRef, pub num_scan_tasks: u32, + pub in_memory_size_per_task: Option, } #[typetag::serde] @@ -67,7 +69,7 @@ impl ScanTaskLike for DummyScanTask { } fn estimate_in_memory_size_bytes(&self, _: Option<&DaftExecutionConfig>) -> Option { - None + self.in_memory_size } fn file_format_config(&self) -> Arc { @@ -136,6 +138,7 @@ impl ScanOperator for DummyScanOperator { let scan_task = Arc::new(DummyScanTask { schema: self.schema.clone(), pushdowns, + in_memory_size: self.in_memory_size_per_task, }); Ok((0..self.num_scan_tasks) diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs new file mode 100644 index 0000000000..0f5a12592d --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/greedy_join_order.rs @@ -0,0 +1,153 @@ +use std::{collections::HashMap, sync::Arc}; + +use common_error::DaftResult; +use daft_dsl::{col, ExprRef}; + +use super::join_graph::{JoinCondition, JoinGraph}; +use crate::{LogicalPlanBuilder, LogicalPlanRef}; + +// This is an implementation of the Greedy Operator Ordering algorithm (GOO) [1] for join selection. This algorithm +// selects join edges greedily by picking the edge with the smallest cost at each step. This is similar to Kruskal's +// minimum spanning tree algorithm, with the caveat that edge costs update at each step, due to changing cardinalities +// and selectivities between join nodes. +// +// Compared to DP-based algorithms, GOO is not always optimal. However, GOO has a complexity of O(n^3) and is more viable +// than DP-based algorithms when performing join ordering on many relations. DP Connected subgraph Complement Pairs (DPccp) [2] +// is the DP-based algorithm widely used in database systems today and has a O(3^n) complexity, although the latest +// literature does offer a super-polynomially faster DP-algorithm but that still has a O(2^n) to O(2^n * n^3) complexity [3]. +// +// For this reason, we maintain a greedy-based join ordering algorithm to use when the number of relations is large, and default +// to DP-based algorithms otherwise. +// +// [1]: Fegaras, L. (1998). A New Heuristic for Optimizing Large Queries. International Conference on Database and Expert Systems Applications. +// [2]: Moerkotte, G., & Neumann, T. (2006). Analysis of two existing and one new dynamic programming algorithm for the generation of optimal bushy join trees without cross products. Very Large Data Bases Conference. +// [3]: Stoian, M., & Kipf, A. (2024). DPconv: Super-Polynomially Faster Join Ordering. ArXiv, abs/2409.08013. +pub(crate) struct GreedyJoinOrderer {} + +impl GreedyJoinOrderer { + /// Consumes the join graph and transforms it into a logical plan with joins reordered. + pub(crate) fn compute_join_order(join_graph: &mut JoinGraph) -> DaftResult { + // While the join graph consists of more than one join node, select the edge that has the smallest cost, + // then join the left and right nodes connected by this edge. + while join_graph.adj_list.0.len() > 1 { + let selected_pair = GreedyJoinOrderer::find_minimum_cost_join(&join_graph.adj_list.0); + if let Some((left, right, join_conds)) = selected_pair { + // Join the left and right relations using the given join conditions. + let (left_on, right_on) = join_conds + .iter() + .map(|join_cond| { + ( + col(join_cond.left_on.clone()), + col(join_cond.right_on.clone()), + ) + }) + .collect::<(Vec, Vec)>(); + let left_builder = LogicalPlanBuilder::from(left.clone()); + let join = left_builder + .inner_join(right.clone(), left_on, right_on)? + .build(); + let join = Arc::new(Arc::unwrap_or_clone(join).with_materialized_stats()); + + // Add the new node into the adjacency list. + let left_neighbors = join_graph.adj_list.0.remove(&left).unwrap(); + let right_neighbors = join_graph.adj_list.0.remove(&right).unwrap(); + let mut new_join_edges = HashMap::new(); + + // Helper function that takes in neighbors to the left and right nodes, then combines edges that point + // back to the left and/or right nodes into edges that point to the new join node. + let mut update_neighbors = + |neighbors: HashMap>| { + for (neighbor, _) in neighbors { + if neighbor == right || neighbor == left { + // Skip the nodes that we just joined. + continue; + } + let mut join_conditions = Vec::new(); + // If this neighbor was connected to left or right nodes, collect the join conditions. + let neighbor_edges = join_graph + .adj_list + .0 + .get_mut(&neighbor) + .expect("The neighbor should still be in the join graph"); + if let Some(left_conds) = neighbor_edges.remove(&left) { + join_conditions.extend(left_conds); + } + if let Some(right_conds) = neighbor_edges.remove(&right) { + join_conditions.extend(right_conds); + } + // If this neighbor had any connections to left or right, create a new edge to the new join node. + if !join_conditions.is_empty() { + neighbor_edges.insert(join.clone(), join_conditions.clone()); + new_join_edges.insert( + neighbor.clone(), + join_conditions.iter().map(|cond| cond.flip()).collect(), + ); + } + } + }; + + // Process all neighbors from both the left and right sides. + update_neighbors(left_neighbors); + update_neighbors(right_neighbors); + + // Add the new join node and its edges to the graph. + join_graph.adj_list.0.insert(join, new_join_edges); + } else { + panic!( + "No valid join edge selected despite join graph containing more than one relation" + ); + } + } + // Apply projections and filters on top of the fully joined plan. + if let Some(joined_plan) = join_graph.adj_list.0.drain().map(|(plan, _)| plan).last() { + join_graph.apply_projections_and_filters_to_plan(joined_plan) + } else { + panic!("No valid logical plan after join reordering") + } + } + + /// Helper functions that finds the next join edge in the adjacency list that has the smallest cost. + /// Currently cost is determined based on the max size in bytes of the candidate left and right relations. + fn find_minimum_cost_join( + adj_list: &HashMap>>, + ) -> Option<(LogicalPlanRef, LogicalPlanRef, Vec)> { + let mut min_cost = None; + let mut selected_pair = None; + + for (candidate_left, neighbors) in adj_list { + for (candidate_right, join_conds) in neighbors { + let left_stats = candidate_left.materialized_stats(); + let right_stats = candidate_right.materialized_stats(); + + // Assume primary key foreign key join which would have a size bounded by the foreign key relation, + // which is typically larger. + let cur_cost = left_stats + .approx_stats + .upper_bound_bytes + .max(right_stats.approx_stats.upper_bound_bytes); + + if let Some(existing_min) = min_cost { + if let Some(current) = cur_cost { + if current < existing_min { + min_cost = Some(current); + selected_pair = Some(( + candidate_left.clone(), + candidate_right.clone(), + join_conds.clone(), + )); + } + } + } else { + min_cost = cur_cost; + selected_pair = Some(( + candidate_left.clone(), + candidate_right.clone(), + join_conds.clone(), + )); + } + } + } + + selected_pair + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs index f004fe0b3d..e7c09f3174 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/join_graph.rs @@ -4,12 +4,13 @@ use std::{ sync::Arc, }; +use common_error::DaftResult; use daft_core::join::JoinType; use daft_dsl::{col, optimization::replace_columns_with_expressions, ExprRef}; use crate::{ ops::{Filter, Join, Project}, - LogicalPlan, LogicalPlanRef, + LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, }; #[derive(Debug)] @@ -19,6 +20,10 @@ struct JoinNode { final_name: String, } +// TODO(desmond): We should also take into account user provided values for: +// - null equals null +// - join strategy + /// JoinNodes represent a relation (i.e. a non-reorderable logical plan node), the column /// that's being accessed from the relation, and the final name of the column in the output. impl JoinNode { @@ -46,36 +51,76 @@ impl Display for JoinNode { } } -/// JoinEdges currently represent a bidirectional edge between two relations that have -/// an equi-join condition between each other. -#[derive(Debug)] -struct JoinEdge(JoinNode, JoinNode); +#[derive(Clone, Debug)] +pub(crate) struct JoinCondition { + pub left_on: String, + pub right_on: String, +} -impl JoinEdge { - /// Helper function that summarizes join edge information. - fn simple_repr(&self) -> String { - format!("{} <-> {}", self.0, self.1) +impl JoinCondition { + pub(crate) fn flip(&self) -> Self { + JoinCondition { + left_on: self.right_on.clone(), + right_on: self.left_on.clone(), + } } } -impl Display for JoinEdge { +pub(crate) struct JoinAdjList( + pub HashMap>>, +); + +impl std::fmt::Display for JoinAdjList { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.simple_repr()) + writeln!(f, "Join Graph Adjacency List:")?; + for (node, neighbors) in &self.0 { + writeln!(f, "Node {}:", node.name())?; + for (neighbor, join_conds) in neighbors { + writeln!(f, " -> {} with conditions:", neighbor.name())?; + for (i, cond) in join_conds.iter().enumerate() { + writeln!(f, " {}: {} = {}", i, cond.left_on, cond.right_on)?; + } + } + } + Ok(()) + } +} + +impl JoinAdjList { + fn add_unidirectional_edge(&mut self, left: &JoinNode, right: &JoinNode) { + // TODO(desmond): We should also keep track of projections that we need to do. + let join_condition = JoinCondition { + left_on: left.final_name.clone(), + right_on: right.final_name.clone(), + }; + if let Some(neighbors) = self.0.get_mut(&left.plan) { + if let Some(join_conditions) = neighbors.get_mut(&right.plan) { + join_conditions.push(join_condition); + } else { + neighbors.insert(right.plan.clone(), vec![join_condition]); + } + } else { + let mut neighbors = HashMap::new(); + neighbors.insert(right.plan.clone(), vec![join_condition]); + self.0.insert(left.plan.clone(), neighbors); + } + } + fn add_bidirectional_edge(&mut self, node1: JoinNode, node2: JoinNode) { + self.add_unidirectional_edge(&node1, &node2); + self.add_unidirectional_edge(&node2, &node1); } } #[derive(Debug)] -enum ProjectionOrFilter { +pub(crate) enum ProjectionOrFilter { Projection(Vec), Filter(ExprRef), } /// Representation of a logical plan as edges between relations, along with additional information needed to /// reconstruct a logcial plan that's equivalent to the plan that produced this graph. -struct JoinGraph { - // TODO(desmond): Instead of simply storing edges, we might want to maintain adjacency lists between - // relations. We can make this decision later when we implement join order selection. - edges: Vec, +pub(crate) struct JoinGraph { + pub adj_list: JoinAdjList, // List of projections and filters that should be applied after join reordering. This list respects // pre-order traversal of projections and filters in the query tree, so we should apply these operators // starting from the back of the list. @@ -84,47 +129,84 @@ struct JoinGraph { impl JoinGraph { pub(crate) fn new( - edges: Vec, + adj_list: JoinAdjList, final_projections_and_filters: Vec, ) -> Self { Self { - edges, + adj_list, final_projections_and_filters, } } + pub(crate) fn apply_projections_and_filters_to_plan( + &mut self, + plan: LogicalPlanRef, + ) -> DaftResult { + let mut plan = LogicalPlanBuilder::from(plan); + // Apply projections and filters in post-traversal order. + let mut reversed_items = self + .final_projections_and_filters + .drain(..) + .rev() + .peekable(); + while let Some(projection_or_filter) = reversed_items.next() { + let is_last = reversed_items.peek().is_none(); + + match projection_or_filter { + ProjectionOrFilter::Projection(projections) => { + if is_last { + // The final projection is the output projection, so here we select the final projection. + plan = plan.select(projections)?; + } else { + // Intermediate projections might only transform a subset of columns, so we use `with_columns()` instead of `select()`. + plan = plan.with_columns(projections)?; + } + } + ProjectionOrFilter::Filter(predicate) => { + plan = plan.filter(predicate)?; + } + } + } + Ok(plan.build()) + } + /// Test helper function to get the number of edges that the current graph contains. pub(crate) fn num_edges(&self) -> usize { - self.edges.len() + let mut num_edges = 0; + for (_, edges) in &self.adj_list.0 { + num_edges += edges.len(); + } + // Each edge is bidirectional, so we divide by 2 to get the correct number of edges. + num_edges / 2 } /// Test helper function to check that all relations in this graph are connected. pub(crate) fn fully_connected(&self) -> bool { - // Assuming that we're not testing an empty graph, there should be at least one edge in a connected graph. - if self.edges.is_empty() { - return false; - } - let mut adj_list: HashMap<*const _, Vec<*const _>> = HashMap::new(); - for edge in &self.edges { - let l_ptr = Arc::as_ptr(&edge.0.plan); - let r_ptr = Arc::as_ptr(&edge.1.plan); - - adj_list.entry(l_ptr).or_default().push(r_ptr); - adj_list.entry(r_ptr).or_default().push(l_ptr); - } - let start_ptr = Arc::as_ptr(&self.edges[0].0.plan); + let start = if let Some((node, _)) = self.adj_list.0.iter().next() { + node + } else { + // There are no nodes. The empty graph is fully connected. + return true; + }; + // let start_ptr = Arc::as_ptr(&self.edges[0].0.plan); let mut seen = HashSet::new(); - let mut stack = vec![start_ptr]; + let mut stack = vec![start]; while let Some(current) = stack.pop() { if seen.insert(current) { // If this is a new node, add all its neighbors to the stack. - if let Some(neighbors) = adj_list.get(¤t) { - stack.extend(neighbors.iter().filter(|&&n| !seen.contains(&n))); + if let Some(neighbors) = self.adj_list.0.get(current) { + stack.extend(neighbors.iter().filter_map(|(neighbor, _)| { + if !seen.contains(neighbor) { + Some(neighbor) + } else { + None + } + })); } } } - seen.len() == adj_list.len() + seen.len() == self.adj_list.0.len() } /// Test helper function that checks if the graph contains the given projection/filter expressions @@ -152,13 +234,27 @@ impl JoinGraph { /// Helper function that loosely checks if a given edge (represented by a simple string) /// exists in the current graph. - pub(crate) fn contains_edge(&self, edge_string: &str) -> bool { - for edge in &self.edges { - if edge.simple_repr() == edge_string { - return true; + pub(crate) fn contains_edges(&self, to_check: Vec<&str>) -> bool { + let mut edge_strings = HashSet::new(); + for (left, neighbors) in &self.adj_list.0 { + for (right, join_conds) in neighbors { + for join_cond in join_conds { + edge_strings.insert(format!( + "{}({}) <-> {}({})", + left.name(), + join_cond.left_on, + right.name(), + join_cond.right_on + )); + } } } - false + for cur_check in to_check { + if !edge_strings.contains(cur_check) { + return false; + } + } + true } } @@ -167,14 +263,14 @@ struct JoinGraphBuilder { plan: LogicalPlanRef, join_conds_to_resolve: Vec<(String, LogicalPlanRef, bool)>, final_name_map: HashMap, - edges: Vec, + adj_list: JoinAdjList, final_projections_and_filters: Vec, } impl JoinGraphBuilder { pub(crate) fn build(mut self) -> JoinGraph { self.process_node(&self.plan.clone()); - JoinGraph::new(self.edges, self.final_projections_and_filters) + JoinGraph::new(self.adj_list, self.final_projections_and_filters) } pub(crate) fn from_logical_plan(plan: LogicalPlanRef) -> Self { @@ -192,7 +288,7 @@ impl JoinGraphBuilder { plan, join_conds_to_resolve: vec![], final_name_map: HashMap::new(), - edges: vec![], + adj_list: JoinAdjList(HashMap::new()), final_projections_and_filters: vec![ProjectionOrFilter::Projection(output_projection)], } } @@ -328,7 +424,7 @@ impl JoinGraphBuilder { rnode.clone(), self.final_name_map.get(&rname).unwrap().name().to_string(), ); - self.edges.push(JoinEdge(node1, node2)); + self.adj_list.add_bidirectional_edge(node1, node2); } else { panic!("Join conditions were unresolved"); } @@ -337,13 +433,43 @@ impl JoinGraphBuilder { // TODO(desmond): There are potentially more reorderable nodes. For example, we can move repartitions around. _ => { // This is an unreorderable node. All unresolved columns coming out of this node should be marked as resolved. + // TODO(desmond): At this point we should perform a fresh join reorder optimization starting from this + // node as the root node. We can do this once we add the optimizer rule. + let mut projections = vec![]; + let mut needs_projection = false; + let mut seen_names = HashSet::new(); for (name, _, done) in self.join_conds_to_resolve.iter_mut() { - if schema.has_field(name) { + if schema.has_field(name) && !*done && !seen_names.contains(name) { + if let Some(final_name) = self.final_name_map.get(name) { + let final_name = final_name.name().to_string(); + if final_name != *name { + needs_projection = true; + projections.push(col(name.clone()).alias(final_name)); + } else { + projections.push(col(name.clone())); + } + } else { + projections.push(col(name.clone())); + } + seen_names.insert(name); + } + } + // Apply projections and return the new plan as the relation for the appropriate join conditions. + let projected_plan = if needs_projection { + let projected_plan = LogicalPlanBuilder::from(plan.clone()) + .select(projections) + .expect("Computed projections could not be applied to relation") + .build(); + Arc::new(Arc::unwrap_or_clone(projected_plan).with_materialized_stats()) + } else { + plan.clone() + }; + for (name, node, done) in self.join_conds_to_resolve.iter_mut() { + if schema.has_field(name) && !*done { *done = true; + *node = projected_plan.clone(); } } - // TODO(desmond): At this point we should perform a fresh join reorder optimization starting from this - // node as the root node. We can do this once we add the optimizer rule. } } } @@ -354,12 +480,21 @@ mod tests { use std::sync::Arc; use common_scan_info::Pushdowns; + use common_treenode::TransformedResult; use daft_core::prelude::CountMode; use daft_dsl::{col, AggExpr, Expr, LiteralValue}; use daft_schema::{dtype::DataType, field::Field}; use super::JoinGraphBuilder; - use crate::test::{dummy_scan_node_with_pushdowns, dummy_scan_operator}; + use crate::{ + optimization::rules::{ + reorder_joins::greedy_join_order::GreedyJoinOrderer, EnrichWithStats, MaterializeScans, + OptimizerRule, + }, + test::{ + dummy_scan_node_with_pushdowns, dummy_scan_operator, dummy_scan_operator_with_size, + }, + }; #[test] fn test_create_join_graph_basic_1() { @@ -372,21 +507,21 @@ mod tests { // | // Scan(c_prime) let scan_a = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("a", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("a", DataType::Int64)], Some(100)), Pushdowns::default(), ); let scan_b = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("b", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("b", DataType::Int64)], Some(10_000)), Pushdowns::default(), ); let scan_c = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("c_prime", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("c_prime", DataType::Int64)], Some(100)), Pushdowns::default(), ) .select(vec![col("c_prime").alias("c")]) .unwrap(); let scan_d = dummy_scan_node_with_pushdowns( - dummy_scan_operator(vec![Field::new("d", DataType::Int64)]), + dummy_scan_operator_with_size(vec![Field::new("d", DataType::Int64)], Some(100)), Pushdowns::default(), ); let join_plan_l = scan_a @@ -410,17 +545,29 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)" + ])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -472,17 +619,29 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - b <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("b#Source(b) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(b) <-> Source(d)", + ])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -528,15 +687,27 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("c")))], ) .unwrap(); - let plan = join_plan_2.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan_2.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: - // - a <-> b - // - a <-> c + // - a_beta <-> b + // - a_beta <-> c assert!(join_graph.num_edges() == 2); - assert!(join_graph.contains_edge("a_beta#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("a_beta#Source(a) <-> c#Source(c)")); + assert!(join_graph.contains_edges(vec![ + "Project(a_beta) <-> Source(b)", + "Project(a_beta) <-> Source(c)", + ])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -589,21 +760,33 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)" + ])); // Check for non-join projections at the end. // `c_prime` gets renamed to `c` in the final projection let double_proj = col("c").add(col("c")).alias("double"); assert!(join_graph.contains_projections_and_filters(vec![&double_proj])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -674,17 +857,26 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Source(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Source(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Source(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Source(a) <-> Source(d)", + ])); // Check for non-join projections and filters at the end. // `c_prime` gets renamed to `c` in the final projection let double_proj = col("c").add(col("c")).alias("double"); @@ -695,6 +887,9 @@ mod tests { &double_proj, &filter_c_prime, ])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } #[test] @@ -760,18 +955,30 @@ mod tests { vec![Arc::new(Expr::Column(Arc::from("d")))], ) .unwrap(); - let plan = join_plan.build(); - let join_graph = JoinGraphBuilder::from_logical_plan(plan).build(); + let original_plan = join_plan.build(); + let scan_materializer = MaterializeScans::new(); + let original_plan = scan_materializer + .try_optimize(original_plan) + .data() + .unwrap(); + let stats_enricher = EnrichWithStats::new(); + let original_plan = stats_enricher.try_optimize(original_plan).data().unwrap(); + let mut join_graph = JoinGraphBuilder::from_logical_plan(original_plan.clone()).build(); assert!(join_graph.fully_connected()); // There should be edges between: // - a <-> b - // - c_prime <-> d + // - c <-> d // - a <-> d assert!(join_graph.num_edges() == 3); - assert!(join_graph.contains_edge("a#Aggregate(a) <-> b#Source(b)")); - assert!(join_graph.contains_edge("c#Source(c_prime) <-> d#Source(d)")); - assert!(join_graph.contains_edge("a#Aggregate(a) <-> d#Source(d)")); + assert!(join_graph.contains_edges(vec![ + "Aggregate(a) <-> Source(b)", + "Project(c) <-> Source(d)", + "Aggregate(a) <-> Source(d)" + ])); // Projections below the aggregation should not be part of the final projections. assert!(!join_graph.contains_projections_and_filters(vec![&a_proj])); + // Test greedy join reordering. + let reordered_plan = GreedyJoinOrderer::compute_join_order(&mut join_graph).unwrap(); + assert!(reordered_plan.schema() == original_plan.schema()); } } diff --git a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs index 09ece20040..58987555ab 100644 --- a/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/reorder_joins/mod.rs @@ -1,2 +1,4 @@ #[cfg(test)] +mod greedy_join_order; +#[cfg(test)] mod join_graph; diff --git a/src/daft-logical-plan/src/test/mod.rs b/src/daft-logical-plan/src/test/mod.rs index 75f8ad386b..7ac8da51c1 100644 --- a/src/daft-logical-plan/src/test/mod.rs +++ b/src/daft-logical-plan/src/test/mod.rs @@ -7,10 +7,20 @@ use crate::builder::LogicalPlanBuilder; /// Create a dummy scan node containing the provided fields in its schema and the provided limit. pub fn dummy_scan_operator(fields: Vec) -> ScanOperatorRef { + dummy_scan_operator_with_size(fields, None) +} + +/// Create dummy scan node containing the provided fields in its schema and the provided limit, +/// and with the provided size estimate. +pub fn dummy_scan_operator_with_size( + fields: Vec, + in_memory_size_per_task: Option, +) -> ScanOperatorRef { let schema = Arc::new(Schema::new(fields).unwrap()); ScanOperatorRef(Arc::new(DummyScanOperator { schema, - num_scan_tasks: 0, + num_scan_tasks: 1, + in_memory_size_per_task, })) } diff --git a/src/daft-physical-plan/src/test/mod.rs b/src/daft-physical-plan/src/test/mod.rs index 3e8de6a74c..29f9d81997 100644 --- a/src/daft-physical-plan/src/test/mod.rs +++ b/src/daft-physical-plan/src/test/mod.rs @@ -10,6 +10,7 @@ pub fn dummy_scan_operator(fields: Vec) -> ScanOperatorRef { ScanOperatorRef(Arc::new(DummyScanOperator { schema, num_scan_tasks: 1, + in_memory_size_per_task: None, })) }