From 594b5e265cc180df2fa34bbf04684b74aa94bb72 Mon Sep 17 00:00:00 2001 From: advancedxy <807537+advancedxy@users.noreply.github.com> Date: Wed, 6 Nov 2024 01:28:15 +0800 Subject: [PATCH] [FEAT] Support null safe equal in joins (#3161) This commit consists of the following parts: 1. Make the logical plan join's on condition null safe aware 2. Support translating null safe equal joins into physical plans 3. Make sure the optimization rules: eliminate cross join and push down filter don't break 4. Modifications to physical join ops to support null safe equal: hash and broadcast joins are supported. SMJ is not supported yet. 5. Glue code in Python side to make the whole pipeline work 6. Some UTs in the python side Fixes: #3069 TODOs(in follow-up PRs): - [ ] rewrite null safe equal for SMJ so that SMJ could support null safe equals as well - [ ] SQL supports null safe equal, a.k.a SpaceShip(a<=>b) - [ ] Python's DataFrame API supports null safe equal join --- daft/daft/__init__.pyi | 1 + daft/execution/execution_step.py | 2 + daft/execution/physical_plan.py | 15 +- daft/execution/rust_physical_plan_shim.py | 4 + daft/table/micropartition.py | 9 +- .../src/array/ops/arrow2/comparison.rs | 10 +- src/daft-core/src/utils/dyn_compare.rs | 10 +- src/daft-local-execution/src/pipeline.rs | 9 +- .../src/sinks/hash_join_build.rs | 11 +- src/daft-micropartition/src/ops/join.rs | 8 +- src/daft-micropartition/src/python.rs | 2 + src/daft-physical-plan/src/local_plan.rs | 3 + src/daft-physical-plan/src/translate.rs | 1 + src/daft-plan/src/builder.rs | 25 ++ src/daft-plan/src/display.rs | 3 +- src/daft-plan/src/logical_ops/join.rs | 21 ++ .../rules/eliminate_cross_join.rs | 2 + .../rules/push_down_filter.rs | 54 ++++- src/daft-plan/src/logical_plan.rs | 3 +- .../src/physical_ops/broadcast_join.rs | 10 + src/daft-plan/src/physical_ops/hash_join.rs | 9 + .../rules/reorder_partition_keys.rs | 2 + src/daft-plan/src/physical_plan.rs | 5 +- .../src/physical_planner/translate.rs | 13 + src/daft-scheduler/src/scheduler.rs | 4 + src/daft-sql/src/planner.rs | 1 + src/daft-table/src/ops/groups.rs | 4 +- src/daft-table/src/ops/hash.rs | 8 +- src/daft-table/src/ops/joins/hash_join.rs | 20 +- src/daft-table/src/ops/joins/mod.rs | 21 +- src/daft-table/src/probeable/mod.rs | 11 +- src/daft-table/src/probeable/probe_set.rs | 20 +- src/daft-table/src/probeable/probe_table.rs | 17 +- src/daft-table/src/python.rs | 2 + tests/table/test_joins.py | 222 ++++++++++++------ 35 files changed, 440 insertions(+), 122 deletions(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 106474b3bc..495d0e0a97 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1550,6 +1550,7 @@ class PyMicroPartition: right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr], + null_equals_nulls: list[bool] | None, how: JoinType, ) -> PyMicroPartition: ... def pivot( diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index b0763b2c25..dcf2f35b28 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -788,6 +788,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) class HashJoin(SingleOutputInstruction): left_on: ExpressionsProjection right_on: ExpressionsProjection + null_equals_nulls: list[bool] | None how: JoinType is_swapped: bool @@ -810,6 +811,7 @@ def _hash_join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: right, left_on=self.left_on, right_on=self.right_on, + null_equals_nulls=self.null_equals_nulls, how=self.how, ) return [result] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index e3a8501ce5..80fb766d3e 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -346,6 +346,7 @@ def hash_join( right_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, + null_equals_nulls: None | list[bool], how: JoinType, ) -> InProgressPhysicalPlan[PartitionT]: """Hash-based pairwise join the partitions from `left_child_plan` and `right_child_plan` together.""" @@ -387,6 +388,7 @@ def hash_join( instruction=execution_step.HashJoin( left_on=left_on, right_on=right_on, + null_equals_nulls=null_equals_nulls, how=how, is_swapped=False, ) @@ -432,6 +434,7 @@ def _create_broadcast_join_step( receiver_part: SingleOutputPartitionTask[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, + null_equals_nulls: None | list[bool], how: JoinType, is_swapped: bool, ) -> PartitionTaskBuilder[PartitionT]: @@ -477,6 +480,7 @@ def _create_broadcast_join_step( instruction=execution_step.BroadcastJoin( left_on=left_on, right_on=right_on, + null_equals_nulls=null_equals_nulls, how=how, is_swapped=is_swapped, ) @@ -488,6 +492,7 @@ def broadcast_join( receiver_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, + null_equals_nulls: None | list[bool], how: JoinType, is_swapped: bool, ) -> InProgressPhysicalPlan[PartitionT]: @@ -530,7 +535,15 @@ def broadcast_join( # Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop. while receiver_requests and receiver_requests[0].done(): receiver_part = receiver_requests.popleft() - yield _create_broadcast_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped) + yield _create_broadcast_join_step( + broadcaster_parts, + receiver_part, + left_on, + right_on, + null_equals_nulls, + how, + is_swapped, + ) # Execute single child step to pull in more input partitions. try: diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 225fd13185..151f68061c 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -243,6 +243,7 @@ def hash_join( right: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], right_on: list[PyExpr], + null_equals_nulls: list[bool] | None, join_type: JoinType, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) @@ -253,6 +254,7 @@ def hash_join( left_on=left_on_expr_proj, right_on=right_on_expr_proj, how=join_type, + null_equals_nulls=null_equals_nulls, ) @@ -303,6 +305,7 @@ def broadcast_join( receiver: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], right_on: list[PyExpr], + null_equals_nulls: list[bool] | None, join_type: JoinType, is_swapped: bool, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: @@ -315,6 +318,7 @@ def broadcast_join( right_on=right_on_expr_proj, how=join_type, is_swapped=is_swapped, + null_equals_nulls=null_equals_nulls, ) diff --git a/daft/table/micropartition.py b/daft/table/micropartition.py index 81b2b8b2ac..5baf3d379c 100644 --- a/daft/table/micropartition.py +++ b/daft/table/micropartition.py @@ -248,6 +248,7 @@ def hash_join( right: MicroPartition, left_on: ExpressionsProjection, right_on: ExpressionsProjection, + null_equals_nulls: list[bool] | None = None, how: JoinType = JoinType.Inner, ) -> MicroPartition: if len(left_on) != len(right_on): @@ -262,7 +263,13 @@ def hash_join( right_exprs = [e._expr for e in right_on] return MicroPartition._from_pymicropartition( - self._micropartition.hash_join(right._micropartition, left_on=left_exprs, right_on=right_exprs, how=how) + self._micropartition.hash_join( + right._micropartition, + left_on=left_exprs, + right_on=right_exprs, + null_equals_nulls=null_equals_nulls, + how=how, + ) ) def sort_merge_join( diff --git a/src/daft-core/src/array/ops/arrow2/comparison.rs b/src/daft-core/src/array/ops/arrow2/comparison.rs index a9c37c50fb..72a5547551 100644 --- a/src/daft-core/src/array/ops/arrow2/comparison.rs +++ b/src/daft-core/src/array/ops/arrow2/comparison.rs @@ -80,17 +80,17 @@ pub fn build_is_equal( pub fn build_multi_array_is_equal( left: &[Series], right: &[Series], - nulls_equal: bool, - nan_equal: bool, + nulls_equal: &[bool], + nans_equal: &[bool], ) -> DaftResult bool + Send + Sync>> { let mut fn_list = Vec::with_capacity(left.len()); - for (l, r) in left.iter().zip(right.iter()) { + for (idx, (l, r)) in left.iter().zip(right.iter()).enumerate() { fn_list.push(build_is_equal( l.to_arrow().as_ref(), r.to_arrow().as_ref(), - nulls_equal, - nan_equal, + nulls_equal[idx], + nans_equal[idx], )?); } diff --git a/src/daft-core/src/utils/dyn_compare.rs b/src/daft-core/src/utils/dyn_compare.rs index f5c11a6eaf..d58d295365 100644 --- a/src/daft-core/src/utils/dyn_compare.rs +++ b/src/daft-core/src/utils/dyn_compare.rs @@ -34,16 +34,16 @@ pub fn build_dyn_compare( pub fn build_dyn_multi_array_compare( schema: &Schema, - nulls_equal: bool, - nans_equal: bool, + nulls_equal: &[bool], + nans_equal: &[bool], ) -> DaftResult { let mut fn_list = Vec::with_capacity(schema.len()); - for field in schema.fields.values() { + for (idx, field) in schema.fields.values().enumerate() { fn_list.push(build_dyn_compare( &field.dtype, &field.dtype, - nulls_equal, - nans_equal, + nulls_equal[idx], + nans_equal[idx], )?); } let combined_fn = Box::new( diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index f15ab50543..2bda0a302a 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -313,6 +313,7 @@ pub fn physical_plan_to_pipeline( right, left_on, right_on, + null_equals_null, join_type, schema, }) => { @@ -371,9 +372,13 @@ pub fn physical_plan_to_pipeline( .zip(key_schema.fields.values()) .map(|(e, f)| e.clone().cast(&f.dtype)) .collect::>(); - // we should move to a builder pattern - let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?; + let build_sink = HashJoinBuildSink::new( + key_schema, + casted_build_on, + null_equals_null.clone(), + join_type, + )?; let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?; let build_node = BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed(); diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 677f63279d..56e6ab8dbc 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -25,11 +25,16 @@ impl ProbeTableState { fn new( key_schema: &SchemaRef, projection: Vec, + nulls_equal_aware: Option<&Vec>, join_type: &JoinType, ) -> DaftResult { let track_indices = !matches!(join_type, JoinType::Anti | JoinType::Semi); Ok(Self::Building { - probe_table_builder: Some(make_probeable_builder(key_schema.clone(), track_indices)?), + probe_table_builder: Some(make_probeable_builder( + key_schema.clone(), + nulls_equal_aware, + track_indices, + )?), projection, tables: Vec::new(), }) @@ -83,6 +88,7 @@ impl BlockingSinkState for ProbeTableState { pub struct HashJoinBuildSink { key_schema: SchemaRef, projection: Vec, + nulls_equal_aware: Option>, join_type: JoinType, } @@ -90,11 +96,13 @@ impl HashJoinBuildSink { pub(crate) fn new( key_schema: SchemaRef, projection: Vec, + nulls_equal_aware: Option>, join_type: &JoinType, ) -> DaftResult { Ok(Self { key_schema, projection, + nulls_equal_aware, join_type: *join_type, }) } @@ -144,6 +152,7 @@ impl BlockingSink for HashJoinBuildSink { Ok(Box::new(ProbeTableState::new( &self.key_schema, self.projection.clone(), + self.nulls_equal_aware.as_ref(), &self.join_type, )?)) } diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index 0d671d0fe3..a4bf18b546 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -82,11 +82,17 @@ impl MicroPartition { right: &Self, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: Option>, how: JoinType, ) -> DaftResult { let io_stats = IOStatsContext::new("MicroPartition::hash_join"); + let null_equals_nulls = null_equals_nulls.unwrap_or_else(|| vec![false; left_on.len()]); + let table_join = + |lt: &Table, rt: &Table, lo: &[ExprRef], ro: &[ExprRef], _how: JoinType| { + Table::hash_join(lt, rt, lo, ro, null_equals_nulls.as_slice(), _how) + }; - self.join(right, io_stats, left_on, right_on, how, Table::hash_join) + self.join(right, io_stats, left_on, right_on, how, table_join) } pub fn sort_merge_join( diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index ab9b4a7db1..39bc7ad5c5 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -260,6 +260,7 @@ impl PyMicroPartition { left_on: Vec, right_on: Vec, how: JoinType, + null_equals_nulls: Option>, ) -> PyResult { let left_exprs: Vec = left_on.into_iter().map(std::convert::Into::into).collect(); @@ -272,6 +273,7 @@ impl PyMicroPartition { &right.inner, left_exprs.as_slice(), right_exprs.as_slice(), + null_equals_nulls, how, )? .into()) diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index 39eeca96c5..0fd750ab82 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -261,6 +261,7 @@ impl LocalPhysicalPlan { right: LocalPhysicalPlanRef, left_on: Vec, right_on: Vec, + null_equals_null: Option>, join_type: JoinType, schema: SchemaRef, ) -> LocalPhysicalPlanRef { @@ -269,6 +270,7 @@ impl LocalPhysicalPlan { right, left_on, right_on, + null_equals_null, join_type, schema, }) @@ -452,6 +454,7 @@ pub struct HashJoin { pub right: LocalPhysicalPlanRef, pub left_on: Vec, pub right_on: Vec, + pub null_equals_null: Option>, pub join_type: JoinType, pub schema: SchemaRef, } diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index bba77357d7..f839a80bab 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -130,6 +130,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { right, join.left_on.clone(), join.right_on.clone(), + join.null_equals_nulls.clone(), join.join_type, join.output_schema.clone(), )) diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index c08a975dfd..e5929baeed 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -454,12 +454,37 @@ impl LogicalPlanBuilder { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + ) -> DaftResult { + self.join_with_null_safe_equal( + right, + left_on, + right_on, + None, + join_type, + join_strategy, + join_suffix, + join_prefix, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn join_with_null_safe_equal>( + &self, + right: Right, + left_on: Vec, + right_on: Vec, + null_equals_nulls: Option>, + join_type: JoinType, + join_strategy: Option, + join_suffix: Option<&str>, + join_prefix: Option<&str>, ) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Join::try_new( self.plan.clone(), right.into(), left_on, right_on, + null_equals_nulls, join_type, join_strategy, join_suffix, diff --git a/src/daft-plan/src/display.rs b/src/daft-plan/src/display.rs index 28416c597d..0f3228cc03 100644 --- a/src/daft-plan/src/display.rs +++ b/src/daft-plan/src/display.rs @@ -227,10 +227,11 @@ Project1 --> Limit0 .build(); let plan = LogicalPlanBuilder::new(subplan, None) - .join( + .join_with_null_safe_equal( subplan2, vec![col("id")], vec![col("id")], + Some(vec![true]), JoinType::Inner, None, None, diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 0ba6182535..b3310657d1 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -29,6 +29,7 @@ pub struct Join { pub left_on: Vec, pub right_on: Vec, + pub null_equals_nulls: Option>, pub join_type: JoinType, pub join_strategy: Option, pub output_schema: SchemaRef, @@ -40,6 +41,7 @@ impl std::hash::Hash for Join { std::hash::Hash::hash(&self.right, state); std::hash::Hash::hash(&self.left_on, state); std::hash::Hash::hash(&self.right_on, state); + std::hash::Hash::hash(&self.null_equals_nulls, state); std::hash::Hash::hash(&self.join_type, state); std::hash::Hash::hash(&self.join_strategy, state); std::hash::Hash::hash(&self.output_schema, state); @@ -53,6 +55,7 @@ impl Join { right: Arc, left_on: Vec, right_on: Vec, + null_equals_nulls: Option>, join_type: JoinType, join_strategy: Option, join_suffix: Option<&str>, @@ -92,6 +95,16 @@ impl Join { } } + if let Some(null_equals_null) = &null_equals_nulls { + if null_equals_null.len() != left_on.len() { + return Err(DaftError::ValueError( + "null_equals_nulls must have the same length as left_on or right_on" + .to_string(), + )) + .context(CreationSnafu); + } + } + if matches!(join_type, JoinType::Anti | JoinType::Semi) { // The output schema is the same as the left input schema for anti and semi joins. @@ -102,6 +115,7 @@ impl Join { right, left_on, right_on, + null_equals_nulls, join_type, join_strategy, output_schema, @@ -188,6 +202,7 @@ impl Join { right, left_on, right_on, + null_equals_nulls, join_type, join_strategy, output_schema, @@ -276,6 +291,12 @@ impl Join { )); } } + if let Some(null_equals_nulls) = &self.null_equals_nulls { + res.push(format!( + "Null equals Nulls = [{}]", + null_equals_nulls.iter().map(|b| b.to_string()).join(", ") + )); + } res.push(format!( "Output schema = {}", self.output_schema.short_string() diff --git a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs index 92d6e30cb1..a78a549215 100644 --- a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs +++ b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs @@ -306,6 +306,7 @@ fn find_inner_join( left: left_input, right: right_input, left_on: left_keys, + null_equals_nulls: None, right_on: right_keys, join_type: JoinType::Inner, join_strategy: None, @@ -327,6 +328,7 @@ fn find_inner_join( right, left_on: vec![], right_on: vec![], + null_equals_nulls: None, join_type: JoinType::Inner, join_strategy: None, output_schema: Arc::new(join_schema), diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs index 976a788ae8..0000606be8 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs @@ -650,6 +650,7 @@ mod tests { #[rstest] fn filter_commutes_with_join_left_side( #[values(false, true)] push_into_left_scan: bool, + #[values(false, true)] null_equals_null: bool, #[values(JoinType::Inner, JoinType::Left, JoinType::Anti, JoinType::Semi)] how: JoinType, ) -> DaftResult<()> { let left_scan_op = dummy_scan_operator(vec![ @@ -666,12 +667,18 @@ mod tests { ); let right_scan_plan = dummy_scan_node(right_scan_op.clone()); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equals_null { + Some(vec![true]) + } else { + None + }; let pred = col("a").lt(lit(2)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on.clone(), + null_equals_nulls.clone(), how, None, None, @@ -688,10 +695,11 @@ mod tests { left_scan_plan.filter(pred)? }; let expected = expected_left_filter_scan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, @@ -706,6 +714,7 @@ mod tests { #[rstest] fn filter_commutes_with_join_right_side( #[values(false, true)] push_into_right_scan: bool, + #[values(false, true)] null_equals_null: bool, #[values(JoinType::Inner, JoinType::Right)] how: JoinType, ) -> DaftResult<()> { let left_scan_op = dummy_scan_operator(vec![ @@ -722,12 +731,18 @@ mod tests { Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), ); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equals_null { + Some(vec![true]) + } else { + None + }; let pred = col("c").lt(lit(2.0)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on.clone(), + null_equals_nulls.clone(), how, None, None, @@ -744,10 +759,11 @@ mod tests { right_scan_plan.filter(pred)? }; let expected = left_scan_plan - .join( + .join_with_null_safe_equal( &expected_right_filter_scan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, @@ -763,6 +779,7 @@ mod tests { fn filter_commutes_with_join_on_join_key( #[values(false, true)] push_into_left_scan: bool, #[values(false, true)] push_into_right_scan: bool, + #[values(false, true)] null_equals_null: bool, #[values( JoinType::Inner, JoinType::Left, @@ -791,12 +808,18 @@ mod tests { Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), ); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equals_null { + Some(vec![true]) + } else { + None + }; let pred = col("b").lt(lit(2)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on.clone(), + null_equals_nulls.clone(), how, None, None, @@ -821,10 +844,11 @@ mod tests { right_scan_plan.filter(pred)? }; let expected = expected_left_filter_scan - .join( + .join_with_null_safe_equal( &expected_right_filter_scan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, @@ -838,6 +862,7 @@ mod tests { /// Tests that Filter can be pushed into the left side of a Join. #[rstest] fn filter_does_not_commute_with_join_left_side( + #[values(false, true)] null_equal_null: bool, #[values(JoinType::Right, JoinType::Outer)] how: JoinType, ) -> DaftResult<()> { let left_scan_op = dummy_scan_operator(vec![ @@ -851,12 +876,18 @@ mod tests { let left_scan_plan = dummy_scan_node(left_scan_op.clone()); let right_scan_plan = dummy_scan_node(right_scan_op.clone()); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equal_null { + Some(vec![true]) + } else { + None + }; let pred = col("a").lt(lit(2)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, @@ -873,6 +904,7 @@ mod tests { /// Tests that Filter can be pushed into the right side of a Join. #[rstest] fn filter_does_not_commute_with_join_right_side( + #[values(false, true)] null_equal_null: bool, #[values(JoinType::Left, JoinType::Outer)] how: JoinType, ) -> DaftResult<()> { let left_scan_op = dummy_scan_operator(vec![ @@ -886,12 +918,18 @@ mod tests { let left_scan_plan = dummy_scan_node(left_scan_op.clone()); let right_scan_plan = dummy_scan_node(right_scan_op.clone()); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equal_null { + Some(vec![true]) + } else { + None + }; let pred = col("c").lt(lit(2.0)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 866c7940b4..37b75217c8 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -264,11 +264,12 @@ impl LogicalPlan { [input1, input2] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()), - Self::Join(Join { left_on, right_on, join_type, join_strategy, .. }) => Self::Join(Join::try_new( + Self::Join(Join { left_on, right_on, null_equals_nulls, join_type, join_strategy, .. }) => Self::Join(Join::try_new( input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), + null_equals_nulls.clone(), *join_type, *join_strategy, None, // The suffix is already eagerly computed in the constructor diff --git a/src/daft-plan/src/physical_ops/broadcast_join.rs b/src/daft-plan/src/physical_ops/broadcast_join.rs index b45ce7c4fa..e8048f305c 100644 --- a/src/daft-plan/src/physical_ops/broadcast_join.rs +++ b/src/daft-plan/src/physical_ops/broadcast_join.rs @@ -13,6 +13,7 @@ pub struct BroadcastJoin { pub receiver: PhysicalPlanRef, pub left_on: Vec, pub right_on: Vec, + pub null_equals_nulls: Option>, pub join_type: JoinType, pub is_swapped: bool, } @@ -23,6 +24,7 @@ impl BroadcastJoin { receiver: PhysicalPlanRef, left_on: Vec, right_on: Vec, + null_equals_nulls: Option>, join_type: JoinType, is_swapped: bool, ) -> Self { @@ -31,6 +33,7 @@ impl BroadcastJoin { receiver, left_on, right_on, + null_equals_nulls, join_type, is_swapped, } @@ -58,6 +61,13 @@ impl BroadcastJoin { )); } } + + if let Some(null_equals_nulls) = &self.null_equals_nulls { + res.push(format!( + "Null equals Nulls = [{}]", + null_equals_nulls.iter().map(|b| b.to_string()).join(", ") + )); + } res.push(format!("Is swapped = {}", self.is_swapped)); res } diff --git a/src/daft-plan/src/physical_ops/hash_join.rs b/src/daft-plan/src/physical_ops/hash_join.rs index 5d1895c1d8..c41ac018fe 100644 --- a/src/daft-plan/src/physical_ops/hash_join.rs +++ b/src/daft-plan/src/physical_ops/hash_join.rs @@ -13,6 +13,7 @@ pub struct HashJoin { pub right: PhysicalPlanRef, pub left_on: Vec, pub right_on: Vec, + pub null_equals_nulls: Option>, pub join_type: JoinType, } @@ -22,6 +23,7 @@ impl HashJoin { right: PhysicalPlanRef, left_on: Vec, right_on: Vec, + null_equals_nulls: Option>, join_type: JoinType, ) -> Self { Self { @@ -29,6 +31,7 @@ impl HashJoin { right, left_on, right_on, + null_equals_nulls, join_type, } } @@ -55,6 +58,12 @@ impl HashJoin { )); } } + if let Some(null_equals_nulls) = &self.null_equals_nulls { + res.push(format!( + "Null equals Nulls = [{}]", + null_equals_nulls.iter().map(|b| b.to_string()).join(", ") + )); + } res } } diff --git a/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs b/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs index 76910ed834..30a1e38be7 100644 --- a/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs +++ b/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs @@ -272,6 +272,7 @@ mod tests { plan2, vec![col("b"), col("a")], vec![col("x"), col("y")], + None, JoinType::Inner, )) .arced(); @@ -285,6 +286,7 @@ mod tests { add_repartition(base2, 1, vec![col("x"), col("y")]), vec![col("b"), col("a")], vec![col("x"), col("y")], + None, JoinType::Inner, )) .arced(); diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index e3aa9607e7..25327e66e6 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -460,14 +460,15 @@ impl PhysicalPlan { Self::InMemoryScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), Self::TabularScan(..) | Self::EmptyScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), - Self::HashJoin(HashJoin { left_on, right_on, join_type, .. }) => Self::HashJoin(HashJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type)), + Self::HashJoin(HashJoin { left_on, right_on, null_equals_nulls, join_type, .. }) => Self::HashJoin(HashJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), null_equals_nulls.clone(), *join_type)), Self::BroadcastJoin(BroadcastJoin { left_on, right_on, + null_equals_nulls, join_type, is_swapped, .. - }) => Self::BroadcastJoin(BroadcastJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type, *is_swapped)), + }) => Self::BroadcastJoin(BroadcastJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), null_equals_nulls.clone(), *join_type, *is_swapped)), Self::SortMergeJoin(SortMergeJoin { left_on, right_on, join_type, num_partitions, left_is_larger, needs_presort, .. }) => Self::SortMergeJoin(SortMergeJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type, *num_partitions, *left_is_larger, *needs_presort)), Self::Concat(_) => Self::Concat(Concat::new(input1.clone(), input2.clone())), _ => panic!("Physical op {:?} has one input, but got two", self), diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 2567cd3fca..2bfc4a2aed 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -400,6 +400,7 @@ pub(super) fn translate_single_logical_node( right, left_on, right_on, + null_equals_nulls, join_type, join_strategy, .. @@ -474,6 +475,9 @@ pub(super) fn translate_single_logical_node( } else { is_right_hash_partitioned || is_right_sort_partitioned }; + let has_null_safe_equals = null_equals_nulls + .as_ref() + .map_or(false, |v| v.iter().any(|b| *b)); let join_strategy = join_strategy.unwrap_or_else(|| { fn keys_are_primitive(on: &[ExprRef], schema: &SchemaRef) -> bool { on.iter().all(|expr| { @@ -506,6 +510,7 @@ pub(super) fn translate_single_logical_node( // TODO(Clark): Also do a sort-merge join if a downstream op needs the table to be sorted on the join key. // TODO(Clark): Look into defaulting to sort-merge join over hash join under more input partitioning setups. // TODO(Kevin): Support sort-merge join for other types of joins. + // TODO(advancedxy): Rewrite null safe equals to support SMJ } else if *join_type == JoinType::Inner && keys_are_primitive(left_on, &left.schema()) && keys_are_primitive(right_on, &right.schema()) @@ -513,6 +518,7 @@ pub(super) fn translate_single_logical_node( && (!is_larger_partitioned || (left_is_larger && is_left_sort_partitioned || !left_is_larger && is_right_sort_partitioned)) + && !has_null_safe_equals { JoinStrategy::SortMerge // Otherwise, use a hash join. @@ -544,6 +550,7 @@ pub(super) fn translate_single_logical_node( right_physical, left_on.clone(), right_on.clone(), + null_equals_nulls.clone(), *join_type, is_swapped, )) @@ -555,6 +562,11 @@ pub(super) fn translate_single_logical_node( "Sort-merge join currently only supports inner joins".to_string(), )); } + if has_null_safe_equals { + return Err(common_error::DaftError::ValueError( + "Sort-merge join does not support null-safe equals yet".to_string(), + )); + } let needs_presort = if cfg.sort_merge_join_sort_with_aligned_boundaries { // Use the special-purpose presorting that ensures join inputs are sorted with aligned @@ -645,6 +657,7 @@ pub(super) fn translate_single_logical_node( right_physical, left_on.clone(), right_on.clone(), + null_equals_nulls.clone(), *join_type, )) .arced()) diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 9d0894d1d8..65f1cbc67b 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -626,6 +626,7 @@ fn physical_plan_to_partition_tasks( right, left_on, right_on, + null_equals_nulls, join_type, .. }) => { @@ -649,6 +650,7 @@ fn physical_plan_to_partition_tasks( upstream_right_iter, left_on_pyexprs, right_on_pyexprs, + null_equals_nulls.clone(), *join_type, ))?; Ok(py_iter.into()) @@ -706,6 +708,7 @@ fn physical_plan_to_partition_tasks( receiver: right, left_on, right_on, + null_equals_nulls, join_type, is_swapped, }) => { @@ -729,6 +732,7 @@ fn physical_plan_to_partition_tasks( upstream_right_iter, left_on_pyexprs, right_on_pyexprs, + null_equals_nulls.clone(), *join_type, *is_swapped, ))?; diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index b27a0060ce..76e30d5912 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -419,6 +419,7 @@ impl SQLPlanner { left_rel: &Relation, right_rel: &Relation, ) -> SQLPlannerResult<(Vec, Vec)> { + // TODO: support null safe equal, a.k.a. <=>. if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression { match *op { BinaryOperator::Eq => { diff --git a/src/daft-table/src/ops/groups.rs b/src/daft-table/src/ops/groups.rs index 1edccccdc7..580d2c4288 100644 --- a/src/daft-table/src/ops/groups.rs +++ b/src/daft-table/src/ops/groups.rs @@ -65,8 +65,8 @@ impl Table { let comparator = build_multi_array_is_equal( self.columns.as_slice(), self.columns.as_slice(), - true, - true, + vec![true; self.columns.len()].as_slice(), + vec![true; self.columns.len()].as_slice(), )?; // To group the argsort values together, we will traverse the table in argsort order, diff --git a/src/daft-table/src/ops/hash.rs b/src/daft-table/src/ops/hash.rs index 4b4d120da0..5421988518 100644 --- a/src/daft-table/src/ops/hash.rs +++ b/src/daft-table/src/ops/hash.rs @@ -32,8 +32,8 @@ impl Table { let comparator = build_multi_array_is_equal( self.columns.as_slice(), self.columns.as_slice(), - true, - true, + vec![true; self.columns.len()].as_slice(), + vec![true; self.columns.len()].as_slice(), )?; let mut probe_table = @@ -77,8 +77,8 @@ impl Table { let comparator = build_multi_array_is_equal( self.columns.as_slice(), self.columns.as_slice(), - true, - true, + vec![true; self.columns.len()].as_slice(), + vec![true; self.columns.len()].as_slice(), )?; let mut probe_table = diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index 1da3ea5f16..7f74666443 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -18,6 +18,7 @@ pub(super) fn hash_inner_join( right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], ) -> DaftResult { let join_schema = infer_join_schema( &left.schema, @@ -55,8 +56,8 @@ pub(super) fn hash_inner_join( let is_equal = build_multi_array_is_equal( lkeys.columns.as_slice(), rkeys.columns.as_slice(), - false, - false, + null_equals_nulls, + vec![false; lkeys.columns.len()].as_slice(), )?; let mut left_idx = vec![]; @@ -107,6 +108,7 @@ pub(super) fn hash_left_right_join( right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], left_side: bool, ) -> DaftResult
{ let join_schema = infer_join_schema( @@ -147,8 +149,8 @@ pub(super) fn hash_left_right_join( let is_equal = build_multi_array_is_equal( lkeys.columns.as_slice(), rkeys.columns.as_slice(), - false, - false, + null_equals_nulls, + vec![false; lkeys.columns.len()].as_slice(), )?; // we will have at least as many rows in the join table as the right table @@ -222,6 +224,7 @@ pub(super) fn hash_semi_anti_join( right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], is_anti: bool, ) -> DaftResult
{ let lkeys = left.eval_expression_list(left_on)?; @@ -246,8 +249,8 @@ pub(super) fn hash_semi_anti_join( let is_equal = build_multi_array_is_equal( lkeys.columns.as_slice(), rkeys.columns.as_slice(), - false, - false, + null_equals_nulls, + vec![false; lkeys.columns.len()].as_slice(), )?; let rows = rkeys.len(); @@ -282,6 +285,7 @@ pub(super) fn hash_outer_join( right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], ) -> DaftResult
{ let join_schema = infer_join_schema( &left.schema, @@ -333,8 +337,8 @@ pub(super) fn hash_outer_join( let is_equal = build_multi_array_is_equal( lkeys.columns.as_slice(), rkeys.columns.as_slice(), - false, - false, + null_equals_nulls, + vec![false; lkeys.columns.len()].as_slice(), )?; // we will have at least as many rows in the join table as the max of the left and right tables diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 0c6b678d35..97c034a9df 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -75,6 +75,7 @@ impl Table { right: &Self, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], how: JoinType, ) -> DaftResult { if left_on.len() != right_on.len() { @@ -92,12 +93,20 @@ impl Table { } match how { - JoinType::Inner => hash_inner_join(self, right, left_on, right_on), - JoinType::Left => hash_left_right_join(self, right, left_on, right_on, true), - JoinType::Right => hash_left_right_join(self, right, left_on, right_on, false), - JoinType::Outer => hash_outer_join(self, right, left_on, right_on), - JoinType::Semi => hash_semi_anti_join(self, right, left_on, right_on, false), - JoinType::Anti => hash_semi_anti_join(self, right, left_on, right_on, true), + JoinType::Inner => hash_inner_join(self, right, left_on, right_on, null_equals_nulls), + JoinType::Left => { + hash_left_right_join(self, right, left_on, right_on, null_equals_nulls, true) + } + JoinType::Right => { + hash_left_right_join(self, right, left_on, right_on, null_equals_nulls, false) + } + JoinType::Outer => hash_outer_join(self, right, left_on, right_on, null_equals_nulls), + JoinType::Semi => { + hash_semi_anti_join(self, right, left_on, right_on, null_equals_nulls, false) + } + JoinType::Anti => { + hash_semi_anti_join(self, right, left_on, right_on, null_equals_nulls, true) + } } } diff --git a/src/daft-table/src/probeable/mod.rs b/src/daft-table/src/probeable/mod.rs index 3346bd9869..c68cb60884 100644 --- a/src/daft-table/src/probeable/mod.rs +++ b/src/daft-table/src/probeable/mod.rs @@ -14,12 +14,19 @@ struct ArrowTableEntry(Vec>); pub fn make_probeable_builder( schema: SchemaRef, + nulls_equal_aware: Option<&Vec>, track_indices: bool, ) -> DaftResult> { if track_indices { - Ok(Box::new(ProbeTableBuilder(ProbeTable::new(schema)?))) + Ok(Box::new(ProbeTableBuilder(ProbeTable::new( + schema, + nulls_equal_aware, + )?))) } else { - Ok(Box::new(ProbeSetBuilder(ProbeSet::new(schema)?))) + Ok(Box::new(ProbeSetBuilder(ProbeSet::new( + schema, + nulls_equal_aware, + )?))) } } diff --git a/src/daft-table/src/probeable/probe_set.rs b/src/daft-table/src/probeable/probe_set.rs index adf9251756..40ab4ab86e 100644 --- a/src/daft-table/src/probeable/probe_set.rs +++ b/src/daft-table/src/probeable/probe_set.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use daft_core::{ array::ops::as_arrow::AsArrow, prelude::SchemaRef, @@ -31,12 +31,26 @@ impl ProbeSet { const DEFAULT_SIZE: usize = 20; - pub(crate) fn new(schema: SchemaRef) -> DaftResult { + pub(crate) fn new( + schema: SchemaRef, + nulls_equal_aware: Option<&Vec>, + ) -> DaftResult { let hash_table = HashMap::::with_capacity_and_hasher( Self::DEFAULT_SIZE, Default::default(), ); - let compare_fn = build_dyn_multi_array_compare(&schema, false, false)?; + if let Some(null_equal_aware) = nulls_equal_aware { + if null_equal_aware.len() != schema.len() { + return Err(DaftError::InternalError( + format!("null_equal_aware should have the same length as the schema. Expected: {}, Found: {}", + schema.len(), null_equal_aware.len()))); + } + } + let default_nulls_equal = vec![false; schema.len()]; + let nulls_equal = nulls_equal_aware.unwrap_or_else(|| default_nulls_equal.as_ref()); + let nans_equal = &vec![false; schema.len()]; + let compare_fn = + build_dyn_multi_array_compare(&schema, nulls_equal.as_slice(), nans_equal.as_slice())?; Ok(Self { schema, hash_table, diff --git a/src/daft-table/src/probeable/probe_table.rs b/src/daft-table/src/probeable/probe_table.rs index 8e8c8c8647..ba4c11dc41 100644 --- a/src/daft-table/src/probeable/probe_table.rs +++ b/src/daft-table/src/probeable/probe_table.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use daft_core::{ array::ops::as_arrow::AsArrow, prelude::SchemaRef, @@ -32,13 +32,24 @@ impl ProbeTable { const DEFAULT_SIZE: usize = 20; - pub(crate) fn new(schema: SchemaRef) -> DaftResult { + pub(crate) fn new(schema: SchemaRef, null_equal_aware: Option<&Vec>) -> DaftResult { let hash_table = HashMap::, IdentityBuildHasher>::with_capacity_and_hasher( Self::DEFAULT_SIZE, Default::default(), ); - let compare_fn = build_dyn_multi_array_compare(&schema, false, false)?; + if let Some(null_equal_aware) = null_equal_aware { + if null_equal_aware.len() != schema.len() { + return Err(DaftError::InternalError( + format!("null_equal_aware should have the same length as the schema. Expected: {}, Found: {}", + schema.len(), null_equal_aware.len()))); + } + } + let default_nulls_equal = vec![false; schema.len()]; + let nulls_equal = null_equal_aware.unwrap_or_else(|| default_nulls_equal.as_ref()); + let nans_equal = &vec![false; schema.len()]; + let compare_fn = + build_dyn_multi_array_compare(&schema, nulls_equal.as_slice(), nans_equal.as_slice())?; Ok(Self { schema, hash_table, diff --git a/src/daft-table/src/python.rs b/src/daft-table/src/python.rs index 89f1a12016..4bef3f9b05 100644 --- a/src/daft-table/src/python.rs +++ b/src/daft-table/src/python.rs @@ -134,6 +134,7 @@ impl PyTable { left_on.into_iter().map(std::convert::Into::into).collect(); let right_exprs: Vec = right_on.into_iter().map(std::convert::Into::into).collect(); + let null_equals_nulls = vec![false; left_exprs.len()]; py.allow_threads(|| { Ok(self .table @@ -141,6 +142,7 @@ impl PyTable { &right.table, left_exprs.as_slice(), right_exprs.as_slice(), + null_equals_nulls.as_slice(), how, )? .into()) diff --git a/tests/table/test_joins.py b/tests/table/test_joins.py index 968ab97630..6c58bbce6e 100644 --- a/tests/table/test_joins.py +++ b/tests/table/test_joins.py @@ -26,29 +26,83 @@ daft_string_types = [DataType.string()] +def skip_null_safe_equal_for_smj(func): + from functools import wraps + from inspect import getfullargspec + + @wraps(func) + def wrapper(*args, **kwargs): + spec = getfullargspec(func) + join_impl, null_safe_equal = None, None + if "join_impl" in kwargs: + join_impl = kwargs["join_impl"] + elif "join_impl" in spec.args: + idx = spec.args.index("join_impl") + join_impl = args[idx] if idx >= 0 else None + if "null_safe_equal" in kwargs: + null_safe_equal = kwargs["null_safe_equal"] + elif "null_safe_equal" in spec.args: + idx = spec.args.index("null_safe_equal") + null_safe_equal = args[idx] if idx >= 0 else None + if join_impl == "sort_merge_join" and null_safe_equal: + pytest.skip("sort merge join does not support null safe equal yet") + return func(*args, **kwargs) + + return wrapper + + +@skip_null_safe_equal_for_smj @pytest.mark.parametrize( - "dtype, data", + "dtype, data, null_safe_equal", itertools.product( daft_numeric_types + daft_string_types, [ - ([0, 1, 2, 3, None], [0, 1, 2, 3, None], [(0, 0), (1, 1), (2, 2), (3, 3)]), - ([None, None, 3, 1, 2, 0], [0, 1, 2, 3, None], [(5, 0), (3, 1), (4, 2), (2, 3)]), - ([None, 4, 5, 6, 7], [0, 1, 2, 3, None], []), - ([None, 0, 0, 0, 1, None], [0, 1, 2, 3, None], [(1, 0), (2, 0), (3, 0), (4, 1)]), - ([None, 0, 0, 1, 1, None], [0, 1, 2, 3, None], [(1, 0), (2, 0), (3, 1), (4, 1)]), - ([None, 0, 0, 1, 1, None], [3, 1, 0, 2, None], [(1, 2), (2, 2), (3, 1), (4, 1)]), + ( + [0, 1, 2, 3, None], + [0, 1, 2, 3, None], + [(0, 0), (1, 1), (2, 2), (3, 3)], + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)], + ), + ( + [None, None, 3, 1, 2, 0], + [0, 1, 2, 3, None], + [(5, 0), (3, 1), (4, 2), (2, 3)], + [(5, 0), (3, 1), (4, 2), (2, 3), (0, 4), (1, 4)], + ), + ([None, 4, 5, 6, 7], [0, 1, 2, 3, None], [], [(0, 4)]), + ( + [None, 0, 0, 0, 1, None], + [0, 1, 2, 3, None], + [(1, 0), (2, 0), (3, 0), (4, 1)], + [(1, 0), (2, 0), (3, 0), (4, 1), (5, 4), (0, 4)], + ), + ( + [None, 0, 0, 1, 1, None], + [0, 1, 2, 3, None], + [(1, 0), (2, 0), (3, 1), (4, 1)], + [(1, 0), (2, 0), (3, 1), (4, 1), (5, 4), (0, 4)], + ), + ( + [None, 0, 0, 1, 1, None], + [3, 1, 0, 2, None], + [(1, 2), (2, 2), (3, 1), (4, 1)], + [(1, 2), (2, 2), (3, 1), (4, 1), (5, 4), (0, 4)], + ), ], + [True, False], ), ) @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_single_column(join_impl, dtype, data) -> None: - left, right, expected_pairs = data +def test_table_join_single_column(join_impl, dtype, data, null_safe_equal) -> None: + left, right, expected_pairs, null_safe_expected = data + expected_pairs = null_safe_expected if null_safe_equal else expected_pairs + null_equals_nulls = {"null_equals_nulls": [null_safe_equal]} if null_safe_equal else {} left_table = MicroPartition.from_pydict({"x": left, "x_ind": list(range(len(left)))}).eval_expression_list( [col("x").cast(dtype), col("x_ind")] ) right_table = MicroPartition.from_pydict({"y": right, "y_ind": list(range(len(right)))}) result_table = getattr(left_table, join_impl)( - right_table, left_on=[col("x")], right_on=[col("y")], how=JoinType.Inner + right_table, left_on=[col("x")], right_on=[col("y")], how=JoinType.Inner, **null_equals_nulls ) assert result_table.column_names() == ["x", "x_ind", "y", "y_ind"] @@ -65,7 +119,7 @@ def test_table_join_single_column(join_impl, dtype, data) -> None: # make sure the result is the same with right table on left result_table = getattr(right_table, join_impl)( - left_table, right_on=[col("x")], left_on=[col("y")], how=JoinType.Inner + left_table, right_on=[col("x")], left_on=[col("y")], how=JoinType.Inner, **null_equals_nulls ) assert result_table.column_names() == ["y", "y_ind", "x", "x_ind"] @@ -90,6 +144,7 @@ def test_table_join_mismatch_column(join_impl) -> None: getattr(left_table, join_impl)(right_table, left_on=[col("x"), col("y")], right_on=[col("a")]) +@skip_null_safe_equal_for_smj @pytest.mark.parametrize( "left", [ @@ -105,7 +160,8 @@ def test_table_join_mismatch_column(join_impl) -> None: ], ) @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_multicolumn_empty_result(join_impl, left, right) -> None: +@pytest.mark.parametrize("null_safe_equal", [True, False]) +def test_table_join_multicolumn_empty_result(join_impl, left, right, null_safe_equal) -> None: """Various multicol joins that should all produce an empty result.""" left_table = MicroPartition.from_pydict(left).eval_expression_list( [col("a").cast(DataType.string()), col("b").cast(DataType.int32())] @@ -114,12 +170,18 @@ def test_table_join_multicolumn_empty_result(join_impl, left, right) -> None: [col("x").cast(DataType.string()), col("y").cast(DataType.int32())] ) - result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal] * 2} if null_safe_equal else {} + + result = getattr(left_table, join_impl)( + right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")], **null_equals_nulls + ) assert result.to_pydict() == {"a": [], "b": [], "x": [], "y": []} -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_multicolumn_nocross(join_impl) -> None: +@pytest.mark.parametrize( + "join_impl,null_safe_equal", [("hash_join", True), ("hash_join", False), ("sort_merge_join", False)] +) +def test_table_join_multicolumn_nocross(join_impl, null_safe_equal) -> None: """A multicol join that should produce two rows and no cross product results. Input has duplicate join values and overlapping single-column values, @@ -127,85 +189,104 @@ def test_table_join_multicolumn_nocross(join_impl) -> None: """ left_table = MicroPartition.from_pydict( { - "a": ["apple", "apple", "banana", "banana", "carrot"], - "b": [1, 2, 2, 2, 3], - "c": [1, 2, 3, 4, 5], + "a": ["apple", "apple", "banana", "banana", "carrot", None], + "b": [1, 2, 2, 2, 3, 3], + "c": [1, 2, 3, 4, 5, 5], } ) right_table = MicroPartition.from_pydict( { - "x": ["banana", "carrot", "apple", "banana", "apple", "durian"], - "y": [1, 3, 2, 1, 3, 6], - "z": [1, 2, 3, 4, 5, 6], + "x": ["banana", "carrot", "apple", "banana", "apple", "durian", None], + "y": [1, 3, 2, 1, 3, 6, 3], + "z": [1, 2, 3, 4, 5, 6, 6], } ) - result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) - assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set( - utils.freeze( - [ - {"a": "apple", "b": 2, "c": 2, "x": "apple", "y": 2, "z": 3}, - {"a": "carrot", "b": 3, "c": 5, "x": "carrot", "y": 3, "z": 2}, - ] - ) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal] * 2} if null_safe_equal else {} + result = getattr(left_table, join_impl)( + right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")], **null_equals_nulls ) + expected = [ + {"a": "apple", "b": 2, "c": 2, "x": "apple", "y": 2, "z": 3}, + {"a": "carrot", "b": 3, "c": 5, "x": "carrot", "y": 3, "z": 2}, + ] + if null_safe_equal: + expected.append({"a": None, "b": 3, "c": 5, "x": None, "y": 3, "z": 6}) + assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze(expected)) -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_multicolumn_cross(join_impl) -> None: +@pytest.mark.parametrize( + "join_impl,null_safe_equal", [("hash_join", True), ("hash_join", False), ("sort_merge_join", False)] +) +def test_table_join_multicolumn_cross(join_impl, null_safe_equal) -> None: """A multicol join that should produce a cross product and a non-cross product.""" left_table = MicroPartition.from_pydict( { - "a": ["apple", "apple", "banana", "banana", "banana"], - "b": [1, 0, 1, 1, 1], - "c": [1, 2, 3, 4, 5], + "a": ["apple", "apple", "banana", "banana", "banana", None], + "b": [1, 0, 1, 1, 1, 1], + "c": [1, 2, 3, 4, 5, 5], } ) right_table = MicroPartition.from_pydict( { - "x": ["apple", "apple", "banana", "banana", "banana"], - "y": [1, 0, 1, 1, 0], - "z": [1, 2, 3, 4, 5], + "x": ["apple", "apple", "banana", "banana", "banana", None], + "y": [1, 0, 1, 1, 0, 1], + "z": [1, 2, 3, 4, 5, 5], } ) - result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) - assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set( - utils.freeze( - [ - {"a": "apple", "b": 1, "c": 1, "x": "apple", "y": 1, "z": 1}, - {"a": "apple", "b": 0, "c": 2, "x": "apple", "y": 0, "z": 2}, - {"a": "banana", "b": 1, "c": 3, "x": "banana", "y": 1, "z": 3}, - {"a": "banana", "b": 1, "c": 3, "x": "banana", "y": 1, "z": 4}, - {"a": "banana", "b": 1, "c": 4, "x": "banana", "y": 1, "z": 3}, - {"a": "banana", "b": 1, "c": 4, "x": "banana", "y": 1, "z": 4}, - {"a": "banana", "b": 1, "c": 5, "x": "banana", "y": 1, "z": 3}, - {"a": "banana", "b": 1, "c": 5, "x": "banana", "y": 1, "z": 4}, - ] - ) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal] * 2} if null_safe_equal else {} + result = getattr(left_table, join_impl)( + right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")], **null_equals_nulls ) + expected = [ + {"a": "apple", "b": 1, "c": 1, "x": "apple", "y": 1, "z": 1}, + {"a": "apple", "b": 0, "c": 2, "x": "apple", "y": 0, "z": 2}, + {"a": "banana", "b": 1, "c": 3, "x": "banana", "y": 1, "z": 3}, + {"a": "banana", "b": 1, "c": 3, "x": "banana", "y": 1, "z": 4}, + {"a": "banana", "b": 1, "c": 4, "x": "banana", "y": 1, "z": 3}, + {"a": "banana", "b": 1, "c": 4, "x": "banana", "y": 1, "z": 4}, + {"a": "banana", "b": 1, "c": 5, "x": "banana", "y": 1, "z": 3}, + {"a": "banana", "b": 1, "c": 5, "x": "banana", "y": 1, "z": 4}, + ] + if null_safe_equal: + expected.append({"a": None, "b": 1, "c": 5, "x": None, "y": 1, "z": 5}) + assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze(expected)) -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_multicolumn_all_nulls(join_impl) -> None: +@pytest.mark.parametrize( + "join_impl,null_safe_equal", [("hash_join", True), ("hash_join", False), ("sort_merge_join", False)] +) +def test_table_join_multicolumn_all_nulls(join_impl, null_safe_equal) -> None: left_table = MicroPartition.from_pydict( { - "a": Series.from_pylist([None, None, None]).cast(DataType.int64()), - "b": Series.from_pylist([None, None, None]).cast(DataType.string()), - "c": [1, 2, 3], + "a": Series.from_pylist([None, None]).cast(DataType.int64()), + "b": Series.from_pylist([None, None]).cast(DataType.string()), + "c": [1, 2], } ) right_table = MicroPartition.from_pydict( { - "x": Series.from_pylist([None, None, None]).cast(DataType.int64()), - "y": Series.from_pylist([None, None, None]).cast(DataType.string()), - "z": [1, 2, 3], + "x": Series.from_pylist([None, None]).cast(DataType.int64()), + "y": Series.from_pylist([None, None]).cast(DataType.string()), + "z": [1, 2], } ) - result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) - assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze([])) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal] * 2} if null_safe_equal else {} + result = getattr(left_table, join_impl)( + right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")], **null_equals_nulls + ) + expected = [] + if null_safe_equal: + expected = [ + {"a": None, "b": None, "c": 1, "x": None, "y": None, "z": 1}, + {"a": None, "b": None, "c": 1, "x": None, "y": None, "z": 2}, + {"a": None, "b": None, "c": 2, "x": None, "y": None, "z": 1}, + {"a": None, "b": None, "c": 2, "x": None, "y": None, "z": 2}, + ] + assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze(expected)) @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) @@ -217,16 +298,25 @@ def test_table_join_no_columns(join_impl) -> None: getattr(left_table, join_impl)(right_table, left_on=[], right_on=[]) -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_single_column_name_boolean(join_impl) -> None: +@pytest.mark.parametrize( + "join_impl,null_safe_equal", [("hash_join", True), ("hash_join", False), ("sort_merge_join", False)] +) +def test_table_join_single_column_name_boolean(join_impl, null_safe_equal) -> None: left_table = MicroPartition.from_pydict({"x": [False, True, None], "y": [0, 1, 2]}) right_table = MicroPartition.from_pydict({"x": [None, True, False, None], "right.y": [0, 1, 2, 3]}) - result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal]} if null_safe_equal else {} + result_table = getattr(left_table, join_impl)( + right_table, left_on=[col("x")], right_on=[col("x")], **null_equals_nulls + ) assert result_table.column_names() == ["x", "y", "right.y"] result_sorted = result_table.sort([col("x")]) - assert result_sorted.get_column("y").to_pylist() == [0, 1] - assert result_sorted.get_column("right.y").to_pylist() == [2, 1] + if null_safe_equal: + assert result_sorted.get_column("y").to_pylist() == [0, 1, 2, 2] + assert result_sorted.get_column("right.y").to_pylist() == [2, 1, 0, 3] + else: + assert result_sorted.get_column("y").to_pylist() == [0, 1] + assert result_sorted.get_column("right.y").to_pylist() == [2, 1] @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"])