Skip to content

Commit

Permalink
fix(join): joining on different types (#3716)
Browse files Browse the repository at this point in the history
This PR fixes a few things and also has some QOL changes.
Fixes:
- Joining on null-type join keys
- Joining on empty table (resolves #3071) which turned out to be the
above issue
- Joining on join keys with different types
- Combined column typing (right and outer joins should not just give
left column types)

QOL:
- Combine all the column renaming parameters into a `JoinOptions` type.
Reduces the parameters for a bunch of functions and also uses the
builder pattern
- Rename `keep_join_keys` field to `merge_matching_join_keys` to make
behavior more clear
  • Loading branch information
kevinzwang authored Jan 29, 2025
1 parent de5acf5 commit d00e444
Show file tree
Hide file tree
Showing 31 changed files with 652 additions and 638 deletions.
8 changes: 4 additions & 4 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1455,8 +1455,8 @@ class PyMicroPartition:
right: PyMicroPartition,
left_on: list[PyExpr],
right_on: list[PyExpr],
null_equals_nulls: list[bool] | None,
how: JoinType,
null_equals_nulls: list[bool] | None = None,
) -> PyMicroPartition: ...
def pivot(
self,
Expand Down Expand Up @@ -1643,9 +1643,9 @@ class LogicalPlanBuilder:
left_on: list[PyExpr],
right_on: list[PyExpr],
join_type: JoinType,
strategy: JoinStrategy | None = None,
join_prefix: str | None = None,
join_suffix: str | None = None,
join_strategy: JoinStrategy | None = None,
prefix: str | None = None,
suffix: str | None = None,
) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
Expand Down
4 changes: 2 additions & 2 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2012,8 +2012,8 @@ def join(
right_on=right_exprs,
how=join_type,
strategy=join_strategy,
join_prefix=prefix,
join_suffix=suffix,
prefix=prefix,
suffix=suffix,
)
return DataFrame(builder)

Expand Down
8 changes: 4 additions & 4 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,17 @@ def join( # type: ignore[override]
right_on: list[Expression],
how: JoinType = JoinType.Inner,
strategy: JoinStrategy | None = None,
join_suffix: str | None = None,
join_prefix: str | None = None,
prefix: str | None = None,
suffix: str | None = None,
) -> LogicalPlanBuilder:
builder = self._builder.join(
right._builder,
[expr._expr for expr in left_on],
[expr._expr for expr in right_on],
how,
strategy,
join_suffix,
join_prefix,
prefix,
suffix,
)
return LogicalPlanBuilder(builder)

Expand Down
22 changes: 15 additions & 7 deletions src/arrow2/src/array/dyn_ord.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
use std::cmp::Ordering;

use num_traits::Float;
use ord::total_cmp;

use std::cmp::Ordering;

use crate::datatypes::*;
use crate::error::Error;
use crate::offset::Offset;
use crate::{array::*, types::NativeType};
use crate::{array::*, datatypes::*, error::Error, offset::Offset, types::NativeType};

/// Compare the values at two arbitrary indices in two arbitrary arrays.
pub type DynArrayComparator =
Box<dyn Fn(&dyn Array, &dyn Array, usize, usize) -> Ordering + Send + Sync>;

#[inline]
unsafe fn is_valid<A: Array>(arr: &A, i: usize) -> bool {
unsafe fn is_valid(arr: &dyn Array, i: usize) -> bool {
// avoid dyn function hop by using generic
arr.validity()
.as_ref()
Expand Down Expand Up @@ -122,6 +119,16 @@ fn compare_dyn_boolean(nulls_equal: bool) -> DynArrayComparator {
})
}

fn compare_dyn_null(nulls_equal: bool) -> DynArrayComparator {
let ordering = if nulls_equal {
Ordering::Equal
} else {
Ordering::Less
};

Box::new(move |_, _, _, _| ordering)
}

pub fn build_dyn_array_compare(
left: &DataType,
right: &DataType,
Expand Down Expand Up @@ -187,6 +194,7 @@ pub fn build_dyn_array_compare(
// }
// }
// }
(Null, Null) => compare_dyn_null(nulls_equal),
(lhs, _) => {
return Err(Error::InvalidArgumentError(format!(
"The data type type {lhs:?} has no natural order"
Expand Down
14 changes: 10 additions & 4 deletions src/arrow2/src/array/ord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
use std::cmp::Ordering;

use crate::datatypes::*;
use crate::error::Error;
use crate::offset::Offset;
use crate::{array::*, types::NativeType};
use crate::{array::*, datatypes::*, error::Error, offset::Offset, types::NativeType};

/// Compare the values at two arbitrary indices in two arrays.
pub type DynComparator = Box<dyn Fn(usize, usize) -> Ordering + Send + Sync>;
Expand Down Expand Up @@ -157,6 +154,14 @@ macro_rules! dyn_dict {
}};
}

fn compare_null() -> DynComparator {
Box::new(move |_i: usize, _j: usize| {
// nulls do not have a canonical ordering, but it is trivially implemented so that
// null arrays can be used in things that depend on `build_compare`
Ordering::Less
})
}

/// returns a comparison function that compares values at two different slots
/// between two [`Array`].
/// # Example
Expand Down Expand Up @@ -243,6 +248,7 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result<DynComparato
}
}
}
(Null, Null) => compare_null(),
(lhs, _) => {
return Err(Error::InvalidArgumentError(format!(
"The data type type {lhs:?} has no natural order"
Expand Down
36 changes: 32 additions & 4 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod tests;

use std::{
any::Any,
collections::HashSet,
hash::{DefaultHasher, Hash, Hasher},
io::{self, Write},
str::FromStr,
Expand All @@ -21,7 +22,6 @@ use daft_core::{
utils::supertype::try_get_supertype,
};
use derive_more::Display;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

use super::functions::FunctionExpr;
Expand Down Expand Up @@ -1320,9 +1320,9 @@ impl FromStr for Operator {
// Check if one set of columns is a reordering of the other
pub fn is_partition_compatible(a: &[ExprRef], b: &[ExprRef]) -> bool {
// sort a and b by name
let a: Vec<&str> = a.iter().map(|a| a.name()).sorted().collect();
let b: Vec<&str> = b.iter().map(|a| a.name()).sorted().collect();
a == b
let a_set: HashSet<&ExprRef> = HashSet::from_iter(a);
let b_set: HashSet<&ExprRef> = HashSet::from_iter(b);
a_set == b_set
}

pub fn has_agg(expr: &ExprRef) -> bool {
Expand Down Expand Up @@ -1443,3 +1443,31 @@ pub fn exprs_to_schema(exprs: &[ExprRef], input_schema: SchemaRef) -> DaftResult
.collect::<DaftResult<_>>()?;
Ok(Arc::new(Schema::new(fields)?))
}

/// Adds aliases as appropriate to ensure that all expressions have unique names.
pub fn deduplicate_expr_names(exprs: &[ExprRef]) -> Vec<ExprRef> {
let mut names_so_far = HashSet::new();

exprs
.iter()
.map(|e| {
let curr_name = e.name();

let mut i = 0;
let mut new_name = curr_name.to_string();

while names_so_far.contains(&new_name) {
i += 1;
new_name = format!("{}_{}", curr_name, i);
}

names_so_far.insert(new_name.clone());

if i == 0 {
e.clone()
} else {
e.alias(new_name)
}
})
.collect()
}
99 changes: 99 additions & 0 deletions src/daft-dsl/src/join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use common_error::DaftResult;
use daft_core::{prelude::*, utils::supertype::try_get_supertype};
use indexmap::IndexSet;

use crate::{deduplicate_expr_names, ExprRef};

pub fn get_common_join_cols<'a>(
left_schema: &'a SchemaRef,
right_schema: &'a SchemaRef,
) -> impl Iterator<Item = &'a String> {
left_schema
.fields
.keys()
.filter(|name| right_schema.has_field(name))
}

/// Infer the schema of a join operation
pub fn infer_join_schema(
left_schema: &SchemaRef,
right_schema: &SchemaRef,
join_type: JoinType,
) -> DaftResult<SchemaRef> {
if matches!(join_type, JoinType::Anti | JoinType::Semi) {
Ok(left_schema.clone())
} else {
let common_cols = get_common_join_cols(left_schema, right_schema).collect::<IndexSet<_>>();

// common columns, then unique left fields, then unique right fields
let fields = common_cols
.iter()
.map(|name| {
let left_field = left_schema.get_field(name).unwrap();
let right_field = right_schema.get_field(name).unwrap();

Ok(match join_type {
JoinType::Inner => left_field.clone(),
JoinType::Left => left_field.clone(),
JoinType::Right => right_field.clone(),
JoinType::Outer => {
let supertype = try_get_supertype(&left_field.dtype, &right_field.dtype)?;

Field::new(*name, supertype)
}
JoinType::Anti | JoinType::Semi => unreachable!(),
})
})
.chain(
left_schema
.fields
.iter()
.chain(right_schema.fields.iter())
.filter_map(|(name, field)| {
if common_cols.contains(name) {
None
} else {
Some(field.clone())
}
})
.map(Ok),
)
.collect::<DaftResult<_>>()?;

Ok(Schema::new(fields)?.into())
}
}

/// Casts join keys to the same types and make their names unique.
pub fn normalize_join_keys(
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
left_schema: SchemaRef,
right_schema: SchemaRef,
) -> DaftResult<(Vec<ExprRef>, Vec<ExprRef>)> {
let (left_on, right_on) = left_on
.into_iter()
.zip(right_on)
.map(|(mut l, mut r)| {
let l_dtype = l.to_field(&left_schema)?.dtype;
let r_dtype = r.to_field(&right_schema)?.dtype;

let supertype = try_get_supertype(&l_dtype, &r_dtype)?;

if l_dtype != supertype {
l = l.cast(&supertype);
}

if r_dtype != supertype {
r = r.cast(&supertype);
}

Ok((l, r))
})
.collect::<DaftResult<(Vec<_>, Vec<_>)>>()?;

let left_on = deduplicate_expr_names(&left_on);
let right_on = deduplicate_expr_names(&right_on);

Ok((left_on, right_on))
}
84 changes: 0 additions & 84 deletions src/daft-dsl/src/join/mod.rs

This file was deleted.

Loading

0 comments on commit d00e444

Please sign in to comment.