Skip to content

Commit

Permalink
[FEAT] Implement standard deviation (#3005)
Browse files Browse the repository at this point in the history
# Overview
- Add a standard deviation function
- similar in implementation to how `AggExpr::count` and `AggExpr::Mean`
work

## Notes
Implementations differ slightly for non- vs multi- partitioned based
dataframes:

1. The non-partitioned implementation uses the simple, naive approach,
derived from definition of stddev (i.e., `stddev(X) = sqrt(sum((x_i -
mean(X))^2) / N)`).
2. The multi-partitioned implementation calculates `stddev(X) =
sqrt(E(X^2) - E(X)^2)`.
  • Loading branch information
Raunak Bhagat authored Oct 8, 2024
1 parent f995792 commit 64b8699
Show file tree
Hide file tree
Showing 39 changed files with 1,190 additions and 759 deletions.
2 changes: 2 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,7 @@ class PyExpr:
def approx_count_distinct(self) -> PyExpr: ...
def approx_percentiles(self, percentiles: float | list[float]) -> PyExpr: ...
def mean(self) -> PyExpr: ...
def stddev(self) -> PyExpr: ...
def min(self) -> PyExpr: ...
def max(self) -> PyExpr: ...
def any_value(self, ignore_nulls: bool) -> PyExpr: ...
Expand Down Expand Up @@ -1336,6 +1337,7 @@ class PySeries:
def count(self, mode: CountMode) -> PySeries: ...
def sum(self) -> PySeries: ...
def mean(self) -> PySeries: ...
def stddev(self) -> PySeries: ...
def min(self) -> PySeries: ...
def max(self) -> PySeries: ...
def agg_list(self) -> PySeries: ...
Expand Down
55 changes: 55 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,6 +2118,33 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame":
"""
return self._apply_agg_fn(Expression.mean, cols)

@DataframePublicAPI
def stddev(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs a global standard deviation on the DataFrame
Example:
>>> import daft
>>> df = daft.from_pydict({"col_a":[0,1,2]})
>>> df = df.stddev("col_a")
>>> df.show()
╭───────────────────╮
│ col_a │
│ --- │
│ Float64 │
╞═══════════════════╡
│ 0.816496580927726 │
╰───────────────────╯
<BLANKLINE>
(Showing first 1 of 1 rows)
Args:
*cols (Union[str, Expression]): columns to stddev
Returns:
DataFrame: Globally aggregated standard deviation. Should be a single row.
"""
return self._apply_agg_fn(Expression.stddev, cols)

@DataframePublicAPI
def min(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs a global min on the DataFrame
Expand Down Expand Up @@ -2856,6 +2883,34 @@ def mean(self, *cols: ColumnInputType) -> "DataFrame":
"""
return self.df._apply_agg_fn(Expression.mean, cols, self.group_by)

def stddev(self, *cols: ColumnInputType) -> "DataFrame":
"""Performs grouped standard deviation on this GroupedDataFrame.
Example:
>>> import daft
>>> df = daft.from_pydict({"keys": ["a", "a", "a", "b"], "col_a": [0,1,2,100]})
>>> df = df.groupby("keys").stddev()
>>> df.show()
╭──────┬───────────────────╮
│ keys ┆ col_a │
│ --- ┆ --- │
│ Utf8 ┆ Float64 │
╞══════╪═══════════════════╡
│ a ┆ 0.816496580927726 │
├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ b ┆ 0 │
╰──────┴───────────────────╯
<BLANKLINE>
(Showing first 2 of 2 rows)
Args:
*cols (Union[str, Expression]): columns to stddev
Returns:
DataFrame: DataFrame with grouped standard deviation.
"""
return self.df._apply_agg_fn(Expression.stddev, cols, self.group_by)

def min(self, *cols: ColumnInputType) -> "DataFrame":
"""Perform grouped min on this GroupedDataFrame.
Expand Down
5 changes: 5 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,11 @@ def mean(self) -> Expression:
expr = self._expr.mean()
return Expression._from_pyexpr(expr)

def stddev(self) -> Expression:
"""Calculates the standard deviation of the values in the expression"""
expr = self._expr.stddev()
return Expression._from_pyexpr(expr)

def min(self) -> Expression:
"""Calculates the minimum value in the expression"""
expr = self._expr.min()
Expand Down
4 changes: 4 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,10 @@ def mean(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.mean())

def stddev(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.stddev())

def sum(self) -> Series:
assert self._series is not None
return Series._from_pyseries(self._series.sum())
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ Aggregations
DataFrame.groupby
DataFrame.sum
DataFrame.mean
DataFrame.stddev
DataFrame.count
DataFrame.min
DataFrame.max
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ The following can be used with DataFrame.agg or GroupedDataFrame.agg
Expression.count
Expression.sum
Expression.mean
Expression.stddev
Expression.min
Expression.max
Expression.any_value
Expand Down
49 changes: 16 additions & 33 deletions src/daft-core/src/array/ops/mean.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,27 @@
use std::sync::Arc;

use arrow2::array::PrimitiveArray;
use common_error::DaftResult;

use super::{as_arrow::AsArrow, DaftCountAggable, DaftMeanAggable, DaftSumAggable};
use crate::{array::ops::GroupIndices, count_mode::CountMode, datatypes::*};
impl DaftMeanAggable for &DataArray<Float64Type> {
type Output = DaftResult<DataArray<Float64Type>>;
use crate::{
array::ops::{DaftMeanAggable, GroupIndices},
datatypes::*,
utils::stats,
};

fn mean(&self) -> Self::Output {
let sum_value = DaftSumAggable::sum(self)?.as_arrow().value(0);
let count_value = DaftCountAggable::count(self, CountMode::Valid)?
.as_arrow()
.value(0);

let result = match count_value {
0 => None,
count_value => Some(sum_value / count_value as f64),
};
let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result]));
impl DaftMeanAggable for DataArray<Float64Type> {
type Output = DaftResult<Self>;

DataArray::new(
Arc::new(Field::new(self.field.name.clone(), DataType::Float64)),
arrow_array,
)
fn mean(&self) -> Self::Output {
let stats = stats::calculate_stats(self)?;
let data = PrimitiveArray::from([stats.mean]).boxed();
let field = Arc::new(Field::new(self.field.name.clone(), DataType::Float64));
Self::new(field, data)
}

fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output {
use arrow2::array::PrimitiveArray;
let sum_values = self.grouped_sum(groups)?;
let count_values = self.grouped_count(groups, CountMode::Valid)?;
assert_eq!(sum_values.len(), count_values.len());
let mean_per_group = sum_values
.as_arrow()
.values_iter()
.zip(count_values.as_arrow().values_iter())
.map(|(s, c)| match (s, c) {
(_, 0) => None,
(s, c) => Some(s / (*c as f64)),
});
let mean_array = Box::new(PrimitiveArray::from_trusted_len_iter(mean_per_group));
Ok(DataArray::from((self.field.name.as_ref(), mean_array)))
let grouped_means = stats::grouped_stats(self, groups)?.map(|(stats, _)| stats.mean);
let data = Box::new(PrimitiveArray::from_iter(grouped_means));
Ok(Self::from((self.field.name.as_ref(), data)))
}
}
7 changes: 7 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod sketch_percentile;
mod sort;
pub(crate) mod sparse_tensor;
mod sqrt;
mod stddev;
mod struct_;
mod sum;
mod take;
Expand Down Expand Up @@ -189,6 +190,12 @@ pub trait DaftMeanAggable {
fn grouped_mean(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftStddevAggable {
type Output;
fn stddev(&self) -> Self::Output;
fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftCompareAggable {
type Output;
fn min(&self) -> Self::Output;
Expand Down
34 changes: 34 additions & 0 deletions src/daft-core/src/array/ops/stddev.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use arrow2::array::PrimitiveArray;
use common_error::DaftResult;

use crate::{
array::{
ops::{DaftStddevAggable, GroupIndices},
DataArray,
},
datatypes::Float64Type,
utils::stats,
};

impl DaftStddevAggable for DataArray<Float64Type> {
type Output = DaftResult<Self>;

fn stddev(&self) -> Self::Output {
let stats = stats::calculate_stats(self)?;
let values = self.into_iter().flatten().copied();
let stddev = stats::calculate_stddev(stats, values);
let field = self.field.clone();
let data = PrimitiveArray::<f64>::from([stddev]).boxed();
Self::new(field, data)
}

fn grouped_stddev(&self, groups: &GroupIndices) -> Self::Output {
let grouped_stddevs_iter = stats::grouped_stats(self, groups)?.map(|(stats, group)| {
let values = group.iter().filter_map(|&index| self.get(index as _));
stats::calculate_stddev(stats, values)
});
let field = self.field.clone();
let data = PrimitiveArray::<f64>::from_iter(grouped_stddevs_iter).boxed();
Self::new(field, data)
}
}
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/agg_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {
}

/// Get the data type that the mean of a column of the given data type should be casted to.
pub fn try_mean_supertype(dtype: &DataType) -> DaftResult<DataType> {
pub fn try_mean_stddev_aggregation_supertype(dtype: &DataType) -> DaftResult<DataType> {
if dtype.is_numeric() {
Ok(DataType::Float64)
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use infer_datatype::InferDataType;
pub mod prelude;
use std::ops::{Add, Div, Mul, Rem, Sub};

pub use agg_ops::{try_mean_supertype, try_sum_supertype};
pub use agg_ops::{try_mean_stddev_aggregation_supertype, try_sum_supertype};
use arrow2::{
compute::comparison::Simd8,
types::{simd::Simd, NativeType},
Expand Down
60 changes: 30 additions & 30 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use logical::Decimal128Array;

use crate::{
array::{
ops::{DaftHllMergeAggable, GroupIndices},
ops::{
DaftApproxSketchAggable, DaftHllMergeAggable, DaftMeanAggable, DaftStddevAggable,
DaftSumAggable, GroupIndices,
},
ListArray,
},
count_mode::CountMode,
Expand All @@ -26,12 +29,10 @@ impl Series {
}

pub fn sum(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftSumAggable, datatypes::DataType::*};

match self.data_type() {
// intX -> int64 (in line with numpy)
Int8 | Int16 | Int32 | Int64 => {
let casted = self.cast(&Int64)?;
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
let casted = self.cast(&DataType::Int64)?;
match groups {
Some(groups) => {
Ok(DaftSumAggable::grouped_sum(&casted.i64()?, groups)?.into_series())
Expand All @@ -40,8 +41,8 @@ impl Series {
}
}
// uintX -> uint64 (in line with numpy)
UInt8 | UInt16 | UInt32 | UInt64 => {
let casted = self.cast(&UInt64)?;
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
let casted = self.cast(&DataType::UInt64)?;
match groups {
Some(groups) => {
Ok(DaftSumAggable::grouped_sum(&casted.u64()?, groups)?.into_series())
Expand All @@ -50,23 +51,23 @@ impl Series {
}
}
// floatX -> floatX (in line with numpy)
Float32 => match groups {
DataType::Float32 => match groups {
Some(groups) => Ok(DaftSumAggable::grouped_sum(
&self.downcast::<Float32Array>()?,
groups,
)?
.into_series()),
None => Ok(DaftSumAggable::sum(&self.downcast::<Float32Array>()?)?.into_series()),
},
Float64 => match groups {
DataType::Float64 => match groups {
Some(groups) => Ok(DaftSumAggable::grouped_sum(
&self.downcast::<Float64Array>()?,
groups,
)?
.into_series()),
None => Ok(DaftSumAggable::sum(&self.downcast::<Float64Array>()?)?.into_series()),
},
Decimal128(_, _) => match groups {
DataType::Decimal128(_, _) => match groups {
Some(groups) => Ok(Decimal128Array::new(
Field {
dtype: try_sum_supertype(self.data_type())?,
Expand Down Expand Up @@ -95,12 +96,10 @@ impl Series {
}

pub fn approx_sketch(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftApproxSketchAggable, datatypes::DataType::*};

// Upcast all numeric types to float64 and compute approx_sketch.
match self.data_type() {
dt if dt.is_numeric() => {
let casted = self.cast(&Float64)?;
let casted = self.cast(&DataType::Float64)?;
match groups {
Some(groups) => Ok(DaftApproxSketchAggable::grouped_approx_sketch(
&casted.f64()?,
Expand Down Expand Up @@ -149,24 +148,25 @@ impl Series {
}

pub fn mean(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
use crate::{array::ops::DaftMeanAggable, datatypes::DataType::*};

// Upcast all numeric types to float64 and use f64 mean kernel.
match self.data_type() {
dt if dt.is_numeric() => {
let casted = self.cast(&Float64)?;
match groups {
Some(groups) => {
Ok(DaftMeanAggable::grouped_mean(&casted.f64()?, groups)?.into_series())
}
None => Ok(DaftMeanAggable::mean(&casted.f64()?)?.into_series()),
}
}
other => Err(DaftError::TypeError(format!(
"Numeric mean is not implemented for type {}",
other
))),
}
self.data_type().assert_is_numeric()?;
let casted = self.cast(&DataType::Float64)?;
let casted = casted.f64()?;
let series = groups
.map_or_else(|| casted.mean(), |groups| casted.grouped_mean(groups))?
.into_series();
Ok(series)
}

pub fn stddev(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
// Upcast all numeric types to float64 and use f64 stddev kernel.
self.data_type().assert_is_numeric()?;
let casted = self.cast(&DataType::Float64)?;
let casted = casted.f64()?;
let series = groups
.map_or_else(|| casted.stddev(), |groups| casted.grouped_stddev(groups))?
.into_series();
Ok(series)
}

pub fn min(&self, groups: Option<&GroupIndices>) -> DaftResult<Self> {
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod arrow;
pub mod display;
pub mod dyn_compare;
pub mod identity_hash_set;
pub mod stats;
pub mod supertype;
Loading

0 comments on commit 64b8699

Please sign in to comment.