diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 106474b3bc..495d0e0a97 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1550,6 +1550,7 @@ class PyMicroPartition: right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr], + null_equals_nulls: list[bool] | None, how: JoinType, ) -> PyMicroPartition: ... def pivot( diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index b0763b2c25..dcf2f35b28 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -788,6 +788,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) class HashJoin(SingleOutputInstruction): left_on: ExpressionsProjection right_on: ExpressionsProjection + null_equals_nulls: list[bool] | None how: JoinType is_swapped: bool @@ -810,6 +811,7 @@ def _hash_join(self, inputs: list[MicroPartition]) -> list[MicroPartition]: right, left_on=self.left_on, right_on=self.right_on, + null_equals_nulls=self.null_equals_nulls, how=self.how, ) return [result] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index e3a8501ce5..80fb766d3e 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -346,6 +346,7 @@ def hash_join( right_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, + null_equals_nulls: None | list[bool], how: JoinType, ) -> InProgressPhysicalPlan[PartitionT]: """Hash-based pairwise join the partitions from `left_child_plan` and `right_child_plan` together.""" @@ -387,6 +388,7 @@ def hash_join( instruction=execution_step.HashJoin( left_on=left_on, right_on=right_on, + null_equals_nulls=null_equals_nulls, how=how, is_swapped=False, ) @@ -432,6 +434,7 @@ def _create_broadcast_join_step( receiver_part: SingleOutputPartitionTask[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, + null_equals_nulls: None | list[bool], how: JoinType, is_swapped: bool, ) -> PartitionTaskBuilder[PartitionT]: @@ -477,6 +480,7 @@ def _create_broadcast_join_step( instruction=execution_step.BroadcastJoin( left_on=left_on, right_on=right_on, + null_equals_nulls=null_equals_nulls, how=how, is_swapped=is_swapped, ) @@ -488,6 +492,7 @@ def broadcast_join( receiver_plan: InProgressPhysicalPlan[PartitionT], left_on: ExpressionsProjection, right_on: ExpressionsProjection, + null_equals_nulls: None | list[bool], how: JoinType, is_swapped: bool, ) -> InProgressPhysicalPlan[PartitionT]: @@ -530,7 +535,15 @@ def broadcast_join( # Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop. while receiver_requests and receiver_requests[0].done(): receiver_part = receiver_requests.popleft() - yield _create_broadcast_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped) + yield _create_broadcast_join_step( + broadcaster_parts, + receiver_part, + left_on, + right_on, + null_equals_nulls, + how, + is_swapped, + ) # Execute single child step to pull in more input partitions. try: diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 225fd13185..151f68061c 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -243,6 +243,7 @@ def hash_join( right: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], right_on: list[PyExpr], + null_equals_nulls: list[bool] | None, join_type: JoinType, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) @@ -253,6 +254,7 @@ def hash_join( left_on=left_on_expr_proj, right_on=right_on_expr_proj, how=join_type, + null_equals_nulls=null_equals_nulls, ) @@ -303,6 +305,7 @@ def broadcast_join( receiver: physical_plan.InProgressPhysicalPlan[PartitionT], left_on: list[PyExpr], right_on: list[PyExpr], + null_equals_nulls: list[bool] | None, join_type: JoinType, is_swapped: bool, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: @@ -315,6 +318,7 @@ def broadcast_join( right_on=right_on_expr_proj, how=join_type, is_swapped=is_swapped, + null_equals_nulls=null_equals_nulls, ) diff --git a/daft/table/micropartition.py b/daft/table/micropartition.py index 81b2b8b2ac..5baf3d379c 100644 --- a/daft/table/micropartition.py +++ b/daft/table/micropartition.py @@ -248,6 +248,7 @@ def hash_join( right: MicroPartition, left_on: ExpressionsProjection, right_on: ExpressionsProjection, + null_equals_nulls: list[bool] | None = None, how: JoinType = JoinType.Inner, ) -> MicroPartition: if len(left_on) != len(right_on): @@ -262,7 +263,13 @@ def hash_join( right_exprs = [e._expr for e in right_on] return MicroPartition._from_pymicropartition( - self._micropartition.hash_join(right._micropartition, left_on=left_exprs, right_on=right_exprs, how=how) + self._micropartition.hash_join( + right._micropartition, + left_on=left_exprs, + right_on=right_exprs, + null_equals_nulls=null_equals_nulls, + how=how, + ) ) def sort_merge_join( diff --git a/src/daft-core/src/array/ops/arrow2/comparison.rs b/src/daft-core/src/array/ops/arrow2/comparison.rs index a9c37c50fb..72a5547551 100644 --- a/src/daft-core/src/array/ops/arrow2/comparison.rs +++ b/src/daft-core/src/array/ops/arrow2/comparison.rs @@ -80,17 +80,17 @@ pub fn build_is_equal( pub fn build_multi_array_is_equal( left: &[Series], right: &[Series], - nulls_equal: bool, - nan_equal: bool, + nulls_equal: &[bool], + nans_equal: &[bool], ) -> DaftResult bool + Send + Sync>> { let mut fn_list = Vec::with_capacity(left.len()); - for (l, r) in left.iter().zip(right.iter()) { + for (idx, (l, r)) in left.iter().zip(right.iter()).enumerate() { fn_list.push(build_is_equal( l.to_arrow().as_ref(), r.to_arrow().as_ref(), - nulls_equal, - nan_equal, + nulls_equal[idx], + nans_equal[idx], )?); } diff --git a/src/daft-core/src/utils/dyn_compare.rs b/src/daft-core/src/utils/dyn_compare.rs index f5c11a6eaf..d58d295365 100644 --- a/src/daft-core/src/utils/dyn_compare.rs +++ b/src/daft-core/src/utils/dyn_compare.rs @@ -34,16 +34,16 @@ pub fn build_dyn_compare( pub fn build_dyn_multi_array_compare( schema: &Schema, - nulls_equal: bool, - nans_equal: bool, + nulls_equal: &[bool], + nans_equal: &[bool], ) -> DaftResult { let mut fn_list = Vec::with_capacity(schema.len()); - for field in schema.fields.values() { + for (idx, field) in schema.fields.values().enumerate() { fn_list.push(build_dyn_compare( &field.dtype, &field.dtype, - nulls_equal, - nans_equal, + nulls_equal[idx], + nans_equal[idx], )?); } let combined_fn = Box::new( diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index f15ab50543..2bda0a302a 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -313,6 +313,7 @@ pub fn physical_plan_to_pipeline( right, left_on, right_on, + null_equals_null, join_type, schema, }) => { @@ -371,9 +372,13 @@ pub fn physical_plan_to_pipeline( .zip(key_schema.fields.values()) .map(|(e, f)| e.clone().cast(&f.dtype)) .collect::>(); - // we should move to a builder pattern - let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?; + let build_sink = HashJoinBuildSink::new( + key_schema, + casted_build_on, + null_equals_null.clone(), + join_type, + )?; let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?; let build_node = BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed(); diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 677f63279d..56e6ab8dbc 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -25,11 +25,16 @@ impl ProbeTableState { fn new( key_schema: &SchemaRef, projection: Vec, + nulls_equal_aware: Option<&Vec>, join_type: &JoinType, ) -> DaftResult { let track_indices = !matches!(join_type, JoinType::Anti | JoinType::Semi); Ok(Self::Building { - probe_table_builder: Some(make_probeable_builder(key_schema.clone(), track_indices)?), + probe_table_builder: Some(make_probeable_builder( + key_schema.clone(), + nulls_equal_aware, + track_indices, + )?), projection, tables: Vec::new(), }) @@ -83,6 +88,7 @@ impl BlockingSinkState for ProbeTableState { pub struct HashJoinBuildSink { key_schema: SchemaRef, projection: Vec, + nulls_equal_aware: Option>, join_type: JoinType, } @@ -90,11 +96,13 @@ impl HashJoinBuildSink { pub(crate) fn new( key_schema: SchemaRef, projection: Vec, + nulls_equal_aware: Option>, join_type: &JoinType, ) -> DaftResult { Ok(Self { key_schema, projection, + nulls_equal_aware, join_type: *join_type, }) } @@ -144,6 +152,7 @@ impl BlockingSink for HashJoinBuildSink { Ok(Box::new(ProbeTableState::new( &self.key_schema, self.projection.clone(), + self.nulls_equal_aware.as_ref(), &self.join_type, )?)) } diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index 0d671d0fe3..a4bf18b546 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -82,11 +82,17 @@ impl MicroPartition { right: &Self, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: Option>, how: JoinType, ) -> DaftResult { let io_stats = IOStatsContext::new("MicroPartition::hash_join"); + let null_equals_nulls = null_equals_nulls.unwrap_or_else(|| vec![false; left_on.len()]); + let table_join = + |lt: &Table, rt: &Table, lo: &[ExprRef], ro: &[ExprRef], _how: JoinType| { + Table::hash_join(lt, rt, lo, ro, null_equals_nulls.as_slice(), _how) + }; - self.join(right, io_stats, left_on, right_on, how, Table::hash_join) + self.join(right, io_stats, left_on, right_on, how, table_join) } pub fn sort_merge_join( diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index ab9b4a7db1..39bc7ad5c5 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -260,6 +260,7 @@ impl PyMicroPartition { left_on: Vec, right_on: Vec, how: JoinType, + null_equals_nulls: Option>, ) -> PyResult { let left_exprs: Vec = left_on.into_iter().map(std::convert::Into::into).collect(); @@ -272,6 +273,7 @@ impl PyMicroPartition { &right.inner, left_exprs.as_slice(), right_exprs.as_slice(), + null_equals_nulls, how, )? .into()) diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index 39eeca96c5..0fd750ab82 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -261,6 +261,7 @@ impl LocalPhysicalPlan { right: LocalPhysicalPlanRef, left_on: Vec, right_on: Vec, + null_equals_null: Option>, join_type: JoinType, schema: SchemaRef, ) -> LocalPhysicalPlanRef { @@ -269,6 +270,7 @@ impl LocalPhysicalPlan { right, left_on, right_on, + null_equals_null, join_type, schema, }) @@ -452,6 +454,7 @@ pub struct HashJoin { pub right: LocalPhysicalPlanRef, pub left_on: Vec, pub right_on: Vec, + pub null_equals_null: Option>, pub join_type: JoinType, pub schema: SchemaRef, } diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index bba77357d7..f839a80bab 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -130,6 +130,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { right, join.left_on.clone(), join.right_on.clone(), + join.null_equals_nulls.clone(), join.join_type, join.output_schema.clone(), )) diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index c08a975dfd..e5929baeed 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -454,12 +454,37 @@ impl LogicalPlanBuilder { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + ) -> DaftResult { + self.join_with_null_safe_equal( + right, + left_on, + right_on, + None, + join_type, + join_strategy, + join_suffix, + join_prefix, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn join_with_null_safe_equal>( + &self, + right: Right, + left_on: Vec, + right_on: Vec, + null_equals_nulls: Option>, + join_type: JoinType, + join_strategy: Option, + join_suffix: Option<&str>, + join_prefix: Option<&str>, ) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Join::try_new( self.plan.clone(), right.into(), left_on, right_on, + null_equals_nulls, join_type, join_strategy, join_suffix, diff --git a/src/daft-plan/src/display.rs b/src/daft-plan/src/display.rs index 28416c597d..0f3228cc03 100644 --- a/src/daft-plan/src/display.rs +++ b/src/daft-plan/src/display.rs @@ -227,10 +227,11 @@ Project1 --> Limit0 .build(); let plan = LogicalPlanBuilder::new(subplan, None) - .join( + .join_with_null_safe_equal( subplan2, vec![col("id")], vec![col("id")], + Some(vec![true]), JoinType::Inner, None, None, diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 0ba6182535..b3310657d1 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -29,6 +29,7 @@ pub struct Join { pub left_on: Vec, pub right_on: Vec, + pub null_equals_nulls: Option>, pub join_type: JoinType, pub join_strategy: Option, pub output_schema: SchemaRef, @@ -40,6 +41,7 @@ impl std::hash::Hash for Join { std::hash::Hash::hash(&self.right, state); std::hash::Hash::hash(&self.left_on, state); std::hash::Hash::hash(&self.right_on, state); + std::hash::Hash::hash(&self.null_equals_nulls, state); std::hash::Hash::hash(&self.join_type, state); std::hash::Hash::hash(&self.join_strategy, state); std::hash::Hash::hash(&self.output_schema, state); @@ -53,6 +55,7 @@ impl Join { right: Arc, left_on: Vec, right_on: Vec, + null_equals_nulls: Option>, join_type: JoinType, join_strategy: Option, join_suffix: Option<&str>, @@ -92,6 +95,16 @@ impl Join { } } + if let Some(null_equals_null) = &null_equals_nulls { + if null_equals_null.len() != left_on.len() { + return Err(DaftError::ValueError( + "null_equals_nulls must have the same length as left_on or right_on" + .to_string(), + )) + .context(CreationSnafu); + } + } + if matches!(join_type, JoinType::Anti | JoinType::Semi) { // The output schema is the same as the left input schema for anti and semi joins. @@ -102,6 +115,7 @@ impl Join { right, left_on, right_on, + null_equals_nulls, join_type, join_strategy, output_schema, @@ -188,6 +202,7 @@ impl Join { right, left_on, right_on, + null_equals_nulls, join_type, join_strategy, output_schema, @@ -276,6 +291,12 @@ impl Join { )); } } + if let Some(null_equals_nulls) = &self.null_equals_nulls { + res.push(format!( + "Null equals Nulls = [{}]", + null_equals_nulls.iter().map(|b| b.to_string()).join(", ") + )); + } res.push(format!( "Output schema = {}", self.output_schema.short_string() diff --git a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs index 92d6e30cb1..a78a549215 100644 --- a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs +++ b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs @@ -306,6 +306,7 @@ fn find_inner_join( left: left_input, right: right_input, left_on: left_keys, + null_equals_nulls: None, right_on: right_keys, join_type: JoinType::Inner, join_strategy: None, @@ -327,6 +328,7 @@ fn find_inner_join( right, left_on: vec![], right_on: vec![], + null_equals_nulls: None, join_type: JoinType::Inner, join_strategy: None, output_schema: Arc::new(join_schema), diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs index 976a788ae8..0000606be8 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs @@ -650,6 +650,7 @@ mod tests { #[rstest] fn filter_commutes_with_join_left_side( #[values(false, true)] push_into_left_scan: bool, + #[values(false, true)] null_equals_null: bool, #[values(JoinType::Inner, JoinType::Left, JoinType::Anti, JoinType::Semi)] how: JoinType, ) -> DaftResult<()> { let left_scan_op = dummy_scan_operator(vec![ @@ -666,12 +667,18 @@ mod tests { ); let right_scan_plan = dummy_scan_node(right_scan_op.clone()); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equals_null { + Some(vec![true]) + } else { + None + }; let pred = col("a").lt(lit(2)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on.clone(), + null_equals_nulls.clone(), how, None, None, @@ -688,10 +695,11 @@ mod tests { left_scan_plan.filter(pred)? }; let expected = expected_left_filter_scan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, @@ -706,6 +714,7 @@ mod tests { #[rstest] fn filter_commutes_with_join_right_side( #[values(false, true)] push_into_right_scan: bool, + #[values(false, true)] null_equals_null: bool, #[values(JoinType::Inner, JoinType::Right)] how: JoinType, ) -> DaftResult<()> { let left_scan_op = dummy_scan_operator(vec![ @@ -722,12 +731,18 @@ mod tests { Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), ); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equals_null { + Some(vec![true]) + } else { + None + }; let pred = col("c").lt(lit(2.0)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on.clone(), + null_equals_nulls.clone(), how, None, None, @@ -744,10 +759,11 @@ mod tests { right_scan_plan.filter(pred)? }; let expected = left_scan_plan - .join( + .join_with_null_safe_equal( &expected_right_filter_scan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, @@ -763,6 +779,7 @@ mod tests { fn filter_commutes_with_join_on_join_key( #[values(false, true)] push_into_left_scan: bool, #[values(false, true)] push_into_right_scan: bool, + #[values(false, true)] null_equals_null: bool, #[values( JoinType::Inner, JoinType::Left, @@ -791,12 +808,18 @@ mod tests { Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), ); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equals_null { + Some(vec![true]) + } else { + None + }; let pred = col("b").lt(lit(2)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on.clone(), + null_equals_nulls.clone(), how, None, None, @@ -821,10 +844,11 @@ mod tests { right_scan_plan.filter(pred)? }; let expected = expected_left_filter_scan - .join( + .join_with_null_safe_equal( &expected_right_filter_scan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, @@ -838,6 +862,7 @@ mod tests { /// Tests that Filter can be pushed into the left side of a Join. #[rstest] fn filter_does_not_commute_with_join_left_side( + #[values(false, true)] null_equal_null: bool, #[values(JoinType::Right, JoinType::Outer)] how: JoinType, ) -> DaftResult<()> { let left_scan_op = dummy_scan_operator(vec![ @@ -851,12 +876,18 @@ mod tests { let left_scan_plan = dummy_scan_node(left_scan_op.clone()); let right_scan_plan = dummy_scan_node(right_scan_op.clone()); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equal_null { + Some(vec![true]) + } else { + None + }; let pred = col("a").lt(lit(2)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, @@ -873,6 +904,7 @@ mod tests { /// Tests that Filter can be pushed into the right side of a Join. #[rstest] fn filter_does_not_commute_with_join_right_side( + #[values(false, true)] null_equal_null: bool, #[values(JoinType::Left, JoinType::Outer)] how: JoinType, ) -> DaftResult<()> { let left_scan_op = dummy_scan_operator(vec![ @@ -886,12 +918,18 @@ mod tests { let left_scan_plan = dummy_scan_node(left_scan_op.clone()); let right_scan_plan = dummy_scan_node(right_scan_op.clone()); let join_on = vec![col("b")]; + let null_equals_nulls = if null_equal_null { + Some(vec![true]) + } else { + None + }; let pred = col("c").lt(lit(2.0)); let plan = left_scan_plan - .join( + .join_with_null_safe_equal( &right_scan_plan, join_on.clone(), join_on, + null_equals_nulls, how, None, None, diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 866c7940b4..37b75217c8 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -264,11 +264,12 @@ impl LogicalPlan { [input1, input2] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()), - Self::Join(Join { left_on, right_on, join_type, join_strategy, .. }) => Self::Join(Join::try_new( + Self::Join(Join { left_on, right_on, null_equals_nulls, join_type, join_strategy, .. }) => Self::Join(Join::try_new( input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), + null_equals_nulls.clone(), *join_type, *join_strategy, None, // The suffix is already eagerly computed in the constructor diff --git a/src/daft-plan/src/physical_ops/broadcast_join.rs b/src/daft-plan/src/physical_ops/broadcast_join.rs index b45ce7c4fa..e8048f305c 100644 --- a/src/daft-plan/src/physical_ops/broadcast_join.rs +++ b/src/daft-plan/src/physical_ops/broadcast_join.rs @@ -13,6 +13,7 @@ pub struct BroadcastJoin { pub receiver: PhysicalPlanRef, pub left_on: Vec, pub right_on: Vec, + pub null_equals_nulls: Option>, pub join_type: JoinType, pub is_swapped: bool, } @@ -23,6 +24,7 @@ impl BroadcastJoin { receiver: PhysicalPlanRef, left_on: Vec, right_on: Vec, + null_equals_nulls: Option>, join_type: JoinType, is_swapped: bool, ) -> Self { @@ -31,6 +33,7 @@ impl BroadcastJoin { receiver, left_on, right_on, + null_equals_nulls, join_type, is_swapped, } @@ -58,6 +61,13 @@ impl BroadcastJoin { )); } } + + if let Some(null_equals_nulls) = &self.null_equals_nulls { + res.push(format!( + "Null equals Nulls = [{}]", + null_equals_nulls.iter().map(|b| b.to_string()).join(", ") + )); + } res.push(format!("Is swapped = {}", self.is_swapped)); res } diff --git a/src/daft-plan/src/physical_ops/hash_join.rs b/src/daft-plan/src/physical_ops/hash_join.rs index 5d1895c1d8..c41ac018fe 100644 --- a/src/daft-plan/src/physical_ops/hash_join.rs +++ b/src/daft-plan/src/physical_ops/hash_join.rs @@ -13,6 +13,7 @@ pub struct HashJoin { pub right: PhysicalPlanRef, pub left_on: Vec, pub right_on: Vec, + pub null_equals_nulls: Option>, pub join_type: JoinType, } @@ -22,6 +23,7 @@ impl HashJoin { right: PhysicalPlanRef, left_on: Vec, right_on: Vec, + null_equals_nulls: Option>, join_type: JoinType, ) -> Self { Self { @@ -29,6 +31,7 @@ impl HashJoin { right, left_on, right_on, + null_equals_nulls, join_type, } } @@ -55,6 +58,12 @@ impl HashJoin { )); } } + if let Some(null_equals_nulls) = &self.null_equals_nulls { + res.push(format!( + "Null equals Nulls = [{}]", + null_equals_nulls.iter().map(|b| b.to_string()).join(", ") + )); + } res } } diff --git a/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs b/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs index 76910ed834..30a1e38be7 100644 --- a/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs +++ b/src/daft-plan/src/physical_optimization/rules/reorder_partition_keys.rs @@ -272,6 +272,7 @@ mod tests { plan2, vec![col("b"), col("a")], vec![col("x"), col("y")], + None, JoinType::Inner, )) .arced(); @@ -285,6 +286,7 @@ mod tests { add_repartition(base2, 1, vec![col("x"), col("y")]), vec![col("b"), col("a")], vec![col("x"), col("y")], + None, JoinType::Inner, )) .arced(); diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index e3aa9607e7..25327e66e6 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -460,14 +460,15 @@ impl PhysicalPlan { Self::InMemoryScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), Self::TabularScan(..) | Self::EmptyScan(..) => panic!("Source nodes don't have children, with_new_children() should never be called for source ops"), - Self::HashJoin(HashJoin { left_on, right_on, join_type, .. }) => Self::HashJoin(HashJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type)), + Self::HashJoin(HashJoin { left_on, right_on, null_equals_nulls, join_type, .. }) => Self::HashJoin(HashJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), null_equals_nulls.clone(), *join_type)), Self::BroadcastJoin(BroadcastJoin { left_on, right_on, + null_equals_nulls, join_type, is_swapped, .. - }) => Self::BroadcastJoin(BroadcastJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type, *is_swapped)), + }) => Self::BroadcastJoin(BroadcastJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), null_equals_nulls.clone(), *join_type, *is_swapped)), Self::SortMergeJoin(SortMergeJoin { left_on, right_on, join_type, num_partitions, left_is_larger, needs_presort, .. }) => Self::SortMergeJoin(SortMergeJoin::new(input1.clone(), input2.clone(), left_on.clone(), right_on.clone(), *join_type, *num_partitions, *left_is_larger, *needs_presort)), Self::Concat(_) => Self::Concat(Concat::new(input1.clone(), input2.clone())), _ => panic!("Physical op {:?} has one input, but got two", self), diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 2567cd3fca..2bfc4a2aed 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -400,6 +400,7 @@ pub(super) fn translate_single_logical_node( right, left_on, right_on, + null_equals_nulls, join_type, join_strategy, .. @@ -474,6 +475,9 @@ pub(super) fn translate_single_logical_node( } else { is_right_hash_partitioned || is_right_sort_partitioned }; + let has_null_safe_equals = null_equals_nulls + .as_ref() + .map_or(false, |v| v.iter().any(|b| *b)); let join_strategy = join_strategy.unwrap_or_else(|| { fn keys_are_primitive(on: &[ExprRef], schema: &SchemaRef) -> bool { on.iter().all(|expr| { @@ -506,6 +510,7 @@ pub(super) fn translate_single_logical_node( // TODO(Clark): Also do a sort-merge join if a downstream op needs the table to be sorted on the join key. // TODO(Clark): Look into defaulting to sort-merge join over hash join under more input partitioning setups. // TODO(Kevin): Support sort-merge join for other types of joins. + // TODO(advancedxy): Rewrite null safe equals to support SMJ } else if *join_type == JoinType::Inner && keys_are_primitive(left_on, &left.schema()) && keys_are_primitive(right_on, &right.schema()) @@ -513,6 +518,7 @@ pub(super) fn translate_single_logical_node( && (!is_larger_partitioned || (left_is_larger && is_left_sort_partitioned || !left_is_larger && is_right_sort_partitioned)) + && !has_null_safe_equals { JoinStrategy::SortMerge // Otherwise, use a hash join. @@ -544,6 +550,7 @@ pub(super) fn translate_single_logical_node( right_physical, left_on.clone(), right_on.clone(), + null_equals_nulls.clone(), *join_type, is_swapped, )) @@ -555,6 +562,11 @@ pub(super) fn translate_single_logical_node( "Sort-merge join currently only supports inner joins".to_string(), )); } + if has_null_safe_equals { + return Err(common_error::DaftError::ValueError( + "Sort-merge join does not support null-safe equals yet".to_string(), + )); + } let needs_presort = if cfg.sort_merge_join_sort_with_aligned_boundaries { // Use the special-purpose presorting that ensures join inputs are sorted with aligned @@ -645,6 +657,7 @@ pub(super) fn translate_single_logical_node( right_physical, left_on.clone(), right_on.clone(), + null_equals_nulls.clone(), *join_type, )) .arced()) diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 9d0894d1d8..65f1cbc67b 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -626,6 +626,7 @@ fn physical_plan_to_partition_tasks( right, left_on, right_on, + null_equals_nulls, join_type, .. }) => { @@ -649,6 +650,7 @@ fn physical_plan_to_partition_tasks( upstream_right_iter, left_on_pyexprs, right_on_pyexprs, + null_equals_nulls.clone(), *join_type, ))?; Ok(py_iter.into()) @@ -706,6 +708,7 @@ fn physical_plan_to_partition_tasks( receiver: right, left_on, right_on, + null_equals_nulls, join_type, is_swapped, }) => { @@ -729,6 +732,7 @@ fn physical_plan_to_partition_tasks( upstream_right_iter, left_on_pyexprs, right_on_pyexprs, + null_equals_nulls.clone(), *join_type, *is_swapped, ))?; diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index b27a0060ce..76e30d5912 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -419,6 +419,7 @@ impl SQLPlanner { left_rel: &Relation, right_rel: &Relation, ) -> SQLPlannerResult<(Vec, Vec)> { + // TODO: support null safe equal, a.k.a. <=>. if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression { match *op { BinaryOperator::Eq => { diff --git a/src/daft-table/src/ops/groups.rs b/src/daft-table/src/ops/groups.rs index 1edccccdc7..580d2c4288 100644 --- a/src/daft-table/src/ops/groups.rs +++ b/src/daft-table/src/ops/groups.rs @@ -65,8 +65,8 @@ impl Table { let comparator = build_multi_array_is_equal( self.columns.as_slice(), self.columns.as_slice(), - true, - true, + vec![true; self.columns.len()].as_slice(), + vec![true; self.columns.len()].as_slice(), )?; // To group the argsort values together, we will traverse the table in argsort order, diff --git a/src/daft-table/src/ops/hash.rs b/src/daft-table/src/ops/hash.rs index 4b4d120da0..5421988518 100644 --- a/src/daft-table/src/ops/hash.rs +++ b/src/daft-table/src/ops/hash.rs @@ -32,8 +32,8 @@ impl Table { let comparator = build_multi_array_is_equal( self.columns.as_slice(), self.columns.as_slice(), - true, - true, + vec![true; self.columns.len()].as_slice(), + vec![true; self.columns.len()].as_slice(), )?; let mut probe_table = @@ -77,8 +77,8 @@ impl Table { let comparator = build_multi_array_is_equal( self.columns.as_slice(), self.columns.as_slice(), - true, - true, + vec![true; self.columns.len()].as_slice(), + vec![true; self.columns.len()].as_slice(), )?; let mut probe_table = diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index 1da3ea5f16..7f74666443 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -18,6 +18,7 @@ pub(super) fn hash_inner_join( right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], ) -> DaftResult { let join_schema = infer_join_schema( &left.schema, @@ -55,8 +56,8 @@ pub(super) fn hash_inner_join( let is_equal = build_multi_array_is_equal( lkeys.columns.as_slice(), rkeys.columns.as_slice(), - false, - false, + null_equals_nulls, + vec![false; lkeys.columns.len()].as_slice(), )?; let mut left_idx = vec![]; @@ -107,6 +108,7 @@ pub(super) fn hash_left_right_join( right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], left_side: bool, ) -> DaftResult
{ let join_schema = infer_join_schema( @@ -147,8 +149,8 @@ pub(super) fn hash_left_right_join( let is_equal = build_multi_array_is_equal( lkeys.columns.as_slice(), rkeys.columns.as_slice(), - false, - false, + null_equals_nulls, + vec![false; lkeys.columns.len()].as_slice(), )?; // we will have at least as many rows in the join table as the right table @@ -222,6 +224,7 @@ pub(super) fn hash_semi_anti_join( right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], is_anti: bool, ) -> DaftResult
{ let lkeys = left.eval_expression_list(left_on)?; @@ -246,8 +249,8 @@ pub(super) fn hash_semi_anti_join( let is_equal = build_multi_array_is_equal( lkeys.columns.as_slice(), rkeys.columns.as_slice(), - false, - false, + null_equals_nulls, + vec![false; lkeys.columns.len()].as_slice(), )?; let rows = rkeys.len(); @@ -282,6 +285,7 @@ pub(super) fn hash_outer_join( right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], ) -> DaftResult
{ let join_schema = infer_join_schema( &left.schema, @@ -333,8 +337,8 @@ pub(super) fn hash_outer_join( let is_equal = build_multi_array_is_equal( lkeys.columns.as_slice(), rkeys.columns.as_slice(), - false, - false, + null_equals_nulls, + vec![false; lkeys.columns.len()].as_slice(), )?; // we will have at least as many rows in the join table as the max of the left and right tables diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 0c6b678d35..97c034a9df 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -75,6 +75,7 @@ impl Table { right: &Self, left_on: &[ExprRef], right_on: &[ExprRef], + null_equals_nulls: &[bool], how: JoinType, ) -> DaftResult { if left_on.len() != right_on.len() { @@ -92,12 +93,20 @@ impl Table { } match how { - JoinType::Inner => hash_inner_join(self, right, left_on, right_on), - JoinType::Left => hash_left_right_join(self, right, left_on, right_on, true), - JoinType::Right => hash_left_right_join(self, right, left_on, right_on, false), - JoinType::Outer => hash_outer_join(self, right, left_on, right_on), - JoinType::Semi => hash_semi_anti_join(self, right, left_on, right_on, false), - JoinType::Anti => hash_semi_anti_join(self, right, left_on, right_on, true), + JoinType::Inner => hash_inner_join(self, right, left_on, right_on, null_equals_nulls), + JoinType::Left => { + hash_left_right_join(self, right, left_on, right_on, null_equals_nulls, true) + } + JoinType::Right => { + hash_left_right_join(self, right, left_on, right_on, null_equals_nulls, false) + } + JoinType::Outer => hash_outer_join(self, right, left_on, right_on, null_equals_nulls), + JoinType::Semi => { + hash_semi_anti_join(self, right, left_on, right_on, null_equals_nulls, false) + } + JoinType::Anti => { + hash_semi_anti_join(self, right, left_on, right_on, null_equals_nulls, true) + } } } diff --git a/src/daft-table/src/probeable/mod.rs b/src/daft-table/src/probeable/mod.rs index 3346bd9869..c68cb60884 100644 --- a/src/daft-table/src/probeable/mod.rs +++ b/src/daft-table/src/probeable/mod.rs @@ -14,12 +14,19 @@ struct ArrowTableEntry(Vec>); pub fn make_probeable_builder( schema: SchemaRef, + nulls_equal_aware: Option<&Vec>, track_indices: bool, ) -> DaftResult> { if track_indices { - Ok(Box::new(ProbeTableBuilder(ProbeTable::new(schema)?))) + Ok(Box::new(ProbeTableBuilder(ProbeTable::new( + schema, + nulls_equal_aware, + )?))) } else { - Ok(Box::new(ProbeSetBuilder(ProbeSet::new(schema)?))) + Ok(Box::new(ProbeSetBuilder(ProbeSet::new( + schema, + nulls_equal_aware, + )?))) } } diff --git a/src/daft-table/src/probeable/probe_set.rs b/src/daft-table/src/probeable/probe_set.rs index adf9251756..40ab4ab86e 100644 --- a/src/daft-table/src/probeable/probe_set.rs +++ b/src/daft-table/src/probeable/probe_set.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use daft_core::{ array::ops::as_arrow::AsArrow, prelude::SchemaRef, @@ -31,12 +31,26 @@ impl ProbeSet { const DEFAULT_SIZE: usize = 20; - pub(crate) fn new(schema: SchemaRef) -> DaftResult { + pub(crate) fn new( + schema: SchemaRef, + nulls_equal_aware: Option<&Vec>, + ) -> DaftResult { let hash_table = HashMap::::with_capacity_and_hasher( Self::DEFAULT_SIZE, Default::default(), ); - let compare_fn = build_dyn_multi_array_compare(&schema, false, false)?; + if let Some(null_equal_aware) = nulls_equal_aware { + if null_equal_aware.len() != schema.len() { + return Err(DaftError::InternalError( + format!("null_equal_aware should have the same length as the schema. Expected: {}, Found: {}", + schema.len(), null_equal_aware.len()))); + } + } + let default_nulls_equal = vec![false; schema.len()]; + let nulls_equal = nulls_equal_aware.unwrap_or_else(|| default_nulls_equal.as_ref()); + let nans_equal = &vec![false; schema.len()]; + let compare_fn = + build_dyn_multi_array_compare(&schema, nulls_equal.as_slice(), nans_equal.as_slice())?; Ok(Self { schema, hash_table, diff --git a/src/daft-table/src/probeable/probe_table.rs b/src/daft-table/src/probeable/probe_table.rs index 8e8c8c8647..ba4c11dc41 100644 --- a/src/daft-table/src/probeable/probe_table.rs +++ b/src/daft-table/src/probeable/probe_table.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use daft_core::{ array::ops::as_arrow::AsArrow, prelude::SchemaRef, @@ -32,13 +32,24 @@ impl ProbeTable { const DEFAULT_SIZE: usize = 20; - pub(crate) fn new(schema: SchemaRef) -> DaftResult { + pub(crate) fn new(schema: SchemaRef, null_equal_aware: Option<&Vec>) -> DaftResult { let hash_table = HashMap::, IdentityBuildHasher>::with_capacity_and_hasher( Self::DEFAULT_SIZE, Default::default(), ); - let compare_fn = build_dyn_multi_array_compare(&schema, false, false)?; + if let Some(null_equal_aware) = null_equal_aware { + if null_equal_aware.len() != schema.len() { + return Err(DaftError::InternalError( + format!("null_equal_aware should have the same length as the schema. Expected: {}, Found: {}", + schema.len(), null_equal_aware.len()))); + } + } + let default_nulls_equal = vec![false; schema.len()]; + let nulls_equal = null_equal_aware.unwrap_or_else(|| default_nulls_equal.as_ref()); + let nans_equal = &vec![false; schema.len()]; + let compare_fn = + build_dyn_multi_array_compare(&schema, nulls_equal.as_slice(), nans_equal.as_slice())?; Ok(Self { schema, hash_table, diff --git a/src/daft-table/src/python.rs b/src/daft-table/src/python.rs index 89f1a12016..4bef3f9b05 100644 --- a/src/daft-table/src/python.rs +++ b/src/daft-table/src/python.rs @@ -134,6 +134,7 @@ impl PyTable { left_on.into_iter().map(std::convert::Into::into).collect(); let right_exprs: Vec = right_on.into_iter().map(std::convert::Into::into).collect(); + let null_equals_nulls = vec![false; left_exprs.len()]; py.allow_threads(|| { Ok(self .table @@ -141,6 +142,7 @@ impl PyTable { &right.table, left_exprs.as_slice(), right_exprs.as_slice(), + null_equals_nulls.as_slice(), how, )? .into()) diff --git a/tests/table/test_joins.py b/tests/table/test_joins.py index 968ab97630..6c58bbce6e 100644 --- a/tests/table/test_joins.py +++ b/tests/table/test_joins.py @@ -26,29 +26,83 @@ daft_string_types = [DataType.string()] +def skip_null_safe_equal_for_smj(func): + from functools import wraps + from inspect import getfullargspec + + @wraps(func) + def wrapper(*args, **kwargs): + spec = getfullargspec(func) + join_impl, null_safe_equal = None, None + if "join_impl" in kwargs: + join_impl = kwargs["join_impl"] + elif "join_impl" in spec.args: + idx = spec.args.index("join_impl") + join_impl = args[idx] if idx >= 0 else None + if "null_safe_equal" in kwargs: + null_safe_equal = kwargs["null_safe_equal"] + elif "null_safe_equal" in spec.args: + idx = spec.args.index("null_safe_equal") + null_safe_equal = args[idx] if idx >= 0 else None + if join_impl == "sort_merge_join" and null_safe_equal: + pytest.skip("sort merge join does not support null safe equal yet") + return func(*args, **kwargs) + + return wrapper + + +@skip_null_safe_equal_for_smj @pytest.mark.parametrize( - "dtype, data", + "dtype, data, null_safe_equal", itertools.product( daft_numeric_types + daft_string_types, [ - ([0, 1, 2, 3, None], [0, 1, 2, 3, None], [(0, 0), (1, 1), (2, 2), (3, 3)]), - ([None, None, 3, 1, 2, 0], [0, 1, 2, 3, None], [(5, 0), (3, 1), (4, 2), (2, 3)]), - ([None, 4, 5, 6, 7], [0, 1, 2, 3, None], []), - ([None, 0, 0, 0, 1, None], [0, 1, 2, 3, None], [(1, 0), (2, 0), (3, 0), (4, 1)]), - ([None, 0, 0, 1, 1, None], [0, 1, 2, 3, None], [(1, 0), (2, 0), (3, 1), (4, 1)]), - ([None, 0, 0, 1, 1, None], [3, 1, 0, 2, None], [(1, 2), (2, 2), (3, 1), (4, 1)]), + ( + [0, 1, 2, 3, None], + [0, 1, 2, 3, None], + [(0, 0), (1, 1), (2, 2), (3, 3)], + [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)], + ), + ( + [None, None, 3, 1, 2, 0], + [0, 1, 2, 3, None], + [(5, 0), (3, 1), (4, 2), (2, 3)], + [(5, 0), (3, 1), (4, 2), (2, 3), (0, 4), (1, 4)], + ), + ([None, 4, 5, 6, 7], [0, 1, 2, 3, None], [], [(0, 4)]), + ( + [None, 0, 0, 0, 1, None], + [0, 1, 2, 3, None], + [(1, 0), (2, 0), (3, 0), (4, 1)], + [(1, 0), (2, 0), (3, 0), (4, 1), (5, 4), (0, 4)], + ), + ( + [None, 0, 0, 1, 1, None], + [0, 1, 2, 3, None], + [(1, 0), (2, 0), (3, 1), (4, 1)], + [(1, 0), (2, 0), (3, 1), (4, 1), (5, 4), (0, 4)], + ), + ( + [None, 0, 0, 1, 1, None], + [3, 1, 0, 2, None], + [(1, 2), (2, 2), (3, 1), (4, 1)], + [(1, 2), (2, 2), (3, 1), (4, 1), (5, 4), (0, 4)], + ), ], + [True, False], ), ) @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_single_column(join_impl, dtype, data) -> None: - left, right, expected_pairs = data +def test_table_join_single_column(join_impl, dtype, data, null_safe_equal) -> None: + left, right, expected_pairs, null_safe_expected = data + expected_pairs = null_safe_expected if null_safe_equal else expected_pairs + null_equals_nulls = {"null_equals_nulls": [null_safe_equal]} if null_safe_equal else {} left_table = MicroPartition.from_pydict({"x": left, "x_ind": list(range(len(left)))}).eval_expression_list( [col("x").cast(dtype), col("x_ind")] ) right_table = MicroPartition.from_pydict({"y": right, "y_ind": list(range(len(right)))}) result_table = getattr(left_table, join_impl)( - right_table, left_on=[col("x")], right_on=[col("y")], how=JoinType.Inner + right_table, left_on=[col("x")], right_on=[col("y")], how=JoinType.Inner, **null_equals_nulls ) assert result_table.column_names() == ["x", "x_ind", "y", "y_ind"] @@ -65,7 +119,7 @@ def test_table_join_single_column(join_impl, dtype, data) -> None: # make sure the result is the same with right table on left result_table = getattr(right_table, join_impl)( - left_table, right_on=[col("x")], left_on=[col("y")], how=JoinType.Inner + left_table, right_on=[col("x")], left_on=[col("y")], how=JoinType.Inner, **null_equals_nulls ) assert result_table.column_names() == ["y", "y_ind", "x", "x_ind"] @@ -90,6 +144,7 @@ def test_table_join_mismatch_column(join_impl) -> None: getattr(left_table, join_impl)(right_table, left_on=[col("x"), col("y")], right_on=[col("a")]) +@skip_null_safe_equal_for_smj @pytest.mark.parametrize( "left", [ @@ -105,7 +160,8 @@ def test_table_join_mismatch_column(join_impl) -> None: ], ) @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_multicolumn_empty_result(join_impl, left, right) -> None: +@pytest.mark.parametrize("null_safe_equal", [True, False]) +def test_table_join_multicolumn_empty_result(join_impl, left, right, null_safe_equal) -> None: """Various multicol joins that should all produce an empty result.""" left_table = MicroPartition.from_pydict(left).eval_expression_list( [col("a").cast(DataType.string()), col("b").cast(DataType.int32())] @@ -114,12 +170,18 @@ def test_table_join_multicolumn_empty_result(join_impl, left, right) -> None: [col("x").cast(DataType.string()), col("y").cast(DataType.int32())] ) - result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal] * 2} if null_safe_equal else {} + + result = getattr(left_table, join_impl)( + right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")], **null_equals_nulls + ) assert result.to_pydict() == {"a": [], "b": [], "x": [], "y": []} -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_multicolumn_nocross(join_impl) -> None: +@pytest.mark.parametrize( + "join_impl,null_safe_equal", [("hash_join", True), ("hash_join", False), ("sort_merge_join", False)] +) +def test_table_join_multicolumn_nocross(join_impl, null_safe_equal) -> None: """A multicol join that should produce two rows and no cross product results. Input has duplicate join values and overlapping single-column values, @@ -127,85 +189,104 @@ def test_table_join_multicolumn_nocross(join_impl) -> None: """ left_table = MicroPartition.from_pydict( { - "a": ["apple", "apple", "banana", "banana", "carrot"], - "b": [1, 2, 2, 2, 3], - "c": [1, 2, 3, 4, 5], + "a": ["apple", "apple", "banana", "banana", "carrot", None], + "b": [1, 2, 2, 2, 3, 3], + "c": [1, 2, 3, 4, 5, 5], } ) right_table = MicroPartition.from_pydict( { - "x": ["banana", "carrot", "apple", "banana", "apple", "durian"], - "y": [1, 3, 2, 1, 3, 6], - "z": [1, 2, 3, 4, 5, 6], + "x": ["banana", "carrot", "apple", "banana", "apple", "durian", None], + "y": [1, 3, 2, 1, 3, 6, 3], + "z": [1, 2, 3, 4, 5, 6, 6], } ) - result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) - assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set( - utils.freeze( - [ - {"a": "apple", "b": 2, "c": 2, "x": "apple", "y": 2, "z": 3}, - {"a": "carrot", "b": 3, "c": 5, "x": "carrot", "y": 3, "z": 2}, - ] - ) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal] * 2} if null_safe_equal else {} + result = getattr(left_table, join_impl)( + right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")], **null_equals_nulls ) + expected = [ + {"a": "apple", "b": 2, "c": 2, "x": "apple", "y": 2, "z": 3}, + {"a": "carrot", "b": 3, "c": 5, "x": "carrot", "y": 3, "z": 2}, + ] + if null_safe_equal: + expected.append({"a": None, "b": 3, "c": 5, "x": None, "y": 3, "z": 6}) + assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze(expected)) -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_multicolumn_cross(join_impl) -> None: +@pytest.mark.parametrize( + "join_impl,null_safe_equal", [("hash_join", True), ("hash_join", False), ("sort_merge_join", False)] +) +def test_table_join_multicolumn_cross(join_impl, null_safe_equal) -> None: """A multicol join that should produce a cross product and a non-cross product.""" left_table = MicroPartition.from_pydict( { - "a": ["apple", "apple", "banana", "banana", "banana"], - "b": [1, 0, 1, 1, 1], - "c": [1, 2, 3, 4, 5], + "a": ["apple", "apple", "banana", "banana", "banana", None], + "b": [1, 0, 1, 1, 1, 1], + "c": [1, 2, 3, 4, 5, 5], } ) right_table = MicroPartition.from_pydict( { - "x": ["apple", "apple", "banana", "banana", "banana"], - "y": [1, 0, 1, 1, 0], - "z": [1, 2, 3, 4, 5], + "x": ["apple", "apple", "banana", "banana", "banana", None], + "y": [1, 0, 1, 1, 0, 1], + "z": [1, 2, 3, 4, 5, 5], } ) - result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) - assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set( - utils.freeze( - [ - {"a": "apple", "b": 1, "c": 1, "x": "apple", "y": 1, "z": 1}, - {"a": "apple", "b": 0, "c": 2, "x": "apple", "y": 0, "z": 2}, - {"a": "banana", "b": 1, "c": 3, "x": "banana", "y": 1, "z": 3}, - {"a": "banana", "b": 1, "c": 3, "x": "banana", "y": 1, "z": 4}, - {"a": "banana", "b": 1, "c": 4, "x": "banana", "y": 1, "z": 3}, - {"a": "banana", "b": 1, "c": 4, "x": "banana", "y": 1, "z": 4}, - {"a": "banana", "b": 1, "c": 5, "x": "banana", "y": 1, "z": 3}, - {"a": "banana", "b": 1, "c": 5, "x": "banana", "y": 1, "z": 4}, - ] - ) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal] * 2} if null_safe_equal else {} + result = getattr(left_table, join_impl)( + right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")], **null_equals_nulls ) + expected = [ + {"a": "apple", "b": 1, "c": 1, "x": "apple", "y": 1, "z": 1}, + {"a": "apple", "b": 0, "c": 2, "x": "apple", "y": 0, "z": 2}, + {"a": "banana", "b": 1, "c": 3, "x": "banana", "y": 1, "z": 3}, + {"a": "banana", "b": 1, "c": 3, "x": "banana", "y": 1, "z": 4}, + {"a": "banana", "b": 1, "c": 4, "x": "banana", "y": 1, "z": 3}, + {"a": "banana", "b": 1, "c": 4, "x": "banana", "y": 1, "z": 4}, + {"a": "banana", "b": 1, "c": 5, "x": "banana", "y": 1, "z": 3}, + {"a": "banana", "b": 1, "c": 5, "x": "banana", "y": 1, "z": 4}, + ] + if null_safe_equal: + expected.append({"a": None, "b": 1, "c": 5, "x": None, "y": 1, "z": 5}) + assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze(expected)) -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_multicolumn_all_nulls(join_impl) -> None: +@pytest.mark.parametrize( + "join_impl,null_safe_equal", [("hash_join", True), ("hash_join", False), ("sort_merge_join", False)] +) +def test_table_join_multicolumn_all_nulls(join_impl, null_safe_equal) -> None: left_table = MicroPartition.from_pydict( { - "a": Series.from_pylist([None, None, None]).cast(DataType.int64()), - "b": Series.from_pylist([None, None, None]).cast(DataType.string()), - "c": [1, 2, 3], + "a": Series.from_pylist([None, None]).cast(DataType.int64()), + "b": Series.from_pylist([None, None]).cast(DataType.string()), + "c": [1, 2], } ) right_table = MicroPartition.from_pydict( { - "x": Series.from_pylist([None, None, None]).cast(DataType.int64()), - "y": Series.from_pylist([None, None, None]).cast(DataType.string()), - "z": [1, 2, 3], + "x": Series.from_pylist([None, None]).cast(DataType.int64()), + "y": Series.from_pylist([None, None]).cast(DataType.string()), + "z": [1, 2], } ) - result = getattr(left_table, join_impl)(right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")]) - assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze([])) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal] * 2} if null_safe_equal else {} + result = getattr(left_table, join_impl)( + right_table, left_on=[col("a"), col("b")], right_on=[col("x"), col("y")], **null_equals_nulls + ) + expected = [] + if null_safe_equal: + expected = [ + {"a": None, "b": None, "c": 1, "x": None, "y": None, "z": 1}, + {"a": None, "b": None, "c": 1, "x": None, "y": None, "z": 2}, + {"a": None, "b": None, "c": 2, "x": None, "y": None, "z": 1}, + {"a": None, "b": None, "c": 2, "x": None, "y": None, "z": 2}, + ] + assert set(utils.freeze(utils.pydict_to_rows(result.to_pydict()))) == set(utils.freeze(expected)) @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) @@ -217,16 +298,25 @@ def test_table_join_no_columns(join_impl) -> None: getattr(left_table, join_impl)(right_table, left_on=[], right_on=[]) -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_single_column_name_boolean(join_impl) -> None: +@pytest.mark.parametrize( + "join_impl,null_safe_equal", [("hash_join", True), ("hash_join", False), ("sort_merge_join", False)] +) +def test_table_join_single_column_name_boolean(join_impl, null_safe_equal) -> None: left_table = MicroPartition.from_pydict({"x": [False, True, None], "y": [0, 1, 2]}) right_table = MicroPartition.from_pydict({"x": [None, True, False, None], "right.y": [0, 1, 2, 3]}) - result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) + null_equals_nulls = {"null_equals_nulls": [null_safe_equal]} if null_safe_equal else {} + result_table = getattr(left_table, join_impl)( + right_table, left_on=[col("x")], right_on=[col("x")], **null_equals_nulls + ) assert result_table.column_names() == ["x", "y", "right.y"] result_sorted = result_table.sort([col("x")]) - assert result_sorted.get_column("y").to_pylist() == [0, 1] - assert result_sorted.get_column("right.y").to_pylist() == [2, 1] + if null_safe_equal: + assert result_sorted.get_column("y").to_pylist() == [0, 1, 2, 2] + assert result_sorted.get_column("right.y").to_pylist() == [2, 1, 0, 3] + else: + assert result_sorted.get_column("y").to_pylist() == [0, 1] + assert result_sorted.get_column("right.y").to_pylist() == [2, 1] @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"])