diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index b2abe3610c..f91f73d110 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1455,8 +1455,8 @@ class PyMicroPartition: right: PyMicroPartition, left_on: list[PyExpr], right_on: list[PyExpr], - null_equals_nulls: list[bool] | None, how: JoinType, + null_equals_nulls: list[bool] | None = None, ) -> PyMicroPartition: ... def pivot( self, @@ -1643,9 +1643,9 @@ class LogicalPlanBuilder: left_on: list[PyExpr], right_on: list[PyExpr], join_type: JoinType, - strategy: JoinStrategy | None = None, - join_prefix: str | None = None, - join_suffix: str | None = None, + join_strategy: JoinStrategy | None = None, + prefix: str | None = None, + suffix: str | None = None, ) -> LogicalPlanBuilder: ... def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ... def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 0926c48b04..686a7fe971 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -2012,8 +2012,8 @@ def join( right_on=right_exprs, how=join_type, strategy=join_strategy, - join_prefix=prefix, - join_suffix=suffix, + prefix=prefix, + suffix=suffix, ) return DataFrame(builder) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 3d8178c729..0a469a14f2 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -270,8 +270,8 @@ def join( # type: ignore[override] right_on: list[Expression], how: JoinType = JoinType.Inner, strategy: JoinStrategy | None = None, - join_suffix: str | None = None, - join_prefix: str | None = None, + prefix: str | None = None, + suffix: str | None = None, ) -> LogicalPlanBuilder: builder = self._builder.join( right._builder, @@ -279,8 +279,8 @@ def join( # type: ignore[override] [expr._expr for expr in right_on], how, strategy, - join_suffix, - join_prefix, + prefix, + suffix, ) return LogicalPlanBuilder(builder) diff --git a/src/arrow2/src/array/dyn_ord.rs b/src/arrow2/src/array/dyn_ord.rs index 7813cd4661..074c599693 100644 --- a/src/arrow2/src/array/dyn_ord.rs +++ b/src/arrow2/src/array/dyn_ord.rs @@ -1,19 +1,16 @@ +use std::cmp::Ordering; + use num_traits::Float; use ord::total_cmp; -use std::cmp::Ordering; - -use crate::datatypes::*; -use crate::error::Error; -use crate::offset::Offset; -use crate::{array::*, types::NativeType}; +use crate::{array::*, datatypes::*, error::Error, offset::Offset, types::NativeType}; /// Compare the values at two arbitrary indices in two arbitrary arrays. pub type DynArrayComparator = Box Ordering + Send + Sync>; #[inline] -unsafe fn is_valid(arr: &A, i: usize) -> bool { +unsafe fn is_valid(arr: &dyn Array, i: usize) -> bool { // avoid dyn function hop by using generic arr.validity() .as_ref() @@ -122,6 +119,16 @@ fn compare_dyn_boolean(nulls_equal: bool) -> DynArrayComparator { }) } +fn compare_dyn_null(nulls_equal: bool) -> DynArrayComparator { + let ordering = if nulls_equal { + Ordering::Equal + } else { + Ordering::Less + }; + + Box::new(move |_, _, _, _| ordering) +} + pub fn build_dyn_array_compare( left: &DataType, right: &DataType, @@ -187,6 +194,7 @@ pub fn build_dyn_array_compare( // } // } // } + (Null, Null) => compare_dyn_null(nulls_equal), (lhs, _) => { return Err(Error::InvalidArgumentError(format!( "The data type type {lhs:?} has no natural order" diff --git a/src/arrow2/src/array/ord.rs b/src/arrow2/src/array/ord.rs index 6bf0d95126..5e574546d3 100644 --- a/src/arrow2/src/array/ord.rs +++ b/src/arrow2/src/array/ord.rs @@ -2,10 +2,7 @@ use std::cmp::Ordering; -use crate::datatypes::*; -use crate::error::Error; -use crate::offset::Offset; -use crate::{array::*, types::NativeType}; +use crate::{array::*, datatypes::*, error::Error, offset::Offset, types::NativeType}; /// Compare the values at two arbitrary indices in two arrays. pub type DynComparator = Box Ordering + Send + Sync>; @@ -157,6 +154,14 @@ macro_rules! dyn_dict { }}; } +fn compare_null() -> DynComparator { + Box::new(move |_i: usize, _j: usize| { + // nulls do not have a canonical ordering, but it is trivially implemented so that + // null arrays can be used in things that depend on `build_compare` + Ordering::Less + }) +} + /// returns a comparison function that compares values at two different slots /// between two [`Array`]. /// # Example @@ -243,6 +248,7 @@ pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result compare_null(), (lhs, _) => { return Err(Error::InvalidArgumentError(format!( "The data type type {lhs:?} has no natural order" diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 6595a022d3..b2a72ece88 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -3,6 +3,7 @@ mod tests; use std::{ any::Any, + collections::HashSet, hash::{DefaultHasher, Hash, Hasher}, io::{self, Write}, str::FromStr, @@ -21,7 +22,6 @@ use daft_core::{ utils::supertype::try_get_supertype, }; use derive_more::Display; -use itertools::Itertools; use serde::{Deserialize, Serialize}; use super::functions::FunctionExpr; @@ -1320,9 +1320,9 @@ impl FromStr for Operator { // Check if one set of columns is a reordering of the other pub fn is_partition_compatible(a: &[ExprRef], b: &[ExprRef]) -> bool { // sort a and b by name - let a: Vec<&str> = a.iter().map(|a| a.name()).sorted().collect(); - let b: Vec<&str> = b.iter().map(|a| a.name()).sorted().collect(); - a == b + let a_set: HashSet<&ExprRef> = HashSet::from_iter(a); + let b_set: HashSet<&ExprRef> = HashSet::from_iter(b); + a_set == b_set } pub fn has_agg(expr: &ExprRef) -> bool { @@ -1443,3 +1443,31 @@ pub fn exprs_to_schema(exprs: &[ExprRef], input_schema: SchemaRef) -> DaftResult .collect::>()?; Ok(Arc::new(Schema::new(fields)?)) } + +/// Adds aliases as appropriate to ensure that all expressions have unique names. +pub fn deduplicate_expr_names(exprs: &[ExprRef]) -> Vec { + let mut names_so_far = HashSet::new(); + + exprs + .iter() + .map(|e| { + let curr_name = e.name(); + + let mut i = 0; + let mut new_name = curr_name.to_string(); + + while names_so_far.contains(&new_name) { + i += 1; + new_name = format!("{}_{}", curr_name, i); + } + + names_so_far.insert(new_name.clone()); + + if i == 0 { + e.clone() + } else { + e.alias(new_name) + } + }) + .collect() +} diff --git a/src/daft-dsl/src/join.rs b/src/daft-dsl/src/join.rs new file mode 100644 index 0000000000..1260a6d63b --- /dev/null +++ b/src/daft-dsl/src/join.rs @@ -0,0 +1,99 @@ +use common_error::DaftResult; +use daft_core::{prelude::*, utils::supertype::try_get_supertype}; +use indexmap::IndexSet; + +use crate::{deduplicate_expr_names, ExprRef}; + +pub fn get_common_join_cols<'a>( + left_schema: &'a SchemaRef, + right_schema: &'a SchemaRef, +) -> impl Iterator { + left_schema + .fields + .keys() + .filter(|name| right_schema.has_field(name)) +} + +/// Infer the schema of a join operation +pub fn infer_join_schema( + left_schema: &SchemaRef, + right_schema: &SchemaRef, + join_type: JoinType, +) -> DaftResult { + if matches!(join_type, JoinType::Anti | JoinType::Semi) { + Ok(left_schema.clone()) + } else { + let common_cols = get_common_join_cols(left_schema, right_schema).collect::>(); + + // common columns, then unique left fields, then unique right fields + let fields = common_cols + .iter() + .map(|name| { + let left_field = left_schema.get_field(name).unwrap(); + let right_field = right_schema.get_field(name).unwrap(); + + Ok(match join_type { + JoinType::Inner => left_field.clone(), + JoinType::Left => left_field.clone(), + JoinType::Right => right_field.clone(), + JoinType::Outer => { + let supertype = try_get_supertype(&left_field.dtype, &right_field.dtype)?; + + Field::new(*name, supertype) + } + JoinType::Anti | JoinType::Semi => unreachable!(), + }) + }) + .chain( + left_schema + .fields + .iter() + .chain(right_schema.fields.iter()) + .filter_map(|(name, field)| { + if common_cols.contains(name) { + None + } else { + Some(field.clone()) + } + }) + .map(Ok), + ) + .collect::>()?; + + Ok(Schema::new(fields)?.into()) + } +} + +/// Casts join keys to the same types and make their names unique. +pub fn normalize_join_keys( + left_on: Vec, + right_on: Vec, + left_schema: SchemaRef, + right_schema: SchemaRef, +) -> DaftResult<(Vec, Vec)> { + let (left_on, right_on) = left_on + .into_iter() + .zip(right_on) + .map(|(mut l, mut r)| { + let l_dtype = l.to_field(&left_schema)?.dtype; + let r_dtype = r.to_field(&right_schema)?.dtype; + + let supertype = try_get_supertype(&l_dtype, &r_dtype)?; + + if l_dtype != supertype { + l = l.cast(&supertype); + } + + if r_dtype != supertype { + r = r.cast(&supertype); + } + + Ok((l, r)) + }) + .collect::, Vec<_>)>>()?; + + let left_on = deduplicate_expr_names(&left_on); + let right_on = deduplicate_expr_names(&right_on); + + Ok((left_on, right_on)) +} diff --git a/src/daft-dsl/src/join/mod.rs b/src/daft-dsl/src/join/mod.rs deleted file mode 100644 index 1de29b995e..0000000000 --- a/src/daft-dsl/src/join/mod.rs +++ /dev/null @@ -1,84 +0,0 @@ -#[cfg(test)] -mod tests; - -use std::sync::Arc; - -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; -use indexmap::IndexSet; - -use crate::{Expr, ExprRef}; - -/// Get the columns between the two sides of the join that should be merged in the order of the join keys. -/// Join keys should only be merged if they are column expressions. -pub fn get_common_join_keys<'a>( - left_on: &'a [ExprRef], - right_on: &'a [ExprRef], -) -> impl Iterator> { - left_on.iter().zip(right_on.iter()).filter_map(|(l, r)| { - if let (Expr::Column(l_name), Expr::Column(r_name)) = (&**l, &**r) - && l_name == r_name - { - Some(l_name) - } else { - None - } - }) -} - -/// Infer the schema of a join operation -/// -/// This function assumes that the only common field names between the left and right schemas are the join fields, -/// which is valid because the right columns are renamed during the construction of a join logical operation. -pub fn infer_join_schema( - left_schema: &SchemaRef, - right_schema: &SchemaRef, - left_on: &[ExprRef], - right_on: &[ExprRef], - how: JoinType, -) -> DaftResult { - if left_on.len() != right_on.len() { - return Err(DaftError::ValueError(format!( - "Length of left_on does not match length of right_on for Join {} vs {}", - left_on.len(), - right_on.len() - ))); - } - - if matches!(how, JoinType::Anti | JoinType::Semi) { - Ok(left_schema.clone()) - } else { - let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on) - .map(|k| k.to_string()) - .collect(); - - // common join fields, then unique left fields, then unique right fields - let fields: Vec<_> = common_join_keys - .iter() - .map(|name| { - left_schema - .get_field(name) - .expect("Common join key should exist in left schema") - }) - .chain(left_schema.fields.iter().filter_map(|(name, field)| { - if common_join_keys.contains(name) { - None - } else { - Some(field) - } - })) - .chain(right_schema.fields.iter().filter_map(|(name, field)| { - if common_join_keys.contains(name) { - None - } else if left_schema.fields.contains_key(name) { - unreachable!("Right schema should have renamed columns") - } else { - Some(field) - } - })) - .cloned() - .collect(); - - Ok(Schema::new(fields)?.into()) - } -} diff --git a/src/daft-dsl/src/join/tests.rs b/src/daft-dsl/src/join/tests.rs deleted file mode 100644 index 52d58a76c0..0000000000 --- a/src/daft-dsl/src/join/tests.rs +++ /dev/null @@ -1,27 +0,0 @@ -use super::*; -use crate::col; - -#[test] -fn test_get_common_join_keys() { - let left_on: &[ExprRef] = &[ - col("a"), - col("b_left"), - col("c").alias("c_new"), - col("d").alias("d_new"), - col("e").add(col("f")), - ]; - - let right_on: &[ExprRef] = &[ - col("a"), - col("b_right"), - col("c"), - col("d").alias("d_new"), - col("e"), - ]; - - let common_join_keys = get_common_join_keys(left_on, right_on) - .map(|k| k.to_string()) - .collect::>(); - - assert_eq!(common_join_keys, vec!["a"]); -} diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index c29ca9f779..a28fa41c46 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -14,9 +14,10 @@ pub mod python; mod treenode; pub use common_treenode; pub use expr::{ - binary_op, col, count_actor_pool_udfs, estimated_selectivity, exprs_to_schema, has_agg, - is_actor_pool_udf, is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef, - Operator, OuterReferenceColumn, SketchType, Subquery, SubqueryPlan, + binary_op, col, count_actor_pool_udfs, deduplicate_expr_names, estimated_selectivity, + exprs_to_schema, has_agg, is_actor_pool_udf, is_partition_compatible, AggExpr, + ApproxPercentileParams, Expr, ExprRef, Operator, OuterReferenceColumn, SketchType, Subquery, + SubqueryPlan, }; pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 0033758491..49844f2c61 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -7,15 +7,10 @@ use common_display::{ tree::TreeDisplay, DisplayLevel, }; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use common_file_formats::FileFormat; -use daft_core::{ - datatypes::Field, - join::JoinSide, - prelude::{Schema, SchemaRef}, - utils::supertype, -}; -use daft_dsl::{col, join::get_common_join_keys}; +use daft_core::{join::JoinSide, prelude::Schema}; +use daft_dsl::{col, join::get_common_join_cols}; use daft_local_plan::{ ActorPoolProject, Concat, CrossJoin, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, MonotonicallyIncreasingId, PhysicalWrite, Pivot, @@ -424,7 +419,7 @@ pub fn physical_plan_to_pipeline( let build_schema = build_child.schema(); let probe_schema = probe_child.schema(); || -> DaftResult<_> { - let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on) + let common_join_cols: IndexSet<_> = get_common_join_cols(left_schema, right_schema) .map(std::string::ToString::to_string) .collect(); let build_key_fields = build_on @@ -435,29 +430,16 @@ pub fn physical_plan_to_pipeline( .iter() .map(|e| e.to_field(probe_schema)) .collect::>>()?; - let key_schema: SchemaRef = Schema::new( - build_key_fields - .into_iter() - .zip(probe_key_fields.into_iter()) - .map(|(l, r)| { - // TODO we should be using the comparison_op function here instead but i'm just using existing behavior for now - let dtype = supertype::try_get_supertype(&l.dtype, &r.dtype)?; - Ok(Field::new(l.name, dtype)) - }) - .collect::>>()?, - )? - .into(); - let casted_build_on = build_on - .iter() - .zip(key_schema.fields.values()) - .map(|(e, f)| e.clone().cast(&f.dtype)) - .collect::>(); - let casted_probe_on = probe_on - .iter() - .zip(key_schema.fields.values()) - .map(|(e, f)| e.clone().cast(&f.dtype)) - .collect::>(); + for (build_field, probe_field) in build_key_fields.iter().zip(probe_key_fields.iter()) { + if build_field.dtype != probe_field.dtype { + return Err(DaftError::SchemaMismatch( + format!("Expected build and probe key field datatypes to match, found: {} vs {}", build_field.dtype, probe_field.dtype) + )); + } + } + let key_schema = Arc::new(Schema::new(build_key_fields)?); + // we should move to a builder pattern let probe_state_bridge = BroadcastStateBridge::new(); let track_indices = if matches!(join_type, JoinType::Anti | JoinType::Semi) { @@ -467,7 +449,7 @@ pub fn physical_plan_to_pipeline( }; let build_sink = HashJoinBuildSink::new( key_schema, - casted_build_on, + build_on.clone(), null_equals_null.clone(), track_indices, probe_state_bridge.clone(), @@ -485,7 +467,7 @@ pub fn physical_plan_to_pipeline( match join_type { JoinType::Anti | JoinType::Semi => Ok(StreamingSinkNode::new( Arc::new(AntiSemiProbeSink::new( - casted_probe_on, + probe_on.clone(), join_type, schema, probe_state_bridge, @@ -497,11 +479,11 @@ pub fn physical_plan_to_pipeline( .boxed()), JoinType::Inner => Ok(IntermediateNode::new( Arc::new(InnerHashJoinProbeOperator::new( - casted_probe_on, + probe_on.clone(), left_schema, right_schema, build_on_left, - common_join_keys, + common_join_cols, schema, probe_state_bridge, )), @@ -512,15 +494,15 @@ pub fn physical_plan_to_pipeline( JoinType::Left | JoinType::Right | JoinType::Outer => { Ok(StreamingSinkNode::new( Arc::new(OuterHashJoinProbeSink::new( - casted_probe_on, + probe_on.clone(), left_schema, right_schema, *join_type, build_on_left, - common_join_keys, + common_join_cols, schema, probe_state_bridge, - )), + )?), vec![build_node, probe_child_node], stats_state.clone(), ) diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 110b45ff87..847a8c9e6a 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -1,12 +1,11 @@ use std::sync::Arc; +use bitmap::{and, Bitmap, MutableBitmap}; use common_error::DaftResult; use daft_core::{ - prelude::{ - bitmap::{and, Bitmap, MutableBitmap}, - BooleanArray, Schema, SchemaRef, - }, + prelude::*, series::{IntoSeries, Series}, + utils::supertype::try_get_supertype, }; use daft_dsl::ExprRef; use daft_logical_plan::JoinType; @@ -125,9 +124,10 @@ impl StreamingSinkState for OuterHashJoinState { struct OuterHashJoinParams { probe_on: Vec, - common_join_keys: Vec, + common_join_cols: Vec, left_non_join_columns: Vec, right_non_join_columns: Vec, + outer_common_col_schema: SchemaRef, left_non_join_schema: SchemaRef, right_non_join_schema: SchemaRef, join_type: JoinType, @@ -149,10 +149,10 @@ impl OuterHashJoinProbeSink { right_schema: &SchemaRef, join_type: JoinType, build_on_left: bool, - common_join_keys: IndexSet, + common_join_cols: IndexSet, output_schema: &SchemaRef, probe_state_bridge: BroadcastStateBridgeRef, - ) -> Self { + ) -> DaftResult { let needs_bitmap = join_type == JoinType::Outer || join_type == JoinType::Right && !build_on_left || join_type == JoinType::Left && build_on_left; @@ -161,31 +161,42 @@ impl OuterHashJoinProbeSink { (JoinType::Outer, false) => (right_schema, left_schema), _ => (left_schema, right_schema), }; + let outer_common_col_fields = common_join_cols + .iter() + .map(|name| { + let supertype = try_get_supertype( + &left_schema.get_field(name)?.dtype, + &right_schema.get_field(name)?.dtype, + )?; + + Ok(Field::new(name.clone(), supertype)) + }) + .collect::>()?; + let outer_common_col_schema = Arc::new(Schema::new(outer_common_col_fields)?); let left_non_join_fields = left_schema .fields .values() - .filter(|f| !common_join_keys.contains(&f.name)) + .filter(|f| !common_join_cols.contains(&f.name)) .cloned() .collect(); - let left_non_join_schema = - Arc::new(Schema::new(left_non_join_fields).expect("left schema should be valid")); + let left_non_join_schema = Arc::new(Schema::new(left_non_join_fields)?); let left_non_join_columns = left_non_join_schema.fields.keys().cloned().collect(); let right_non_join_fields = right_schema .fields .values() - .filter(|f| !common_join_keys.contains(&f.name)) + .filter(|f| !common_join_cols.contains(&f.name)) .cloned() .collect(); - let right_non_join_schema = - Arc::new(Schema::new(right_non_join_fields).expect("right schema should be valid")); + let right_non_join_schema = Arc::new(Schema::new(right_non_join_fields)?); let right_non_join_columns = right_non_join_schema.fields.keys().cloned().collect(); - let common_join_keys = common_join_keys.into_iter().collect(); - Self { + let common_join_cols = common_join_cols.into_iter().collect(); + Ok(Self { params: Arc::new(OuterHashJoinParams { probe_on, - common_join_keys, + common_join_cols, left_non_join_columns, right_non_join_columns, + outer_common_col_schema, left_non_join_schema, right_non_join_schema, join_type, @@ -194,7 +205,7 @@ impl OuterHashJoinProbeSink { needs_bitmap, output_schema: output_schema.clone(), probe_state_bridge, - } + }) } fn probe_left_right_with_bitmap( @@ -203,7 +214,7 @@ impl OuterHashJoinProbeSink { probe_state: &ProbeState, join_type: JoinType, probe_on: &[ExprRef], - common_join_keys: &[String], + common_join_cols: &[String], left_non_join_columns: &[String], right_non_join_columns: &[String], ) -> DaftResult> { @@ -248,12 +259,12 @@ impl OuterHashJoinProbeSink { let probe_side_table = probe_side_growable.build()?; let final_table = if join_type == JoinType::Left { - let join_table = build_side_table.get_columns(common_join_keys)?; + let join_table = build_side_table.get_columns(common_join_cols)?; let left = build_side_table.get_columns(left_non_join_columns)?; let right = probe_side_table.get_columns(right_non_join_columns)?; join_table.union(&left)?.union(&right)? } else { - let join_table = build_side_table.get_columns(common_join_keys)?; + let join_table = build_side_table.get_columns(common_join_cols)?; let left = probe_side_table.get_columns(left_non_join_columns)?; let right = build_side_table.get_columns(right_non_join_columns)?; join_table.union(&left)?.union(&right)? @@ -270,7 +281,7 @@ impl OuterHashJoinProbeSink { probe_state: &ProbeState, join_type: JoinType, probe_on: &[ExprRef], - common_join_keys: &[String], + common_join_cols: &[String], left_non_join_columns: &[String], right_non_join_columns: &[String], ) -> DaftResult> { @@ -317,12 +328,12 @@ impl OuterHashJoinProbeSink { let probe_side_table = probe_side_growable.build()?; let final_table = if join_type == JoinType::Left { - let join_table = probe_side_table.get_columns(common_join_keys)?; + let join_table = probe_side_table.get_columns(common_join_cols)?; let left = probe_side_table.get_columns(left_non_join_columns)?; let right = build_side_table.get_columns(right_non_join_columns)?; join_table.union(&left)?.union(&right)? } else { - let join_table = probe_side_table.get_columns(common_join_keys)?; + let join_table = probe_side_table.get_columns(common_join_cols)?; let left = build_side_table.get_columns(left_non_join_columns)?; let right = probe_side_table.get_columns(right_non_join_columns)?; join_table.union(&left)?.union(&right)? @@ -340,7 +351,8 @@ impl OuterHashJoinProbeSink { probe_state: &ProbeState, bitmap_builder: &mut IndexBitmapBuilder, probe_on: &[ExprRef], - common_join_keys: &[String], + common_join_cols: &[String], + outer_common_col_schema: &SchemaRef, left_non_join_columns: &[String], right_non_join_columns: &[String], build_on_left: bool, @@ -387,7 +399,9 @@ impl OuterHashJoinProbeSink { let build_side_table = build_side_growable.build()?; let probe_side_table = probe_side_growable.build()?; - let join_table = probe_side_table.get_columns(common_join_keys)?; + let join_table = probe_side_table + .get_columns(common_join_cols)? + .cast_to_schema(outer_common_col_schema)?; let left = build_side_table.get_columns(left_non_join_columns)?; let right = probe_side_table.get_columns(right_non_join_columns)?; // If we built the probe table on the right, flip the order of union. @@ -461,13 +475,16 @@ impl OuterHashJoinProbeSink { async fn finalize_outer( states: Vec>, - common_join_keys: &[String], + common_join_cols: &[String], + outer_common_col_schema: &SchemaRef, left_non_join_columns: &[String], right_non_join_schema: &SchemaRef, build_on_left: bool, ) -> DaftResult>> { let build_side_table = Self::merge_bitmaps_and_construct_null_table(states).await?; - let join_table = build_side_table.get_columns(common_join_keys)?; + let join_table = build_side_table + .get_columns(common_join_cols)? + .cast_to_schema(outer_common_col_schema)?; let left = build_side_table.get_columns(left_non_join_columns)?; let right = { let columns = right_non_join_schema @@ -493,12 +510,12 @@ impl OuterHashJoinProbeSink { async fn finalize_left( states: Vec>, - common_join_keys: &[String], + common_join_cols: &[String], left_non_join_columns: &[String], right_non_join_schema: &SchemaRef, ) -> DaftResult>> { let build_side_table = Self::merge_bitmaps_and_construct_null_table(states).await?; - let join_table = build_side_table.get_columns(common_join_keys)?; + let join_table = build_side_table.get_columns(common_join_cols)?; let left = build_side_table.get_columns(left_non_join_columns)?; let right = { let columns = right_non_join_schema @@ -518,12 +535,12 @@ impl OuterHashJoinProbeSink { async fn finalize_right( states: Vec>, - common_join_keys: &[String], + common_join_cols: &[String], right_non_join_columns: &[String], left_non_join_schema: &SchemaRef, ) -> DaftResult>> { let build_side_table = Self::merge_bitmaps_and_construct_null_table(states).await?; - let join_table = build_side_table.get_columns(common_join_keys)?; + let join_table = build_side_table.get_columns(common_join_cols)?; let left = { let columns = left_non_join_schema .fields @@ -581,7 +598,7 @@ impl StreamingSink for OuterHashJoinProbeSink { &probe_state, params.join_type, ¶ms.probe_on, - ¶ms.common_join_keys, + ¶ms.common_join_cols, ¶ms.left_non_join_columns, ¶ms.right_non_join_columns, ) @@ -591,7 +608,7 @@ impl StreamingSink for OuterHashJoinProbeSink { &probe_state, params.join_type, ¶ms.probe_on, - ¶ms.common_join_keys, + ¶ms.common_join_cols, ¶ms.left_non_join_columns, ¶ms.right_non_join_columns, ), @@ -606,7 +623,8 @@ impl StreamingSink for OuterHashJoinProbeSink { &probe_state, bitmap_builder, ¶ms.probe_on, - ¶ms.common_join_keys, + ¶ms.common_join_cols, + ¶ms.outer_common_col_schema, ¶ms.left_non_join_columns, ¶ms.right_non_join_columns, params.build_on_left, @@ -670,21 +688,22 @@ impl StreamingSink for OuterHashJoinProbeSink { match params.join_type { JoinType::Left => Self::finalize_left( states, - ¶ms.common_join_keys, + ¶ms.common_join_cols, ¶ms.left_non_join_columns, ¶ms.right_non_join_schema, ) .await, JoinType::Right => Self::finalize_right( states, - ¶ms.common_join_keys, + ¶ms.common_join_cols, ¶ms.right_non_join_columns, ¶ms.left_non_join_schema, ) .await, JoinType::Outer => Self::finalize_outer( states, - ¶ms.common_join_keys, + ¶ms.common_join_cols, + ¶ms.outer_common_col_schema, ¶ms.left_non_join_columns, ¶ms.right_non_join_schema, params.build_on_left, diff --git a/src/daft-local-plan/src/translate.rs b/src/daft-local-plan/src/translate.rs index 0547e9e737..d5d7090fab 100644 --- a/src/daft-local-plan/src/translate.rs +++ b/src/daft-local-plan/src/translate.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use common_error::{DaftError, DaftResult}; use common_scan_info::ScanState; use daft_core::join::JoinStrategy; -use daft_dsl::ExprRef; +use daft_dsl::{join::normalize_join_keys, ExprRef}; use daft_logical_plan::{JoinType, LogicalPlan, LogicalPlanRef, SourceInfo}; use super::plan::{LocalPhysicalPlan, LocalPhysicalPlanRef}; @@ -147,10 +147,14 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { let left = translate(&join.left)?; let right = translate(&join.right)?; - if join.left_on.is_empty() - && join.right_on.is_empty() - && join.join_type == JoinType::Inner - { + let (left_on, right_on) = normalize_join_keys( + join.left_on.clone(), + join.right_on.clone(), + join.left.schema(), + join.right.schema(), + )?; + + if left_on.is_empty() && right_on.is_empty() && join.join_type == JoinType::Inner { Ok(LocalPhysicalPlan::cross_join( left, right, @@ -161,8 +165,8 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { Ok(LocalPhysicalPlan::hash_join( left, right, - join.left_on.clone(), - join.right_on.clone(), + left_on, + right_on, join.null_equals_nulls.clone(), join.join_type, join.output_schema.clone(), diff --git a/src/daft-logical-plan/src/builder/mod.rs b/src/daft-logical-plan/src/builder/mod.rs index f0776517f6..58433f1849 100644 --- a/src/daft-logical-plan/src/builder/mod.rs +++ b/src/daft-logical-plan/src/builder/mod.rs @@ -34,7 +34,7 @@ use { use crate::{ logical_plan::LogicalPlan, - ops, + ops::{self, join::JoinOptions}, optimization::OptimizerBuilder, partitioning::{ HashRepartitionConfig, IntoPartitionsConfig, RandomShuffleConfig, RepartitionSpec, @@ -496,9 +496,7 @@ impl LogicalPlanBuilder { right_on, JoinType::Inner, None, - None, - None, - false, + Default::default(), ) } @@ -510,9 +508,7 @@ impl LogicalPlanBuilder { right_on: Vec, join_type: JoinType, join_strategy: Option, - join_suffix: Option<&str>, - join_prefix: Option<&str>, - keep_join_keys: bool, + options: JoinOptions, ) -> DaftResult { self.join_with_null_safe_equal( right, @@ -521,9 +517,7 @@ impl LogicalPlanBuilder { None, join_type, join_strategy, - join_suffix, - join_prefix, - keep_join_keys, + options, ) } @@ -536,9 +530,7 @@ impl LogicalPlanBuilder { null_equals_nulls: Option>, join_type: JoinType, join_strategy: Option, - join_suffix: Option<&str>, - join_prefix: Option<&str>, - keep_join_keys: bool, + options: JoinOptions, ) -> DaftResult { let left_plan = self.plan.clone(); let right_plan = right.into(); @@ -548,18 +540,8 @@ impl LogicalPlanBuilder { let (left_on, _) = expr_resolver.resolve(left_on, &left_plan.schema())?; let (right_on, _) = expr_resolver.resolve(right_on, &right_plan.schema())?; - // TODO(kevin): we should do this, but it has not been properly used before and is nondeterministic, which causes some tests to break - // let (left_on, right_on) = ops::Join::rename_join_keys(left_on, right_on); - - let (right_plan, right_on) = ops::Join::rename_right_columns( - left_plan.clone(), - right_plan, - left_on.clone(), - right_on, - join_type, - join_suffix, - join_prefix, - keep_join_keys, + let (left_plan, right_plan, left_on, right_on) = ops::join::Join::deduplicate_join_columns( + left_plan, right_plan, left_on, right_on, join_type, options, )?; let logical_plan: LogicalPlan = ops::Join::try_new( @@ -578,19 +560,9 @@ impl LogicalPlanBuilder { pub fn cross_join>( &self, right: Right, - join_suffix: Option<&str>, - join_prefix: Option<&str>, + options: JoinOptions, ) -> DaftResult { - self.join( - right, - vec![], - vec![], - JoinType::Inner, - None, - join_suffix, - join_prefix, - false, // no join keys to keep - ) + self.join(right, vec![], vec![], JoinType::Inner, None, options) } pub fn concat(&self, other: &Self) -> DaftResult { @@ -1051,8 +1023,8 @@ impl PyLogicalPlanBuilder { right_on, join_type, join_strategy=None, - join_suffix=None, - join_prefix=None + prefix=None, + suffix=None, ))] pub fn join( &self, @@ -1061,8 +1033,8 @@ impl PyLogicalPlanBuilder { right_on: Vec, join_type: JoinType, join_strategy: Option, - join_suffix: Option<&str>, - join_prefix: Option<&str>, + prefix: Option, + suffix: Option, ) -> PyResult { Ok(self .builder @@ -1072,9 +1044,11 @@ impl PyLogicalPlanBuilder { pyexprs_to_exprs(right_on), join_type, join_strategy, - join_suffix, - join_prefix, - false, // dataframes do not keep the join keys when joining + JoinOptions { + prefix, + suffix, + merge_matching_join_keys: true, + }, )? .into()) } diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index 3958477b77..50f3aa39f8 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -39,8 +39,8 @@ mod test { use pretty_assertions::assert_eq; use crate::{ - ops::Source, source_info::PlaceHolderInfo, ClusteringSpec, LogicalPlan, LogicalPlanBuilder, - LogicalPlanRef, SourceInfo, + ops::Source, source_info::PlaceHolderInfo, ClusteringSpec, JoinOptions, LogicalPlan, + LogicalPlanBuilder, LogicalPlanRef, SourceInfo, }; fn plan_1() -> LogicalPlanRef { @@ -106,9 +106,7 @@ mod test { vec![col("id")], JoinType::Inner, None, - None, - None, - false, + JoinOptions::default().merge_matching_join_keys(true), )? .filter(col("first_name").eq(lit("hello")))? .select(vec![col("first_name")])? @@ -181,9 +179,7 @@ Project1 --> Limit0 Some(vec![true]), JoinType::Inner, None, - None, - None, - false, + JoinOptions::default().merge_matching_join_keys(true), )? .filter(col("first_name").eq(lit("hello")))? .select(vec![col("first_name")])? diff --git a/src/daft-logical-plan/src/lib.rs b/src/daft-logical-plan/src/lib.rs index 317a92535e..cd3eeb9831 100644 --- a/src/daft-logical-plan/src/lib.rs +++ b/src/daft-logical-plan/src/lib.rs @@ -23,6 +23,7 @@ use common_file_formats::{ }; pub use daft_core::join::{JoinStrategy, JoinType}; pub use logical_plan::{LogicalPlan, LogicalPlanRef}; +pub use ops::join::JoinOptions; pub use partitioning::ClusteringSpec; #[cfg(feature = "python")] use pyo3::prelude::*; @@ -41,6 +42,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; parent.add_function(wrap_pyfunction!( builder::py_check_column_name_validity, parent diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index f7ad07737c..e254d9fbd1 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -4,16 +4,15 @@ use std::{ }; use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{prelude::*, utils::supertype::try_get_supertype}; use daft_dsl::{ - col, - join::{get_common_join_keys, infer_join_schema}, - optimization::replace_columns_with_expressions, - Expr, ExprRef, + col, join::infer_join_schema, optimization::replace_columns_with_expressions, Expr, ExprRef, }; +use indexmap::IndexSet; use itertools::Itertools; +#[cfg(feature = "python")] +use pyo3::prelude::*; use snafu::ResultExt; -use uuid::Uuid; use crate::{ logical_plan::{self, CreationSnafu}, @@ -22,7 +21,7 @@ use crate::{ LogicalPlan, LogicalPlanRef, }; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Join { // Upstream nodes. pub left: Arc, @@ -37,21 +36,11 @@ pub struct Join { pub stats_state: StatsState, } -impl std::hash::Hash for Join { - fn hash(&self, state: &mut H) { - std::hash::Hash::hash(&self.left, state); - std::hash::Hash::hash(&self.right, state); - std::hash::Hash::hash(&self.left_on, state); - std::hash::Hash::hash(&self.right_on, state); - std::hash::Hash::hash(&self.null_equals_nulls, state); - std::hash::Hash::hash(&self.join_type, state); - std::hash::Hash::hash(&self.join_strategy, state); - std::hash::Hash::hash(&self.output_schema, state); - } -} - impl Join { - #[allow(clippy::too_many_arguments)] + /// Create a new join node, checking the validity of the inputs and deriving the output schema. + /// + /// Columns that have the same name between left and right are assumed to be merged. + /// If that is not the desired behavior, call `Join::deduplicate_join_keys` before initializing the join node. pub(crate) fn try_new( left: Arc, right: Arc, @@ -61,35 +50,38 @@ impl Join { join_type: JoinType, join_strategy: Option, ) -> logical_plan::Result { - for (on_exprs, side) in [(&left_on, &left), (&right_on, &right)] { - for expr in on_exprs { - // Null type check for both fields and expressions - if matches!(expr.to_field(&side.schema())?.dtype, DataType::Null) { - return Err(DaftError::ValueError(format!( - "Can't join on null type expressions: {expr}" - ))) - .context(CreationSnafu); - } - } + if left_on.len() != right_on.len() { + return Err(DaftError::ValueError(format!( + "Expected length of left_on to match length of right_on for Join, received: {} vs {}", + left_on.len(), + right_on.len() + ))) + .context(CreationSnafu); + } + + for (l, r) in left_on.iter().zip(right_on.iter()) { + let l_dtype = l.to_field(&left.schema())?.dtype; + let r_dtype = r.to_field(&right.schema())?.dtype; + + try_get_supertype(&l_dtype, &r_dtype).map_err(|_| { + DaftError::TypeError( + format!("Expected dtypes of left_on and right_on for Join to have a valid supertype, received: {l_dtype} vs {r_dtype}") + ) + })?; } if let Some(null_equals_null) = &null_equals_nulls { if null_equals_null.len() != left_on.len() { - return Err(DaftError::ValueError( - "null_equals_nulls must have the same length as left_on or right_on" - .to_string(), - )) + return Err(DaftError::ValueError(format!( + "Expected null_equals_nulls to have the same length as left_on or right_on, received: {} vs {}", + null_equals_null.len(), + left_on.len() + ))) .context(CreationSnafu); } } - let output_schema = infer_join_schema( - &left.schema(), - &right.schema(), - &left_on, - &right_on, - join_type, - )?; + let output_schema = infer_join_schema(&left.schema(), &right.schema(), join_type)?; Ok(Self { left, @@ -106,24 +98,37 @@ impl Join { /// Add a project under the right side plan when necessary in order to resolve naming conflicts /// between left and right side columns. - #[allow(clippy::too_many_arguments)] - pub(crate) fn rename_right_columns( + /// + /// Returns: + /// - left (unchanged) + /// - updated right + /// - left_on (unchanged) + /// - updated right_on + pub(crate) fn deduplicate_join_columns( left: LogicalPlanRef, right: LogicalPlanRef, left_on: Vec, right_on: Vec, join_type: JoinType, - join_suffix: Option<&str>, - join_prefix: Option<&str>, - keep_join_keys: bool, - ) -> DaftResult<(LogicalPlanRef, Vec)> { + options: JoinOptions, + ) -> DaftResult<(LogicalPlanRef, LogicalPlanRef, Vec, Vec)> { if matches!(join_type, JoinType::Anti | JoinType::Semi) { - Ok((right, right_on)) + Ok((left, right, left_on, right_on)) } else { - let common_join_keys: HashSet<_> = - get_common_join_keys(left_on.as_slice(), right_on.as_slice()) - .map(|k| k.to_string()) - .collect(); + let merged_cols = if options.merge_matching_join_keys { + left_on + .iter() + .zip(right_on.iter()) + .filter_map(|(l, r)| match (l.as_ref(), r.as_ref()) { + (Expr::Column(l_name), Expr::Column(r_name)) if l_name == r_name => { + Some(l_name.to_string()) + } + _ => None, + }) + .collect() + } else { + IndexSet::new() + }; let left_names = left.schema().names(); let right_names = right.schema().names(); @@ -135,14 +140,13 @@ impl Join { let right_rename_mapping: HashMap<_, _> = right_names .iter() .filter_map(|name| { - if !names_so_far.contains(name) - || (common_join_keys.contains(name) && !keep_join_keys) - { + if !names_so_far.contains(name) || merged_cols.contains(name.as_str()) { + names_so_far.insert(name.clone()); None } else { let mut new_name = name.clone(); while names_so_far.contains(&new_name) { - new_name = match (join_prefix, join_suffix) { + new_name = match (&options.prefix, &options.suffix) { (Some(prefix), Some(suffix)) => { format!("{}{}{}", prefix, new_name, suffix) } @@ -165,7 +169,7 @@ impl Join { .collect(); if right_rename_mapping.is_empty() { - Ok((right, right_on)) + Ok((left, right, left_on, right_on)) } else { // projection to update the right side with the new column names let new_right_projection: Vec<_> = right_names @@ -192,66 +196,11 @@ impl Join { .map(|expr| replace_columns_with_expressions(expr, &right_on_replace_map)) .collect::>(); - Ok((new_right.into(), new_right_on)) + Ok((left, new_right.into(), left_on, new_right_on)) } } } - /// Renames join keys for the given left and right expressions. This is required to - /// prevent errors when the join keys on the left and right expressions have the same key - /// name. - /// - /// This function takes two vectors of expressions (`left_exprs` and `right_exprs`) and - /// checks for pairs of column expressions that differ. If both expressions in a pair - /// are column expressions and they are not identical, it generates a unique identifier - /// and renames both expressions by appending this identifier to their original names. - /// - /// The function returns two vectors of expressions, where the renamed expressions are - /// substituted for the original expressions in the cases where renaming occurred. - /// - /// # Parameters - /// - `left_exprs`: A vector of expressions from the left side of a join. - /// - `right_exprs`: A vector of expressions from the right side of a join. - /// - /// # Returns - /// A tuple containing two vectors of expressions, one for the left side and one for the - /// right side, where expressions that needed to be renamed have been modified. - /// - /// # Example - /// ``` - /// let (renamed_left, renamed_right) = rename_join_keys(left_expressions, right_expressions); - /// ``` - /// - /// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649). - #[allow(dead_code)] - pub(crate) fn rename_join_keys( - left_exprs: Vec>, - right_exprs: Vec>, - ) -> (Vec>, Vec>) { - left_exprs - .into_iter() - .zip(right_exprs) - .map( - |(left_expr, right_expr)| match (&*left_expr, &*right_expr) { - (Expr::Column(left_name), Expr::Column(right_name)) - if left_name == right_name => - { - (left_expr, right_expr) - } - _ => { - let unique_id = Uuid::new_v4().to_string(); - - let renamed_left_expr = - left_expr.alias(format!("{}_{}", left_expr.name(), unique_id)); - let renamed_right_expr = - right_expr.alias(format!("{}_{}", right_expr.name(), unique_id)); - (renamed_left_expr, renamed_right_expr) - } - }, - ) - .unzip() - } - pub(crate) fn with_materialized_stats(mut self) -> Self { // Assume a Primary-key + Foreign-Key join which would yield the max of the two tables. // TODO(desmond): We can do better estimations here. For now, use the old logic. @@ -314,3 +263,52 @@ impl Join { res } } + +#[cfg_attr(feature = "python", pyclass)] +#[derive(Clone, Default)] +pub struct JoinOptions { + pub prefix: Option, + pub suffix: Option, + /// For join predicates in the form col(a) = col(a), + /// merge column "a" from both sides into one column. + pub merge_matching_join_keys: bool, +} + +impl JoinOptions { + pub fn prefix(mut self, val: impl Into) -> Self { + self.prefix = Some(val.into()); + self + } + + pub fn suffix(mut self, val: impl Into) -> Self { + self.suffix = Some(val.into()); + self + } + + pub fn merge_matching_join_keys(mut self, val: bool) -> Self { + self.merge_matching_join_keys = val; + self + } +} + +#[cfg(feature = "python")] +#[pymethods] +impl JoinOptions { + #[new] + #[pyo3(signature = ( + prefix, + suffix, + merge_matching_join_keys, + ))] + pub fn new( + prefix: Option, + suffix: Option, + merge_matching_join_keys: bool, + ) -> Self { + Self { + prefix, + suffix, + merge_matching_join_keys, + } + } +} diff --git a/src/daft-logical-plan/src/ops/mod.rs b/src/daft-logical-plan/src/ops/mod.rs index c042c04e7a..1ec45f3ced 100644 --- a/src/daft-logical-plan/src/ops/mod.rs +++ b/src/daft-logical-plan/src/ops/mod.rs @@ -4,7 +4,7 @@ mod concat; mod distinct; mod explode; mod filter; -mod join; +pub mod join; mod limit; mod monotonically_increasing_id; mod pivot; diff --git a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs index cd192f0df9..2a2d93fa04 100644 --- a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs +++ b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs @@ -430,8 +430,8 @@ mod tests { use super::*; use crate::{ - logical_plan::Source, source_info::PlaceHolderInfo, ClusteringSpec, LogicalPlan, - LogicalPlanBuilder, LogicalPlanRef, SourceInfo, + logical_plan::Source, source_info::PlaceHolderInfo, ClusteringSpec, JoinOptions, + LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, SourceInfo, }; #[fixture] @@ -502,7 +502,7 @@ mod tests { fn eliminate_cross_with_simple_and(t1: LogicalPlanRef, t2: LogicalPlanRef) -> DaftResult<()> { // could eliminate to inner join since filter has Join predicates let plan = LogicalPlanBuilder::from(t1.clone()) - .cross_join(t2.clone(), None, None)? + .cross_join(t2.clone(), Default::default())? .filter(col("a").eq(col("right.a")).and(col("b").eq(col("right.b"))))? .build(); @@ -517,9 +517,7 @@ mod tests { vec![col("right.a"), col("right.b")], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .build(); @@ -533,7 +531,7 @@ mod tests { // could not eliminate to inner join since filter OR expression and there is no common // Join predicates in left and right of OR expr. let plan = LogicalPlanBuilder::from(t1.clone()) - .cross_join(t2.clone(), None, None)? + .cross_join(t2.clone(), Default::default())? .filter(col("a").eq(col("right.a")).or(col("right.b").eq(col("a"))))? .build(); @@ -548,9 +546,7 @@ mod tests { vec![], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter(col("a").eq(col("right.a")).or(col("right.b").eq(col("a"))))? .build(); @@ -568,7 +564,7 @@ mod tests { let expr4 = col("right.c").eq(lit(10u32)); // could eliminate to inner join let plan = LogicalPlanBuilder::from(t1.clone()) - .cross_join(t2.clone(), None, None)? + .cross_join(t2.clone(), Default::default())? .filter(expr1.and(expr2.clone()).and(expr3).and(expr4.clone()))? .build(); @@ -583,9 +579,7 @@ mod tests { vec![col("right.a")], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter(expr2.and(expr4))? .build(); @@ -603,7 +597,7 @@ mod tests { let expr3 = col("a").eq(col("right.a")); let expr4 = col("right.c").eq(lit(688u32)); let plan = LogicalPlanBuilder::from(t1.clone()) - .cross_join(t2.clone(), None, None)? + .cross_join(t2.clone(), Default::default())? .filter(expr1.and(expr2.clone()).or(expr3.and(expr4.clone())))? .build(); @@ -618,9 +612,7 @@ mod tests { vec![col("right.a")], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter(expr2.or(expr4))? .build(); @@ -639,7 +631,7 @@ mod tests { ) -> DaftResult<()> { // could eliminate to inner join let plan1 = LogicalPlanBuilder::from(t1.clone()) - .cross_join(t2.clone(), None, Some("t2."))? + .cross_join(t2.clone(), JoinOptions::default().prefix("t2."))? .filter( col("a") .eq(col("t2.a")) @@ -649,7 +641,7 @@ mod tests { .build(); let plan2 = LogicalPlanBuilder::from(t3.clone()) - .cross_join(t4.clone(), None, Some("t4."))? + .cross_join(t4.clone(), JoinOptions::default().prefix("t4."))? .filter( (col("a") .eq(col("t4.a")) @@ -660,7 +652,7 @@ mod tests { .build(); let plan = LogicalPlanBuilder::from(plan1.clone()) - .cross_join(plan2.clone(), None, Some("t3."))? + .cross_join(plan2.clone(), JoinOptions::default().prefix("t3."))? .filter( col("t3.a") .eq(col("a")) @@ -679,9 +671,7 @@ mod tests { vec![col("t2.a")], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter(col("t2.c").lt(lit(15u32)).or(col("t2.c").eq(lit(688u32))))? .build(); @@ -697,9 +687,7 @@ mod tests { vec![col("t4.a")], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter( col("t4.c") @@ -723,9 +711,7 @@ mod tests { vec![col("t3.a")], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter(col("t4.c").lt(lit(15u32)).or(col("t4.c").eq(lit(688u32))))? .build(); diff --git a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs index 80aab76c18..449aebff81 100644 --- a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs +++ b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs @@ -222,9 +222,7 @@ mod tests { vec![col("c")], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .build(); @@ -237,9 +235,7 @@ mod tests { vec![col("c")], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .build(); @@ -270,9 +266,7 @@ mod tests { Some(vec![false, true, false]), JoinType::Left, None, - None, - None, - false, + Default::default(), )? .build(); @@ -287,9 +281,7 @@ mod tests { Some(vec![false, true, false]), JoinType::Left, None, - None, - None, - false, + Default::default(), )? .build(); diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index e10e14efb9..621c831f93 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -699,9 +699,7 @@ mod tests { null_equals_nulls.clone(), how, None, - None, - None, - false, + Default::default(), )? .filter(pred.clone())? .build(); @@ -721,9 +719,7 @@ mod tests { null_equals_nulls, how, None, - None, - None, - false, + Default::default(), )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -765,9 +761,7 @@ mod tests { null_equals_nulls.clone(), how, None, - None, - None, - false, + Default::default(), )? .filter(pred.clone())? .build(); @@ -787,9 +781,7 @@ mod tests { null_equals_nulls, how, None, - None, - None, - false, + Default::default(), )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -812,6 +804,8 @@ mod tests { )] how: JoinType, ) -> DaftResult<()> { + use crate::JoinOptions; + let left_scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Utf8), Field::new("b", DataType::Int64), @@ -844,9 +838,7 @@ mod tests { null_equals_nulls.clone(), how, None, - None, - None, - false, + JoinOptions::default().merge_matching_join_keys(true), )? .filter(pred.clone())? .build(); @@ -874,9 +866,7 @@ mod tests { null_equals_nulls, how, None, - None, - None, - false, + JoinOptions::default().merge_matching_join_keys(true), )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -914,9 +904,7 @@ mod tests { null_equals_nulls, how, None, - None, - None, - false, + Default::default(), )? .filter(pred)? .build(); @@ -957,9 +945,7 @@ mod tests { null_equals_nulls, how, None, - None, - None, - false, + Default::default(), )? .filter(pred)? .build(); @@ -983,9 +969,7 @@ mod tests { vec![], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter( (col("a").eq(lit("FRANCE")).and(col("b").eq(lit("GERMANY")))) @@ -1010,9 +994,7 @@ mod tests { vec![], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter( (col("a").eq(lit("FRANCE")).or(col("b").eq(lit("FRANCE")))) diff --git a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs index 18e05a8218..b4ea7a165d 100644 --- a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs +++ b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs @@ -126,16 +126,15 @@ impl UnnestScalarSubquery { JoinType::Left }; - let (decorrelated_subquery, subquery_on) = Join::rename_right_columns( - curr_input.clone(), - decorrelated_subquery, - input_on.clone(), - subquery_on, - join_type, - None, - None, - false, - )?; + let (curr_input, decorrelated_subquery, input_on, subquery_on) = + Join::deduplicate_join_columns( + curr_input, + decorrelated_subquery, + input_on, + subquery_on, + join_type, + Default::default(), + )?; Ok(Arc::new(LogicalPlan::Join(Join::try_new( curr_input, @@ -618,9 +617,7 @@ mod tests { vec![], JoinType::Inner, None, - None, - None, - false, + Default::default(), )? .filter(col("key").eq(col(subquery_alias)))? .select(vec![col("key"), col("val")])? @@ -673,9 +670,7 @@ mod tests { vec![col("inner_key")], JoinType::Left, None, - None, - None, - false, + Default::default(), )? .filter(col("outer_key").eq(col(subquery_alias)))? .select(vec![col("outer_key"), col("inner_key"), col("val")])? @@ -713,9 +708,7 @@ mod tests { vec![col("key")], JoinType::Semi, None, - None, - None, - false, + Default::default(), )? .select(vec![col("val")])? .build(); @@ -757,9 +750,7 @@ mod tests { vec![col("key")], JoinType::Anti, None, - None, - None, - false, + Default::default(), )? .select(vec![col("val")])? .build(); diff --git a/src/daft-logical-plan/src/treenode.rs b/src/daft-logical-plan/src/treenode.rs index 151fe7eb4a..a54ff46e29 100644 --- a/src/daft-logical-plan/src/treenode.rs +++ b/src/daft-logical-plan/src/treenode.rs @@ -152,8 +152,7 @@ impl LogicalPlan { null_equals_nulls, join_type, join_strategy, - output_schema, - stats_state, + .. }) => { let o = left_on .into_iter() @@ -164,7 +163,7 @@ impl LogicalPlan { let (left_on, right_on) = o.data.into_iter().unzip(); if o.transformed { - Transformed::yes(Self::Join(Join { + Transformed::yes(Self::Join(Join::try_new( left, right, left_on, @@ -172,11 +171,9 @@ impl LogicalPlan { null_equals_nulls, join_type, join_strategy, - output_schema, - stats_state, - })) + )?)) } else { - Transformed::no(Self::Join(Join { + Transformed::no(Self::Join(Join::try_new( left, right, left_on, @@ -184,9 +181,7 @@ impl LogicalPlan { null_equals_nulls, join_type, join_strategy, - output_schema, - stats_state, - })) + )?)) } } lp => Transformed::no(lp), diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index 1b8e7f207b..7438706a42 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -25,7 +25,7 @@ impl MicroPartition { where F: FnOnce(&Table, &Table, &[ExprRef], &[ExprRef], JoinType) -> DaftResult, { - let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on, how)?; + let join_schema = infer_join_schema(&self.schema, &right.schema, how)?; match (how, self.len(), right.len()) { (JoinType::Inner | JoinType::Left | JoinType::Semi, 0, _) | (JoinType::Inner | JoinType::Right, _, 0) diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index c9c0c5f24c..13b660b698 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -10,8 +10,8 @@ use common_file_formats::FileFormat; use common_scan_info::{PhysicalScanInfo, ScanState, SPLIT_AND_MERGE_PASS}; use daft_core::{join::JoinSide, prelude::*}; use daft_dsl::{ - col, estimated_selectivity, functions::agg::merge_mean, is_partition_compatible, AggExpr, - ApproxPercentileParams, Expr, ExprRef, SketchType, + col, estimated_selectivity, functions::agg::merge_mean, is_partition_compatible, + join::normalize_join_keys, AggExpr, ApproxPercentileParams, Expr, ExprRef, SketchType, }; use daft_functions::{list::unique_count, numeric::sqrt}; use daft_logical_plan::{ @@ -435,6 +435,13 @@ pub(super) fn translate_single_logical_node( let mut right_physical = physical_children.pop().expect("requires 1 inputs"); let mut left_physical = physical_children.pop().expect("requires 2 inputs"); + let (left_on, right_on) = normalize_join_keys( + left_on.clone(), + right_on.clone(), + left.schema(), + right.schema(), + )?; + let left_clustering_spec = left_physical.clustering_spec(); let right_clustering_spec = right_physical.clustering_spec(); let num_partitions = max( @@ -444,10 +451,10 @@ pub(super) fn translate_single_logical_node( let is_left_hash_partitioned = matches!(left_clustering_spec.as_ref(), ClusteringSpec::Hash(..)) - && is_partition_compatible(&left_clustering_spec.partition_by(), left_on); + && is_partition_compatible(&left_clustering_spec.partition_by(), &left_on); let is_right_hash_partitioned = matches!(right_clustering_spec.as_ref(), ClusteringSpec::Hash(..)) - && is_partition_compatible(&right_clustering_spec.partition_by(), right_on); + && is_partition_compatible(&right_clustering_spec.partition_by(), &right_on); // Left-side of join is considered to be sort-partitioned on the join key if it is sort-partitioned on a // sequence of expressions that has the join key as a prefix. @@ -530,8 +537,8 @@ pub(super) fn translate_single_logical_node( // TODO(Kevin): Support sort-merge join for other types of joins. // TODO(advancedxy): Rewrite null safe equals to support SMJ } else if *join_type == JoinType::Inner - && keys_are_primitive(left_on, &left.schema()) - && keys_are_primitive(right_on, &right.schema()) + && keys_are_primitive(&left_on, &left.schema()) + && keys_are_primitive(&right_on, &right.schema()) && (is_left_sort_partitioned || is_right_sort_partitioned) && (!is_larger_partitioned || (left_is_larger && is_left_sort_partitioned @@ -566,8 +573,8 @@ pub(super) fn translate_single_logical_node( Ok(PhysicalPlan::BroadcastJoin(BroadcastJoin::new( left_physical, right_physical, - left_on.clone(), - right_on.clone(), + left_on, + right_on, null_equals_nulls.clone(), *join_type, is_swapped, @@ -619,8 +626,8 @@ pub(super) fn translate_single_logical_node( Ok(PhysicalPlan::SortMergeJoin(SortMergeJoin::new( left_physical, right_physical, - left_on.clone(), - right_on.clone(), + left_on, + right_on, *join_type, num_partitions, left_is_larger, @@ -681,8 +688,8 @@ pub(super) fn translate_single_logical_node( Ok(PhysicalPlan::HashJoin(HashJoin::new( left_physical, right_physical, - left_on.clone(), - right_on.clone(), + left_on, + right_on, null_equals_nulls.clone(), *join_type, )) @@ -1345,9 +1352,7 @@ mod tests { vec![col("a"), col("b")], JoinType::Inner, Some(JoinStrategy::Hash), - None, - None, - false, + Default::default(), )? .build(); logical_to_physical(logical_plan, cfg) diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index dec033a709..ad84220ffa 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -32,8 +32,8 @@ mod tests { use daft_core::prelude::*; use daft_dsl::{col, lit, Expr, OuterReferenceColumn, Subquery}; use daft_logical_plan::{ - logical_plan::Source, source_info::PlaceHolderInfo, ClusteringSpec, LogicalPlan, - LogicalPlanBuilder, LogicalPlanRef, SourceInfo, + logical_plan::Source, source_info::PlaceHolderInfo, ClusteringSpec, JoinOptions, + LogicalPlan, LogicalPlanBuilder, LogicalPlanRef, SourceInfo, }; use error::SQLPlannerResult; use rstest::{fixture, rstest}; @@ -306,9 +306,7 @@ mod tests { Some(vec![null_equals_null]), JoinType::Inner, None, - None, - Some("tbl3."), - true, + JoinOptions::default().prefix("tbl3."), )? .select(vec![col("*")])? .build(); @@ -334,9 +332,7 @@ mod tests { Some(vec![false]), JoinType::Inner, None, - None, - Some("tbl3."), - true, + JoinOptions::default().prefix("tbl3."), )? .select(vec![col("*")])? .build(); diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 8e16dcc414..05f34a365e 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -20,7 +20,7 @@ use daft_functions::{ numeric::{ceil::ceil, floor::floor}, utf8::{ilike, like, to_date, to_datetime}, }; -use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; +use daft_logical_plan::{JoinOptions, LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ self, ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct, @@ -759,11 +759,12 @@ impl<'a> SQLPlanner<'a> { for tbl in from_iter { let right = self.plan_relation(&tbl.relation)?; self.table_map.insert(right.get_name(), right.clone()); - let right_join_prefix = Some(format!("{}.", right.get_name())); + let right_join_prefix = format!("{}.", right.get_name()); - rel.inner = - rel.inner - .cross_join(right.inner, None, right_join_prefix.as_deref())?; + rel.inner = rel.inner.cross_join( + right.inner, + JoinOptions::default().prefix(right_join_prefix), + )?; } self.current_relation = Some(rel); return Ok(()); @@ -895,7 +896,7 @@ impl<'a> SQLPlanner<'a> { }; let right_rel = self.plan_relation(&join.relation)?; let right_rel_name = right_rel.get_name(); - let right_join_prefix = Some(format!("{right_rel_name}.")); + let right_join_prefix = format!("{right_rel_name}."); // construct a planner with the right table to use for expr planning let mut right_planner = self.new_with_context(); @@ -920,7 +921,7 @@ impl<'a> SQLPlanner<'a> { let mut left_filters = Vec::new(); let mut right_filters = Vec::new(); - let (keep_join_keys, null_eq_nulls) = match &constraint { + let (merge_matching_join_keys, null_eq_nulls) = match &constraint { JoinConstraint::On(expr) => { let mut null_eq_nulls = Vec::new(); @@ -935,7 +936,7 @@ impl<'a> SQLPlanner<'a> { &mut right_filters, )?; - (true, Some(null_eq_nulls)) + (false, Some(null_eq_nulls)) } JoinConstraint::Using(idents) => { left_on = idents @@ -944,7 +945,7 @@ impl<'a> SQLPlanner<'a> { .collect::>(); right_on.clone_from(&left_on); - (false, None) + (true, None) } JoinConstraint::Natural => unsupported_sql_err!("NATURAL JOIN not supported"), JoinConstraint::None => unsupported_sql_err!("JOIN without ON/USING not supported"), @@ -967,9 +968,9 @@ impl<'a> SQLPlanner<'a> { null_eq_nulls, join_type, None, - None, - right_join_prefix.as_deref(), - keep_join_keys, + JoinOptions::default() + .prefix(right_join_prefix) + .merge_matching_join_keys(merge_matching_join_keys), )?; self.table_map.insert(right_rel_name, right_rel); } diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index 3a2dd08a52..dccbe574a7 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -1,4 +1,4 @@ -use std::{cmp, iter::repeat, sync::Arc}; +use std::{cmp, iter::repeat, ops::Not, sync::Arc}; use arrow2::{bitmap::MutableBitmap, types::IndexRange}; use common_error::DaftResult; @@ -7,7 +7,7 @@ use daft_core::{ prelude::*, }; use daft_dsl::{ - join::{get_common_join_keys, infer_join_schema}, + join::{get_common_join_cols, infer_join_schema}, ExprRef, }; @@ -20,13 +20,7 @@ pub(super) fn hash_inner_join( right_on: &[ExprRef], null_equals_nulls: &[bool], ) -> DaftResult
{ - let join_schema = infer_join_schema( - &left.schema, - &right.schema, - left_on, - right_on, - JoinType::Inner, - )?; + let join_schema = infer_join_schema(&left.schema, &right.schema, JoinType::Inner)?; let lkeys = left.eval_expression_list(left_on)?; let rkeys = right.eval_expression_list(right_on)?; @@ -87,14 +81,13 @@ pub(super) fn hash_inner_join( } }; - let common_join_keys: Vec<_> = get_common_join_keys(left_on, right_on).collect(); + let common_cols: Vec<_> = get_common_join_cols(&left.schema, &right.schema).collect(); - let join_series = left - .get_columns(common_join_keys.as_slice())? - .take(&lidx)? - .columns; - - let mut join_series = Arc::unwrap_or_clone(join_series); + let mut join_series = Arc::unwrap_or_clone( + left.get_columns(common_cols.as_slice())? + .take(&lidx)? + .columns, + ); drop(lkeys); drop(rkeys); @@ -113,13 +106,7 @@ pub(super) fn hash_left_right_join( null_equals_nulls: &[bool], left_side: bool, ) -> DaftResult
{ - let join_schema = infer_join_schema( - &left.schema, - &right.schema, - left_on, - right_on, - JoinType::Right, - )?; + let join_schema = infer_join_schema(&left.schema, &right.schema, JoinType::Right)?; let lkeys = left.eval_expression_list(left_on)?; let rkeys = right.eval_expression_list(right_on)?; @@ -197,23 +184,21 @@ pub(super) fn hash_left_right_join( (lkeys, rkeys, lidx, ridx) }; - let common_join_keys = get_common_join_keys(left_on, right_on); + let common_cols = get_common_join_cols(&left.schema, &right.schema).collect::>(); - let mut join_series = if left_side { - Arc::unwrap_or_clone( - left.get_columns(common_join_keys.collect::>().as_slice())? - .take(&lidx)? - .columns, - ) + let (common_cols_tbl, common_cols_idx) = if left_side { + (left, &lidx) } else { - common_join_keys - .map(|name| { - let col_dtype = &left.schema.get_field(name)?.dtype; - right.get_column(name)?.take(&ridx)?.cast(col_dtype) - }) - .collect::>>()? + (right, &ridx) }; + let mut join_series = Arc::unwrap_or_clone( + common_cols_tbl + .get_columns(&common_cols)? + .take(common_cols_idx)? + .columns, + ); + drop(lkeys); drop(rkeys); @@ -291,13 +276,7 @@ pub(super) fn hash_outer_join( right_on: &[ExprRef], null_equals_nulls: &[bool], ) -> DaftResult
{ - let join_schema = infer_join_schema( - &left.schema, - &right.schema, - left_on, - right_on, - JoinType::Outer, - )?; + let join_schema = infer_join_schema(&left.schema, &right.schema, JoinType::Outer)?; let lkeys = left.eval_expression_list(left_on)?; let rkeys = right.eval_expression_list(right_on)?; @@ -405,33 +384,21 @@ pub(super) fn hash_outer_join( } }; - let common_join_keys: Vec<_> = get_common_join_keys(left_on, right_on).collect(); + let common_cols: Vec<_> = get_common_join_cols(&left.schema, &right.schema).collect(); - let mut join_series = if common_join_keys.is_empty() { + let mut join_series = if common_cols.is_empty() { vec![] } else { - let join_key_predicate = BooleanArray::from(( - "join_key_predicate", - arrow2::array::BooleanArray::from_trusted_len_values_iter( - lidx.u64()? - .into_iter() - .zip(ridx.u64()?) - .map(|(l, r)| match (l, r) { - (Some(_), _) => true, - (None, Some(_)) => false, - (None, None) => unreachable!("Join should not have None for both sides"), - }), - ), - )) - .into_series(); - - common_join_keys + // use right side value if left is null + let take_from_left = lidx.is_null()?.not()?; + + common_cols .into_iter() .map(|name| { let lcol = left.get_column(name)?.take(&lidx)?; let rcol = right.get_column(name)?.take(&ridx)?; - lcol.if_else(&rcol, &join_key_predicate) + lcol.if_else(&rcol, &take_from_left) }) .collect::>>()? }; diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 5eb22b7cbd..8ffc86b847 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -5,7 +5,7 @@ use daft_core::{ array::growable::make_growable, join::JoinSide, prelude::*, utils::supertype::try_get_supertype, }; use daft_dsl::{ - join::{get_common_join_keys, infer_join_schema}, + join::{get_common_join_cols, infer_join_schema}, ExprRef, }; use hash_join::hash_semi_anti_join; @@ -157,20 +157,14 @@ impl Table { return left.sort_merge_join(&right, left_on, right_on, true); } - let join_schema = infer_join_schema( - &self.schema, - &right.schema, - left_on, - right_on, - JoinType::Inner, - )?; + let join_schema = infer_join_schema(&self.schema, &right.schema, JoinType::Inner)?; let ltable = self.eval_expression_list(left_on)?; let rtable = right.eval_expression_list(right_on)?; let (ltable, rtable) = match_types_for_tables(<able, &rtable)?; let (lidx, ridx) = merge_join::merge_inner_join(<able, &rtable)?; - let mut join_series = get_common_join_keys(left_on, right_on) + let mut join_series = get_common_join_cols(&self.schema, &right.schema) .map(|name| { let lcol = self.get_column(name)?; let rcol = right.get_column(name)?; diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 5f99c331fd..abf5d42636 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -6,7 +6,6 @@ import daft from daft import col from daft.datatype import DataType -from daft.errors import ExpressionTypeError from tests.conftest import get_tests_daft_runner_name from tests.utils import sort_arrow_table @@ -725,8 +724,23 @@ def test_join_all_null(join_strategy, join_type, expected, make_df, repartition_ [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True, ) -@pytest.mark.parametrize("join_type", ["inner", "left", "right", "outer"]) -def test_join_null_type_column(join_strategy, join_type, make_df, with_morsel_size): +@pytest.mark.parametrize( + "join_type,expected", + [ + ("inner", {"id": [], "values_left": [], "values_right": []}), + ("left", {"id": [None, None, None], "values_left": ["a1", "b1", "c1"], "values_right": [None, None, None]}), + ("right", {"id": [None, None, None], "values_left": [None, None, None], "values_right": ["a2", "b2", "c2"]}), + ( + "outer", + { + "id": [None, None, None, None, None, None], + "values_left": ["a1", "b1", "c1", None, None, None], + "values_right": [None, None, None, "a2", "b2", "c2"], + }, + ), + ], +) +def test_join_null_type_column(join_strategy, join_type, expected, make_df, with_morsel_size): skip_invalid_join_strategies(join_strategy, join_type) daft_df = make_df( @@ -742,8 +756,10 @@ def test_join_null_type_column(join_strategy, join_type, make_df, with_morsel_si } ) - with pytest.raises((ExpressionTypeError, ValueError)): - daft_df.join(daft_df2, on="id", how=join_type, strategy=join_strategy) + daft_df = daft_df.join(daft_df2, on="id", how=join_type, strategy=join_strategy).sort( + ["values_left", "values_right"] + ) + assert pa.Table.from_pydict(daft_df.to_pydict()) == pa.Table.from_pydict(expected) @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) @@ -860,8 +876,44 @@ def test_join_semi_anti_different_names( ) -@pytest.mark.parametrize("join_type", ["inner", "left", "right", "outer"]) -def test_join_true_join_keys(join_type, make_df, with_morsel_size): +@pytest.mark.parametrize( + "join_type,expected_dtypes", + [ + ( + "inner", + { + "id": DataType.int64(), + "values": DataType.string(), + "right.values": DataType.string(), + }, + ), + ( + "left", + { + "id": DataType.int64(), + "values": DataType.string(), + "right.values": DataType.string(), + }, + ), + ( + "right", + { + "id": DataType.float64(), + "values": DataType.string(), + "right.values": DataType.string(), + }, + ), + ( + "outer", + { + "id": DataType.float64(), + "values": DataType.string(), + "right.values": DataType.string(), + }, + ), + ], +) +def test_join_true_join_keys(join_type, expected_dtypes, make_df, with_morsel_size): daft_df = make_df( { "id": [1, 2, 3], @@ -878,9 +930,9 @@ def test_join_true_join_keys(join_type, make_df, with_morsel_size): result = daft_df.join(daft_df2, left_on=["id", "values"], right_on=["id", col("values").str.left(1)], how=join_type) assert result.schema().column_names() == ["id", "values", "right.values"] - assert result.schema()["id"].dtype == daft_df.schema()["id"].dtype - assert result.schema()["values"].dtype == daft_df.schema()["values"].dtype - assert result.schema()["right.values"].dtype == daft_df2.schema()["values"].dtype + assert result.schema()["id"].dtype == expected_dtypes["id"] + assert result.schema()["values"].dtype == expected_dtypes["values"] + assert result.schema()["right.values"].dtype == expected_dtypes["right.values"] @pytest.mark.parametrize( @@ -1208,30 +1260,12 @@ def test_cross_join(left_partitions, right_partitions, make_df, with_morsel_size ], ) def test_join_empty(join_type, repartition_nparts, left, right, expected, make_df, with_morsel_size): - left = pa.Table.from_pydict( - left, - schema=pa.schema( - [ - ("a", pa.int32()), - ("b", pa.string()), - ] - ), - ) left_df = make_df( left, repartition=repartition_nparts, repartition_columns=["a"], ) - right = pa.Table.from_pydict( - right, - schema=pa.schema( - [ - ("c", pa.int32()), - ("d", pa.string()), - ] - ), - ) right_df = make_df( right, repartition=repartition_nparts, @@ -1246,11 +1280,79 @@ def test_join_empty(join_type, repartition_nparts, left, right, expected, make_d right_on = ["c"] result = left_df.join(right_df, left_on=left_on, right_on=right_on, how=join_type) - if join_type in ["inner", "left", "right", "outer", "cross"]: - result = result.sort(["a", "b", "c", "d"]) - else: - result = result.sort(["a", "b"]) - expected_result = expected[join_type] + sort_by = ["a", "b", "c", "d"] if join_type in ["inner", "left", "right", "outer", "cross"] else ["a", "b"] + + assert sort_arrow_table(pa.Table.from_pydict(result.to_pydict()), *sort_by) == sort_arrow_table( + pa.Table.from_pydict(expected[join_type]), *sort_by + ) - assert result.to_pydict() == expected_result + +@pytest.mark.parametrize( + "join_type,expected", + [ + ("inner", {"a": [1, 1, 2], "b": ["a", "a", "b"], "c": [1.0, 1.0, 2.0], "d": ["g", "h", "j"]}), + ( + "left", + { + "a": [1, 1, 2, 3, 4, 5, 6], + "b": ["a", "a", "b", "c", "d", "e", "f"], + "c": [1.0, 1.0, 2.0, None, None, None, None], + "d": ["g", "h", "j", None, None, None, None], + }, + ), + ( + "right", + { + "a": [1, 1, 2, None, None, None], + "b": ["a", "a", "b", None, None, None], + "c": [1.0, 1.0, 2.0, 1.5, 2.1, 2.5], + "d": ["g", "h", "j", "i", "k", "l"], + }, + ), + ( + "outer", + { + "a": [1, 1, 2, 3, 4, 5, 6, None, None, None], + "b": ["a", "a", "b", "c", "d", "e", "f", None, None, None], + "c": [1.0, 1.0, 2.0, None, None, None, None, 1.5, 2.1, 2.5], + "d": ["g", "h", "j", None, None, None, None, "i", "k", "l"], + }, + ), + ( + "anti", + { + "a": [3, 4, 5, 6], + "b": ["c", "d", "e", "f"], + }, + ), + ( + "semi", + { + "a": [1, 2], + "b": ["a", "b"], + }, + ), + ], +) +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +def test_join_different_join_key_types(join_type, expected, repartition_nparts, make_df, with_morsel_size): + left_df = make_df( + {"a": [1, 2, 3, 4, 5, 6], "b": ["a", "b", "c", "d", "e", "f"]}, + repartition=repartition_nparts, + repartition_columns=["a"], + ) + + right_df = make_df( + {"c": [1.0, 1.0, 1.5, 2.0, 2.1, 2.5], "d": ["g", "h", "i", "j", "k", "l"]}, + repartition=repartition_nparts, + repartition_columns=["c"], + ) + + result = left_df.join(right_df, left_on="a", right_on="c", how=join_type) + + sort_by = ["a", "b", "c", "d"] if join_type in ["inner", "left", "right", "outer"] else ["a", "b"] + + assert sort_arrow_table(pa.Table.from_pydict(result.to_pydict()), *sort_by) == sort_arrow_table( + pa.Table.from_pydict(expected), *sort_by + ) diff --git a/tests/utils.py b/tests/utils.py index 0d8ffebad0..836758c0bd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,7 +3,6 @@ import re import pyarrow as pa -import pyarrow.compute as pac from daft.table import Table @@ -13,10 +12,8 @@ TH_STYLE = 'style="text-wrap: nowrap; max-width:192px; overflow:auto; text-align:left"' -def sort_arrow_table(tbl: pa.Table, sort_by: str): - """In arrow versions < 7, pa.Table does not support sorting yet so we add a helper method here.""" - sort_indices = pac.sort_indices(tbl.column(sort_by)) - return pac.take(tbl, sort_indices) +def sort_arrow_table(tbl: pa.Table, *sort_by: str): + return tbl.sort_by([(name, "descending") for name in sort_by]) def assert_pyarrow_tables_equal(from_daft: pa.Table, expected: pa.Table):