Skip to content

Commit

Permalink
refactor: logical op constructor+builder boundary (#3684)
Browse files Browse the repository at this point in the history
## The problem
Plan ops are created for various reasons through our code - from our
dataframe or sql interfaces, to optimization rules, to even op
constructors themselves which can sometimes create other ones. All of
these cases generally go through the same new/try_new constructor for
each op, which tries to accommodate all of these use cases. This creates
complexity, adds unnecessary compute to planning time, and also
conflates user input errors with Daft internal errors.

For example, I don't expect any optimization rules to create unresolved
expressions, expression resolution should only be done for the builder.
Another example is the Join op, where inputs such as join_prefix and
join_suffix are only applicable for renaming columns, which should also
only happen via the builder. We recently added another initializer to
some ops for that reason, but it bypasses the validation that is
typically done and is not standardized across ops.

## My solution
Every op should provide a `try_new` constructor which contain explicit
checks for all the requirements about the op's state (one example would
be that all expression columns exist in the schema), but otherwise
should simply put those values into the struct without any modification
and return it.
- Functions such as `LogicalPlan::with_new_children` will just call
`try_new`.
- Other constructors/helpers may exist that explicitly provide
additional functionality and ultimately call `try_new`. E.g. a
`Join::rename_right_columns` to rename the right side columns that
conflict with the left side, called to update the right side schema
before calling `try_new`.
- User input normalization, such as expression resolution, should be
handled by the logical plan builder. After the logical plan op has been
constructed, everything should be in a valid state from there on.
  • Loading branch information
kevinzwang authored Jan 16, 2025
1 parent beae462 commit 3720c2a
Show file tree
Hide file tree
Showing 29 changed files with 345 additions and 450 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions src/daft-dsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ daft-sketch = {path = "../daft-sketch", default-features = false}
derive_more = {workspace = true}
indexmap = {workspace = true}
itertools = {workspace = true}
log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true}
typed-builder = {workspace = true}
typetag = {workspace = true}

[features]
Expand Down
8 changes: 8 additions & 0 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1435,3 +1435,11 @@ pub fn estimated_selectivity(expr: &Expr, schema: &Schema) -> f64 {
Expr::Agg(_) => panic!("Aggregates are not allowed in WHERE clauses"),
}
}

pub fn exprs_to_schema(exprs: &[ExprRef], input_schema: SchemaRef) -> DaftResult<SchemaRef> {
let fields = exprs
.iter()
.map(|e| e.to_field(&input_schema))
.collect::<DaftResult<_>>()?;
Ok(Arc::new(Schema::new(fields)?))
}
12 changes: 3 additions & 9 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@ pub mod optimization;
mod pyobj_serde;
#[cfg(feature = "python")]
pub mod python;
mod resolve_expr;
mod treenode;
pub use common_treenode;
pub use expr::{
binary_op, col, count_actor_pool_udfs, estimated_selectivity, has_agg, is_actor_pool_udf,
is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator,
OuterReferenceColumn, SketchType, Subquery, SubqueryPlan,
binary_op, col, count_actor_pool_udfs, estimated_selectivity, exprs_to_schema, has_agg,
is_actor_pool_udf, is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef,
Operator, OuterReferenceColumn, SketchType, Subquery, SubqueryPlan,
};
pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue};
#[cfg(feature = "python")]
use pyo3::prelude::*;
pub use resolve_expr::{check_column_name_validity, ExprResolver};

#[cfg(feature = "python")]
pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
Expand All @@ -41,10 +39,6 @@ pub fn register_modules(parent: &Bound<PyModule>) -> PyResult<()> {
parent.add_function(wrap_pyfunction!(python::initialize_udfs, parent)?)?;
parent.add_function(wrap_pyfunction!(python::get_udf_names, parent)?)?;
parent.add_function(wrap_pyfunction!(python::eq, parent)?)?;
parent.add_function(wrap_pyfunction!(
python::check_column_name_validity,
parent
)?)?;

Ok(())
}
5 changes: 0 additions & 5 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,6 @@ pub fn eq(expr1: &PyExpr, expr2: &PyExpr) -> PyResult<bool> {
Ok(expr1.expr == expr2.expr)
}

#[pyfunction]
pub fn check_column_name_validity(name: &str, schema: &PySchema) -> PyResult<()> {
Ok(crate::check_column_name_validity(name, &schema.schema)?)
}

#[derive(FromPyObject)]
pub enum ApproxPercentileInput {
Single(f64),
Expand Down
1 change: 1 addition & 0 deletions src/daft-logical-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true, features = ["rc"]}
snafu = {workspace = true}
typed-builder = {workspace = true}
uuid = {version = "1", features = ["v4"]}

