From 3720c2aa872ea8881e769b30885d21b64ab358dd Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Thu, 16 Jan 2025 11:44:59 -0800 Subject: [PATCH] refactor: logical op constructor+builder boundary (#3684) ## The problem Plan ops are created for various reasons through our code - from our dataframe or sql interfaces, to optimization rules, to even op constructors themselves which can sometimes create other ones. All of these cases generally go through the same new/try_new constructor for each op, which tries to accommodate all of these use cases. This creates complexity, adds unnecessary compute to planning time, and also conflates user input errors with Daft internal errors. For example, I don't expect any optimization rules to create unresolved expressions, expression resolution should only be done for the builder. Another example is the Join op, where inputs such as join_prefix and join_suffix are only applicable for renaming columns, which should also only happen via the builder. We recently added another initializer to some ops for that reason, but it bypasses the validation that is typically done and is not standardized across ops. ## My solution Every op should provide a `try_new` constructor which contain explicit checks for all the requirements about the op's state (one example would be that all expression columns exist in the schema), but otherwise should simply put those values into the struct without any modification and return it. - Functions such as `LogicalPlan::with_new_children` will just call `try_new`. - Other constructors/helpers may exist that explicitly provide additional functionality and ultimately call `try_new`. E.g. a `Join::rename_right_columns` to rename the right side columns that conflict with the left side, called to update the right side schema before calling `try_new`. - User input normalization, such as expression resolution, should be handled by the logical plan builder. After the logical plan op has been constructed, everything should be in a valid state from there on. --- Cargo.lock | 3 +- src/daft-dsl/Cargo.toml | 2 - src/daft-dsl/src/expr/mod.rs | 8 + src/daft-dsl/src/lib.rs | 12 +- src/daft-dsl/src/python.rs | 5 - src/daft-logical-plan/Cargo.toml | 1 + .../src/{builder.rs => builder/mod.rs} | 132 ++++++++++++---- .../src/builder/resolve_expr.rs} | 31 ++-- .../src/builder}/tests.rs | 11 +- src/daft-logical-plan/src/display.rs | 6 +- src/daft-logical-plan/src/lib.rs | 4 + src/daft-logical-plan/src/logical_plan.rs | 7 +- .../src/ops/actor_pool_project.rs | 16 +- src/daft-logical-plan/src/ops/agg.rs | 26 +-- src/daft-logical-plan/src/ops/concat.rs | 9 +- src/daft-logical-plan/src/ops/explode.rs | 37 ++--- src/daft-logical-plan/src/ops/filter.rs | 19 ++- src/daft-logical-plan/src/ops/join.rs | 149 ++++++------------ .../src/ops/monotonically_increasing_id.rs | 30 ++-- src/daft-logical-plan/src/ops/pivot.rs | 35 ++-- src/daft-logical-plan/src/ops/project.rs | 31 ++-- src/daft-logical-plan/src/ops/repartition.rs | 31 +--- .../src/ops/set_operations.rs | 6 +- src/daft-logical-plan/src/ops/sink.rs | 42 +---- src/daft-logical-plan/src/ops/sort.rs | 13 +- src/daft-logical-plan/src/ops/unpivot.rs | 47 +++--- .../rules/eliminate_cross_join.rs | 17 +- .../src/optimization/rules/unnest_subquery.rs | 57 +++---- tests/sql/test_binary_op_exprs.py | 8 +- 29 files changed, 345 insertions(+), 450 deletions(-) rename src/daft-logical-plan/src/{builder.rs => builder/mod.rs} (88%) rename src/{daft-dsl/src/resolve_expr/mod.rs => daft-logical-plan/src/builder/resolve_expr.rs} (93%) rename src/{daft-dsl/src/resolve_expr => daft-logical-plan/src/builder}/tests.rs (94%) diff --git a/Cargo.lock b/Cargo.lock index e4eb914800..18baebb301 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2133,10 +2133,8 @@ dependencies = [ "derive_more", "indexmap 2.7.0", "itertools 0.11.0", - "log", "pyo3", "serde", - "typed-builder 0.20.0", "typetag", ] @@ -2379,6 +2377,7 @@ dependencies = [ "serde", "snafu", "test-log", + "typed-builder 0.20.0", "uuid 1.11.0", ] diff --git a/src/daft-dsl/Cargo.toml b/src/daft-dsl/Cargo.toml index 6e04a977aa..87b2bc1bbc 100644 --- a/src/daft-dsl/Cargo.toml +++ b/src/daft-dsl/Cargo.toml @@ -10,10 +10,8 @@ daft-sketch = {path = "../daft-sketch", default-features = false} derive_more = {workspace = true} indexmap = {workspace = true} itertools = {workspace = true} -log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true} -typed-builder = {workspace = true} typetag = {workspace = true} [features] diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index f7acc6608c..6595a022d3 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -1435,3 +1435,11 @@ pub fn estimated_selectivity(expr: &Expr, schema: &Schema) -> f64 { Expr::Agg(_) => panic!("Aggregates are not allowed in WHERE clauses"), } } + +pub fn exprs_to_schema(exprs: &[ExprRef], input_schema: SchemaRef) -> DaftResult { + let fields = exprs + .iter() + .map(|e| e.to_field(&input_schema)) + .collect::>()?; + Ok(Arc::new(Schema::new(fields)?)) +} diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 3de03f5bf5..c29ca9f779 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -11,18 +11,16 @@ pub mod optimization; mod pyobj_serde; #[cfg(feature = "python")] pub mod python; -mod resolve_expr; mod treenode; pub use common_treenode; pub use expr::{ - binary_op, col, count_actor_pool_udfs, estimated_selectivity, has_agg, is_actor_pool_udf, - is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator, - OuterReferenceColumn, SketchType, Subquery, SubqueryPlan, + binary_op, col, count_actor_pool_udfs, estimated_selectivity, exprs_to_schema, has_agg, + is_actor_pool_udf, is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef, + Operator, OuterReferenceColumn, SketchType, Subquery, SubqueryPlan, }; pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] use pyo3::prelude::*; -pub use resolve_expr::{check_column_name_validity, ExprResolver}; #[cfg(feature = "python")] pub fn register_modules(parent: &Bound) -> PyResult<()> { @@ -41,10 +39,6 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_function(wrap_pyfunction!(python::initialize_udfs, parent)?)?; parent.add_function(wrap_pyfunction!(python::get_udf_names, parent)?)?; parent.add_function(wrap_pyfunction!(python::eq, parent)?)?; - parent.add_function(wrap_pyfunction!( - python::check_column_name_validity, - parent - )?)?; Ok(()) } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 654d282f78..df380bd154 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -257,11 +257,6 @@ pub fn eq(expr1: &PyExpr, expr2: &PyExpr) -> PyResult { Ok(expr1.expr == expr2.expr) } -#[pyfunction] -pub fn check_column_name_validity(name: &str, schema: &PySchema) -> PyResult<()> { - Ok(crate::check_column_name_validity(name, &schema.schema)?) -} - #[derive(FromPyObject)] pub enum ApproxPercentileInput { Single(f64), diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index 6d5f6b2fb0..b46cf1d222 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -21,6 +21,7 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} snafu = {workspace = true} +typed-builder = {workspace = true} uuid = {version = "1", features = ["v4"]} [dev-dependencies] diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder/mod.rs similarity index 88% rename from src/daft-logical-plan/src/builder.rs rename to src/daft-logical-plan/src/builder/mod.rs index 937fb45f44..244a42f933 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder/mod.rs @@ -1,3 +1,7 @@ +mod resolve_expr; +#[cfg(test)] +mod tests; + use std::{ collections::{HashMap, HashSet}, sync::Arc, @@ -12,6 +16,10 @@ use common_scan_info::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; use daft_core::join::{JoinStrategy, JoinType}; use daft_dsl::{col, ExprRef}; use daft_schema::schema::{Schema, SchemaRef}; +use indexmap::IndexSet; +#[cfg(feature = "python")] +pub use resolve_expr::py_check_column_name_validity; +use resolve_expr::ExprResolver; #[cfg(feature = "python")] use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, @@ -188,11 +196,19 @@ impl LogicalPlanBuilder { } pub fn select(&self, to_select: Vec) -> DaftResult { + let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); + + let (to_select, _) = expr_resolver.resolve(to_select, &self.schema())?; + let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), to_select)?.into(); Ok(self.with_new_plan(logical_plan)) } pub fn with_columns(&self, columns: Vec) -> DaftResult { + let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); + + let (columns, _) = expr_resolver.resolve(columns, &self.schema())?; + let fields = &self.schema().fields; let current_col_names = fields .iter() @@ -245,6 +261,10 @@ impl LogicalPlanBuilder { } pub fn filter(&self, predicate: ExprRef) -> DaftResult { + let expr_resolver = ExprResolver::default(); + + let (predicate, _) = expr_resolver.resolve_single(predicate, &self.schema())?; + let logical_plan: LogicalPlan = ops::Filter::try_new(self.plan.clone(), predicate)?.into(); Ok(self.with_new_plan(logical_plan)) } @@ -255,6 +275,10 @@ impl LogicalPlanBuilder { } pub fn explode(&self, to_explode: Vec) -> DaftResult { + let expr_resolver = ExprResolver::default(); + + let (to_explode, _) = expr_resolver.resolve(to_explode, &self.schema())?; + let logical_plan: LogicalPlan = ops::Explode::try_new(self.plan.clone(), to_explode)?.into(); Ok(self.with_new_plan(logical_plan)) @@ -264,25 +288,24 @@ impl LogicalPlanBuilder { &self, ids: Vec, values: Vec, - variable_name: &str, - value_name: &str, + variable_name: String, + value_name: String, ) -> DaftResult { + let expr_resolver = ExprResolver::default(); + let (values, _) = expr_resolver.resolve(values, &self.schema())?; + let (ids, _) = expr_resolver.resolve(ids, &self.schema())?; + let values = if values.is_empty() { - let ids_set = HashSet::<_>::from_iter(ids.iter()); + let ids_set = IndexSet::<_>::from_iter(ids.iter().cloned()); - self.schema() + let columns_set = self + .schema() .fields - .iter() - .filter_map(|(name, _)| { - let column = col(name.clone()); - - if ids_set.contains(&column) { - None - } else { - Some(column) - } - }) - .collect() + .keys() + .map(|name| col(name.clone())) + .collect::>(); + + columns_set.difference(&ids_set).cloned().collect() } else { values }; @@ -299,6 +322,10 @@ impl LogicalPlanBuilder { descending: Vec, nulls_first: Vec, ) -> DaftResult { + let expr_resolver = ExprResolver::default(); + + let (sort_by, _) = expr_resolver.resolve(sort_by, &self.schema())?; + let logical_plan: LogicalPlan = ops::Sort::try_new(self.plan.clone(), sort_by, descending, nulls_first)?.into(); Ok(self.with_new_plan(logical_plan)) @@ -309,28 +336,32 @@ impl LogicalPlanBuilder { num_partitions: Option, partition_by: Vec, ) -> DaftResult { - let logical_plan: LogicalPlan = ops::Repartition::try_new( + let expr_resolver = ExprResolver::default(); + + let (partition_by, _) = expr_resolver.resolve(partition_by, &self.schema())?; + + let logical_plan: LogicalPlan = ops::Repartition::new( self.plan.clone(), RepartitionSpec::Hash(HashRepartitionConfig::new(num_partitions, partition_by)), - )? + ) .into(); Ok(self.with_new_plan(logical_plan)) } pub fn random_shuffle(&self, num_partitions: Option) -> DaftResult { - let logical_plan: LogicalPlan = ops::Repartition::try_new( + let logical_plan: LogicalPlan = ops::Repartition::new( self.plan.clone(), RepartitionSpec::Random(RandomShuffleConfig::new(num_partitions)), - )? + ) .into(); Ok(self.with_new_plan(logical_plan)) } pub fn into_partitions(&self, num_partitions: usize) -> DaftResult { - let logical_plan: LogicalPlan = ops::Repartition::try_new( + let logical_plan: LogicalPlan = ops::Repartition::new( self.plan.clone(), RepartitionSpec::IntoPartitions(IntoPartitionsConfig::new(num_partitions)), - )? + ) .into(); Ok(self.with_new_plan(logical_plan)) } @@ -356,6 +387,12 @@ impl LogicalPlanBuilder { agg_exprs: Vec, groupby_exprs: Vec, ) -> DaftResult { + let groupby_resolver = ExprResolver::default(); + let (groupby_exprs, _) = groupby_resolver.resolve(groupby_exprs, &self.schema())?; + + let agg_resolver = ExprResolver::builder().groupby(&groupby_exprs).build(); + let (agg_exprs, _) = agg_resolver.resolve(agg_exprs, &self.schema())?; + let logical_plan: LogicalPlan = ops::Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into(); Ok(self.with_new_plan(logical_plan)) @@ -369,6 +406,14 @@ impl LogicalPlanBuilder { agg_expr: ExprRef, names: Vec, ) -> DaftResult { + let agg_resolver = ExprResolver::builder().groupby(&group_by).build(); + let (agg_expr, _) = agg_resolver.resolve_single(agg_expr, &self.schema())?; + + let expr_resolver = ExprResolver::default(); + let (group_by, _) = expr_resolver.resolve(group_by, &self.schema())?; + let (pivot_column, _) = expr_resolver.resolve_single(pivot_column, &self.schema())?; + let (value_column, _) = expr_resolver.resolve_single(value_column, &self.schema())?; + let pivot_logical_plan: LogicalPlan = ops::Pivot::try_new( self.plan.clone(), group_by, @@ -438,17 +483,36 @@ impl LogicalPlanBuilder { join_prefix: Option<&str>, keep_join_keys: bool, ) -> DaftResult { + let left_plan = self.plan.clone(); + let right_plan = right.into(); + + let expr_resolver = ExprResolver::default(); + + let (left_on, _) = expr_resolver.resolve(left_on, &left_plan.schema())?; + let (right_on, _) = expr_resolver.resolve(right_on, &right_plan.schema())?; + + // TODO(kevin): we should do this, but it has not been properly used before and is nondeterministic, which causes some tests to break + // let (left_on, right_on) = ops::Join::rename_join_keys(left_on, right_on); + + let (right_plan, right_on) = ops::Join::rename_right_columns( + left_plan.clone(), + right_plan, + left_on.clone(), + right_on, + join_type, + join_suffix, + join_prefix, + keep_join_keys, + )?; + let logical_plan: LogicalPlan = ops::Join::try_new( - self.plan.clone(), - right.into(), + left_plan, + right_plan, left_on, right_on, null_equals_nulls, join_type, join_strategy, - join_suffix, - join_prefix, - keep_join_keys, )? .into(); Ok(self.with_new_plan(logical_plan)) @@ -501,7 +565,7 @@ impl LogicalPlanBuilder { pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult { let logical_plan: LogicalPlan = - ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into(); + ops::MonotonicallyIncreasingId::try_new(self.plan.clone(), column_name)?.into(); Ok(self.with_new_plan(logical_plan)) } @@ -513,6 +577,16 @@ impl LogicalPlanBuilder { compression: Option, io_config: Option, ) -> DaftResult { + let partition_cols = partition_cols + .map(|cols| { + let expr_resolver = ExprResolver::default(); + + expr_resolver + .resolve(cols, &self.schema()) + .map(|(resolved_cols, _)| resolved_cols) + }) + .transpose()?; + let sink_info = SinkInfo::OutputFileInfo(OutputFileInfo::new( root_dir.into(), file_format, @@ -752,8 +826,8 @@ impl PyLogicalPlanBuilder { &self, ids: Vec, values: Vec, - variable_name: &str, - value_name: &str, + variable_name: String, + value_name: String, ) -> PyResult { let ids_exprs = ids .iter() diff --git a/src/daft-dsl/src/resolve_expr/mod.rs b/src/daft-logical-plan/src/builder/resolve_expr.rs similarity index 93% rename from src/daft-dsl/src/resolve_expr/mod.rs rename to src/daft-logical-plan/src/builder/resolve_expr.rs index 35d97bc9a8..cd98930ca7 100644 --- a/src/daft-dsl/src/resolve_expr/mod.rs +++ b/src/daft-logical-plan/src/builder/resolve_expr.rs @@ -1,6 +1,3 @@ -#[cfg(test)] -mod tests; - use std::{ cmp::Ordering, collections::{BinaryHeap, HashMap, HashSet}, @@ -10,15 +7,16 @@ use std::{ use common_error::{DaftError, DaftResult}; use common_treenode::{Transformed, TransformedResult, TreeNode}; use daft_core::prelude::*; +#[cfg(feature = "python")] +use daft_core::python::PySchema; +use daft_dsl::{col, functions::FunctionExpr, has_agg, is_actor_pool_udf, AggExpr, Expr, ExprRef}; +#[cfg(feature = "python")] +use pyo3::prelude::*; use typed_builder::TypedBuilder; -use crate::{ - col, expr::has_agg, functions::FunctionExpr, is_actor_pool_udf, AggExpr, Expr, ExprRef, -}; - // Calculates all the possible struct get expressions in a schema. // For each sugared string, calculates all possible corresponding expressions, in order of priority. -fn calculate_struct_expr_map(schema: &Schema) -> HashMap> { +pub fn calculate_struct_expr_map(schema: &Schema) -> HashMap> { #[derive(PartialEq, Eq)] struct BfsState<'a> { name: String, @@ -61,7 +59,7 @@ fn calculate_struct_expr_map(schema: &Schema) -> HashMap> { for child in children { pq.push(BfsState { name: format!("{}.{}", name, child.name), - expr: crate::functions::struct_::get(expr.clone(), &child.name), + expr: daft_dsl::functions::struct_::get(expr.clone(), &child.name), field: child, }); } @@ -76,7 +74,7 @@ fn calculate_struct_expr_map(schema: &Schema) -> HashMap> { /// /// For example, if col("a.b.c") could be interpreted as either col("a.b").struct.get("c") /// or col("a").struct.get("b.c"), this function will resolve it to col("a.b").struct.get("c"). -fn transform_struct_gets( +pub fn transform_struct_gets( expr: ExprRef, struct_expr_map: &HashMap>, ) -> DaftResult { @@ -103,7 +101,10 @@ fn transform_struct_gets( // Finds the names of all the wildcard expressions in an expression tree. // Needs the schema because column names with stars must not count as wildcards -fn find_wildcards(expr: ExprRef, struct_expr_map: &HashMap>) -> Vec> { +pub fn find_wildcards( + expr: ExprRef, + struct_expr_map: &HashMap>, +) -> Vec> { match expr.as_ref() { Expr::Column(name) => { if name.contains('*') { @@ -346,7 +347,7 @@ impl<'a> ExprResolver<'a> { } } -pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> { +fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> { let struct_expr_map = calculate_struct_expr_map(schema); let names = if name == "*" || name.ends_with(".*") { @@ -371,3 +372,9 @@ pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> Ok(()) } + +#[cfg(feature = "python")] +#[pyfunction(name = "check_column_name_validity")] +pub fn py_check_column_name_validity(name: &str, schema: &PySchema) -> PyResult<()> { + Ok(check_column_name_validity(name, &schema.schema)?) +} diff --git a/src/daft-dsl/src/resolve_expr/tests.rs b/src/daft-logical-plan/src/builder/tests.rs similarity index 94% rename from src/daft-dsl/src/resolve_expr/tests.rs rename to src/daft-logical-plan/src/builder/tests.rs index dcb3147207..f8e98526f0 100644 --- a/src/daft-dsl/src/resolve_expr/tests.rs +++ b/src/daft-logical-plan/src/builder/tests.rs @@ -1,4 +1,11 @@ -use super::*; +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use daft_core::prelude::Schema; +use daft_dsl::{col, ExprRef}; +use daft_schema::{dtype::DataType, field::Field}; + +use super::resolve_expr::*; fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { let struct_expr_map = calculate_struct_expr_map(schema); @@ -7,7 +14,7 @@ fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult DaftResult<()> { - use crate::functions::struct_::get as struct_get; + use daft_dsl::functions::struct_::get as struct_get; let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index 84db90d273..3958477b77 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -94,7 +94,7 @@ mod test { startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? .limit(1000, false)? - .add_monotonically_increasing_id(None)? + .add_monotonically_increasing_id(Some("id2"))? .distinct()? .sort(vec![col("last_name")], vec![false], vec![false])? .build(); @@ -124,7 +124,7 @@ Filter2["Filter: col(first_name) == lit('hello')"] Join3["Join: Type = Inner Strategy = Auto On = col(id) -Output schema = id#Int32, text#Utf8, first_name#Utf8, last_name#Utf8"] +Output schema = id#Int32, text#Utf8, id2#UInt64, first_name#Utf8, last_name#Utf8"] Filter4["Filter: col(id) == lit(1)"] Source5["PlaceHolder: Source ID = 0 @@ -168,7 +168,7 @@ Project1 --> Limit0 startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? .limit(1000, false)? - .add_monotonically_increasing_id(None)? + .add_monotonically_increasing_id(Some("id2"))? .distinct()? .sort(vec![col("last_name")], vec![false], vec![false])? .build(); diff --git a/src/daft-logical-plan/src/lib.rs b/src/daft-logical-plan/src/lib.rs index 5296d99a23..317a92535e 100644 --- a/src/daft-logical-plan/src/lib.rs +++ b/src/daft-logical-plan/src/lib.rs @@ -41,6 +41,10 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_function(wrap_pyfunction!( + builder::py_check_column_name_validity, + parent + )?)?; Ok(()) } diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 4abd15fcaa..7a3ccd04b2 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -334,12 +334,12 @@ impl LogicalPlan { Self::Limit(Limit { limit, eager, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager)), Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()), Self::Sort(Sort { sort_by, descending, nulls_first, .. }) => Self::Sort(Sort::try_new(input.clone(), sort_by.clone(), descending.clone(), nulls_first.clone()).unwrap()), - Self::Repartition(Repartition { repartition_spec: scheme_config, .. }) => Self::Repartition(Repartition::try_new(input.clone(), scheme_config.clone()).unwrap()), + Self::Repartition(Repartition { repartition_spec: scheme_config, .. }) => Self::Repartition(Repartition::new(input.clone(), scheme_config.clone())), Self::Distinct(_) => Self::Distinct(Distinct::new(input.clone())), Self::Aggregate(Aggregate { aggregations, groupby, ..}) => Self::Aggregate(Aggregate::try_new(input.clone(), aggregations.clone(), groupby.clone()).unwrap()), Self::Pivot(Pivot { group_by, pivot_column, value_column, aggregation, names, ..}) => Self::Pivot(Pivot::try_new(input.clone(), group_by.clone(), pivot_column.clone(), value_column.clone(), aggregation.into(), names.clone()).unwrap()), Self::Sink(Sink { sink_info, .. }) => Self::Sink(Sink::try_new(input.clone(), sink_info.clone()).unwrap()), - Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))), + Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::try_new(input.clone(), Some(column_name)).unwrap()), Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot::new(input.clone(), ids.clone(), values.clone(), variable_name.clone(), value_name.clone(), output_schema.clone())), Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), @@ -361,9 +361,6 @@ impl LogicalPlan { null_equals_nulls.clone(), *join_type, *join_strategy, - None, // The suffix is already eagerly computed in the constructor - None, // the prefix is already eagerly computed in the constructor - false // this is already eagerly computed in the constructor ).unwrap()), _ => panic!("Logical op {} has one input, but got two", self), }, diff --git a/src/daft-logical-plan/src/ops/actor_pool_project.rs b/src/daft-logical-plan/src/ops/actor_pool_project.rs index d9a2aa0c4b..f76d9b9fc3 100644 --- a/src/daft-logical-plan/src/ops/actor_pool_project.rs +++ b/src/daft-logical-plan/src/ops/actor_pool_project.rs @@ -3,16 +3,15 @@ use std::sync::Arc; use common_error::DaftError; use common_resource_request::ResourceRequest; use daft_dsl::{ - count_actor_pool_udfs, + count_actor_pool_udfs, exprs_to_schema, functions::python::{get_concurrency, get_resource_request, get_udf_names}, - ExprRef, ExprResolver, + ExprRef, }; -use daft_schema::schema::{Schema, SchemaRef}; +use daft_schema::schema::SchemaRef; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{CreationSnafu, Error, Result}, + logical_plan::{Error, Result}, stats::StatsState, LogicalPlan, }; @@ -28,17 +27,12 @@ pub struct ActorPoolProject { impl ActorPoolProject { pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); - let (projection, fields) = expr_resolver - .resolve(projection, input.schema().as_ref()) - .context(CreationSnafu)?; - let num_actor_pool_udfs: usize = count_actor_pool_udfs(&projection); if !num_actor_pool_udfs == 1 { return Err(Error::CreationError { source: DaftError::InternalError(format!("Expected ActorPoolProject to have exactly 1 actor pool UDF expression but found: {num_actor_pool_udfs}")) }); } - let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); + let projected_schema = exprs_to_schema(&projection, input.schema())?; Ok(Self { input, diff --git a/src/daft-logical-plan/src/ops/agg.rs b/src/daft-logical-plan/src/ops/agg.rs index 060a77bd9e..826e98c9a8 100644 --- a/src/daft-logical-plan/src/ops/agg.rs +++ b/src/daft-logical-plan/src/ops/agg.rs @@ -1,12 +1,11 @@ use std::sync::Arc; -use daft_dsl::{ExprRef, ExprResolver}; -use daft_schema::schema::{Schema, SchemaRef}; +use daft_dsl::{exprs_to_schema, ExprRef}; +use daft_schema::schema::SchemaRef; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{self, CreationSnafu}, + logical_plan::{self}, stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -36,21 +35,10 @@ impl Aggregate { aggregations: Vec, groupby: Vec, ) -> logical_plan::Result { - let upstream_schema = input.schema(); - - let agg_resolver = ExprResolver::builder().groupby(&groupby).build(); - let (aggregations, aggregation_fields) = agg_resolver - .resolve(aggregations, &upstream_schema) - .context(CreationSnafu)?; - - let groupby_resolver = ExprResolver::default(); - let (groupby, groupby_fields) = groupby_resolver - .resolve(groupby, &upstream_schema) - .context(CreationSnafu)?; - - let fields = [groupby_fields, aggregation_fields].concat(); - - let output_schema = Schema::new(fields).context(CreationSnafu)?.into(); + let output_schema = exprs_to_schema( + &[groupby.as_slice(), aggregations.as_slice()].concat(), + input.schema(), + )?; Ok(Self { input, diff --git a/src/daft-logical-plan/src/ops/concat.rs b/src/daft-logical-plan/src/ops/concat.rs index fb18441c4c..207bceffed 100644 --- a/src/daft-logical-plan/src/ops/concat.rs +++ b/src/daft-logical-plan/src/ops/concat.rs @@ -18,14 +18,6 @@ pub struct Concat { } impl Concat { - pub(crate) fn new(input: Arc, other: Arc) -> Self { - Self { - input, - other, - stats_state: StatsState::NotMaterialized, - } - } - pub(crate) fn try_new( input: Arc, other: Arc, @@ -39,6 +31,7 @@ impl Concat { ))) .context(CreationSnafu); } + Ok(Self { input, other, diff --git a/src/daft-logical-plan/src/ops/explode.rs b/src/daft-logical-plan/src/ops/explode.rs index c1f6ee278a..ed214430e8 100644 --- a/src/daft-logical-plan/src/ops/explode.rs +++ b/src/daft-logical-plan/src/ops/explode.rs @@ -1,12 +1,11 @@ use std::sync::Arc; -use daft_dsl::{ExprRef, ExprResolver}; +use daft_dsl::{exprs_to_schema, ExprRef}; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{self, CreationSnafu}, + logical_plan::{self}, stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -26,35 +25,23 @@ impl Explode { input: Arc, to_explode: Vec, ) -> logical_plan::Result { - let upstream_schema = input.schema(); - - let expr_resolver = ExprResolver::default(); + let exploded_schema = { + let explode_exprs = to_explode + .iter() + .cloned() + .map(daft_functions::list::explode) + .collect::>(); - let (to_explode, _) = expr_resolver - .resolve(to_explode, &upstream_schema) - .context(CreationSnafu)?; + let explode_schema = exprs_to_schema(&explode_exprs, input.schema())?; - let explode_exprs = to_explode - .iter() - .cloned() - .map(daft_functions::list::explode) - .collect::>(); - let exploded_schema = { - let explode_schema = { - let explode_fields = explode_exprs - .iter() - .map(|e| e.to_field(&upstream_schema)) - .collect::>>() - .context(CreationSnafu)?; - Schema::new(explode_fields).context(CreationSnafu)? - }; - let fields = upstream_schema + let fields = input + .schema() .fields .iter() .map(|(name, field)| explode_schema.fields.get(name).unwrap_or(field)) .cloned() .collect::>(); - Schema::new(fields).context(CreationSnafu)?.into() + Schema::new(fields)?.into() }; Ok(Self { diff --git a/src/daft-logical-plan/src/ops/filter.rs b/src/daft-logical-plan/src/ops/filter.rs index 2a046b66e7..a8f6507641 100644 --- a/src/daft-logical-plan/src/ops/filter.rs +++ b/src/daft-logical-plan/src/ops/filter.rs @@ -2,11 +2,11 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::prelude::*; -use daft_dsl::{estimated_selectivity, ExprRef, ExprResolver}; +use daft_dsl::{estimated_selectivity, ExprRef}; use snafu::ResultExt; use crate::{ - logical_plan::{CreationSnafu, Result}, + logical_plan::{self, CreationSnafu}, stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -21,17 +21,16 @@ pub struct Filter { } impl Filter { - pub(crate) fn try_new(input: Arc, predicate: ExprRef) -> Result { - let expr_resolver = ExprResolver::default(); + pub(crate) fn try_new( + input: Arc, + predicate: ExprRef, + ) -> logical_plan::Result { + let dtype = predicate.to_field(&input.schema())?.dtype; - let (predicate, field) = expr_resolver - .resolve_single(predicate, &input.schema()) - .context(CreationSnafu)?; - - if !matches!(field.dtype, DataType::Boolean) { + if !matches!(dtype, DataType::Boolean) { return Err(DaftError::ValueError(format!( "Expected expression {predicate} to resolve to type Boolean, but received: {}", - field.dtype + dtype ))) .context(CreationSnafu); } diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index 18dede0720..f7ad07737c 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -9,7 +9,7 @@ use daft_dsl::{ col, join::{get_common_join_keys, infer_join_schema}, optimization::replace_columns_with_expressions, - Expr, ExprRef, ExprResolver, + Expr, ExprRef, }; use itertools::Itertools; use snafu::ResultExt; @@ -19,7 +19,7 @@ use crate::{ logical_plan::{self, CreationSnafu}, ops::Project, stats::{ApproxStats, PlanStats, StatsState}, - LogicalPlan, + LogicalPlan, LogicalPlanRef, }; #[derive(Clone, Debug, PartialEq, Eq)] @@ -51,30 +51,6 @@ impl std::hash::Hash for Join { } impl Join { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - left: Arc, - right: Arc, - left_on: Vec, - right_on: Vec, - null_equals_nulls: Option>, - join_type: JoinType, - join_strategy: Option, - output_schema: SchemaRef, - ) -> Self { - Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - } - } - #[allow(clippy::too_many_arguments)] pub(crate) fn try_new( left: Arc, @@ -84,45 +60,11 @@ impl Join { null_equals_nulls: Option>, join_type: JoinType, join_strategy: Option, - join_suffix: Option<&str>, - join_prefix: Option<&str>, - // if true, then duplicate column names will be kept - // ex: select * from a left join b on a.id = b.id - // if true, then the resulting schema will have two columns named id (id, and b.id) - // In SQL the join column is always kept, while in dataframes it is not - keep_join_keys: bool, ) -> logical_plan::Result { - let expr_resolver = ExprResolver::default(); - - let (left_on, _) = expr_resolver - .resolve(left_on, &left.schema()) - .context(CreationSnafu)?; - let (right_on, _) = expr_resolver - .resolve(right_on, &right.schema()) - .context(CreationSnafu)?; - - let (unique_left_on, unique_right_on) = - Self::rename_join_keys(left_on.clone(), right_on.clone()); - - let left_fields: Vec = unique_left_on - .iter() - .map(|e| e.to_field(&left.schema())) - .collect::>>() - .context(CreationSnafu)?; - - let right_fields: Vec = unique_right_on - .iter() - .map(|e| e.to_field(&right.schema())) - .collect::>>() - .context(CreationSnafu)?; - - for (on_exprs, on_fields) in [ - (&unique_left_on, &left_fields), - (&unique_right_on, &right_fields), - ] { - for (field, expr) in on_fields.iter().zip(on_exprs.iter()) { + for (on_exprs, side) in [(&left_on, &left), (&right_on, &right)] { + for expr in on_exprs { // Null type check for both fields and expressions - if matches!(field.dtype, DataType::Null) { + if matches!(expr.to_field(&side.schema())?.dtype, DataType::Null) { return Err(DaftError::ValueError(format!( "Can't join on null type expressions: {expr}" ))) @@ -141,22 +83,42 @@ impl Join { } } - if matches!(join_type, JoinType::Anti | JoinType::Semi) { - // The output schema is the same as the left input schema for anti and semi joins. + let output_schema = infer_join_schema( + &left.schema(), + &right.schema(), + &left_on, + &right_on, + join_type, + )?; - let output_schema = left.schema(); + Ok(Self { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + output_schema, + stats_state: StatsState::NotMaterialized, + }) + } - Ok(Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - }) + /// Add a project under the right side plan when necessary in order to resolve naming conflicts + /// between left and right side columns. + #[allow(clippy::too_many_arguments)] + pub(crate) fn rename_right_columns( + left: LogicalPlanRef, + right: LogicalPlanRef, + left_on: Vec, + right_on: Vec, + join_type: JoinType, + join_suffix: Option<&str>, + join_prefix: Option<&str>, + keep_join_keys: bool, + ) -> DaftResult<(LogicalPlanRef, Vec)> { + if matches!(join_type, JoinType::Anti | JoinType::Semi) { + Ok((right, right_on)) } else { let common_join_keys: HashSet<_> = get_common_join_keys(left_on.as_slice(), right_on.as_slice()) @@ -202,8 +164,8 @@ impl Join { }) .collect(); - let (right, right_on) = if right_rename_mapping.is_empty() { - (right, right_on) + if right_rename_mapping.is_empty() { + Ok((right, right_on)) } else { // projection to update the right side with the new column names let new_right_projection: Vec<_> = right_names @@ -230,29 +192,8 @@ impl Join { .map(|expr| replace_columns_with_expressions(expr, &right_on_replace_map)) .collect::>(); - (new_right.into(), new_right_on) - }; - - let output_schema = infer_join_schema( - &left.schema(), - &right.schema(), - &left_on, - &right_on, - join_type, - ) - .context(CreationSnafu)?; - - Ok(Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - }) + Ok((new_right.into(), new_right_on)) + } } } @@ -282,8 +223,8 @@ impl Join { /// ``` /// /// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649). - - fn rename_join_keys( + #[allow(dead_code)] + pub(crate) fn rename_join_keys( left_exprs: Vec>, right_exprs: Vec>, ) -> (Vec>, Vec>) { diff --git a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs index 170296fa2a..ea288ab446 100644 --- a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs +++ b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs @@ -2,7 +2,11 @@ use std::sync::Arc; use daft_core::prelude::*; -use crate::{stats::StatsState, LogicalPlan}; +use crate::{ + logical_plan::{self}, + stats::StatsState, + LogicalPlan, +}; #[derive(Hash, Eq, PartialEq, Debug, Clone)] pub struct MonotonicallyIncreasingId { @@ -13,25 +17,23 @@ pub struct MonotonicallyIncreasingId { } impl MonotonicallyIncreasingId { - pub(crate) fn new(input: Arc, column_name: Option<&str>) -> Self { + pub(crate) fn try_new( + input: Arc, + column_name: Option<&str>, + ) -> logical_plan::Result { let column_name = column_name.unwrap_or("id"); - let mut schema_with_id_index_map = input.schema().fields.clone(); - schema_with_id_index_map.shift_insert( - 0, - column_name.to_string(), - Field::new(column_name, DataType::UInt64), - ); - let schema_with_id = Schema { - fields: schema_with_id_index_map, - }; - - Self { + let fields_with_id = std::iter::once(Field::new(column_name, DataType::UInt64)) + .chain(input.schema().fields.values().cloned()) + .collect(); + let schema_with_id = Schema::new(fields_with_id)?; + + Ok(Self { input, schema: Arc::new(schema_with_id), column_name: column_name.to_string(), stats_state: StatsState::NotMaterialized, - } + }) } pub(crate) fn with_materialized_stats(mut self) -> Self { diff --git a/src/daft-logical-plan/src/ops/pivot.rs b/src/daft-logical-plan/src/ops/pivot.rs index 57ee3bb1c5..cb24e47232 100644 --- a/src/daft-logical-plan/src/ops/pivot.rs +++ b/src/daft-logical-plan/src/ops/pivot.rs @@ -1,14 +1,13 @@ use std::sync::Arc; -use common_error::DaftError; +use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; -use daft_dsl::{AggExpr, Expr, ExprRef, ExprResolver}; +use daft_dsl::{AggExpr, Expr, ExprRef}; use daft_schema::schema::{Schema, SchemaRef}; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{self, CreationSnafu}, + logical_plan::{self}, stats::StatsState, LogicalPlan, }; @@ -34,24 +33,6 @@ impl Pivot { aggregation: ExprRef, names: Vec, ) -> logical_plan::Result { - let upstream_schema = input.schema(); - - let agg_resolver = ExprResolver::builder().groupby(&group_by).build(); - let (aggregation, _) = agg_resolver - .resolve_single(aggregation, &upstream_schema) - .context(CreationSnafu)?; - - let expr_resolver = ExprResolver::default(); - let (group_by, group_by_fields) = expr_resolver - .resolve(group_by, &upstream_schema) - .context(CreationSnafu)?; - let (pivot_column, _) = expr_resolver - .resolve_single(pivot_column, &upstream_schema) - .context(CreationSnafu)?; - let (value_column, value_col_field) = expr_resolver - .resolve_single(value_column, &upstream_schema) - .context(CreationSnafu)?; - let Expr::Agg(agg_expr) = aggregation.as_ref() else { return Err(DaftError::ValueError(format!( "Pivot only supports using top level aggregation expressions, received {aggregation}", @@ -60,16 +41,22 @@ impl Pivot { }; let output_schema = { - let value_col_dtype = value_col_field.dtype; + let value_col_dtype = value_column.to_field(&input.schema())?.dtype; let pivot_value_fields = names .iter() .map(|f| Field::new(f, value_col_dtype.clone())) .collect::>(); + + let group_by_fields = group_by + .iter() + .map(|expr| expr.to_field(&input.schema())) + .collect::>>()?; + let fields = group_by_fields .into_iter() .chain(pivot_value_fields) .collect::>(); - Schema::new(fields).context(CreationSnafu)?.into() + Schema::new(fields)?.into() }; Ok(Self { diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index 165d989a09..171899203c 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -1,14 +1,14 @@ use std::sync::Arc; +use common_error::DaftResult; use common_treenode::Transformed; use daft_core::prelude::*; -use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef, ExprResolver}; +use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; -use snafu::ResultExt; use crate::{ - logical_plan::{CreationSnafu, Result}, + logical_plan::{self}, stats::StatsState, LogicalPlan, }; @@ -23,18 +23,20 @@ pub struct Project { } impl Project { - pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); - - let (projection, fields) = expr_resolver - .resolve(projection, &input.schema()) - .context(CreationSnafu)?; - + pub(crate) fn try_new( + input: Arc, + projection: Vec, + ) -> logical_plan::Result { // Factor the projection and see if there are any substitutions to factor out. let (factored_input, factored_projection) = Self::try_factor_subexpressions(input, projection)?; - let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); + let fields = factored_projection + .iter() + .map(|expr| expr.to_field(&factored_input.schema())) + .collect::>()?; + + let projected_schema = Schema::new(fields)?.into(); Ok(Self { input: factored_input, @@ -45,7 +47,10 @@ impl Project { } /// Create a new Projection using the specified output schema - pub(crate) fn new_from_schema(input: Arc, schema: SchemaRef) -> Result { + pub(crate) fn new_from_schema( + input: Arc, + schema: SchemaRef, + ) -> logical_plan::Result { let expr: Vec = schema .names() .into_iter() @@ -75,7 +80,7 @@ impl Project { fn try_factor_subexpressions( input: Arc, projection: Vec, - ) -> Result<(Arc, Vec)> { + ) -> logical_plan::Result<(Arc, Vec)> { // Given construction parameters for a projection, // see if we can factor out common subexpressions. // Returns a new set of projection parameters diff --git a/src/daft-logical-plan/src/ops/repartition.rs b/src/daft-logical-plan/src/ops/repartition.rs index ac12970c49..d67ccd86f7 100644 --- a/src/daft-logical-plan/src/ops/repartition.rs +++ b/src/daft-logical-plan/src/ops/repartition.rs @@ -1,13 +1,6 @@ use std::sync::Arc; -use common_error::DaftResult; -use daft_dsl::ExprResolver; - -use crate::{ - partitioning::{HashRepartitionConfig, RepartitionSpec}, - stats::StatsState, - LogicalPlan, -}; +use crate::{partitioning::RepartitionSpec, stats::StatsState, LogicalPlan}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Repartition { @@ -18,28 +11,12 @@ pub struct Repartition { } impl Repartition { - pub(crate) fn try_new( - input: Arc, - repartition_spec: RepartitionSpec, - ) -> DaftResult { - let repartition_spec = match repartition_spec { - RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by }) => { - let expr_resolver = ExprResolver::default(); - - let (resolved_by, _) = expr_resolver.resolve(by, &input.schema())?; - RepartitionSpec::Hash(HashRepartitionConfig { - num_partitions, - by: resolved_by, - }) - } - RepartitionSpec::Random(_) | RepartitionSpec::IntoPartitions(_) => repartition_spec, - }; - - Ok(Self { + pub(crate) fn new(input: Arc, repartition_spec: RepartitionSpec) -> Self { + Self { input, repartition_spec, stats_state: StatsState::NotMaterialized, - }) + } } pub(crate) fn with_materialized_stats(mut self) -> Self { diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs index 42009182b6..64521f4ed9 100644 --- a/src/daft-logical-plan/src/ops/set_operations.rs +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -47,9 +47,6 @@ fn intersect_or_except_plan( Some(vec![true; left_on_size]), join_type, None, - None, - None, - false, ); join.map(|j| Distinct::new(j.into()).into()) } @@ -303,8 +300,7 @@ impl Union { } else { (self.lhs.clone(), self.rhs.clone()) }; - // we don't want to use `try_new` as we have already checked the schema - let concat = LogicalPlan::Concat(Concat::new(lhs, rhs)); + let concat = LogicalPlan::Concat(Concat::try_new(lhs, rhs)?); if self.is_all { Ok(concat) } else { diff --git a/src/daft-logical-plan/src/ops/sink.rs b/src/daft-logical-plan/src/ops/sink.rs index e5eb9f3f2e..46aa17b1dd 100644 --- a/src/daft-logical-plan/src/ops/sink.rs +++ b/src/daft-logical-plan/src/ops/sink.rs @@ -1,15 +1,14 @@ use std::sync::Arc; -use common_error::DaftResult; use daft_core::prelude::*; -use daft_dsl::ExprResolver; #[cfg(feature = "python")] use crate::sink_info::CatalogType; use crate::{ + logical_plan::{self}, sink_info::SinkInfo, stats::{PlanStats, StatsState}, - LogicalPlan, OutputFileInfo, + LogicalPlan, }; #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -23,41 +22,12 @@ pub struct Sink { } impl Sink { - pub(crate) fn try_new(input: Arc, sink_info: Arc) -> DaftResult { + pub(crate) fn try_new( + input: Arc, + sink_info: Arc, + ) -> logical_plan::Result { let schema = input.schema(); - // replace partition columns with resolved columns - let sink_info = match sink_info.as_ref() { - SinkInfo::OutputFileInfo(OutputFileInfo { - root_dir, - file_format, - partition_cols, - compression, - io_config, - }) => { - let expr_resolver = ExprResolver::default(); - - let resolved_partition_cols = partition_cols - .clone() - .map(|cols| { - expr_resolver - .resolve(cols, &schema) - .map(|(resolved_cols, _)| resolved_cols) - }) - .transpose()?; - - Arc::new(SinkInfo::OutputFileInfo(OutputFileInfo { - root_dir: root_dir.clone(), - file_format: *file_format, - partition_cols: resolved_partition_cols, - compression: compression.clone(), - io_config: io_config.clone(), - })) - } - #[cfg(feature = "python")] - SinkInfo::CatalogInfo(_) => sink_info, - }; - let fields = match sink_info.as_ref() { SinkInfo::OutputFileInfo(output_file_info) => { let mut fields = vec![Field::new("path", DataType::Utf8)]; diff --git a/src/daft-logical-plan/src/ops/sort.rs b/src/daft-logical-plan/src/ops/sort.rs index 9c2cd046fd..b5196c617c 100644 --- a/src/daft-logical-plan/src/ops/sort.rs +++ b/src/daft-logical-plan/src/ops/sort.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::prelude::*; -use daft_dsl::{ExprRef, ExprResolver}; +use daft_dsl::{exprs_to_schema, ExprRef}; use itertools::Itertools; use snafu::ResultExt; @@ -32,15 +32,10 @@ impl Sort { .context(CreationSnafu); } - let expr_resolver = ExprResolver::default(); + // TODO(Kevin): make sort by expression names unique so that we can do things like sort(col("a"), col("a") + col("b")) + let sort_by_schema = exprs_to_schema(&sort_by, input.schema())?; - let (sort_by, sort_by_fields) = expr_resolver - .resolve(sort_by, &input.schema()) - .context(CreationSnafu)?; - - let sort_by_resolved_schema = Schema::new(sort_by_fields).context(CreationSnafu)?; - - for (field, expr) in sort_by_resolved_schema.fields.values().zip(sort_by.iter()) { + for (field, expr) in sort_by_schema.fields.values().zip(sort_by.iter()) { // Disallow sorting by null, binary, and boolean columns. // TODO(Clark): This is a port of an existing constraint, we should look at relaxing this. if let dt @ (DataType::Null | DataType::Binary) = &field.dtype { diff --git a/src/daft-logical-plan/src/ops/unpivot.rs b/src/daft-logical-plan/src/ops/unpivot.rs index 293fe1cad6..e6dda83bef 100644 --- a/src/daft-logical-plan/src/ops/unpivot.rs +++ b/src/daft-logical-plan/src/ops/unpivot.rs @@ -1,8 +1,8 @@ use std::sync::Arc; -use common_error::DaftError; +use common_error::{DaftError, DaftResult}; use daft_core::{prelude::*, utils::supertype::try_get_supertype}; -use daft_dsl::{ExprRef, ExprResolver}; +use daft_dsl::ExprRef; use itertools::Itertools; use snafu::ResultExt; @@ -48,8 +48,8 @@ impl Unpivot { input: Arc, ids: Vec, values: Vec, - variable_name: &str, - value_name: &str, + variable_name: String, + value_name: String, ) -> logical_plan::Result { if values.is_empty() { return Err(DaftError::ValueError( @@ -58,40 +58,29 @@ impl Unpivot { .context(CreationSnafu); } - let expr_resolver = ExprResolver::default(); - - let input_schema = input.schema(); - let (values, values_fields) = expr_resolver - .resolve(values, &input_schema) - .context(CreationSnafu)?; - - let value_dtype = values_fields + let value_dtype = values .iter() - .map(|f| f.dtype.clone()) - .try_reduce(|a, b| try_get_supertype(&a, &b)) - .context(CreationSnafu)? - .unwrap(); + .map(|expr| Ok(expr.to_field(&input.schema())?.dtype)) + .reduce(|a, b| try_get_supertype(&a?, &b?)) + .unwrap()?; - let variable_field = Field::new(variable_name, DataType::Utf8); - let value_field = Field::new(value_name, value_dtype); + let variable_field = Field::new(&variable_name, DataType::Utf8); + let value_field = Field::new(&value_name, value_dtype); - let (ids, ids_fields) = expr_resolver - .resolve(ids, &input_schema) - .context(CreationSnafu)?; - - let output_fields = ids_fields - .into_iter() - .chain([variable_field, value_field]) - .collect::>(); + let output_fields = ids + .iter() + .map(|id| id.to_field(&input.schema())) + .chain([Ok(variable_field), Ok(value_field)]) + .collect::>>()?; - let output_schema = Schema::new(output_fields).context(CreationSnafu)?.into(); + let output_schema = Schema::new(output_fields)?.into(); Ok(Self { input, ids, values, - variable_name: variable_name.to_string(), - value_name: value_name.to_string(), + variable_name, + value_name, output_schema, stats_state: StatsState::NotMaterialized, }) diff --git a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs index e9e3a2e524..cd192f0df9 100644 --- a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs +++ b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs @@ -303,12 +303,10 @@ fn find_inner_join( if !join_keys.is_empty() { all_join_keys.insert_all(join_keys.iter()); let right_input = rights.remove(i); - let join_schema = left_input - .schema() - .non_distinct_union(right_input.schema().as_ref()); let (left_keys, right_keys) = join_keys.iter().cloned().unzip(); - return Ok(LogicalPlan::Join(Join::new( + + return Ok(LogicalPlan::Join(Join::try_new( left_input, right_input, left_keys, @@ -316,8 +314,7 @@ fn find_inner_join( None, JoinType::Inner, None, - Arc::new(join_schema), - )) + )?) .arced()); } } @@ -325,11 +322,8 @@ fn find_inner_join( // no matching right plan had any join keys, cross join with the first right // plan let right = rights.remove(0); - let join_schema = left_input - .schema() - .non_distinct_union(right.schema().as_ref()); - Ok(LogicalPlan::Join(Join::new( + Ok(LogicalPlan::Join(Join::try_new( left_input, right, vec![], @@ -337,8 +331,7 @@ fn find_inner_join( None, JoinType::Inner, None, - Arc::new(join_schema), - )) + )?) .arced()) } diff --git a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs index 5039cc9767..18e05a8218 100644 --- a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs +++ b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs @@ -119,35 +119,33 @@ impl UnnestScalarSubquery { let (decorrelated_subquery, subquery_on, input_on) = pull_up_correlated_cols(subquery_plan)?; - if subquery_on.is_empty() { - // uncorrelated scalar subquery - Ok(Arc::new(LogicalPlan::Join(Join::try_new( - curr_input, - decorrelated_subquery, - vec![], - vec![], - None, - JoinType::Inner, - None, - None, - None, - false, - )?))) + // use inner join when uncorrelated so that filter can be pushed into join and other optimizations + let join_type = if subquery_on.is_empty() { + JoinType::Inner } else { - // correlated scalar subquery - Ok(Arc::new(LogicalPlan::Join(Join::try_new( - curr_input, - decorrelated_subquery, - input_on, - subquery_on, - None, - JoinType::Left, - None, - None, - None, - false, - )?))) - } + JoinType::Left + }; + + let (decorrelated_subquery, subquery_on) = Join::rename_right_columns( + curr_input.clone(), + decorrelated_subquery, + input_on.clone(), + subquery_on, + join_type, + None, + None, + false, + )?; + + Ok(Arc::new(LogicalPlan::Join(Join::try_new( + curr_input, + decorrelated_subquery, + input_on, + subquery_on, + None, + join_type, + None, + )?))) })?; Ok(Transformed::yes((new_input, new_exprs))) @@ -335,9 +333,6 @@ impl OptimizerRule for UnnestPredicateSubquery { None, join_type, None, - None, - None, - false )?))) })?; diff --git a/tests/sql/test_binary_op_exprs.py b/tests/sql/test_binary_op_exprs.py index cfc47efb44..c4a16507c3 100644 --- a/tests/sql/test_binary_op_exprs.py +++ b/tests/sql/test_binary_op_exprs.py @@ -75,20 +75,20 @@ def test_unsupported_div_floor(): _assert_df_op_raise( lambda: df.select(daft.col("A") // daft.col("C")).collect(), - "TypeError Cannot perform floor divide on types: Int64, Boolean", + "Cannot perform floor divide on types: Int64, Boolean", ) _assert_df_op_raise( lambda: df.select(daft.col("C") // daft.col("A")).collect(), - "TypeError Cannot perform floor divide on types: Boolean, Int64", + "Cannot perform floor divide on types: Boolean, Int64", ) _assert_df_op_raise( lambda: df.select(daft.col("B") // daft.col("C")).collect(), - "TypeError Cannot perform floor divide on types: Float64, Boolean", + "Cannot perform floor divide on types: Float64, Boolean", ) _assert_df_op_raise( lambda: df.select(daft.col("C") // daft.col("B")).collect(), - "TypeError Cannot perform floor divide on types: Boolean, Float64", + "Cannot perform floor divide on types: Boolean, Float64", )