Skip to content

Commit

Permalink
[FEAT] Support null safe equal in joins (#3161)
Browse files Browse the repository at this point in the history
This commit consists of the following parts:
1. Make the logical plan join's on condition null safe aware
2. Support translating null safe equal joins into physical plans
3. Make sure the optimization rules: eliminate cross join and push down
filter don't break
4. Modifications to physical join ops to support null safe equal: hash
and broadcast joins are supported. SMJ is not supported yet.
5. Glue code in Python side to make the whole pipeline work
6. Some UTs in the python side

Fixes: #3069 

TODOs(in follow-up PRs):
- [ ] rewrite null safe equal for SMJ so that SMJ could support null
safe equals as well
- [ ] SQL supports null safe equal, a.k.a SpaceShip(a<=>b)
- [ ] Python's DataFrame API supports null safe equal join
  • Loading branch information
advancedxy authored Nov 5, 2024
1 parent c1d82c5 commit 594b5e2
Show file tree
Hide file tree
Showing 35 changed files with 440 additions and 122 deletions.
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,7 @@ class PyMicroPartition:
right: PyMicroPartition,
left_on: list[PyExpr],
right_on: list[PyExpr],
null_equals_nulls: list[bool] | None,
how: JoinType,
) -> PyMicroPartition: ...
def pivot(
Expand Down
2 changes: 2 additions & 0 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
class HashJoin(SingleOutputInstruction):
left_on: ExpressionsProjection
right_on: ExpressionsProjection
null_equals_nulls: list[bool] | None
how: JoinType
is_swapped: bool

Expand All @@ -810,6 +811,7 @@ def _hash_join(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
right,
left_on=self.left_on,
right_on=self.right_on,
null_equals_nulls=self.null_equals_nulls,
how=self.how,
)
return [result]
Expand Down
15 changes: 14 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def hash_join(
right_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
null_equals_nulls: None | list[bool],
how: JoinType,
) -> InProgressPhysicalPlan[PartitionT]:
"""Hash-based pairwise join the partitions from `left_child_plan` and `right_child_plan` together."""
Expand Down Expand Up @@ -387,6 +388,7 @@ def hash_join(
instruction=execution_step.HashJoin(
left_on=left_on,
right_on=right_on,
null_equals_nulls=null_equals_nulls,
how=how,
is_swapped=False,
)
Expand Down Expand Up @@ -432,6 +434,7 @@ def _create_broadcast_join_step(
receiver_part: SingleOutputPartitionTask[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
null_equals_nulls: None | list[bool],
how: JoinType,
is_swapped: bool,
) -> PartitionTaskBuilder[PartitionT]:
Expand Down Expand Up @@ -477,6 +480,7 @@ def _create_broadcast_join_step(
instruction=execution_step.BroadcastJoin(
left_on=left_on,
right_on=right_on,
null_equals_nulls=null_equals_nulls,
how=how,
is_swapped=is_swapped,
)
Expand All @@ -488,6 +492,7 @@ def broadcast_join(
receiver_plan: InProgressPhysicalPlan[PartitionT],
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
null_equals_nulls: None | list[bool],
how: JoinType,
is_swapped: bool,
) -> InProgressPhysicalPlan[PartitionT]:
Expand Down Expand Up @@ -530,7 +535,15 @@ def broadcast_join(
# Broadcast all broadcaster partitions to each new receiver partition that was materialized on this dispatch loop.
while receiver_requests and receiver_requests[0].done():
receiver_part = receiver_requests.popleft()
yield _create_broadcast_join_step(broadcaster_parts, receiver_part, left_on, right_on, how, is_swapped)
yield _create_broadcast_join_step(
broadcaster_parts,
receiver_part,
left_on,
right_on,
null_equals_nulls,
how,
is_swapped,
)

# Execute single child step to pull in more input partitions.
try:
Expand Down
4 changes: 4 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def hash_join(
right: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
null_equals_nulls: list[bool] | None,
join_type: JoinType,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
Expand All @@ -253,6 +254,7 @@ def hash_join(
left_on=left_on_expr_proj,
right_on=right_on_expr_proj,
how=join_type,
null_equals_nulls=null_equals_nulls,
)


Expand Down Expand Up @@ -303,6 +305,7 @@ def broadcast_join(
receiver: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
null_equals_nulls: list[bool] | None,
join_type: JoinType,
is_swapped: bool,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
Expand All @@ -315,6 +318,7 @@ def broadcast_join(
right_on=right_on_expr_proj,
how=join_type,
is_swapped=is_swapped,
null_equals_nulls=null_equals_nulls,
)


Expand Down
9 changes: 8 additions & 1 deletion daft/table/micropartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def hash_join(
right: MicroPartition,
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
null_equals_nulls: list[bool] | None = None,
how: JoinType = JoinType.Inner,
) -> MicroPartition:
if len(left_on) != len(right_on):
Expand All @@ -262,7 +263,13 @@ def hash_join(
right_exprs = [e._expr for e in right_on]

return MicroPartition._from_pymicropartition(
self._micropartition.hash_join(right._micropartition, left_on=left_exprs, right_on=right_exprs, how=how)
self._micropartition.hash_join(
right._micropartition,
left_on=left_exprs,
right_on=right_exprs,
null_equals_nulls=null_equals_nulls,
how=how,
)
)

def sort_merge_join(
Expand Down
10 changes: 5 additions & 5 deletions src/daft-core/src/array/ops/arrow2/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ pub fn build_is_equal(
pub fn build_multi_array_is_equal(
left: &[Series],
right: &[Series],
nulls_equal: bool,
nan_equal: bool,
nulls_equal: &[bool],
nans_equal: &[bool],
) -> DaftResult<Box<dyn Fn(usize, usize) -> bool + Send + Sync>> {
let mut fn_list = Vec::with_capacity(left.len());

for (l, r) in left.iter().zip(right.iter()) {
for (idx, (l, r)) in left.iter().zip(right.iter()).enumerate() {
fn_list.push(build_is_equal(
l.to_arrow().as_ref(),
r.to_arrow().as_ref(),
nulls_equal,
nan_equal,
nulls_equal[idx],
nans_equal[idx],
)?);
}

Expand Down
10 changes: 5 additions & 5 deletions src/daft-core/src/utils/dyn_compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ pub fn build_dyn_compare(

pub fn build_dyn_multi_array_compare(
schema: &Schema,
nulls_equal: bool,
nans_equal: bool,
nulls_equal: &[bool],
nans_equal: &[bool],
) -> DaftResult<MultiDynArrayComparator> {
let mut fn_list = Vec::with_capacity(schema.len());
for field in schema.fields.values() {
for (idx, field) in schema.fields.values().enumerate() {
fn_list.push(build_dyn_compare(
&field.dtype,
&field.dtype,
nulls_equal,
nans_equal,
nulls_equal[idx],
nans_equal[idx],
)?);
}
let combined_fn = Box::new(
Expand Down
9 changes: 7 additions & 2 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ pub fn physical_plan_to_pipeline(
right,
left_on,
right_on,
null_equals_null,
join_type,
schema,
}) => {
Expand Down Expand Up @@ -371,9 +372,13 @@ pub fn physical_plan_to_pipeline(
.zip(key_schema.fields.values())
.map(|(e, f)| e.clone().cast(&f.dtype))
.collect::<Vec<_>>();

// we should move to a builder pattern
let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?;
let build_sink = HashJoinBuildSink::new(
key_schema,
casted_build_on,
null_equals_null.clone(),
join_type,
)?;
let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?;
let build_node =
BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed();
Expand Down
11 changes: 10 additions & 1 deletion src/daft-local-execution/src/sinks/hash_join_build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ impl ProbeTableState {
fn new(
key_schema: &SchemaRef,
projection: Vec<ExprRef>,
nulls_equal_aware: Option<&Vec<bool>>,
join_type: &JoinType,
) -> DaftResult<Self> {
let track_indices = !matches!(join_type, JoinType::Anti | JoinType::Semi);
Ok(Self::Building {
probe_table_builder: Some(make_probeable_builder(key_schema.clone(), track_indices)?),
probe_table_builder: Some(make_probeable_builder(
key_schema.clone(),
nulls_equal_aware,
track_indices,
)?),
projection,
tables: Vec::new(),
})
Expand Down Expand Up @@ -83,18 +88,21 @@ impl BlockingSinkState for ProbeTableState {
pub struct HashJoinBuildSink {
key_schema: SchemaRef,
projection: Vec<ExprRef>,
nulls_equal_aware: Option<Vec<bool>>,
join_type: JoinType,
}

impl HashJoinBuildSink {
pub(crate) fn new(
key_schema: SchemaRef,
projection: Vec<ExprRef>,
nulls_equal_aware: Option<Vec<bool>>,
join_type: &JoinType,
) -> DaftResult<Self> {
Ok(Self {
key_schema,
projection,
nulls_equal_aware,
join_type: *join_type,
})
}
Expand Down Expand Up @@ -144,6 +152,7 @@ impl BlockingSink for HashJoinBuildSink {
Ok(Box::new(ProbeTableState::new(
&self.key_schema,
self.projection.clone(),
self.nulls_equal_aware.as_ref(),
&self.join_type,
)?))
}
Expand Down
8 changes: 7 additions & 1 deletion src/daft-micropartition/src/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,17 @@ impl MicroPartition {
right: &Self,
left_on: &[ExprRef],
right_on: &[ExprRef],
null_equals_nulls: Option<Vec<bool>>,
how: JoinType,
) -> DaftResult<Self> {
let io_stats = IOStatsContext::new("MicroPartition::hash_join");
let null_equals_nulls = null_equals_nulls.unwrap_or_else(|| vec![false; left_on.len()]);
let table_join =
|lt: &Table, rt: &Table, lo: &[ExprRef], ro: &[ExprRef], _how: JoinType| {
Table::hash_join(lt, rt, lo, ro, null_equals_nulls.as_slice(), _how)
};

self.join(right, io_stats, left_on, right_on, how, Table::hash_join)
self.join(right, io_stats, left_on, right_on, how, table_join)
}

pub fn sort_merge_join(
Expand Down
2 changes: 2 additions & 0 deletions src/daft-micropartition/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ impl PyMicroPartition {
left_on: Vec<PyExpr>,
right_on: Vec<PyExpr>,
how: JoinType,
null_equals_nulls: Option<Vec<bool>>,
) -> PyResult<Self> {
let left_exprs: Vec<daft_dsl::ExprRef> =
left_on.into_iter().map(std::convert::Into::into).collect();
Expand All @@ -272,6 +273,7 @@ impl PyMicroPartition {
&right.inner,
left_exprs.as_slice(),
right_exprs.as_slice(),
null_equals_nulls,
how,
)?
.into())
Expand Down
3 changes: 3 additions & 0 deletions src/daft-physical-plan/src/local_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ impl LocalPhysicalPlan {
right: LocalPhysicalPlanRef,
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
null_equals_null: Option<Vec<bool>>,
join_type: JoinType,
schema: SchemaRef,
) -> LocalPhysicalPlanRef {
Expand All @@ -269,6 +270,7 @@ impl LocalPhysicalPlan {
right,
left_on,
right_on,
null_equals_null,
join_type,
schema,
})
Expand Down Expand Up @@ -452,6 +454,7 @@ pub struct HashJoin {
pub right: LocalPhysicalPlanRef,
pub left_on: Vec<ExprRef>,
pub right_on: Vec<ExprRef>,
pub null_equals_null: Option<Vec<bool>>,
pub join_type: JoinType,
pub schema: SchemaRef,
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-physical-plan/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult<LocalPhysicalPlanRef> {
right,
join.left_on.clone(),
join.right_on.clone(),
join.null_equals_nulls.clone(),
join.join_type,
join.output_schema.clone(),
))
Expand Down
25 changes: 25 additions & 0 deletions src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,37 @@ impl LogicalPlanBuilder {
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
) -> DaftResult<Self> {
self.join_with_null_safe_equal(
right,
left_on,
right_on,
None,
join_type,
join_strategy,
join_suffix,
join_prefix,
)
}

#[allow(clippy::too_many_arguments)]
pub fn join_with_null_safe_equal<Right: Into<LogicalPlanRef>>(
&self,
right: Right,
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
null_equals_nulls: Option<Vec<bool>>,
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
) -> DaftResult<Self> {
let logical_plan: LogicalPlan = logical_ops::Join::try_new(
self.plan.clone(),
right.into(),
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
join_suffix,
Expand Down
3 changes: 2 additions & 1 deletion src/daft-plan/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,11 @@ Project1 --> Limit0
.build();

let plan = LogicalPlanBuilder::new(subplan, None)
.join(
.join_with_null_safe_equal(
subplan2,
vec![col("id")],
vec![col("id")],
Some(vec![true]),
JoinType::Inner,
None,
None,
Expand Down
Loading

0 comments on commit 594b5e2

Please sign in to comment.