[dev-dependencies]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
mod resolve_expr;
#[cfg(test)]
mod tests;

use std::{
collections::{HashMap, HashSet},
sync::Arc,
Expand All @@ -12,6 +16,10 @@ use common_scan_info::{PhysicalScanInfo, Pushdowns, ScanOperatorRef};
use daft_core::join::{JoinStrategy, JoinType};
use daft_dsl::{col, ExprRef};
use daft_schema::schema::{Schema, SchemaRef};
use indexmap::IndexSet;
#[cfg(feature = "python")]
pub use resolve_expr::py_check_column_name_validity;
use resolve_expr::ExprResolver;
#[cfg(feature = "python")]
use {
crate::sink_info::{CatalogInfo, IcebergCatalogInfo},
Expand Down Expand Up @@ -188,11 +196,19 @@ impl LogicalPlanBuilder {
}

pub fn select(&self, to_select: Vec<ExprRef>) -> DaftResult<Self> {
let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build();

let (to_select, _) = expr_resolver.resolve(to_select, &self.schema())?;

let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), to_select)?.into();
Ok(self.with_new_plan(logical_plan))
}

pub fn with_columns(&self, columns: Vec<ExprRef>) -> DaftResult<Self> {
let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build();

let (columns, _) = expr_resolver.resolve(columns, &self.schema())?;

let fields = &self.schema().fields;
let current_col_names = fields
.iter()
Expand Down Expand Up @@ -245,6 +261,10 @@ impl LogicalPlanBuilder {
}

pub fn filter(&self, predicate: ExprRef) -> DaftResult<Self> {
let expr_resolver = ExprResolver::default();

let (predicate, _) = expr_resolver.resolve_single(predicate, &self.schema())?;

let logical_plan: LogicalPlan = ops::Filter::try_new(self.plan.clone(), predicate)?.into();
Ok(self.with_new_plan(logical_plan))
}
Expand All @@ -255,6 +275,10 @@ impl LogicalPlanBuilder {
}

pub fn explode(&self, to_explode: Vec<ExprRef>) -> DaftResult<Self> {
let expr_resolver = ExprResolver::default();

let (to_explode, _) = expr_resolver.resolve(to_explode, &self.schema())?;

let logical_plan: LogicalPlan =
ops::Explode::try_new(self.plan.clone(), to_explode)?.into();
Ok(self.with_new_plan(logical_plan))
Expand All @@ -264,25 +288,24 @@ impl LogicalPlanBuilder {
&self,
ids: Vec<ExprRef>,
values: Vec<ExprRef>,
variable_name: &str,
value_name: &str,
variable_name: String,
value_name: String,
) -> DaftResult<Self> {
let expr_resolver = ExprResolver::default();
let (values, _) = expr_resolver.resolve(values, &self.schema())?;
let (ids, _) = expr_resolver.resolve(ids, &self.schema())?;

let values = if values.is_empty() {
let ids_set = HashSet::<_>::from_iter(ids.iter());
let ids_set = IndexSet::<_>::from_iter(ids.iter().cloned());

self.schema()
let columns_set = self
.schema()
.fields
.iter()
.filter_map(|(name, _)| {
let column = col(name.clone());

if ids_set.contains(&column) {
None
} else {
Some(column)
}
})
.collect()
.keys()
.map(|name| col(name.clone()))
.collect::<IndexSet<_>>();

columns_set.difference(&ids_set).cloned().collect()
} else {
values
};
Expand All @@ -299,6 +322,10 @@ impl LogicalPlanBuilder {
descending: Vec<bool>,
nulls_first: Vec<bool>,
) -> DaftResult<Self> {
let expr_resolver = ExprResolver::default();

let (sort_by, _) = expr_resolver.resolve(sort_by, &self.schema())?;

let logical_plan: LogicalPlan =
ops::Sort::try_new(self.plan.clone(), sort_by, descending, nulls_first)?.into();
Ok(self.with_new_plan(logical_plan))
Expand All @@ -309,28 +336,32 @@ impl LogicalPlanBuilder {
num_partitions: Option<usize>,
partition_by: Vec<ExprRef>,
) -> DaftResult<Self> {
let logical_plan: LogicalPlan = ops::Repartition::try_new(
let expr_resolver = ExprResolver::default();

let (partition_by, _) = expr_resolver.resolve(partition_by, &self.schema())?;

let logical_plan: LogicalPlan = ops::Repartition::new(
self.plan.clone(),
RepartitionSpec::Hash(HashRepartitionConfig::new(num_partitions, partition_by)),
)?
)
.into();
Ok(self.with_new_plan(logical_plan))
}

pub fn random_shuffle(&self, num_partitions: Option<usize>) -> DaftResult<Self> {
let logical_plan: LogicalPlan = ops::Repartition::try_new(
let logical_plan: LogicalPlan = ops::Repartition::new(
self.plan.clone(),
RepartitionSpec::Random(RandomShuffleConfig::new(num_partitions)),
)?
)
.into();
Ok(self.with_new_plan(logical_plan))
}

pub fn into_partitions(&self, num_partitions: usize) -> DaftResult<Self> {
let logical_plan: LogicalPlan = ops::Repartition::try_new(
let logical_plan: LogicalPlan = ops::Repartition::new(
self.plan.clone(),
RepartitionSpec::IntoPartitions(IntoPartitionsConfig::new(num_partitions)),
)?
)
.into();
Ok(self.with_new_plan(logical_plan))
}
Expand All @@ -356,6 +387,12 @@ impl LogicalPlanBuilder {
agg_exprs: Vec<ExprRef>,
groupby_exprs: Vec<ExprRef>,
) -> DaftResult<Self> {
let groupby_resolver = ExprResolver::default();
let (groupby_exprs, _) = groupby_resolver.resolve(groupby_exprs, &self.schema())?;

let agg_resolver = ExprResolver::builder().groupby(&groupby_exprs).build();
let (agg_exprs, _) = agg_resolver.resolve(agg_exprs, &self.schema())?;

let logical_plan: LogicalPlan =
ops::Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into();
Ok(self.with_new_plan(logical_plan))
Expand All @@ -369,6 +406,14 @@ impl LogicalPlanBuilder {
agg_expr: ExprRef,
names: Vec<String>,
) -> DaftResult<Self> {
let agg_resolver = ExprResolver::builder().groupby(&group_by).build();
let (agg_expr, _) = agg_resolver.resolve_single(agg_expr, &self.schema())?;

let expr_resolver = ExprResolver::default();
let (group_by, _) = expr_resolver.resolve(group_by, &self.schema())?;
let (pivot_column, _) = expr_resolver.resolve_single(pivot_column, &self.schema())?;
let (value_column, _) = expr_resolver.resolve_single(value_column, &self.schema())?;

let pivot_logical_plan: LogicalPlan = ops::Pivot::try_new(
self.plan.clone(),
group_by,
Expand Down Expand Up @@ -438,17 +483,36 @@ impl LogicalPlanBuilder {
join_prefix: Option<&str>,
keep_join_keys: bool,
) -> DaftResult<Self> {
let left_plan = self.plan.clone();
let right_plan = right.into();

let expr_resolver = ExprResolver::default();

let (left_on, _) = expr_resolver.resolve(left_on, &left_plan.schema())?;
let (right_on, _) = expr_resolver.resolve(right_on, &right_plan.schema())?;

// TODO(kevin): we should do this, but it has not been properly used before and is nondeterministic, which causes some tests to break
// let (left_on, right_on) = ops::Join::rename_join_keys(left_on, right_on);

let (right_plan, right_on) = ops::Join::rename_right_columns(
left_plan.clone(),
right_plan,
left_on.clone(),
right_on,
join_type,
join_suffix,
join_prefix,
keep_join_keys,
)?;

let logical_plan: LogicalPlan = ops::Join::try_new(
self.plan.clone(),
right.into(),
left_plan,
right_plan,
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
join_suffix,
join_prefix,
keep_join_keys,
)?
.into();
Ok(self.with_new_plan(logical_plan))
Expand Down Expand Up @@ -501,7 +565,7 @@ impl LogicalPlanBuilder {

pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into();
ops::MonotonicallyIncreasingId::try_new(self.plan.clone(), column_name)?.into();
Ok(self.with_new_plan(logical_plan))
}

Expand All @@ -513,6 +577,16 @@ impl LogicalPlanBuilder {
compression: Option<String>,
io_config: Option<IOConfig>,
) -> DaftResult<Self> {
let partition_cols = partition_cols
.map(|cols| {
let expr_resolver = ExprResolver::default();

expr_resolver
.resolve(cols, &self.schema())
.map(|(resolved_cols, _)| resolved_cols)
})
.transpose()?;

let sink_info = SinkInfo::OutputFileInfo(OutputFileInfo::new(
root_dir.into(),
file_format,
Expand Down Expand Up @@ -752,8 +826,8 @@ impl PyLogicalPlanBuilder {
&self,
ids: Vec<PyExpr>,
values: Vec<PyExpr>,
variable_name: &str,
value_name: &str,
variable_name: String,
value_name: String,
) -> PyResult<Self> {
let ids_exprs = ids
.iter()
Expand Down
Loading

0 comments on commit 3720c2a

Please sign in to comment.