Skip to content

Commit

Permalink
Merge pull request #17 from e10v/dev
Browse files Browse the repository at this point in the history
Numerical stable algorithms for var and cov
  • Loading branch information
e10v authored Jan 29, 2024
2 parents 45219c1 + 162940c commit 4c4b47d
Showing 1 changed file with 45 additions and 54 deletions.
99 changes: 45 additions & 54 deletions src/tea_tasting/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

_COUNT = "_count"
_MEAN = "_mean__{}"
_MEAN_OF_SQ = "_mean_of_sq__{}"
_MEAN_OF_MUL = "_mean_of_mul__{}__{}"
_VAR = "_var__{}"
_COV = "_cov__{}__{}"
_DEMEAN = "_demean__{}"


class Aggregates:
Expand Down Expand Up @@ -223,28 +224,27 @@ def _add_mean(left: Aggregates, right: Aggregates, col: str) -> float | int:
return sum_ / count

def _add_var(left: Aggregates, right: Aggregates, col: str) -> float | int:
count = left.count() + right.count()
left_mean_of_sq = left.var(col)*(1 - 1/left.count()) + left.mean(col)**2
right_mean_of_sq = right.var(col)*(1 - 1/right.count()) + right.mean(col)**2
mean_of_sq = (left.count()*left_mean_of_sq + right.count()*right_mean_of_sq) / count
mean = _add_mean(left, right, col)
return (mean_of_sq - mean**2) * count / (count - 1)
left_n = left.count()
right_n = right.count()
total_n = left_n + right_n
diff_of_means = left.mean(col) - right.mean(col)
return (
left.var(col) * (left_n - 1)
+ right.var(col) * (right_n - 1)
+ diff_of_means * diff_of_means * left_n * right_n / total_n
) / (total_n - 1)

def _add_cov(left: Aggregates, right: Aggregates, cols: tuple[str, str]) -> float | int:
count = left.count() + right.count()
left_mean_of_mul = (
left.cov(*cols)*(1 - 1/left.count()) +
left.mean(cols[0])*left.mean(cols[1])
)
right_mean_of_mul = (
right.cov(*cols)*(1 - 1/right.count()) +
right.mean(cols[0])*right.mean(cols[1])
)
sum_of_mul = left.count()*left_mean_of_mul + right.count()*right_mean_of_mul
mean_of_mul = sum_of_mul / count
mean0 = _add_mean(left, right, cols[0])
mean1 = _add_mean(left, right, cols[1])
return (mean_of_mul - mean0*mean1) * count / (count - 1)
left_n = left.count()
right_n = right.count()
total_n = left_n + right_n
diff_of_means0 = left.mean(cols[0]) - right.mean(cols[0])
diff_of_means1 = left.mean(cols[1]) - right.mean(cols[1])
return (
left.cov(*cols) * (left_n - 1)
+ right.cov(*cols) * (right_n - 1)
+ diff_of_means0 * diff_of_means1 * left_n * right_n / total_n
) / (total_n - 1)


@overload
Expand Down Expand Up @@ -293,23 +293,36 @@ def read_aggregates(
has_count, mean_cols, var_cols, cov_cols = _validate_aggr_cols(
has_count, mean_cols, var_cols, cov_cols)

demean_cols = tuple({*var_cols, *itertools.chain(*cov_cols)})
if len(demean_cols) > 0:
demean_expr = {
_DEMEAN.format(col): data[col] - data[col].mean() # type: ignore
for col in demean_cols
}
grouped_data = data.group_by(group_col) if group_col is not None else data
data = grouped_data.mutate(**demean_expr)

count_expr = {_COUNT: data.count()} if has_count else {}
mean_expr = {_MEAN.format(col): data[col].mean() for col in mean_cols} # type: ignore
mean_of_sq_expr = {
_MEAN_OF_SQ.format(col): (data[col] * data[col]).mean() # type: ignore
var_expr = {
_VAR.format(col): (
data[_DEMEAN.format(col)] * data[_DEMEAN.format(col)]
).sum().cast("float") / (data.count() - 1) # type: ignore
for col in var_cols
}
mean_of_mul_expr = {
_MEAN_OF_MUL.format(left, right): (data[left] * data[right]).mean() # type: ignore
cov_expr = {
_COV.format(left, right): (
data[_DEMEAN.format(left)] * data[_DEMEAN.format(right)]
).sum().cast("float") / (data.count() - 1) # type: ignore
for left, right in cov_cols
}

grouped_data = data.group_by(group_col) if group_col is not None else data
aggr_data = grouped_data.aggregate(
**count_expr,
**mean_expr,
**mean_of_sq_expr,
**mean_of_mul_expr,
**var_expr,
**cov_expr,
).to_pandas()

if group_col is None:
Expand Down Expand Up @@ -341,33 +354,11 @@ def _get_aggregates(
cov_cols: Sequence[tuple[str, str]],
) -> Aggregates:
s = data.iloc[0]
mean = {col: s[_MEAN.format(col)] for col in mean_cols}

if has_count:
count = s[_COUNT]
bessel_factor = count / (count - 1)
var = {
col: (s[_MEAN_OF_SQ.format(col)] - s[_MEAN.format(col)]**2) * bessel_factor
for col in var_cols
}

cov = {
(left, right): (
s[_MEAN_OF_MUL.format(left, right)] -
s[_MEAN.format(left)]*s[_MEAN.format(right)]
) * bessel_factor
for left, right in cov_cols
}
else:
count = None
var = {}
cov = {}

return Aggregates(
count=count,
mean=mean,
var=var,
cov=cov,
count=s[_COUNT] if has_count else None,
mean={col: s[_MEAN.format(col)] for col in mean_cols},
var={col: s[_VAR.format(col)] for col in var_cols},
cov={cols: s[_COV.format(*cols)] for cols in cov_cols},
)


Expand Down

0 comments on commit 4c4b47d

Please sign in to comment.