From 64b8699617c3842106d064cc753777c4d143d533 Mon Sep 17 00:00:00 2001 From: Raunak Bhagat Date: Tue, 8 Oct 2024 15:13:00 -0700 Subject: [PATCH] [FEAT] Implement standard deviation (#3005) # Overview - Add a standard deviation function - similar in implementation to how `AggExpr::count` and `AggExpr::Mean` work ## Notes Implementations differ slightly for non- vs multi- partitioned based dataframes: 1. The non-partitioned implementation uses the simple, naive approach, derived from definition of stddev (i.e., `stddev(X) = sqrt(sum((x_i - mean(X))^2) / N)`). 2. The multi-partitioned implementation calculates `stddev(X) = sqrt(E(X^2) - E(X)^2)`. --- daft/daft/__init__.pyi | 2 + daft/dataframe/dataframe.py | 55 +++ daft/expressions/expressions.py | 5 + daft/series.py | 4 + docs/source/api_docs/dataframe.rst | 1 + docs/source/api_docs/expressions.rst | 1 + src/daft-core/src/array/ops/mean.rs | 49 +- src/daft-core/src/array/ops/mod.rs | 7 + src/daft-core/src/array/ops/stddev.rs | 34 ++ src/daft-core/src/datatypes/agg_ops.rs | 2 +- src/daft-core/src/datatypes/mod.rs | 2 +- src/daft-core/src/series/ops/agg.rs | 60 +-- src/daft-core/src/utils/mod.rs | 1 + src/daft-core/src/utils/stats.rs | 82 ++++ .../src/{arithmetic.rs => arithmetic/mod.rs} | 23 +- src/daft-dsl/src/arithmetic/tests.rs | 16 + src/daft-dsl/src/{expr.rs => expr/mod.rs} | 427 +++++++----------- src/daft-dsl/src/expr/tests.rs | 83 ++++ src/daft-dsl/src/functions/map/mod.rs | 3 +- src/daft-dsl/src/functions/mod.rs | 18 +- .../src/functions/partitioning/mod.rs | 13 +- src/daft-dsl/src/functions/python/mod.rs | 8 +- src/daft-dsl/src/functions/python/udf.rs | 2 +- src/daft-dsl/src/functions/sketch/mod.rs | 3 +- src/daft-dsl/src/functions/struct_/mod.rs | 3 +- src/daft-dsl/src/functions/utf8/mod.rs | 57 ++- src/daft-dsl/src/{join.rs => join/mod.rs} | 34 +- src/daft-dsl/src/join/tests.rs | 27 ++ src/daft-dsl/src/lit.rs | 179 ++++---- src/daft-dsl/src/python.rs | 4 + .../{resolve_expr.rs => resolve_expr/mod.rs} | 199 ++------ src/daft-dsl/src/resolve_expr/tests.rs | 141 ++++++ src/daft-functions/src/list/mean.rs | 4 +- src/daft-plan/src/logical_ops/project.rs | 28 +- .../src/physical_planner/translate.rs | 203 ++++++--- src/daft-schema/src/dtype.rs | 12 + src/daft-sql/src/modules/aggs.rs | 4 + src/daft-table/src/lib.rs | 9 +- tests/dataframe/test_stddev.py | 144 ++++++ 39 files changed, 1190 insertions(+), 759 deletions(-) create mode 100644 src/daft-core/src/array/ops/stddev.rs create mode 100644 src/daft-core/src/utils/stats.rs rename src/daft-dsl/src/{arithmetic.rs => arithmetic/mod.rs} (57%) create mode 100644 src/daft-dsl/src/arithmetic/tests.rs rename src/daft-dsl/src/{expr.rs => expr/mod.rs} (77%) create mode 100644 src/daft-dsl/src/expr/tests.rs rename src/daft-dsl/src/{join.rs => join/mod.rs} (79%) create mode 100644 src/daft-dsl/src/join/tests.rs rename src/daft-dsl/src/{resolve_expr.rs => resolve_expr/mod.rs} (68%) create mode 100644 src/daft-dsl/src/resolve_expr/tests.rs create mode 100644 tests/dataframe/test_stddev.py diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 4f465eb5fe..47e980e907 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1051,6 +1051,7 @@ class PyExpr: def approx_count_distinct(self) -> PyExpr: ... def approx_percentiles(self, percentiles: float | list[float]) -> PyExpr: ... def mean(self) -> PyExpr: ... + def stddev(self) -> PyExpr: ... def min(self) -> PyExpr: ... def max(self) -> PyExpr: ... def any_value(self, ignore_nulls: bool) -> PyExpr: ... @@ -1336,6 +1337,7 @@ class PySeries: def count(self, mode: CountMode) -> PySeries: ... def sum(self) -> PySeries: ... def mean(self) -> PySeries: ... + def stddev(self) -> PySeries: ... def min(self) -> PySeries: ... def max(self) -> PySeries: ... def agg_list(self) -> PySeries: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 6211423e94..2408890d7b 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -2118,6 +2118,33 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": """ return self._apply_agg_fn(Expression.mean, cols) + @DataframePublicAPI + def stddev(self, *cols: ColumnInputType) -> "DataFrame": + """Performs a global standard deviation on the DataFrame + + Example: + >>> import daft + >>> df = daft.from_pydict({"col_a":[0,1,2]}) + >>> df = df.stddev("col_a") + >>> df.show() + ╭───────────────────╮ + │ col_a │ + │ --- │ + │ Float64 │ + ╞═══════════════════╡ + │ 0.816496580927726 │ + ╰───────────────────╯ + + (Showing first 1 of 1 rows) + + + Args: + *cols (Union[str, Expression]): columns to stddev + Returns: + DataFrame: Globally aggregated standard deviation. Should be a single row. + """ + return self._apply_agg_fn(Expression.stddev, cols) + @DataframePublicAPI def min(self, *cols: ColumnInputType) -> "DataFrame": """Performs a global min on the DataFrame @@ -2856,6 +2883,34 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame": """ return self.df._apply_agg_fn(Expression.mean, cols, self.group_by) + def stddev(self, *cols: ColumnInputType) -> "DataFrame": + """Performs grouped standard deviation on this GroupedDataFrame. + + Example: + >>> import daft + >>> df = daft.from_pydict({"keys": ["a", "a", "a", "b"], "col_a": [0,1,2,100]}) + >>> df = df.groupby("keys").stddev() + >>> df.show() + ╭──────┬───────────────────╮ + │ keys ┆ col_a │ + │ --- ┆ --- │ + │ Utf8 ┆ Float64 │ + ╞══════╪═══════════════════╡ + │ a ┆ 0.816496580927726 │ + ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ b ┆ 0 │ + ╰──────┴───────────────────╯ + + (Showing first 2 of 2 rows) + + Args: + *cols (Union[str, Expression]): columns to stddev + + Returns: + DataFrame: DataFrame with grouped standard deviation. + """ + return self.df._apply_agg_fn(Expression.stddev, cols, self.group_by) + def min(self, *cols: ColumnInputType) -> "DataFrame": """Perform grouped min on this GroupedDataFrame. diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 52d1cae0d8..e67ddead64 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -862,6 +862,11 @@ def mean(self) -> Expression: expr = self._expr.mean() return Expression._from_pyexpr(expr) + def stddev(self) -> Expression: + """Calculates the standard deviation of the values in the expression""" + expr = self._expr.stddev() + return Expression._from_pyexpr(expr) + def min(self) -> Expression: """Calculates the minimum value in the expression""" expr = self._expr.min() diff --git a/daft/series.py b/daft/series.py index 15c5295b4c..5cbcfe7ba0 100644 --- a/daft/series.py +++ b/daft/series.py @@ -512,6 +512,10 @@ def mean(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.mean()) + def stddev(self) -> Series: + assert self._series is not None + return Series._from_pyseries(self._series.stddev()) + def sum(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.sum()) diff --git a/docs/source/api_docs/dataframe.rst b/docs/source/api_docs/dataframe.rst index f93f052742..14a4e9fa20 100644 --- a/docs/source/api_docs/dataframe.rst +++ b/docs/source/api_docs/dataframe.rst @@ -104,6 +104,7 @@ Aggregations DataFrame.groupby DataFrame.sum DataFrame.mean + DataFrame.stddev DataFrame.count DataFrame.min DataFrame.max diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index ec86e0bb5e..a53ef825fd 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -113,6 +113,7 @@ The following can be used with DataFrame.agg or GroupedDataFrame.agg Expression.count Expression.sum Expression.mean + Expression.stddev Expression.min Expression.max Expression.any_value diff --git a/src/daft-core/src/array/ops/mean.rs b/src/daft-core/src/array/ops/mean.rs index b4b4016bbc..d5764c4954 100644 --- a/src/daft-core/src/array/ops/mean.rs +++ b/src/daft-core/src/array/ops/mean.rs @@ -1,44 +1,27 @@ use std::sync::Arc; +use arrow2::array::PrimitiveArray; use common_error::DaftResult; -use super::{as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable}; -use crate::{array::ops::GroupIndices, count_mode::CountMode, datatypes::*}; -impl DaftMeanAggable for &DataArray { - type Output = DaftResult>; +use crate::{ + array::ops::{DaftMeanAggable, GroupIndices}, + datatypes::*, + utils::stats, +}; - fn mean(&self) -> Self::Output { - let sum_value = DaftSumAggable::sum(self)?.as_arrow().value(0); - let count_value = DaftCountAggable::count(self, CountMode::Valid)? - .as_arrow() - .value(0); - - let result = match count_value { - 0 => None, - count_value => Some(sum_value / count_value as f64), - }; - let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result])); +impl DaftMeanAggable for DataArray { + type Output = DaftResult; - DataArray::new( - Arc::new(Field::new(self.field.name.clone(), DataType::Float64)), - arrow_array, - ) + fn mean(&self) -> Self::Output { + let stats = stats::calculate_stats(self)?; + let data = PrimitiveArray::from([stats.mean]).boxed(); + let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64)); + Self::new(field, data) } fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output { - use arrow2::array::PrimitiveArray; - let sum_values = self.grouped_sum(groups)?; - let count_values = self.grouped_count(groups, CountMode::Valid)?; - assert_eq!(sum_values.len(), count_values.len()); - let mean_per_group = sum_values - .as_arrow() - .values_iter() - .zip(count_values.as_arrow().values_iter()) - .map(|(s, c)| match (s, c) { - (_, 0) => None, - (s, c) => Some(s / (*c as f64)), - }); - let mean_array = Box::new(PrimitiveArray::from_trusted_len_iter(mean_per_group)); - Ok(DataArray::from((self.field.name.as_ref(), mean_array))) + let grouped_means = stats::grouped_stats(self, groups)?.map(|(stats, _)| stats.mean); + let data = Box::new(PrimitiveArray::from_iter(grouped_means)); + Ok(Self::from((self.field.name.as_ref(), data))) } } diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index d3a940f376..3bcf0f0cb9 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -49,6 +49,7 @@ mod sketch_percentile; mod sort; pub(crate) mod sparse_tensor; mod sqrt; +mod stddev; mod struct_; mod sum; mod take; @@ -189,6 +190,12 @@ pub trait DaftMeanAggable { fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output; } +pub trait DaftStddevAggable { + type Output; + fn stddev(&self) -> Self::Output; + fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output; +} + pub trait DaftCompareAggable { type Output; fn min(&self) -> Self::Output; diff --git a/src/daft-core/src/array/ops/stddev.rs b/src/daft-core/src/array/ops/stddev.rs new file mode 100644 index 0000000000..c412922937 --- /dev/null +++ b/src/daft-core/src/array/ops/stddev.rs @@ -0,0 +1,34 @@ +use arrow2::array::PrimitiveArray; +use common_error::DaftResult; + +use crate::{ + array::{ + ops::{DaftStddevAggable, GroupIndices}, + DataArray, + }, + datatypes::Float64Type, + utils::stats, +}; + +impl DaftStddevAggable for DataArray { + type Output = DaftResult; + + fn stddev(&self) -> Self::Output { + let stats = stats::calculate_stats(self)?; + let values = self.into_iter().flatten().copied(); + let stddev = stats::calculate_stddev(stats, values); + let field = self.field.clone(); + let data = PrimitiveArray::::from([stddev]).boxed(); + Self::new(field, data) + } + + fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output { + let grouped_stddevs_iter = stats::grouped_stats(self, groups)?.map(|(stats, group)| { + let values = group.iter().filter_map(|&index| self.get(index as _)); + stats::calculate_stddev(stats, values) + }); + let field = self.field.clone(); + let data = PrimitiveArray::::from_iter(grouped_stddevs_iter).boxed(); + Self::new(field, data) + } +} diff --git a/src/daft-core/src/datatypes/agg_ops.rs b/src/daft-core/src/datatypes/agg_ops.rs index a6420b039b..c1f04fecbe 100644 --- a/src/daft-core/src/datatypes/agg_ops.rs +++ b/src/daft-core/src/datatypes/agg_ops.rs @@ -23,7 +23,7 @@ pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { } /// Get the data type that the mean of a column of the given data type should be casted to. -pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { +pub fn try_mean_stddev_aggregation_supertype(dtype: &DataType) -> DaftResult { if dtype.is_numeric() { Ok(DataType::Float64) } else { diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 174098ada9..01a6b6ca6e 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -6,7 +6,7 @@ pub use infer_datatype::InferDataType; pub mod prelude; use std::ops::{Add, Div, Mul, Rem, Sub}; -pub use agg_ops::{try_mean_supertype, try_sum_supertype}; +pub use agg_ops::{try_mean_stddev_aggregation_supertype, try_sum_supertype}; use arrow2::{ compute::comparison::Simd8, types::{simd::Simd, NativeType}, diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index 4af93850ca..b3bfee765c 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -4,7 +4,10 @@ use logical::Decimal128Array; use crate::{ array::{ - ops::{DaftHllMergeAggable, GroupIndices}, + ops::{ + DaftApproxSketchAggable, DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable, + DaftSumAggable, GroupIndices, + }, ListArray, }, count_mode::CountMode, @@ -26,12 +29,10 @@ impl Series { } pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftSumAggable, datatypes::DataType::*}; - match self.data_type() { // intX -> int64 (in line with numpy) - Int8 | Int16 | Int32 | Int64 => { - let casted = self.cast(&Int64)?; + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + let casted = self.cast(&DataType::Int64)?; match groups { Some(groups) => { Ok(DaftSumAggable::grouped_sum(&casted.i64()?, groups)?.into_series()) @@ -40,8 +41,8 @@ impl Series { } } // uintX -> uint64 (in line with numpy) - UInt8 | UInt16 | UInt32 | UInt64 => { - let casted = self.cast(&UInt64)?; + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + let casted = self.cast(&DataType::UInt64)?; match groups { Some(groups) => { Ok(DaftSumAggable::grouped_sum(&casted.u64()?, groups)?.into_series()) @@ -50,7 +51,7 @@ impl Series { } } // floatX -> floatX (in line with numpy) - Float32 => match groups { + DataType::Float32 => match groups { Some(groups) => Ok(DaftSumAggable::grouped_sum( &self.downcast::()?, groups, @@ -58,7 +59,7 @@ impl Series { .into_series()), None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, - Float64 => match groups { + DataType::Float64 => match groups { Some(groups) => Ok(DaftSumAggable::grouped_sum( &self.downcast::()?, groups, @@ -66,7 +67,7 @@ impl Series { .into_series()), None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, - Decimal128(_, _) => match groups { + DataType::Decimal128(_, _) => match groups { Some(groups) => Ok(Decimal128Array::new( Field { dtype: try_sum_supertype(self.data_type())?, @@ -95,12 +96,10 @@ impl Series { } pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftApproxSketchAggable, datatypes::DataType::*}; - // Upcast all numeric types to float64 and compute approx_sketch. match self.data_type() { dt if dt.is_numeric() => { - let casted = self.cast(&Float64)?; + let casted = self.cast(&DataType::Float64)?; match groups { Some(groups) => Ok(DaftApproxSketchAggable::grouped_approx_sketch( &casted.f64()?, @@ -149,24 +148,25 @@ impl Series { } pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult { - use crate::{array::ops::DaftMeanAggable, datatypes::DataType::*}; - // Upcast all numeric types to float64 and use f64 mean kernel. - match self.data_type() { - dt if dt.is_numeric() => { - let casted = self.cast(&Float64)?; - match groups { - Some(groups) => { - Ok(DaftMeanAggable::grouped_mean(&casted.f64()?, groups)?.into_series()) - } - None => Ok(DaftMeanAggable::mean(&casted.f64()?)?.into_series()), - } - } - other => Err(DaftError::TypeError(format!( - "Numeric mean is not implemented for type {}", - other - ))), - } + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else(|| casted.mean(), |groups| casted.grouped_mean(groups))? + .into_series(); + Ok(series) + } + + pub fn stddev(&self, groups: Option<&GroupIndices>) -> DaftResult { + // Upcast all numeric types to float64 and use f64 stddev kernel. + self.data_type().assert_is_numeric()?; + let casted = self.cast(&DataType::Float64)?; + let casted = casted.f64()?; + let series = groups + .map_or_else(|| casted.stddev(), |groups| casted.grouped_stddev(groups))? + .into_series(); + Ok(series) } pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult { diff --git a/src/daft-core/src/utils/mod.rs b/src/daft-core/src/utils/mod.rs index 2e039e6953..baf1dc66fd 100644 --- a/src/daft-core/src/utils/mod.rs +++ b/src/daft-core/src/utils/mod.rs @@ -2,4 +2,5 @@ pub mod arrow; pub mod display; pub mod dyn_compare; pub mod identity_hash_set; +pub mod stats; pub mod supertype; diff --git a/src/daft-core/src/utils/stats.rs b/src/daft-core/src/utils/stats.rs new file mode 100644 index 0000000000..de43b186ea --- /dev/null +++ b/src/daft-core/src/utils/stats.rs @@ -0,0 +1,82 @@ +use common_error::DaftResult; + +use crate::{ + array::{ + ops::{DaftCountAggable, DaftSumAggable, GroupIndices, VecIndices}, + prelude::{Float64Array, UInt64Array}, + }, + count_mode::CountMode, +}; + +#[derive(Clone, Copy, Default, Debug)] +pub struct Stats { + pub sum: f64, + pub count: f64, + pub mean: Option, +} + +pub fn calculate_stats(array: &Float64Array) -> DaftResult { + let sum = array.sum()?.get(0); + let count = array.count(CountMode::Valid)?.get(0); + let stats = sum + .zip(count) + .map_or_else(Default::default, |(sum, count)| Stats { + sum, + count: count as _, + mean: calculate_mean(sum, count), + }); + Ok(stats) +} + +pub fn grouped_stats<'a>( + array: &Float64Array, + groups: &'a GroupIndices, +) -> DaftResult> { + let grouped_sum = array.grouped_sum(groups)?; + let grouped_count = array.grouped_count(groups, CountMode::Valid)?; + debug_assert_eq!(grouped_sum.len(), grouped_count.len()); + debug_assert_eq!(grouped_sum.len(), groups.len()); + Ok(GroupedStats { + grouped_sum, + grouped_count, + groups: groups.iter().enumerate(), + }) +} + +struct GroupedStats<'a, I: Iterator> { + grouped_sum: Float64Array, + grouped_count: UInt64Array, + groups: I, +} + +impl<'a, I: Iterator> Iterator for GroupedStats<'a, I> { + type Item = (Stats, &'a VecIndices); + + fn next(&mut self) -> Option { + let (index, group) = self.groups.next()?; + let sum = self.grouped_sum.get(index); + let count = self.grouped_count.get(index); + let stats = sum + .zip(count) + .map_or_else(Default::default, |(sum, count)| Stats { + sum, + count: count as _, + mean: calculate_mean(sum, count), + }); + Some((stats, group)) + } +} + +pub fn calculate_mean(sum: f64, count: u64) -> Option { + match count { + 0 => None, + _ => Some(sum / count as f64), + } +} + +pub fn calculate_stddev(stats: Stats, values: impl Iterator) -> Option { + stats.mean.map(|mean| { + let sum_of_squares = values.map(|value| (value - mean).powi(2)).sum::(); + (sum_of_squares / stats.count).sqrt() + }) +} diff --git a/src/daft-dsl/src/arithmetic.rs b/src/daft-dsl/src/arithmetic/mod.rs similarity index 57% rename from src/daft-dsl/src/arithmetic.rs rename to src/daft-dsl/src/arithmetic/mod.rs index 95faa64074..d4222fe64c 100644 --- a/src/daft-dsl/src/arithmetic.rs +++ b/src/daft-dsl/src/arithmetic/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use crate::{Expr, ExprRef, Operator}; macro_rules! impl_expr_op { @@ -21,23 +24,3 @@ impl_expr_op!(sub, Minus); impl_expr_op!(mul, Multiply); impl_expr_op!(div, TrueDivide); impl_expr_op!(rem, Modulus); - -#[cfg(test)] -mod tests { - use common_error::{DaftError, DaftResult}; - - use crate::{col, Expr}; - - #[test] - fn check_add_expr_type() -> DaftResult<()> { - let a = col("a"); - let b = col("b"); - let c = a.add(b); - match c.as_ref() { - Expr::BinaryOp { .. } => Ok(()), - other => Err(DaftError::ValueError(format!( - "expected expression to be a binary op expression, got {other:?}" - ))), - } - } -} diff --git a/src/daft-dsl/src/arithmetic/tests.rs b/src/daft-dsl/src/arithmetic/tests.rs new file mode 100644 index 0000000000..19a7c23310 --- /dev/null +++ b/src/daft-dsl/src/arithmetic/tests.rs @@ -0,0 +1,16 @@ +use common_error::{DaftError, DaftResult}; + +use crate::{col, Expr}; + +#[test] +fn check_add_expr_type() -> DaftResult<()> { + let a = col("a"); + let b = col("b"); + let c = a.add(b); + match c.as_ref() { + Expr::BinaryOp { .. } => Ok(()), + other => Err(DaftError::ValueError(format!( + "expected expression to be a binary op expression, got {other:?}" + ))), + } +} diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr/mod.rs similarity index 77% rename from src/daft-dsl/src/expr.rs rename to src/daft-dsl/src/expr/mod.rs index 55a16bc374..873f9013bd 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::{ io::{self, Write}, sync::Arc, @@ -7,7 +10,7 @@ use common_error::{DaftError, DaftResult}; use common_hashable_float_wrapper::FloatWrapper; use common_treenode::TreeNode; use daft_core::{ - datatypes::{try_mean_supertype, try_sum_supertype, InferDataType}, + datatypes::{try_mean_stddev_aggregation_supertype, try_sum_supertype, InferDataType}, prelude::*, utils::supertype::try_get_supertype, }; @@ -121,6 +124,9 @@ pub enum AggExpr { #[display("mean({_0})")] Mean(ExprRef), + #[display("stddev({_0})")] + Stddev(ExprRef), + #[display("min({_0})")] Min(ExprRef), @@ -159,36 +165,35 @@ pub fn binary_op(op: Operator, left: ExprRef, right: ExprRef) -> ExprRef { impl AggExpr { pub fn name(&self) -> &str { - use AggExpr::*; match self { - Count(expr, ..) - | Sum(expr) - | ApproxPercentile(ApproxPercentileParams { child: expr, .. }) - | ApproxCountDistinct(expr) - | ApproxSketch(expr, _) - | MergeSketch(expr, _) - | Mean(expr) - | Min(expr) - | Max(expr) - | AnyValue(expr, _) - | List(expr) - | Concat(expr) => expr.name(), - MapGroups { func: _, inputs } => inputs.first().unwrap().name(), + Self::Count(expr, ..) + | Self::Sum(expr) + | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) + | Self::ApproxCountDistinct(expr) + | Self::ApproxSketch(expr, _) + | Self::MergeSketch(expr, _) + | Self::Mean(expr) + | Self::Stddev(expr) + | Self::Min(expr) + | Self::Max(expr) + | Self::AnyValue(expr, _) + | Self::List(expr) + | Self::Concat(expr) => expr.name(), + Self::MapGroups { func: _, inputs } => inputs.first().unwrap().name(), } } pub fn semantic_id(&self, schema: &Schema) -> FieldID { - use AggExpr::*; match self { - Count(expr, mode) => { + Self::Count(expr, mode) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_count({mode})")) } - Sum(expr) => { + Self::Sum(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_sum()")) } - ApproxPercentile(ApproxPercentileParams { + Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, force_list_output, @@ -199,122 +204,126 @@ impl AggExpr { percentiles, )) } - ApproxCountDistinct(expr) => { + Self::ApproxCountDistinct(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_approx_count_distinct()")) } - ApproxSketch(expr, sketch_type) => { + Self::ApproxSketch(expr, sketch_type) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_approx_sketch(sketch_type={sketch_type:?})" )) } - MergeSketch(expr, sketch_type) => { + Self::MergeSketch(expr, sketch_type) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_merge_sketch(sketch_type={sketch_type:?})" )) } - Mean(expr) => { + Self::Mean(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_mean()")) } - Min(expr) => { + Self::Stddev(expr) => { + let child_id = expr.semantic_id(schema); + FieldID::new(format!("{child_id}.local_stddev()")) + } + Self::Min(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_min()")) } - Max(expr) => { + Self::Max(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_max()")) } - AnyValue(expr, ignore_nulls) => { + Self::AnyValue(expr, ignore_nulls) => { let child_id = expr.semantic_id(schema); FieldID::new(format!( "{child_id}.local_any_value(ignore_nulls={ignore_nulls})" )) } - List(expr) => { + Self::List(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_list()")) } - Concat(expr) => { + Self::Concat(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.local_concat()")) } - MapGroups { func, inputs } => function_semantic_id(func, inputs, schema), + Self::MapGroups { func, inputs } => function_semantic_id(func, inputs, schema), } } pub fn children(&self) -> Vec { - use AggExpr::*; match self { - Count(expr, ..) - | Sum(expr) - | ApproxPercentile(ApproxPercentileParams { child: expr, .. }) - | ApproxCountDistinct(expr) - | ApproxSketch(expr, _) - | MergeSketch(expr, _) - | Mean(expr) - | Min(expr) - | Max(expr) - | AnyValue(expr, _) - | List(expr) - | Concat(expr) => vec![expr.clone()], - MapGroups { func: _, inputs } => inputs.clone(), + Self::Count(expr, ..) + | Self::Sum(expr) + | Self::ApproxPercentile(ApproxPercentileParams { child: expr, .. }) + | Self::ApproxCountDistinct(expr) + | Self::ApproxSketch(expr, _) + | Self::MergeSketch(expr, _) + | Self::Mean(expr) + | Self::Stddev(expr) + | Self::Min(expr) + | Self::Max(expr) + | Self::AnyValue(expr, _) + | Self::List(expr) + | Self::Concat(expr) => vec![expr.clone()], + Self::MapGroups { func: _, inputs } => inputs.clone(), } } - pub fn with_new_children(&self, children: Vec) -> Self { - use AggExpr::*; - - if let MapGroups { func: _, inputs } = &self { + pub fn with_new_children(&self, mut children: Vec) -> Self { + if let Self::MapGroups { func: _, inputs } = &self { assert_eq!(children.len(), inputs.len()); } else { assert_eq!(children.len(), 1); } + let mut first_child = || children.pop().unwrap(); match self { - Count(_, count_mode) => Count(children[0].clone(), *count_mode), - Sum(_) => Sum(children[0].clone()), - Mean(_) => Mean(children[0].clone()), - Min(_) => Min(children[0].clone()), - Max(_) => Max(children[0].clone()), - AnyValue(_, ignore_nulls) => AnyValue(children[0].clone(), *ignore_nulls), - List(_) => List(children[0].clone()), - Concat(_) => Concat(children[0].clone()), - MapGroups { func, inputs: _ } => MapGroups { + Self::Count(_, count_mode) => Self::Count(first_child(), *count_mode), + Self::Sum(_) => Self::Sum(first_child()), + Self::Mean(_) => Self::Mean(first_child()), + Self::Stddev(_) => Self::Stddev(first_child()), + Self::Min(_) => Self::Min(first_child()), + Self::Max(_) => Self::Max(first_child()), + Self::AnyValue(_, ignore_nulls) => Self::AnyValue(first_child(), *ignore_nulls), + Self::List(_) => Self::List(first_child()), + Self::Concat(_) => Self::Concat(first_child()), + Self::MapGroups { func, inputs: _ } => Self::MapGroups { func: func.clone(), inputs: children, }, - ApproxPercentile(ApproxPercentileParams { + Self::ApproxPercentile(ApproxPercentileParams { percentiles, force_list_output, .. - }) => ApproxPercentile(ApproxPercentileParams { - child: children[0].clone(), + }) => Self::ApproxPercentile(ApproxPercentileParams { + child: first_child(), percentiles: percentiles.clone(), force_list_output: *force_list_output, }), - ApproxCountDistinct(_) => ApproxCountDistinct(children[0].clone()), - &ApproxSketch(_, sketch_type) => ApproxSketch(children[0].clone(), sketch_type), - &MergeSketch(_, sketch_type) => MergeSketch(children[0].clone(), sketch_type), + Self::ApproxCountDistinct(_) => Self::ApproxCountDistinct(first_child()), + &Self::ApproxSketch(_, sketch_type) => Self::ApproxSketch(first_child(), sketch_type), + &Self::MergeSketch(_, sketch_type) => Self::MergeSketch(first_child(), sketch_type), } } pub fn to_field(&self, schema: &Schema) -> DaftResult { - use AggExpr::*; match self { - Count(expr, ..) => { + Self::Count(expr, ..) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - Sum(expr) => { + Self::Sum(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), try_sum_supertype(&field.dtype)?, )) } - ApproxPercentile(ApproxPercentileParams { + + Self::ApproxPercentile(ApproxPercentileParams { child: expr, percentiles, force_list_output, @@ -337,11 +346,11 @@ impl AggExpr { }, )) } - ApproxCountDistinct(expr) => { + Self::ApproxCountDistinct(expr) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - ApproxSketch(expr, sketch_type) => { + Self::ApproxSketch(expr, sketch_type) => { let field = expr.to_field(schema)?; let dtype = match sketch_type { SketchType::DDSketch => { @@ -357,7 +366,7 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - MergeSketch(expr, sketch_type) => { + Self::MergeSketch(expr, sketch_type) => { let field = expr.to_field(schema)?; let dtype = match sketch_type { SketchType::DDSketch => { @@ -374,19 +383,19 @@ impl AggExpr { }; Ok(Field::new(field.name, dtype)) } - Mean(expr) => { + Self::Mean(expr) | Self::Stddev(expr) => { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), - try_mean_supertype(&field.dtype)?, + try_mean_stddev_aggregation_supertype(&field.dtype)?, )) } - Min(expr) | Max(expr) | AnyValue(expr, _) => { + Self::Min(expr) | Self::Max(expr) | Self::AnyValue(expr, _) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), field.dtype)) } - List(expr) => expr.to_field(schema)?.to_list_field(), - Concat(expr) => { + Self::List(expr) => expr.to_field(schema)?.to_list_field(), + Self::Concat(expr) => { let field = expr.to_field(schema)?; match field.dtype { DataType::List(..) => Ok(field), @@ -399,23 +408,7 @@ impl AggExpr { ))), } } - MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func), - } - } - - pub fn from_name_and_child_expr(name: &str, child: ExprRef) -> DaftResult { - use AggExpr::*; - match name { - "count" => Ok(Count(child, CountMode::Valid)), - "sum" => Ok(Sum(child)), - "mean" => Ok(Mean(child)), - "min" => Ok(Min(child)), - "max" => Ok(Max(child)), - "list" => Ok(List(child)), - _ => Err(DaftError::ValueError(format!( - "{} not a valid aggregation name", - name - ))), + Self::MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func), } } } @@ -498,6 +491,10 @@ impl Expr { Self::Agg(AggExpr::Mean(self)).into() } + pub fn stddev(self: ExprRef) -> ExprRef { + Self::Agg(AggExpr::Stddev(self)).into() + } + pub fn min(self: ExprRef) -> ExprRef { Self::Agg(AggExpr::Min(self)).into() } @@ -576,57 +573,55 @@ impl Expr { } pub fn semantic_id(&self, schema: &Schema) -> FieldID { - use Expr::*; match self { // Base case - anonymous column reference. // Look up the column name in the provided schema and get its field ID. - Column(name) => FieldID::new(&**name), + Self::Column(name) => FieldID::new(&**name), // Base case - literal. - Literal(value) => FieldID::new(format!("Literal({value:?})")), + Self::Literal(value) => FieldID::new(format!("Literal({value:?})")), // Recursive cases. - Cast(expr, dtype) => { + Self::Cast(expr, dtype) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.cast({dtype})")) } - Not(expr) => { + Self::Not(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.not()")) } - IsNull(expr) => { + Self::IsNull(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.is_null()")) } - NotNull(expr) => { + Self::NotNull(expr) => { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.not_null()")) } - FillNull(expr, fill_value) => { + Self::FillNull(expr, fill_value) => { let child_id = expr.semantic_id(schema); let fill_value_id = fill_value.semantic_id(schema); FieldID::new(format!("{child_id}.fill_null({fill_value_id})")) } - IsIn(expr, items) => { + Self::IsIn(expr, items) => { let child_id = expr.semantic_id(schema); let items_id = items.semantic_id(schema); FieldID::new(format!("{child_id}.is_in({items_id})")) } - Between(expr, lower, upper) => { + Self::Between(expr, lower, upper) => { let child_id = expr.semantic_id(schema); let lower_id = lower.semantic_id(schema); let upper_id = upper.semantic_id(schema); FieldID::new(format!("{child_id}.between({lower_id},{upper_id})")) } - Function { func, inputs } => function_semantic_id(func, inputs, schema), - BinaryOp { op, left, right } => { + Self::Function { func, inputs } => function_semantic_id(func, inputs, schema), + Self::BinaryOp { op, left, right } => { let left_id = left.semantic_id(schema); let right_id = right.semantic_id(schema); // TODO: check for symmetry here. FieldID::new(format!("({left_id} {op} {right_id})")) } - - IfElse { + Self::IfElse { if_true, if_false, predicate, @@ -636,96 +631,100 @@ impl Expr { let predicate = predicate.semantic_id(schema); FieldID::new(format!("({if_true} if {predicate} else {if_false})")) } - // Alias: ID does not change. - Alias(expr, ..) => expr.semantic_id(schema), - + Self::Alias(expr, ..) => expr.semantic_id(schema), // Agg: Separate path. - Agg(agg_expr) => agg_expr.semantic_id(schema), - ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), + Self::Agg(agg_expr) => agg_expr.semantic_id(schema), + Self::ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), } } pub fn children(&self) -> Vec { - use Expr::*; match self { // No children. - Column(..) => vec![], - Literal(..) => vec![], + Self::Column(..) => vec![], + Self::Literal(..) => vec![], // One child. - Not(expr) | IsNull(expr) | NotNull(expr) | Cast(expr, ..) | Alias(expr, ..) => { + Self::Not(expr) + | Self::IsNull(expr) + | Self::NotNull(expr) + | Self::Cast(expr, ..) + | Self::Alias(expr, ..) => { vec![expr.clone()] } - Agg(agg_expr) => agg_expr.children(), + Self::Agg(agg_expr) => agg_expr.children(), // Multiple children. - Function { inputs, .. } => inputs.clone(), - BinaryOp { left, right, .. } => { + Self::Function { inputs, .. } => inputs.clone(), + Self::BinaryOp { left, right, .. } => { vec![left.clone(), right.clone()] } - IsIn(expr, items) => vec![expr.clone(), items.clone()], - Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()], - IfElse { + Self::IsIn(expr, items) => vec![expr.clone(), items.clone()], + Self::Between(expr, lower, upper) => vec![expr.clone(), lower.clone(), upper.clone()], + Self::IfElse { if_true, if_false, predicate, } => { vec![if_true.clone(), if_false.clone(), predicate.clone()] } - FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()], - ScalarFunction(sf) => sf.inputs.clone(), + Self::FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()], + Self::ScalarFunction(sf) => sf.inputs.clone(), } } pub fn with_new_children(&self, children: Vec) -> Self { - use Expr::*; match self { // no children - Column(..) | Literal(..) => { + Self::Column(..) | Self::Literal(..) => { assert!(children.is_empty(), "Should have no children"); self.clone() } // 1 child - Not(..) => Not(children.first().expect("Should have 1 child").clone()), - Alias(.., name) => Alias( + Self::Not(..) => Self::Not(children.first().expect("Should have 1 child").clone()), + Self::Alias(.., name) => Self::Alias( children.first().expect("Should have 1 child").clone(), name.clone(), ), - IsNull(..) => IsNull(children.first().expect("Should have 1 child").clone()), - NotNull(..) => NotNull(children.first().expect("Should have 1 child").clone()), - Cast(.., dtype) => Cast( + Self::IsNull(..) => { + Self::IsNull(children.first().expect("Should have 1 child").clone()) + } + Self::NotNull(..) => { + Self::NotNull(children.first().expect("Should have 1 child").clone()) + } + Self::Cast(.., dtype) => Self::Cast( children.first().expect("Should have 1 child").clone(), dtype.clone(), ), // 2 children - BinaryOp { op, .. } => BinaryOp { + Self::BinaryOp { op, .. } => Self::BinaryOp { op: *op, left: children.first().expect("Should have 1 child").clone(), right: children.get(1).expect("Should have 2 child").clone(), }, - IsIn(..) => IsIn( + Self::IsIn(..) => Self::IsIn( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), ), - Between(..) => Between( + Self::Between(..) => Self::Between( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), children.get(2).expect("Should have 3 child").clone(), ), - FillNull(..) => FillNull( + Self::FillNull(..) => Self::FillNull( children.first().expect("Should have 1 child").clone(), children.get(1).expect("Should have 2 child").clone(), ), // ternary - IfElse { .. } => IfElse { + Self::IfElse { .. } => Self::IfElse { if_true: children.first().expect("Should have 1 child").clone(), if_false: children.get(1).expect("Should have 2 child").clone(), predicate: children.get(2).expect("Should have 3 child").clone(), }, // N-ary - Agg(agg_expr) => Agg(agg_expr.with_new_children(children)), - Function { + Self::Agg(agg_expr) => Self::Agg(agg_expr.with_new_children(children)), + Self::Function { func, inputs: old_children, } => { @@ -733,18 +732,18 @@ impl Expr { children.len() == old_children.len(), "Should have same number of children" ); - Function { + Self::Function { func: func.clone(), inputs: children, } } - ScalarFunction(sf) => { + Self::ScalarFunction(sf) => { assert!( children.len() == sf.inputs.len(), "Should have same number of children" ); - ScalarFunction(crate::functions::ScalarFunction { + Self::ScalarFunction(crate::functions::ScalarFunction { udf: sf.udf.clone(), inputs: children, }) @@ -753,13 +752,12 @@ impl Expr { } pub fn to_field(&self, schema: &Schema) -> DaftResult { - use Expr::*; match self { - Alias(expr, name) => Ok(Field::new(name.as_ref(), expr.get_type(schema)?)), - Agg(agg_expr) => agg_expr.to_field(schema), - Cast(expr, dtype) => Ok(Field::new(expr.name(), dtype.clone())), - Column(name) => Ok(schema.get_field(name).cloned()?), - Not(expr) => { + Self::Alias(expr, name) => Ok(Field::new(name.as_ref(), expr.get_type(schema)?)), + Self::Agg(agg_expr) => agg_expr.to_field(schema), + Self::Cast(expr, dtype) => Ok(Field::new(expr.name(), dtype.clone())), + Self::Column(name) => Ok(schema.get_field(name).cloned()?), + Self::Not(expr) => { let child_field = expr.to_field(schema)?; match child_field.dtype { DataType::Boolean => Ok(Field::new(expr.name(), DataType::Boolean)), @@ -768,9 +766,9 @@ impl Expr { ))), } } - IsNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), - NotNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), - FillNull(expr, fill_value) => { + Self::IsNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), + Self::NotNull(expr) => Ok(Field::new(expr.name(), DataType::Boolean)), + Self::FillNull(expr, fill_value) => { let expr_field = expr.to_field(schema)?; let fill_value_field = fill_value.to_field(schema)?; match try_get_supertype(&expr_field.dtype, &fill_value_field.dtype) { @@ -780,7 +778,7 @@ impl Expr { ))) } } - IsIn(left, right) => { + Self::IsIn(left, right) => { let left_field = left.to_field(schema)?; let right_field = right.to_field(schema)?; let (result_type, _intermediate, _comp_type) = @@ -788,7 +786,7 @@ impl Expr { .membership_op(&InferDataType::from(&right_field.dtype))?; Ok(Field::new(left_field.name.as_str(), result_type)) } - Between(value, lower, upper) => { + Self::Between(value, lower, upper) => { let value_field = value.to_field(schema)?; let lower_field = lower.to_field(schema)?; let upper_field = upper.to_field(schema)?; @@ -803,11 +801,10 @@ impl Expr { .membership_op(&InferDataType::from(&upper_result_type))?; Ok(Field::new(value_field.name.as_str(), result_type)) } - Literal(value) => Ok(Field::new("literal", value.get_type())), - Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func), - ScalarFunction(sf) => sf.to_field(schema), - - BinaryOp { op, left, right } => { + Self::Literal(value) => Ok(Field::new("literal", value.get_type())), + Self::Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func), + Self::ScalarFunction(sf) => sf.to_field(schema), + Self::BinaryOp { op, left, right } => { let left_field = left.to_field(schema)?; let right_field = right.to_field(schema)?; @@ -873,7 +870,7 @@ impl Expr { } } } - IfElse { + Self::IfElse { if_true, if_false, predicate, @@ -903,33 +900,32 @@ impl Expr { } pub fn name(&self) -> &str { - use Expr::*; match self { - Alias(.., name) => name.as_ref(), - Agg(agg_expr) => agg_expr.name(), - Cast(expr, ..) => expr.name(), - Column(name) => name.as_ref(), - Not(expr) => expr.name(), - IsNull(expr) => expr.name(), - NotNull(expr) => expr.name(), - FillNull(expr, ..) => expr.name(), - IsIn(expr, ..) => expr.name(), - Between(expr, ..) => expr.name(), - Literal(..) => "literal", - Function { func, inputs } => match func { + Self::Alias(.., name) => name.as_ref(), + Self::Agg(agg_expr) => agg_expr.name(), + Self::Cast(expr, ..) => expr.name(), + Self::Column(name) => name.as_ref(), + Self::Not(expr) => expr.name(), + Self::IsNull(expr) => expr.name(), + Self::NotNull(expr) => expr.name(), + Self::FillNull(expr, ..) => expr.name(), + Self::IsIn(expr, ..) => expr.name(), + Self::Between(expr, ..) => expr.name(), + Self::Literal(..) => "literal", + Self::Function { func, inputs } => match func { FunctionExpr::Struct(StructExpr::Get(name)) => name, _ => inputs.first().unwrap().name(), }, - ScalarFunction(func) => match func.name() { + Self::ScalarFunction(func) => match func.name() { "to_struct" => "struct", // FIXME: make .name() use output name from schema _ => func.inputs.first().unwrap().name(), }, - BinaryOp { + Self::BinaryOp { op: _, left, right: _, } => left.name(), - IfElse { if_true, .. } => if_true.name(), + Self::IfElse { if_true, .. } => if_true.name(), } } @@ -1119,90 +1115,3 @@ pub fn has_stateful_udf(expr: &ExprRef) -> bool { ) }) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn check_comparison_type() -> DaftResult<()> { - let x = lit(10.); - let y = lit(12); - let schema = Schema::empty(); - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Lt, - }; - assert_eq!(z.get_type(&schema)?, DataType::Boolean); - Ok(()) - } - - #[test] - fn check_alias_type() -> DaftResult<()> { - let a = col("a"); - let b = a.alias("b"); - match b.as_ref() { - Expr::Alias(..) => Ok(()), - other => Err(common_error::DaftError::ValueError(format!( - "expected expression to be a alias, got {other:?}" - ))), - } - } - - #[test] - fn check_arithmetic_type() -> DaftResult<()> { - let x = lit(10.); - let y = lit(12); - let schema = Schema::empty(); - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - let x = lit(10.); - let y = lit(12); - - let z = Expr::BinaryOp { - left: y, - right: x, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - Ok(()) - } - - #[test] - fn check_arithmetic_type_with_columns() -> DaftResult<()> { - let x = col("x"); - let y = col("y"); - let schema = Schema::new(vec![ - Field::new("x", DataType::Float64), - Field::new("y", DataType::Int64), - ])?; - - let z = Expr::BinaryOp { - left: x, - right: y, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - let x = col("x"); - let y = col("y"); - - let z = Expr::BinaryOp { - left: y, - right: x, - op: Operator::Plus, - }; - assert_eq!(z.get_type(&schema)?, DataType::Float64); - - Ok(()) - } -} diff --git a/src/daft-dsl/src/expr/tests.rs b/src/daft-dsl/src/expr/tests.rs new file mode 100644 index 0000000000..aff680c5d3 --- /dev/null +++ b/src/daft-dsl/src/expr/tests.rs @@ -0,0 +1,83 @@ +use super::*; + +#[test] +fn check_comparison_type() -> DaftResult<()> { + let x = lit(10.); + let y = lit(12); + let schema = Schema::empty(); + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Lt, + }; + assert_eq!(z.get_type(&schema)?, DataType::Boolean); + Ok(()) +} + +#[test] +fn check_alias_type() -> DaftResult<()> { + let a = col("a"); + let b = a.alias("b"); + match b.as_ref() { + Expr::Alias(..) => Ok(()), + other => Err(common_error::DaftError::ValueError(format!( + "expected expression to be a alias, got {other:?}" + ))), + } +} + +#[test] +fn check_arithmetic_type() -> DaftResult<()> { + let x = lit(10.); + let y = lit(12); + let schema = Schema::empty(); + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + let x = lit(10.); + let y = lit(12); + + let z = Expr::BinaryOp { + left: y, + right: x, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + Ok(()) +} + +#[test] +fn check_arithmetic_type_with_columns() -> DaftResult<()> { + let x = col("x"); + let y = col("y"); + let schema = Schema::new(vec![ + Field::new("x", DataType::Float64), + Field::new("y", DataType::Int64), + ])?; + + let z = Expr::BinaryOp { + left: x, + right: y, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + let x = col("x"); + let y = col("y"); + + let z = Expr::BinaryOp { + left: y, + right: x, + op: Operator::Plus, + }; + assert_eq!(z.get_type(&schema)?, DataType::Float64); + + Ok(()) +} diff --git a/src/daft-dsl/src/functions/map/mod.rs b/src/daft-dsl/src/functions/map/mod.rs index 979a6ccd1e..083e99e7db 100644 --- a/src/daft-dsl/src/functions/map/mod.rs +++ b/src/daft-dsl/src/functions/map/mod.rs @@ -14,9 +14,8 @@ pub enum MapExpr { impl MapExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use MapExpr::*; match self { - Get => &GetEvaluator {}, + Self::Get => &GetEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 0386d7c54c..6f0b162422 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -1,5 +1,6 @@ pub mod map; pub mod partitioning; +pub mod python; pub mod scalar; pub mod sketch; pub mod struct_; @@ -12,6 +13,7 @@ use std::{ use common_error::DaftResult; use daft_core::prelude::*; +use python::PythonUDF; pub use scalar::*; use serde::{Deserialize, Serialize}; @@ -21,9 +23,6 @@ use self::{ }; use crate::{Expr, ExprRef, Operator}; -pub mod python; -use python::PythonUDF; - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { Utf8(Utf8Expr), @@ -48,14 +47,13 @@ pub trait FunctionEvaluator { impl FunctionExpr { #[inline] fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use FunctionExpr::*; match self { - Utf8(expr) => expr.get_evaluator(), - Map(expr) => expr.get_evaluator(), - Sketch(expr) => expr.get_evaluator(), - Struct(expr) => expr.get_evaluator(), - Python(expr) => expr.get_evaluator(), - Partitioning(expr) => expr.get_evaluator(), + Self::Utf8(expr) => expr.get_evaluator(), + Self::Map(expr) => expr.get_evaluator(), + Self::Sketch(expr) => expr.get_evaluator(), + Self::Struct(expr) => expr.get_evaluator(), + Self::Python(expr) => expr.get_evaluator(), + Self::Partitioning(expr) => expr.get_evaluator(), } } } diff --git a/src/daft-dsl/src/functions/partitioning/mod.rs b/src/daft-dsl/src/functions/partitioning/mod.rs index 9f37414e18..ead6ed91f8 100644 --- a/src/daft-dsl/src/functions/partitioning/mod.rs +++ b/src/daft-dsl/src/functions/partitioning/mod.rs @@ -24,14 +24,13 @@ pub enum PartitioningExpr { impl PartitioningExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use PartitioningExpr::*; match self { - Years => &YearsEvaluator {}, - Months => &MonthsEvaluator {}, - Days => &DaysEvaluator {}, - Hours => &HoursEvaluator {}, - IcebergBucket(..) => &IcebergBucketEvaluator {}, - IcebergTruncate(..) => &IcebergTruncateEvaluator {}, + Self::Years => &YearsEvaluator {}, + Self::Months => &MonthsEvaluator {}, + Self::Days => &DaysEvaluator {}, + Self::Hours => &HoursEvaluator {}, + Self::IcebergBucket(..) => &IcebergBucketEvaluator {}, + Self::IcebergTruncate(..) => &IcebergTruncateEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/python/mod.rs b/src/daft-dsl/src/functions/python/mod.rs index afdd2153ae..c4f69f8331 100644 --- a/src/daft-dsl/src/functions/python/mod.rs +++ b/src/daft-dsl/src/functions/python/mod.rs @@ -2,9 +2,13 @@ mod runtime_py_object; mod udf; mod udf_runtime_binding; -use std::{collections::HashMap, sync::Arc}; +#[cfg(feature = "python")] +use std::collections::HashMap; +use std::sync::Arc; -use common_error::{DaftError, DaftResult}; +#[cfg(feature = "python")] +use common_error::DaftError; +use common_error::DaftResult; use common_resource_request::ResourceRequest; use common_treenode::{TreeNode, TreeNodeRecursion}; use daft_core::datatypes::DataType; diff --git a/src/daft-dsl/src/functions/python/udf.rs b/src/daft-dsl/src/functions/python/udf.rs index 7b9da47bf7..100bd06566 100644 --- a/src/daft-dsl/src/functions/python/udf.rs +++ b/src/daft-dsl/src/functions/python/udf.rs @@ -1,5 +1,5 @@ use common_error::{DaftError, DaftResult}; -use daft_core::{datatypes::DataType, prelude::*}; +use daft_core::prelude::*; #[cfg(feature = "python")] use pyo3::{ types::{PyAnyMethods, PyModule}, diff --git a/src/daft-dsl/src/functions/sketch/mod.rs b/src/daft-dsl/src/functions/sketch/mod.rs index d3c43f1f7b..17770d9c11 100644 --- a/src/daft-dsl/src/functions/sketch/mod.rs +++ b/src/daft-dsl/src/functions/sketch/mod.rs @@ -30,9 +30,8 @@ pub enum SketchExpr { impl SketchExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use SketchExpr::*; match self { - Percentile { .. } => &PercentileEvaluator {}, + Self::Percentile { .. } => &PercentileEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/struct_/mod.rs b/src/daft-dsl/src/functions/struct_/mod.rs index c842c45c64..7d8d192d25 100644 --- a/src/daft-dsl/src/functions/struct_/mod.rs +++ b/src/daft-dsl/src/functions/struct_/mod.rs @@ -14,9 +14,8 @@ pub enum StructExpr { impl StructExpr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use StructExpr::*; match self { - Get(_) => &GetEvaluator {}, + Self::Get(_) => &GetEvaluator {}, } } } diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index cb3a07aca1..7a795250ff 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -95,36 +95,35 @@ pub enum Utf8Expr { impl Utf8Expr { #[inline] pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - use Utf8Expr::*; match self { - EndsWith => &EndswithEvaluator {}, - StartsWith => &StartswithEvaluator {}, - Contains => &ContainsEvaluator {}, - Split(_) => &SplitEvaluator {}, - Match => &MatchEvaluator {}, - Extract(_) => &ExtractEvaluator {}, - ExtractAll(_) => &ExtractAllEvaluator {}, - Replace(_) => &ReplaceEvaluator {}, - Length => &LengthEvaluator {}, - LengthBytes => &LengthBytesEvaluator {}, - Lower => &LowerEvaluator {}, - Upper => &UpperEvaluator {}, - Lstrip => &LstripEvaluator {}, - Rstrip => &RstripEvaluator {}, - Reverse => &ReverseEvaluator {}, - Capitalize => &CapitalizeEvaluator {}, - Left => &LeftEvaluator {}, - Right => &RightEvaluator {}, - Find => &FindEvaluator {}, - Rpad => &RpadEvaluator {}, - Lpad => &LpadEvaluator {}, - Repeat => &RepeatEvaluator {}, - Like => &LikeEvaluator {}, - Ilike => &IlikeEvaluator {}, - Substr => &SubstrEvaluator {}, - ToDate(_) => &ToDateEvaluator {}, - ToDatetime(_, _) => &ToDatetimeEvaluator {}, - Normalize(_) => &NormalizeEvaluator {}, + Self::EndsWith => &EndswithEvaluator {}, + Self::StartsWith => &StartswithEvaluator {}, + Self::Contains => &ContainsEvaluator {}, + Self::Split(_) => &SplitEvaluator {}, + Self::Match => &MatchEvaluator {}, + Self::Extract(_) => &ExtractEvaluator {}, + Self::ExtractAll(_) => &ExtractAllEvaluator {}, + Self::Replace(_) => &ReplaceEvaluator {}, + Self::Length => &LengthEvaluator {}, + Self::LengthBytes => &LengthBytesEvaluator {}, + Self::Lower => &LowerEvaluator {}, + Self::Upper => &UpperEvaluator {}, + Self::Lstrip => &LstripEvaluator {}, + Self::Rstrip => &RstripEvaluator {}, + Self::Reverse => &ReverseEvaluator {}, + Self::Capitalize => &CapitalizeEvaluator {}, + Self::Left => &LeftEvaluator {}, + Self::Right => &RightEvaluator {}, + Self::Find => &FindEvaluator {}, + Self::Rpad => &RpadEvaluator {}, + Self::Lpad => &LpadEvaluator {}, + Self::Repeat => &RepeatEvaluator {}, + Self::Like => &LikeEvaluator {}, + Self::Ilike => &IlikeEvaluator {}, + Self::Substr => &SubstrEvaluator {}, + Self::ToDate(_) => &ToDateEvaluator {}, + Self::ToDatetime(_, _) => &ToDatetimeEvaluator {}, + Self::Normalize(_) => &NormalizeEvaluator {}, } } } diff --git a/src/daft-dsl/src/join.rs b/src/daft-dsl/src/join/mod.rs similarity index 79% rename from src/daft-dsl/src/join.rs rename to src/daft-dsl/src/join/mod.rs index 2f1cf96cb2..1de29b995e 100644 --- a/src/daft-dsl/src/join.rs +++ b/src/daft-dsl/src/join/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::sync::Arc; use common_error::{DaftError, DaftResult}; @@ -79,34 +82,3 @@ pub fn infer_join_schema( Ok(Schema::new(fields)?.into()) } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::col; - - #[test] - fn test_get_common_join_keys() { - let left_on: &[ExprRef] = &[ - col("a"), - col("b_left"), - col("c").alias("c_new"), - col("d").alias("d_new"), - col("e").add(col("f")), - ]; - - let right_on: &[ExprRef] = &[ - col("a"), - col("b_right"), - col("c"), - col("d").alias("d_new"), - col("e"), - ]; - - let common_join_keys = get_common_join_keys(left_on, right_on) - .map(|k| k.to_string()) - .collect::>(); - - assert_eq!(common_join_keys, vec!["a"]); - } -} diff --git a/src/daft-dsl/src/join/tests.rs b/src/daft-dsl/src/join/tests.rs new file mode 100644 index 0000000000..52d58a76c0 --- /dev/null +++ b/src/daft-dsl/src/join/tests.rs @@ -0,0 +1,27 @@ +use super::*; +use crate::col; + +#[test] +fn test_get_common_join_keys() { + let left_on: &[ExprRef] = &[ + col("a"), + col("b_left"), + col("c").alias("c_new"), + col("d").alias("d_new"), + col("e").add(col("f")), + ]; + + let right_on: &[ExprRef] = &[ + col("a"), + col("b_right"), + col("c"), + col("d").alias("d_new"), + col("e"), + ]; + + let common_join_keys = get_common_join_keys(left_on, right_on) + .map(|k| k.to_string()) + .collect::>(); + + assert_eq!(common_join_keys, vec!["a"]); +} diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 45b24c20e9..52ca09af18 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -79,40 +79,38 @@ impl Eq for LiteralValue {} impl Hash for LiteralValue { fn hash(&self, state: &mut H) { - use LiteralValue::*; - match self { // Stable hash for Null variant. - Null => 1.hash(state), - Boolean(bool) => bool.hash(state), - Utf8(s) => s.hash(state), - Binary(arr) => arr.hash(state), - Int32(n) => n.hash(state), - UInt32(n) => n.hash(state), - Int64(n) => n.hash(state), - UInt64(n) => n.hash(state), - Date(n) => n.hash(state), - Time(n, tu) => { + Self::Null => 1.hash(state), + Self::Boolean(bool) => bool.hash(state), + Self::Utf8(s) => s.hash(state), + Self::Binary(arr) => arr.hash(state), + Self::Int32(n) => n.hash(state), + Self::UInt32(n) => n.hash(state), + Self::Int64(n) => n.hash(state), + Self::UInt64(n) => n.hash(state), + Self::Date(n) => n.hash(state), + Self::Time(n, tu) => { n.hash(state); tu.hash(state); } - Timestamp(n, tu, tz) => { + Self::Timestamp(n, tu, tz) => { n.hash(state); tu.hash(state); tz.hash(state); } - Duration(n, tu) => { + Self::Duration(n, tu) => { n.hash(state); tu.hash(state); } // Wrap float64 in hashable newtype. - Float64(n) => FloatWrapper(*n).hash(state), - Decimal(n, precision, scale) => { + Self::Float64(n) => FloatWrapper(*n).hash(state), + Self::Decimal(n, precision, scale) => { n.hash(state); precision.hash(state); scale.hash(state); } - Series(series) => { + Self::Series(series) => { let hash_result = series.hash(None); match hash_result { Ok(hash) => hash.into_iter().for_each(|i| i.hash(state)), @@ -120,8 +118,8 @@ impl Hash for LiteralValue { } } #[cfg(feature = "python")] - Python(py_obj) => py_obj.hash(state), - Struct(entries) => { + Self::Python(py_obj) => py_obj.hash(state), + Self::Struct(entries) => { entries.iter().for_each(|(v, f)| { v.hash(state); f.hash(state); @@ -134,32 +132,31 @@ impl Hash for LiteralValue { impl Display for LiteralValue { // `f` is a buffer, and this method must write the formatted string into it fn fmt(&self, f: &mut Formatter) -> Result { - use LiteralValue::*; match self { - Null => write!(f, "Null"), - Boolean(val) => write!(f, "{val}"), - Utf8(val) => write!(f, "\"{val}\""), - Binary(val) => write!(f, "Binary[{}]", val.len()), - Int32(val) => write!(f, "{val}"), - UInt32(val) => write!(f, "{val}"), - Int64(val) => write!(f, "{val}"), - UInt64(val) => write!(f, "{val}"), - Date(val) => write!(f, "{}", display_date32(*val)), - Time(val, tu) => write!(f, "{}", display_time64(*val, tu)), - Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)), - Duration(val, tu) => write!(f, "{}", display_duration(*val, tu)), - Float64(val) => write!(f, "{val:.1}"), - Decimal(val, precision, scale) => { + Self::Null => write!(f, "Null"), + Self::Boolean(val) => write!(f, "{val}"), + Self::Utf8(val) => write!(f, "\"{val}\""), + Self::Binary(val) => write!(f, "Binary[{}]", val.len()), + Self::Int32(val) => write!(f, "{val}"), + Self::UInt32(val) => write!(f, "{val}"), + Self::Int64(val) => write!(f, "{val}"), + Self::UInt64(val) => write!(f, "{val}"), + Self::Date(val) => write!(f, "{}", display_date32(*val)), + Self::Time(val, tu) => write!(f, "{}", display_time64(*val, tu)), + Self::Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)), + Self::Duration(val, tu) => write!(f, "{}", display_duration(*val, tu)), + Self::Float64(val) => write!(f, "{val:.1}"), + Self::Decimal(val, precision, scale) => { write!(f, "{}", display_decimal128(*val, *precision, *scale)) } - Series(series) => write!(f, "{}", display_series_literal(series)), + Self::Series(series) => write!(f, "{}", display_series_literal(series)), #[cfg(feature = "python")] - Python(pyobj) => write!(f, "PyObject({})", { + Self::Python(pyobj) => write!(f, "PyObject({})", { use pyo3::prelude::*; Python::with_gil(|py| pyobj.0.call_method0(py, pyo3::intern!(py, "__str__"))) .unwrap() }), - Struct(entries) => { + Self::Struct(entries) => { write!(f, "Struct(")?; for (i, (field, v)) in entries.iter().enumerate() { if i > 0 { @@ -175,106 +172,108 @@ impl Display for LiteralValue { impl LiteralValue { pub fn get_type(&self) -> DataType { - use LiteralValue::*; match self { - Null => DataType::Null, - Boolean(_) => DataType::Boolean, - Utf8(_) => DataType::Utf8, - Binary(_) => DataType::Binary, - Int32(_) => DataType::Int32, - UInt32(_) => DataType::UInt32, - Int64(_) => DataType::Int64, - UInt64(_) => DataType::UInt64, - Date(_) => DataType::Date, - Time(_, tu) => DataType::Time(*tu), - Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), - Duration(_, tu) => DataType::Duration(*tu), - Float64(_) => DataType::Float64, - Decimal(_, precision, scale) => { + Self::Null => DataType::Null, + Self::Boolean(_) => DataType::Boolean, + Self::Utf8(_) => DataType::Utf8, + Self::Binary(_) => DataType::Binary, + Self::Int32(_) => DataType::Int32, + Self::UInt32(_) => DataType::UInt32, + Self::Int64(_) => DataType::Int64, + Self::UInt64(_) => DataType::UInt64, + Self::Date(_) => DataType::Date, + Self::Time(_, tu) => DataType::Time(*tu), + Self::Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), + Self::Duration(_, tu) => DataType::Duration(*tu), + Self::Float64(_) => DataType::Float64, + Self::Decimal(_, precision, scale) => { DataType::Decimal128(*precision as usize, *scale as usize) } - Series(series) => series.data_type().clone(), + Self::Series(series) => series.data_type().clone(), #[cfg(feature = "python")] - Python(_) => DataType::Python, - Struct(entries) => DataType::Struct(entries.keys().cloned().collect()), + Self::Python(_) => DataType::Python, + Self::Struct(entries) => DataType::Struct(entries.keys().cloned().collect()), } } pub fn to_series(&self) -> Series { - use LiteralValue::*; - let result = match self { - Null => NullArray::full_null("literal", &DataType::Null, 1).into_series(), - Boolean(val) => BooleanArray::from(("literal", [*val].as_slice())).into_series(), - Utf8(val) => Utf8Array::from(("literal", [val.as_str()].as_slice())).into_series(), - Binary(val) => BinaryArray::from(("literal", val.as_slice())).into_series(), - Int32(val) => Int32Array::from(("literal", [*val].as_slice())).into_series(), - UInt32(val) => UInt32Array::from(("literal", [*val].as_slice())).into_series(), - Int64(val) => Int64Array::from(("literal", [*val].as_slice())).into_series(), - UInt64(val) => UInt64Array::from(("literal", [*val].as_slice())).into_series(), - Date(val) => { + match self { + Self::Null => NullArray::full_null("literal", &DataType::Null, 1).into_series(), + Self::Boolean(val) => BooleanArray::from(("literal", [*val].as_slice())).into_series(), + Self::Utf8(val) => { + Utf8Array::from(("literal", [val.as_str()].as_slice())).into_series() + } + Self::Binary(val) => BinaryArray::from(("literal", val.as_slice())).into_series(), + Self::Int32(val) => Int32Array::from(("literal", [*val].as_slice())).into_series(), + Self::UInt32(val) => UInt32Array::from(("literal", [*val].as_slice())).into_series(), + Self::Int64(val) => Int64Array::from(("literal", [*val].as_slice())).into_series(), + Self::UInt64(val) => UInt64Array::from(("literal", [*val].as_slice())).into_series(), + Self::Date(val) => { let physical = Int32Array::from(("literal", [*val].as_slice())); DateArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Time(val, ..) => { + Self::Time(val, ..) => { let physical = Int64Array::from(("literal", [*val].as_slice())); TimeArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Timestamp(val, ..) => { + Self::Timestamp(val, ..) => { let physical = Int64Array::from(("literal", [*val].as_slice())); TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Duration(val, ..) => { + Self::Duration(val, ..) => { let physical = Int64Array::from(("literal", [*val].as_slice())); DurationArray::new(Field::new("literal", self.get_type()), physical).into_series() } - Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), - Decimal(val, ..) => { + Self::Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), + Self::Decimal(val, ..) => { let physical = Int128Array::from(("literal", [*val].as_slice())); Decimal128Array::new(Field::new("literal", self.get_type()), physical).into_series() } - Series(series) => series.clone().rename("literal"), + Self::Series(series) => series.clone().rename("literal"), #[cfg(feature = "python")] - Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), - Struct(entries) => { + Self::Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), + Self::Struct(entries) => { let struct_dtype = DataType::Struct(entries.keys().cloned().collect()); let struct_field = Field::new("literal", struct_dtype); let values = entries.values().map(|v| v.to_series()).collect(); StructArray::new(struct_field, values, None).into_series() } - }; - result + } } pub fn display_sql(&self, buffer: &mut W) -> io::Result<()> { - use LiteralValue::*; let display_sql_err = Err(io::Error::new( io::ErrorKind::Other, "Unsupported literal for SQL translation", )); match self { - Null => write!(buffer, "NULL"), - Boolean(v) => write!(buffer, "{}", v), - Int32(val) => write!(buffer, "{}", val), - UInt32(val) => write!(buffer, "{}", val), - Int64(val) => write!(buffer, "{}", val), - UInt64(val) => write!(buffer, "{}", val), - Float64(val) => write!(buffer, "{}", val), - Utf8(val) => write!(buffer, "'{}'", val), - Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), + Self::Null => write!(buffer, "NULL"), + Self::Boolean(v) => write!(buffer, "{}", v), + Self::Int32(val) => write!(buffer, "{}", val), + Self::UInt32(val) => write!(buffer, "{}", val), + Self::Int64(val) => write!(buffer, "{}", val), + Self::UInt64(val) => write!(buffer, "{}", val), + Self::Float64(val) => write!(buffer, "{}", val), + Self::Utf8(val) => write!(buffer, "'{}'", val), + Self::Date(val) => write!(buffer, "DATE '{}'", display_date32(*val)), // The `display_timestamp` function formats a timestamp in the ISO 8601 format: "YYYY-MM-DDTHH:MM:SS.fffff". // ANSI SQL standard uses a space instead of 'T'. Some databases do not support 'T', hence it's replaced with a space. // Reference: https://docs.actian.com/ingres/10s/index.html#page/SQLRef/Summary_of_ANSI_Date_2fTime_Data_Types.html - Timestamp(val, tu, tz) => write!( + Self::Timestamp(val, tu, tz) => write!( buffer, "TIMESTAMP '{}'", display_timestamp(*val, tu, tz).replace('T', " ") ), // TODO(Colin): Implement the rest of the types in future work for SQL pushdowns. - Decimal(..) | Series(..) | Time(..) | Binary(..) | Duration(..) => display_sql_err, + Self::Decimal(..) + | Self::Series(..) + | Self::Time(..) + | Self::Binary(..) + | Self::Duration(..) => display_sql_err, #[cfg(feature = "python")] - Python(..) => display_sql_err, - Struct(..) => display_sql_err, + Self::Python(..) => display_sql_err, + Self::Struct(..) => display_sql_err, } } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index e62375e7b0..e0c6dc1700 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -327,6 +327,10 @@ impl PyExpr { Ok(self.expr.clone().mean().into()) } + pub fn stddev(&self) -> PyResult { + Ok(self.expr.clone().stddev().into()) + } + pub fn min(&self) -> PyResult { Ok(self.expr.clone().min().into()) } diff --git a/src/daft-dsl/src/resolve_expr.rs b/src/daft-dsl/src/resolve_expr/mod.rs similarity index 68% rename from src/daft-dsl/src/resolve_expr.rs rename to src/daft-dsl/src/resolve_expr/mod.rs index e9b6930c3a..5888774fe4 100644 --- a/src/daft-dsl/src/resolve_expr.rs +++ b/src/daft-dsl/src/resolve_expr/mod.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use std::{ cmp::Ordering, collections::{BinaryHeap, HashMap}, @@ -202,44 +205,47 @@ fn expand_wildcards( } fn extract_agg_expr(expr: &Expr) -> DaftResult { - use crate::Expr::*; - match expr { - Agg(agg_expr) => Ok(agg_expr.clone()), - Function { func, inputs } => Ok(AggExpr::MapGroups { + Expr::Agg(agg_expr) => Ok(agg_expr.clone()), + Expr::Function { func, inputs } => Ok(AggExpr::MapGroups { func: func.clone(), inputs: inputs.clone(), }), - Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { - use crate::AggExpr::*; - + Expr::Alias(e, name) => extract_agg_expr(e).map(|agg_expr| { // reorder expressions so that alias goes before agg match agg_expr { - Count(e, count_mode) => Count(Alias(e, name.clone()).into(), count_mode), - Sum(e) => Sum(Alias(e, name.clone()).into()), - ApproxPercentile(ApproxPercentileParams { + AggExpr::Count(e, count_mode) => { + AggExpr::Count(Expr::Alias(e, name.clone()).into(), count_mode) + } + AggExpr::Sum(e) => AggExpr::Sum(Expr::Alias(e, name.clone()).into()), + AggExpr::ApproxPercentile(ApproxPercentileParams { child: e, percentiles, force_list_output, - }) => ApproxPercentile(ApproxPercentileParams { - child: Alias(e, name.clone()).into(), + }) => AggExpr::ApproxPercentile(ApproxPercentileParams { + child: Expr::Alias(e, name.clone()).into(), percentiles, force_list_output, }), - ApproxCountDistinct(e) => ApproxCountDistinct(Alias(e, name.clone()).into()), - ApproxSketch(e, sketch_type) => { - ApproxSketch(Alias(e, name.clone()).into(), sketch_type) + AggExpr::ApproxCountDistinct(e) => { + AggExpr::ApproxCountDistinct(Expr::Alias(e, name.clone()).into()) + } + AggExpr::ApproxSketch(e, sketch_type) => { + AggExpr::ApproxSketch(Expr::Alias(e, name.clone()).into(), sketch_type) } - MergeSketch(e, sketch_type) => { - MergeSketch(Alias(e, name.clone()).into(), sketch_type) + AggExpr::MergeSketch(e, sketch_type) => { + AggExpr::MergeSketch(Expr::Alias(e, name.clone()).into(), sketch_type) } - Mean(e) => Mean(Alias(e, name.clone()).into()), - Min(e) => Min(Alias(e, name.clone()).into()), - Max(e) => Max(Alias(e, name.clone()).into()), - AnyValue(e, ignore_nulls) => AnyValue(Alias(e, name.clone()).into(), ignore_nulls), - List(e) => List(Alias(e, name.clone()).into()), - Concat(e) => Concat(Alias(e, name.clone()).into()), - MapGroups { func, inputs } => MapGroups { + AggExpr::Mean(e) => AggExpr::Mean(Expr::Alias(e, name.clone()).into()), + AggExpr::Stddev(e) => AggExpr::Stddev(Expr::Alias(e, name.clone()).into()), + AggExpr::Min(e) => AggExpr::Min(Expr::Alias(e, name.clone()).into()), + AggExpr::Max(e) => AggExpr::Max(Expr::Alias(e, name.clone()).into()), + AggExpr::AnyValue(e, ignore_nulls) => { + AggExpr::AnyValue(Expr::Alias(e, name.clone()).into(), ignore_nulls) + } + AggExpr::List(e) => AggExpr::List(Expr::Alias(e, name.clone()).into()), + AggExpr::Concat(e) => AggExpr::Concat(Expr::Alias(e, name.clone()).into()), + AggExpr::MapGroups { func, inputs } => AggExpr::MapGroups { func, inputs: inputs .into_iter() @@ -409,148 +415,3 @@ pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { - let struct_expr_map = calculate_struct_expr_map(schema); - transform_struct_gets(expr, &struct_expr_map) - } - - #[test] - fn test_substitute_expr_getter_sugar() -> DaftResult<()> { - use crate::functions::struct_::get as struct_get; - - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); - - assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); - assert!(substitute_expr_getter_sugar(col("a.b"), &schema).is_err()); - assert!(matches!( - substitute_expr_getter_sugar(col("a.b"), &schema).unwrap_err(), - DaftError::ValueError(..) - )); - - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Struct(vec![Field::new("b", DataType::Int64)]), - )])?); - - assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - struct_get(col("a"), "b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b").alias("c"), &schema)?, - struct_get(col("a"), "b").alias("c") - ); - - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Struct(vec![Field::new( - "b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - )]), - )])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - struct_get(col("a"), "b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(struct_get(col("a"), "b"), "c") - ); - - let schema = Arc::new(Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new( - "b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - )]), - ), - Field::new("a.b", DataType::Int64), - ])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b"), &schema)?, - col("a.b") - ); - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(struct_get(col("a"), "b"), "c") - ); - - let schema = Arc::new(Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new("b.c", DataType::Int64)]), - ), - Field::new( - "a.b", - DataType::Struct(vec![Field::new("c", DataType::Int64)]), - ), - ])?); - - assert_eq!( - substitute_expr_getter_sugar(col("a.b.c"), &schema)?, - struct_get(col("a.b"), "c") - ); - - Ok(()) - } - - #[test] - fn test_find_wildcards() -> DaftResult<()> { - let schema = Schema::new(vec![ - Field::new( - "a", - DataType::Struct(vec![Field::new("b.*", DataType::Int64)]), - ), - Field::new("c.*", DataType::Int64), - ])?; - let struct_expr_map = calculate_struct_expr_map(&schema); - - let wildcards = find_wildcards(col("test"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "*"); - - let wildcards = find_wildcards(col("t*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); - - let wildcards = find_wildcards(col("a.*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.*"); - - let wildcards = find_wildcards(col("c.*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("a.b.*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - let wildcards = find_wildcards(col("a.b*"), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.b*"); - - // nested expression - let wildcards = find_wildcards(col("t*").add(col("a.*")), &struct_expr_map); - assert!(wildcards.len() == 2); - assert!(wildcards.iter().any(|s| s.as_ref() == "t*")); - assert!(wildcards.iter().any(|s| s.as_ref() == "a.*")); - - let wildcards = find_wildcards(col("t*").add(col("a")), &struct_expr_map); - assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); - - // schema containing * - let schema = Schema::new(vec![Field::new("*", DataType::Int64)])?; - let struct_expr_map = calculate_struct_expr_map(&schema); - - let wildcards = find_wildcards(col("*"), &struct_expr_map); - assert!(wildcards.is_empty()); - - Ok(()) - } -} diff --git a/src/daft-dsl/src/resolve_expr/tests.rs b/src/daft-dsl/src/resolve_expr/tests.rs new file mode 100644 index 0000000000..dcb3147207 --- /dev/null +++ b/src/daft-dsl/src/resolve_expr/tests.rs @@ -0,0 +1,141 @@ +use super::*; + +fn substitute_expr_getter_sugar(expr: ExprRef, schema: &Schema) -> DaftResult { + let struct_expr_map = calculate_struct_expr_map(schema); + transform_struct_gets(expr, &struct_expr_map) +} + +#[test] +fn test_substitute_expr_getter_sugar() -> DaftResult<()> { + use crate::functions::struct_::get as struct_get; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64)])?); + + assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); + assert!(substitute_expr_getter_sugar(col("a.b"), &schema).is_err()); + assert!(matches!( + substitute_expr_getter_sugar(col("a.b"), &schema).unwrap_err(), + DaftError::ValueError(..) + )); + + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new("b", DataType::Int64)]), + )])?); + + assert_eq!(substitute_expr_getter_sugar(col("a"), &schema)?, col("a")); + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + struct_get(col("a"), "b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b").alias("c"), &schema)?, + struct_get(col("a"), "b").alias("c") + ); + + let schema = Arc::new(Schema::new(vec![Field::new( + "a", + DataType::Struct(vec![Field::new( + "b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + )]), + )])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + struct_get(col("a"), "b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(struct_get(col("a"), "b"), "c") + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new( + "b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + )]), + ), + Field::new("a.b", DataType::Int64), + ])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b"), &schema)?, + col("a.b") + ); + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(struct_get(col("a"), "b"), "c") + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new("b.c", DataType::Int64)]), + ), + Field::new( + "a.b", + DataType::Struct(vec![Field::new("c", DataType::Int64)]), + ), + ])?); + + assert_eq!( + substitute_expr_getter_sugar(col("a.b.c"), &schema)?, + struct_get(col("a.b"), "c") + ); + + Ok(()) +} + +#[test] +fn test_find_wildcards() -> DaftResult<()> { + let schema = Schema::new(vec![ + Field::new( + "a", + DataType::Struct(vec![Field::new("b.*", DataType::Int64)]), + ), + Field::new("c.*", DataType::Int64), + ])?; + let struct_expr_map = calculate_struct_expr_map(&schema); + + let wildcards = find_wildcards(col("test"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "*"); + + let wildcards = find_wildcards(col("t*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); + + let wildcards = find_wildcards(col("a.*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.*"); + + let wildcards = find_wildcards(col("c.*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("a.b.*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + let wildcards = find_wildcards(col("a.b*"), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "a.b*"); + + // nested expression + let wildcards = find_wildcards(col("t*").add(col("a.*")), &struct_expr_map); + assert!(wildcards.len() == 2); + assert!(wildcards.iter().any(|s| s.as_ref() == "t*")); + assert!(wildcards.iter().any(|s| s.as_ref() == "a.*")); + + let wildcards = find_wildcards(col("t*").add(col("a")), &struct_expr_map); + assert!(wildcards.len() == 1 && wildcards.first().unwrap().as_ref() == "t*"); + + // schema containing * + let schema = Schema::new(vec![Field::new("*", DataType::Int64)])?; + let struct_expr_map = calculate_struct_expr_map(&schema); + + let wildcards = find_wildcards(col("*"), &struct_expr_map); + assert!(wildcards.is_empty()); + + Ok(()) +} diff --git a/src/daft-functions/src/list/mean.rs b/src/daft-functions/src/list/mean.rs index aa2b51ea81..b01d3c1fa1 100644 --- a/src/daft-functions/src/list/mean.rs +++ b/src/daft-functions/src/list/mean.rs @@ -1,6 +1,6 @@ use common_error::{DaftError, DaftResult}; use daft_core::{ - datatypes::try_mean_supertype, + datatypes::try_mean_stddev_aggregation_supertype, prelude::{Field, Schema}, series::Series, }; @@ -29,7 +29,7 @@ impl ScalarUDF for ListMean { let inner_field = input.to_field(schema)?.to_exploded_field()?; Ok(Field::new( inner_field.name.as_str(), - try_mean_supertype(&inner_field.dtype)?, + try_mean_stddev_aggregation_supertype(&inner_field.dtype)?, )) } _ => Err(DaftError::SchemaMismatch(format!( diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 419c176930..78de22bea6 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -368,24 +368,24 @@ fn replace_column_with_semantic_id_aggexpr( AggExpr::Count(ref child, mode) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::Count(transformed_child, mode), - |_| e.clone(), + |_| e, ) } AggExpr::Sum(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Sum, |_| e.clone()) + .map_yes_no(AggExpr::Sum, |_| e) } AggExpr::ApproxPercentile(ApproxPercentileParams { ref child, ref percentiles, - ref force_list_output, + force_list_output, }) => replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) .map_yes_no( |transformed_child| { AggExpr::ApproxPercentile(ApproxPercentileParams { child: transformed_child, percentiles: percentiles.clone(), - force_list_output: *force_list_output, + force_list_output, }) }, |_| e.clone(), @@ -397,40 +397,44 @@ fn replace_column_with_semantic_id_aggexpr( AggExpr::ApproxSketch(ref child, sketch_type) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::ApproxSketch(transformed_child, sketch_type), - |_| e.clone(), + |_| e, ) } AggExpr::MergeSketch(ref child, sketch_type) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::MergeSketch(transformed_child, sketch_type), - |_| e.clone(), + |_| e, ) } AggExpr::Mean(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Mean, |_| e.clone()) + .map_yes_no(AggExpr::Mean, |_| e) + } + AggExpr::Stddev(ref child) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + .map_yes_no(AggExpr::Stddev, |_| e) } AggExpr::Min(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Min, |_| e.clone()) + .map_yes_no(AggExpr::Min, |_| e) } AggExpr::Max(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Max, |_| e.clone()) + .map_yes_no(AggExpr::Max, |_| e) } AggExpr::AnyValue(ref child, ignore_nulls) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no( |transformed_child| AggExpr::AnyValue(transformed_child, ignore_nulls), - |_| e.clone(), + |_| e, ) } AggExpr::List(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::List, |_| e.clone()) + .map_yes_no(AggExpr::List, |_| e) } AggExpr::Concat(ref child) => { replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) - .map_yes_no(AggExpr::Concat, |_| e.clone()) + .map_yes_no(AggExpr::Concat, |_| e) } AggExpr::MapGroups { func, inputs } => { let transforms = inputs diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 85da833ca5..c7a364c770 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -8,7 +8,10 @@ use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_file_formats::FileFormat; use daft_core::prelude::*; -use daft_dsl::{col, is_partition_compatible, ApproxPercentileParams, ExprRef, SketchType}; +use daft_dsl::{ + col, is_partition_compatible, AggExpr, ApproxPercentileParams, ExprRef, SketchType, +}; +use daft_functions::numeric::sqrt; use daft_scan::PhysicalScanInfo; use crate::{ @@ -765,8 +768,6 @@ pub fn populate_aggregation_stages( HashMap, daft_dsl::AggExpr>, Vec, ) { - use daft_dsl::AggExpr::{self, *}; - // Aggregations to apply in the first and second stages. // Semantic column name -> AggExpr let mut first_stage_aggs: HashMap, AggExpr> = HashMap::new(); @@ -774,147 +775,245 @@ pub fn populate_aggregation_stages( // Project the aggregation results to their final output names let mut final_exprs: Vec = group_by.iter().map(|e| col(e.name())).collect(); + fn add_to_stage( + f: F, + expr: ExprRef, + schema: &Schema, + stage: &mut HashMap, AggExpr>, + ) -> Arc + where + F: Fn(ExprRef) -> AggExpr, + { + let id = f(expr.clone()).semantic_id(schema).id; + let agg_expr = f(expr.alias(id.clone())); + stage.insert(id.clone(), agg_expr); + id + } + for agg_expr in aggregations { let output_name = agg_expr.name(); match agg_expr { - Count(e, mode) => { + AggExpr::Count(e, mode) => { let count_id = agg_expr.semantic_id(schema).id; - let sum_of_count_id = Sum(col(count_id.clone())).semantic_id(schema).id; + let sum_of_count_id = AggExpr::Sum(col(count_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(count_id.clone()) - .or_insert(Count(e.alias(count_id.clone()).clone(), *mode)); + .or_insert(AggExpr::Count(e.alias(count_id.clone()).clone(), *mode)); second_stage_aggs .entry(sum_of_count_id.clone()) - .or_insert(Sum(col(count_id.clone()).alias(sum_of_count_id.clone()))); + .or_insert(AggExpr::Sum( + col(count_id.clone()).alias(sum_of_count_id.clone()), + )); final_exprs.push(col(sum_of_count_id.clone()).alias(output_name)); } - Sum(e) => { + AggExpr::Sum(e) => { let sum_id = agg_expr.semantic_id(schema).id; - let sum_of_sum_id = Sum(col(sum_id.clone())).semantic_id(schema).id; + let sum_of_sum_id = AggExpr::Sum(col(sum_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(sum_id.clone()) - .or_insert(Sum(e.alias(sum_id.clone()).clone())); + .or_insert(AggExpr::Sum(e.alias(sum_id.clone()).clone())); second_stage_aggs .entry(sum_of_sum_id.clone()) - .or_insert(Sum(col(sum_id.clone()).alias(sum_of_sum_id.clone()))); + .or_insert(AggExpr::Sum( + col(sum_id.clone()).alias(sum_of_sum_id.clone()), + )); final_exprs.push(col(sum_of_sum_id.clone()).alias(output_name)); } - Mean(e) => { - let sum_id = Sum(e.clone()).semantic_id(schema).id; - let count_id = Count(e.clone(), CountMode::Valid).semantic_id(schema).id; - let sum_of_sum_id = Sum(col(sum_id.clone())).semantic_id(schema).id; - let sum_of_count_id = Sum(col(count_id.clone())).semantic_id(schema).id; + AggExpr::Mean(e) => { + let sum_id = AggExpr::Sum(e.clone()).semantic_id(schema).id; + let count_id = AggExpr::Count(e.clone(), CountMode::Valid) + .semantic_id(schema) + .id; + let sum_of_sum_id = AggExpr::Sum(col(sum_id.clone())).semantic_id(schema).id; + let sum_of_count_id = AggExpr::Sum(col(count_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(sum_id.clone()) - .or_insert(Sum(e.alias(sum_id.clone()).clone())); + .or_insert(AggExpr::Sum(e.alias(sum_id.clone()).clone())); first_stage_aggs .entry(count_id.clone()) - .or_insert(Count(e.alias(count_id.clone()).clone(), CountMode::Valid)); + .or_insert(AggExpr::Count( + e.alias(count_id.clone()).clone(), + CountMode::Valid, + )); second_stage_aggs .entry(sum_of_sum_id.clone()) - .or_insert(Sum(col(sum_id.clone()).alias(sum_of_sum_id.clone()))); + .or_insert(AggExpr::Sum( + col(sum_id.clone()).alias(sum_of_sum_id.clone()), + )); second_stage_aggs .entry(sum_of_count_id.clone()) - .or_insert(Sum(col(count_id.clone()).alias(sum_of_count_id.clone()))); + .or_insert(AggExpr::Sum( + col(count_id.clone()).alias(sum_of_count_id.clone()), + )); final_exprs.push( (col(sum_of_sum_id.clone()).div(col(sum_of_count_id.clone()))) .alias(output_name), ); } - Min(e) => { + AggExpr::Stddev(sub_expr) => { + // The stddev calculation we're performing here is: + // stddev(X) = sqrt(E(X^2) - E(X)^2) + // where X is the sub_expr. + // + // First stage, we compute `sum(X^2)`, `sum(X)` and `count(X)`. + // Second stage, we `global_sqsum := sum(sum(X^2))`, `global_sum := sum(sum(X))` and `global_count := sum(count(X))` in order to get the global versions of the first stage. + // In the final projection, we then compute `sqrt((global_sqsum / global_count) - (global_sum / global_count) ^ 2)`. + + // first stage aggregation + let sum_id = add_to_stage( + AggExpr::Sum, + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); + let sq_sum_id = add_to_stage( + |sub_expr| AggExpr::Sum(sub_expr.clone().mul(sub_expr)), + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); + let count_id = add_to_stage( + |sub_expr| AggExpr::Count(sub_expr, CountMode::Valid), + sub_expr.clone(), + schema, + &mut first_stage_aggs, + ); + + // second stage aggregation + let global_sum_id = add_to_stage( + AggExpr::Sum, + col(sum_id.clone()), + schema, + &mut second_stage_aggs, + ); + let global_sq_sum_id = add_to_stage( + AggExpr::Sum, + col(sq_sum_id.clone()), + schema, + &mut second_stage_aggs, + ); + let global_count_id = add_to_stage( + AggExpr::Sum, + col(count_id.clone()), + schema, + &mut second_stage_aggs, + ); + + // final projection + let g_sq_sum = col(global_sq_sum_id); + let g_sum = col(global_sum_id); + let g_count = col(global_count_id); + let left = g_sq_sum.div(g_count.clone()); + let right = g_sum.div(g_count); + let right = right.clone().mul(right); + let result = sqrt::sqrt(left.sub(right)).alias(output_name); + + final_exprs.push(result); + } + AggExpr::Min(e) => { let min_id = agg_expr.semantic_id(schema).id; - let min_of_min_id = Min(col(min_id.clone())).semantic_id(schema).id; + let min_of_min_id = AggExpr::Min(col(min_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(min_id.clone()) - .or_insert(Min(e.alias(min_id.clone()).clone())); + .or_insert(AggExpr::Min(e.alias(min_id.clone()).clone())); second_stage_aggs .entry(min_of_min_id.clone()) - .or_insert(Min(col(min_id.clone()).alias(min_of_min_id.clone()))); + .or_insert(AggExpr::Min( + col(min_id.clone()).alias(min_of_min_id.clone()), + )); final_exprs.push(col(min_of_min_id.clone()).alias(output_name)); } - Max(e) => { + AggExpr::Max(e) => { let max_id = agg_expr.semantic_id(schema).id; - let max_of_max_id = Max(col(max_id.clone())).semantic_id(schema).id; + let max_of_max_id = AggExpr::Max(col(max_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(max_id.clone()) - .or_insert(Max(e.alias(max_id.clone()).clone())); + .or_insert(AggExpr::Max(e.alias(max_id.clone()).clone())); second_stage_aggs .entry(max_of_max_id.clone()) - .or_insert(Max(col(max_id.clone()).alias(max_of_max_id.clone()))); + .or_insert(AggExpr::Max( + col(max_id.clone()).alias(max_of_max_id.clone()), + )); final_exprs.push(col(max_of_max_id.clone()).alias(output_name)); } - AnyValue(e, ignore_nulls) => { + AggExpr::AnyValue(e, ignore_nulls) => { let any_id = agg_expr.semantic_id(schema).id; - let any_of_any_id = AnyValue(col(any_id.clone()), *ignore_nulls) + let any_of_any_id = AggExpr::AnyValue(col(any_id.clone()), *ignore_nulls) .semantic_id(schema) .id; first_stage_aggs .entry(any_id.clone()) - .or_insert(AnyValue(e.alias(any_id.clone()).clone(), *ignore_nulls)); + .or_insert(AggExpr::AnyValue( + e.alias(any_id.clone()).clone(), + *ignore_nulls, + )); second_stage_aggs .entry(any_of_any_id.clone()) - .or_insert(AnyValue( + .or_insert(AggExpr::AnyValue( col(any_id.clone()).alias(any_of_any_id.clone()), *ignore_nulls, )); final_exprs.push(col(any_of_any_id.clone()).alias(output_name)); } - List(e) => { + AggExpr::List(e) => { let list_id = agg_expr.semantic_id(schema).id; - let concat_of_list_id = Concat(col(list_id.clone())).semantic_id(schema).id; + let concat_of_list_id = + AggExpr::Concat(col(list_id.clone())).semantic_id(schema).id; first_stage_aggs .entry(list_id.clone()) - .or_insert(List(e.alias(list_id.clone()).clone())); + .or_insert(AggExpr::List(e.alias(list_id.clone()).clone())); second_stage_aggs .entry(concat_of_list_id.clone()) - .or_insert(Concat( + .or_insert(AggExpr::Concat( col(list_id.clone()).alias(concat_of_list_id.clone()), )); final_exprs.push(col(concat_of_list_id.clone()).alias(output_name)); } - Concat(e) => { + AggExpr::Concat(e) => { let concat_id = agg_expr.semantic_id(schema).id; - let concat_of_concat_id = Concat(col(concat_id.clone())).semantic_id(schema).id; + let concat_of_concat_id = AggExpr::Concat(col(concat_id.clone())) + .semantic_id(schema) + .id; first_stage_aggs .entry(concat_id.clone()) - .or_insert(Concat(e.alias(concat_id.clone()).clone())); + .or_insert(AggExpr::Concat(e.alias(concat_id.clone()).clone())); second_stage_aggs .entry(concat_of_concat_id.clone()) - .or_insert(Concat( + .or_insert(AggExpr::Concat( col(concat_id.clone()).alias(concat_of_concat_id.clone()), )); final_exprs.push(col(concat_of_concat_id.clone()).alias(output_name)); } - MapGroups { func, inputs } => { + AggExpr::MapGroups { func, inputs } => { let func_id = agg_expr.semantic_id(schema).id; // No first stage aggregation for MapGroups, do all the work in the second stage. second_stage_aggs .entry(func_id.clone()) - .or_insert(MapGroups { + .or_insert(AggExpr::MapGroups { func: func.clone(), inputs: inputs.clone(), }); final_exprs.push(col(output_name)); } - &ApproxPercentile(ApproxPercentileParams { + &AggExpr::ApproxPercentile(ApproxPercentileParams { child: ref e, ref percentiles, force_list_output, }) => { let percentiles = percentiles.iter().map(|p| p.0).collect::>(); let sketch_id = agg_expr.semantic_id(schema).id; - let approx_id = ApproxSketch(col(sketch_id.clone()), SketchType::DDSketch) + let approx_id = AggExpr::ApproxSketch(col(sketch_id.clone()), SketchType::DDSketch) .semantic_id(schema) .id; first_stage_aggs .entry(sketch_id.clone()) - .or_insert(ApproxSketch( + .or_insert(AggExpr::ApproxSketch( e.alias(sketch_id.clone()), SketchType::DDSketch, )); second_stage_aggs .entry(approx_id.clone()) - .or_insert(MergeSketch( + .or_insert(AggExpr::MergeSketch( col(sketch_id.clone()).alias(approx_id.clone()), SketchType::DDSketch, )); @@ -924,30 +1023,30 @@ pub fn populate_aggregation_stages( .alias(output_name), ); } - ApproxCountDistinct(e) => { + AggExpr::ApproxCountDistinct(e) => { let first_stage_id = agg_expr.semantic_id(schema).id; let second_stage_id = - MergeSketch(col(first_stage_id.clone()), SketchType::HyperLogLog) + AggExpr::MergeSketch(col(first_stage_id.clone()), SketchType::HyperLogLog) .semantic_id(schema) .id; first_stage_aggs .entry(first_stage_id.clone()) - .or_insert(ApproxSketch( + .or_insert(AggExpr::ApproxSketch( e.alias(first_stage_id.clone()), SketchType::HyperLogLog, )); second_stage_aggs .entry(second_stage_id.clone()) - .or_insert(MergeSketch( + .or_insert(AggExpr::MergeSketch( col(first_stage_id).alias(second_stage_id.clone()), SketchType::HyperLogLog, )); final_exprs.push(col(second_stage_id).alias(output_name)); } - ApproxSketch(..) => { + AggExpr::ApproxSketch(..) => { unimplemented!("User-facing approx_sketch aggregation is not implemented") } - MergeSketch(..) => { + AggExpr::MergeSketch(..) => { unimplemented!("User-facing merge_sketch aggregation is not implemented") } } diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index d697d0f022..00ef1083ca 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -376,6 +376,18 @@ impl DataType { } } + #[inline] + pub fn assert_is_numeric(&self) -> DaftResult<()> { + if self.is_numeric() { + Ok(()) + } else { + Err(DaftError::TypeError(format!( + "Numeric mean is not implemented for type {}", + self, + ))) + } + } + #[inline] pub fn is_fixed_size_numeric(&self) -> bool { match self { diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index aaaac5eb0a..7e8ceb5fcb 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -109,6 +109,10 @@ pub fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult { ensure!(args.len() == 1, "mean takes exactly one argument"); Ok(args[0].clone().mean()) } + AggExpr::Stddev(_) => { + ensure!(args.len() == 1, "stddev takes exactly one argument"); + Ok(args[0].clone().stddev()) + } AggExpr::Min(_) => { ensure!(args.len() == 1, "min takes exactly one argument"); Ok(args[0].clone().min()) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index b1138e1e72..cf96344a53 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -479,6 +479,7 @@ impl Table { } } AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups), + AggExpr::Stddev(expr) => self.eval_expression(expr)?.stddev(groups), AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups), AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups), &AggExpr::AnyValue(ref expr, ignore_nulls) => { @@ -584,10 +585,10 @@ impl Table { assert!( !(expected_field.dtype != series.field().dtype), "Data type mismatch in expression evaluation:\n\ - Expected type: {}\n\ - Computed type: {}\n\ - Expression: {}\n\ - This likely indicates an internal error in type inference or computation.", + Expected type: {}\n\ + Computed type: {}\n\ + Expression: {}\n\ + This likely indicates an internal error in type inference or computation.", expected_field.dtype, series.field().dtype, expr diff --git a/tests/dataframe/test_stddev.py b/tests/dataframe/test_stddev.py new file mode 100644 index 0000000000..464d20bd41 --- /dev/null +++ b/tests/dataframe/test_stddev.py @@ -0,0 +1,144 @@ +import functools +import math +from typing import Any, List, Tuple + +import pandas as pd +import pytest + +import daft + + +def grouped_stddev(rows) -> Tuple[List[Any], List[Any]]: + map = {} + for key, data in rows: + if key not in map: + map[key] = [] + map[key].append(data) + + keys = [] + stddevs = [] + for key, nums in map.items(): + keys.append(key) + stddevs.append(stddev(nums)) + + return keys, stddevs + + +def stddev(nums) -> float: + nums = [num for num in nums if num is not None] + + if not nums: + return 0.0 + sum_: float = sum(nums) + count = len(nums) + mean = sum_ / count + + squared_sums = functools.reduce(lambda acc, num: acc + (num - mean) ** 2, nums, 0) + stddev = math.sqrt(squared_sums / count) + return stddev + + +TESTS = [ + [nums := [0], stddev(nums)], + [nums := [1], stddev(nums)], + [nums := [0, 1, 2], stddev(nums)], + [nums := [100, 100, 100], stddev(nums)], + [nums := [None, 100, None], stddev(nums)], + [nums := [None] * 10 + [100], stddev(nums)], +] + + +@pytest.mark.parametrize("data_and_expected", TESTS) +def test_stddev_with_single_partition(data_and_expected): + data, expected = data_and_expected + df = daft.from_pydict({"a": data}) + result = df.agg(daft.col("a").stddev()).collect() + rows = result.iter_rows() + stddev = next(rows) + try: + next(rows) + assert False + except StopIteration: + pass + + assert stddev["a"] == expected + + +@pytest.mark.parametrize("data_and_expected", TESTS) +def test_stddev_with_multiple_partitions(data_and_expected): + data, expected = data_and_expected + df = daft.from_pydict({"a": data}).into_partitions(2) + result = df.agg(daft.col("a").stddev()).collect() + rows = result.iter_rows() + stddev = next(rows) + try: + next(rows) + assert False + except StopIteration: + pass + + assert stddev["a"] == expected + + +GROUPED_TESTS = [ + [rows := [("k1", 0), ("k2", 1), ("k1", 1)], *grouped_stddev(rows)], + [rows := [("k0", 100), ("k1", 100), ("k2", 100)], *grouped_stddev(rows)], + [rows := [("k0", 100), ("k0", 100), ("k0", 100)], *grouped_stddev(rows)], + [rows := [("k0", 0), ("k0", 1), ("k0", 2)], *grouped_stddev(rows)], + [rows := [("k0", None), ("k0", None), ("k0", 100)], *grouped_stddev(rows)], +] + + +def unzip_rows(rows: list) -> Tuple[List, List]: + keys = [] + nums = [] + for key, data in rows: + keys.append(key) + nums.append(data) + return keys, nums + + +@pytest.mark.parametrize("data_and_expected", GROUPED_TESTS) +def test_grouped_stddev_with_single_partition(data_and_expected): + nums, expected_keys, expected_stddevs = data_and_expected + expected_df = daft.from_pydict({"keys": expected_keys, "data": expected_stddevs}) + keys, data = unzip_rows(nums) + df = daft.from_pydict({"keys": keys, "data": data}) + result_df = df.groupby("keys").agg(daft.col("data").stddev()).collect() + + result = result_df.to_pydict() + expected = expected_df.to_pydict() + + pd.testing.assert_series_equal( + pd.Series(result["keys"]).sort_values(), + pd.Series(expected["keys"]).sort_values(), + check_index=False, + ) + pd.testing.assert_series_equal( + pd.Series(result["data"]).sort_values(), + pd.Series(expected["data"]).sort_values(), + check_index=False, + ) + + +@pytest.mark.parametrize("data_and_expected", GROUPED_TESTS) +def test_grouped_stddev_with_multiple_partitions(data_and_expected): + nums, expected_keys, expected_stddevs = data_and_expected + expected_df = daft.from_pydict({"keys": expected_keys, "data": expected_stddevs}) + keys, data = unzip_rows(nums) + df = daft.from_pydict({"keys": keys, "data": data}).into_partitions(2) + result_df = df.groupby("keys").agg(daft.col("data").stddev()).collect() + + result = result_df.to_pydict() + expected = expected_df.to_pydict() + + pd.testing.assert_series_equal( + pd.Series(result["keys"]).sort_values(), + pd.Series(expected["keys"]).sort_values(), + check_index=False, + ) + pd.testing.assert_series_equal( + pd.Series(result["data"]).sort_values(), + pd.Series(expected["data"]).sort_values(), + check_index=False, + )