diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index 4a8d9d2cfc..b3600292e9 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -443,6 +443,7 @@ impl LogicalPlanBuilder { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + keep_join_keys: bool, ) -> DaftResult { self.join_with_null_safe_equal( right, @@ -453,6 +454,7 @@ impl LogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + keep_join_keys, ) } @@ -467,6 +469,7 @@ impl LogicalPlanBuilder { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + keep_join_keys: bool, ) -> DaftResult { let logical_plan: LogicalPlan = ops::Join::try_new( self.plan.clone(), @@ -478,6 +481,7 @@ impl LogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + keep_join_keys, )? .into(); Ok(self.with_new_plan(logical_plan)) @@ -497,6 +501,7 @@ impl LogicalPlanBuilder { None, join_suffix, join_prefix, + false, // no join keys to keep ) } @@ -937,6 +942,7 @@ impl PyLogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + false, // dataframes do not keep the join keys when joining )? .into()) } diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index 162a5b34f2..88ba77787a 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -111,6 +111,7 @@ mod test { None, None, None, + false, )? .filter(col("first_name").eq(lit("hello")))? .select(vec![col("first_name")])? @@ -185,6 +186,7 @@ Project1 --> Limit0 None, None, None, + false, )? .filter(col("first_name").eq(lit("hello")))? .select(vec![col("first_name")])? diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 6c5bce7c66..3ed901cbdc 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -273,7 +273,8 @@ impl LogicalPlan { *join_type, *join_strategy, None, // The suffix is already eagerly computed in the constructor - None // the prefix is already eagerly computed in the constructor + None, // the prefix is already eagerly computed in the constructor + false // this is already eagerly computed in the constructor ).unwrap()), _ => panic!("Logical op {} has one input, but got two", self), }, diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index c0b7496b8c..3d485f8997 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -60,6 +60,11 @@ impl Join { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + // if true, then duplicate column names will be kept + // ex: select * from a left join b on a.id = b.id + // if true, then the resulting schema will have two columns named id (id, and b.id) + // In SQL the join column is always kept, while in dataframes it is not + keep_join_keys: bool, ) -> logical_plan::Result { let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; let (right_on, _) = @@ -136,19 +141,27 @@ impl Join { let right_rename_mapping: HashMap<_, _> = right_names .iter() .filter_map(|name| { - if !names_so_far.contains(name) || common_join_keys.contains(name) { + if !names_so_far.contains(name) + || (common_join_keys.contains(name) && !keep_join_keys) + { None } else { let mut new_name = name.clone(); while names_so_far.contains(&new_name) { - if let Some(prefix) = join_prefix { - new_name = format!("{}{}", prefix, new_name); - } else if join_suffix.is_none() { - new_name = format!("right.{}", new_name); - } - if let Some(suffix) = join_suffix { - new_name = format!("{}{}", new_name, suffix); - } + new_name = match (join_prefix, join_suffix) { + (Some(prefix), Some(suffix)) => { + format!("{}{}{}", prefix, new_name, suffix) + } + (Some(prefix), None) => { + format!("{}{}", prefix, new_name) + } + (None, Some(suffix)) => { + format!("{}{}", new_name, suffix) + } + (None, None) => { + format!("right.{}", new_name) + } + }; } names_so_far.insert(new_name.clone()); @@ -253,6 +266,7 @@ impl Join { } _ => { let unique_id = Uuid::new_v4().to_string(); + let renamed_left_expr = left_expr.alias(format!("{}_{}", left_expr.name(), unique_id)); let renamed_right_expr = diff --git a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs index 6292dfe1d9..c8e888fecf 100644 --- a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs +++ b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs @@ -524,6 +524,7 @@ mod tests { None, None, None, + false, )? .build(); @@ -554,6 +555,7 @@ mod tests { None, None, None, + false, )? .filter(col("a").eq(col("right.a")).or(col("right.b").eq(col("a"))))? .build(); @@ -588,6 +590,7 @@ mod tests { None, None, None, + false, )? .filter(expr2.and(expr4))? .build(); @@ -622,6 +625,7 @@ mod tests { None, None, None, + false, )? .filter(expr2.or(expr4))? .build(); @@ -682,6 +686,7 @@ mod tests { None, None, None, + false, )? .filter(col("t2.c").lt(lit(15u32)).or(col("t2.c").eq(lit(688u32))))? .build(); @@ -699,6 +704,7 @@ mod tests { None, None, None, + false, )? .filter( col("t4.c") @@ -724,6 +730,7 @@ mod tests { None, None, None, + false, )? .filter(col("t4.c").lt(lit(15u32)).or(col("t4.c").eq(lit(688u32))))? .build(); diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 0d66fea700..ac20b0b38d 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -683,6 +683,7 @@ mod tests { None, None, None, + false, )? .filter(pred.clone())? .build(); @@ -704,6 +705,7 @@ mod tests { None, None, None, + false, )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -747,6 +749,7 @@ mod tests { None, None, None, + false, )? .filter(pred.clone())? .build(); @@ -768,6 +771,7 @@ mod tests { None, None, None, + false, )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -824,6 +828,7 @@ mod tests { None, None, None, + false, )? .filter(pred.clone())? .build(); @@ -853,6 +858,7 @@ mod tests { None, None, None, + false, )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -892,6 +898,7 @@ mod tests { None, None, None, + false, )? .filter(pred)? .build(); @@ -934,6 +941,7 @@ mod tests { None, None, None, + false, )? .filter(pred)? .build(); diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index 044a56f164..76fd37d55f 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -1210,6 +1210,7 @@ mod tests { Some(JoinStrategy::Hash), None, None, + false, )? .build(); logical_to_physical(logical_plan, cfg) diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index ff955b60fd..fcd348f02c 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -271,7 +271,8 @@ mod tests { JoinType::Inner, None, None, - None, + Some("tbl3."), + true, )? .select(vec![col("*")])? .build(); diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index fdecb669bb..4613f7e139 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -204,15 +204,16 @@ impl SQLPlanner { let selection = match query.body.as_ref() { SetExpr::Select(selection) => selection, SetExpr::Query(_) => unsupported_sql_err!("Subqueries are not supported"), - SetExpr::SetOperation { .. } => { - unsupported_sql_err!("Set operations are not supported") + SetExpr::SetOperation { + op, set_quantifier, .. + } => { + unsupported_sql_err!("{op} {set_quantifier} is not supported.",) } SetExpr::Values(..) => unsupported_sql_err!("VALUES are not supported"), SetExpr::Insert(..) => unsupported_sql_err!("INSERT is not supported"), SetExpr::Update(..) => unsupported_sql_err!("UPDATE is not supported"), SetExpr::Table(..) => unsupported_sql_err!("TABLE is not supported"), }; - check_select_features(selection)?; if let Some(with) = &query.with { @@ -606,15 +607,9 @@ impl SQLPlanner { self.table_map.insert(right.get_name(), right.clone()); let right_join_prefix = Some(format!("{}.", right.get_name())); - rel.inner = rel.inner.join( - right.inner, - vec![], - vec![], - JoinType::Inner, - None, - None, - right_join_prefix.as_deref(), - )?; + rel.inner = + rel.inner + .cross_join(right.inner, None, right_join_prefix.as_deref())?; } return Ok(rel); } @@ -658,8 +653,33 @@ impl SQLPlanner { let null_equals_null = *op == BinaryOperator::Spaceship; collect_compound_identifiers(left, right, left_rel, right_rel) .map(|(left, right)| (left, right, vec![null_equals_null])) + } else if let ( + sqlparser::ast::Expr::Identifier(left), + sqlparser::ast::Expr::Identifier(right), + ) = (left.as_ref(), right.as_ref()) + { + let left = ident_to_str(left); + let right = ident_to_str(right); + + // we don't know which table the identifiers belong to, so we need to check both + let left_schema = left_rel.schema(); + let right_schema = right_rel.schema(); + + // if the left side is in the left schema, then we assume the right side is in the right schema + let (left_on, right_on) = if left_schema.get_field(&left).is_ok() { + (col(left), col(right)) + // if the right side is in the left schema, then we assume the left side is in the right schema + } else if right_schema.get_field(&left).is_ok() { + (col(right), col(left)) + } else { + unsupported_sql_err!("JOIN clauses must reference columns in the joined tables; found `{}`", left); + }; + + let null_equals_null = *op == BinaryOperator::Spaceship; + + Ok((vec![left_on], vec![right_on], vec![null_equals_null])) } else { - unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); + unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found `{left} {op} {right}`"); } } BinaryOperator::And => { @@ -673,7 +693,7 @@ impl SQLPlanner { Ok((left_i, right_i, null_equals_nulls_i)) } _ => { - unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op); + unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{}'", op); } } } else if let sqlparser::ast::Expr::Nested(expr) = expression { @@ -690,10 +710,7 @@ impl SQLPlanner { for join in &from.joins { use sqlparser::ast::{ JoinConstraint, - JoinOperator::{ - AsOf, CrossApply, CrossJoin, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, - OuterApply, RightAnti, RightOuter, RightSemi, - }, + JoinOperator::{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}, }; let right_rel = self.plan_relation(&join.relation)?; self.table_map @@ -701,94 +718,45 @@ impl SQLPlanner { let right_rel_name = right_rel.get_name(); let right_join_prefix = Some(format!("{right_rel_name}.")); - match &join.join_operator { - Inner(JoinConstraint::On(expr)) => { + let (join_type, constraint) = match &join.join_operator { + Inner(constraint) => (JoinType::Inner, constraint), + LeftOuter(constraint) => (JoinType::Left, constraint), + RightOuter(constraint) => (JoinType::Right, constraint), + FullOuter(constraint) => (JoinType::Outer, constraint), + LeftSemi(constraint) => (JoinType::Semi, constraint), + LeftAnti(constraint) => (JoinType::Anti, constraint), + + _ => unsupported_sql_err!("Unsupported join type: {:?}", join.join_operator), + }; + + let (left_on, right_on, null_eq_null, keep_join_keys) = match &constraint { + JoinConstraint::On(expr) => { let (left_on, right_on, null_equals_nulls) = process_join_on(expr, &left_rel, &right_rel)?; - - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, - left_on, - right_on, - Some(null_equals_nulls), - JoinType::Inner, - None, - None, - right_join_prefix.as_deref(), - )?; + (left_on, right_on, Some(null_equals_nulls), true) } - Inner(JoinConstraint::Using(idents)) => { + JoinConstraint::Using(idents) => { let on = idents .iter() .map(|i| col(i.value.clone())) .collect::>(); - - left_rel.inner = left_rel.inner.join( - right_rel.inner, - on.clone(), - on, - JoinType::Inner, - None, - None, - right_join_prefix.as_deref(), - )?; + (on.clone(), on, None, false) } - LeftOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on, null_equals_nulls) = - process_join_on(expr, &left_rel, &right_rel)?; - - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, - left_on, - right_on, - Some(null_equals_nulls), - JoinType::Left, - None, - None, - right_join_prefix.as_deref(), - )?; - } - RightOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on, null_equals_nulls) = - process_join_on(expr, &left_rel, &right_rel)?; - - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, - left_on, - right_on, - Some(null_equals_nulls), - JoinType::Right, - None, - None, - right_join_prefix.as_deref(), - )?; - } - - FullOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on, null_equals_nulls) = - process_join_on(expr, &left_rel, &right_rel)?; - - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, - left_on, - right_on, - Some(null_equals_nulls), - JoinType::Outer, - None, - None, - right_join_prefix.as_deref(), - )?; - } - CrossJoin => unsupported_sql_err!("CROSS JOIN"), - LeftSemi(_) => unsupported_sql_err!("LEFT SEMI JOIN"), - RightSemi(_) => unsupported_sql_err!("RIGHT SEMI JOIN"), - LeftAnti(_) => unsupported_sql_err!("LEFT ANTI JOIN"), - RightAnti(_) => unsupported_sql_err!("RIGHT ANTI JOIN"), - CrossApply => unsupported_sql_err!("CROSS APPLY"), - OuterApply => unsupported_sql_err!("OUTER APPLY"), - AsOf { .. } => unsupported_sql_err!("AS OF"), - join_type => unsupported_sql_err!("join type: {join_type:?}"), + JoinConstraint::Natural => unsupported_sql_err!("NATURAL JOIN not supported"), + JoinConstraint::None => unsupported_sql_err!("JOIN without ON/USING not supported"), }; + + left_rel.inner = left_rel.inner.join_with_null_safe_equal( + right_rel.inner, + left_on, + right_on, + null_eq_null, + join_type, + None, + None, + right_join_prefix.as_deref(), + keep_join_keys, + )?; } Ok(left_rel) diff --git a/tests/sql/test_joins.py b/tests/sql/test_joins.py index 48d7001df5..3914d43f00 100644 --- a/tests/sql/test_joins.py +++ b/tests/sql/test_joins.py @@ -1,5 +1,6 @@ +import pytest + import daft -from daft import col from daft.sql import SQLCatalog @@ -19,11 +20,13 @@ def test_joins_with_alias(): df1 = daft.from_pydict({"idx": [1, 2], "val": [10, 20]}) df2 = daft.from_pydict({"idx": [1, 2], "score": [0.1, 0.2]}) - df_sql = daft.sql("select * from df1 as foo join df2 as bar on (foo.idx=bar.idx) where bar.score>0.1") + catalog = SQLCatalog({"df1": df1, "df2": df2}) + + df_sql = daft.sql("select * from df1 as foo join df2 as bar on foo.idx=bar.idx where bar.score>0.1", catalog) actual = df_sql.collect().to_pydict() - expected = df1.join(df2, on="idx").filter(col("score") > 0.1).collect().to_pydict() + expected = {"idx": [2], "val": [20], "bar.idx": [2], "score": [0.2]} assert actual == expected @@ -47,31 +50,24 @@ def test_joins_with_wildcard_expansion(): df2 = daft.from_pydict({"idx": [3], "score": [0.1]}) df3 = daft.from_pydict({"idx": [1], "score": [0.1], "a": [1], "b": [2], "c": [3]}) + catalog = SQLCatalog({"df1": df1, "df2": df2, "df3": df3}) + df_sql = ( - daft.sql(""" + daft.sql( + """ select df3.* from df1 left join df2 on (df1.idx=df2.idx) left join df3 on (df1.idx=df3.idx) - """) - .collect() - .to_pydict() - ) - - expected = ( - df1.join(df2, on="idx", how="left") - .join(df3, on="idx", how="left") - .select( - "idx", - col("right.score").alias("score"), - col("a"), - col("b"), - col("c"), + """, + catalog, ) .collect() .to_pydict() ) + expected = {"idx": [1, None], "score": [0.1, None], "a": [1, None], "b": [2, None], "c": [3, None]} + assert df_sql == expected # make sure it works with exclusion patterns too @@ -86,9 +82,46 @@ def test_joins_with_wildcard_expansion(): .to_pydict() ) + expected = {"idx": [1, None], "score": [0.1, None]} + + assert df_sql == expected + + +def test_joins_with_duplicate_columns(): + table1 = daft.from_pydict({"id": [1, 2, 3, 4], "value": ["a", "b", "c", "d"]}) + + table2 = daft.from_pydict({"id": [2, 3, 4, 5], "value": ["b", "c", "d", "e"]}) + + catalog = SQLCatalog({"table1": table1, "table2": table2}) + + actual = daft.sql( + """ + SELECT * + FROM table1 t1 + LEFT JOIN table2 t2 on t2.id = t1.id; + """, + catalog, + ).collect() + expected = { - "idx": [1, 2], - "score": [0.1, None], + "id": [1, 2, 3, 4], + "value": ["a", "b", "c", "d"], + "t2.id": [None, 2, 3, 4], + "t2.value": [None, "b", "c", "d"], } + assert actual.to_pydict() == expected + + +@pytest.mark.parametrize("join_condition", ["idx=idax", "idax=idx"]) +def test_joins_without_compound_ident(join_condition): + df1 = daft.from_pydict({"idx": [1, None], "val": [10, 20]}) + df2 = daft.from_pydict({"idax": [1, None], "score": [0.1, 0.2]}) + + catalog = SQLCatalog({"df1": df1, "df2": df2}) + + df_sql = daft.sql(f"select * from df1 join df2 on {join_condition}", catalog).to_pydict() + + expected = {"idx": [1], "val": [10], "idax": [1], "score": [0.1]} + assert df_sql == expected