Skip to content

Commit

Permalink
Merge pull request #18 from e10v/dev
Browse files Browse the repository at this point in the history
Change cols validation logic
  • Loading branch information
e10v authored Jan 29, 2024
2 parents 4c4b47d + 5c45082 commit 9c9cb92
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 48 deletions.
69 changes: 33 additions & 36 deletions src/tea_tasting/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ def __repr__(self: Aggregates) -> str:
f"var={self._var!r}, cov={self._cov!r})"
)

def filter(
self: Aggregates,
has_count: bool,
mean_cols: Sequence[str],
var_cols: Sequence[str],
cov_cols: Sequence[tuple[str, str]],
) -> Aggregates:
"""Filter aggregated statistics.
Args:
has_count: If True, keep sample size in the resulting object.
mean_cols: Sample means variable names.
var_cols: Sample variances variable names.
cov_cols: Sample covariances variable names.
Returns:
Filtered aggregated statistics.
"""
mean_cols, var_cols, cov_cols = _validate_aggr_cols(
mean_cols, var_cols, cov_cols)

return Aggregates(
count=self.count() if has_count else None,
mean={col: self.mean(col) for col in mean_cols},
var={col: self.var(col) for col in var_cols},
cov={cols: self.cov(*cols) for cols in cov_cols},
)

def count(self: Aggregates) -> int:
"""Sample size.
Expand Down Expand Up @@ -171,34 +199,6 @@ def ratio_cov(
* left_ratio_of_means * right_ratio_of_means
) / self.mean(left_denom) / self.mean(right_denom)

def filter(
self: Aggregates,
has_count: bool,
mean_cols: Sequence[str],
var_cols: Sequence[str],
cov_cols: Sequence[tuple[str, str]],
) -> Aggregates:
"""Filter aggregated statistics.
Args:
has_count: If True, keep sample size in the resulting object.
mean_cols: Sample means variable names.
var_cols: Sample variances variable names.
cov_cols: Sample covariances variable names.
Returns:
Filtered aggregated statistics.
"""
has_count, mean_cols, var_cols, cov_cols = _validate_aggr_cols(
has_count, mean_cols, var_cols, cov_cols)

return Aggregates(
count=self.count() if has_count else None,
mean={col: self.mean(col) for col in mean_cols},
var={col: self.var(col) for col in var_cols},
cov={cols: self.cov(*cols) for cols in cov_cols},
)

def __add__(self: Aggregates, other: Aggregates) -> Aggregates:
"""Calculate aggregated statistics of the concatenation of two samples.
Expand Down Expand Up @@ -290,8 +290,7 @@ def read_aggregates(
Returns:
Aggregated statistics from the Ibis Table.
"""
has_count, mean_cols, var_cols, cov_cols = _validate_aggr_cols(
has_count, mean_cols, var_cols, cov_cols)
mean_cols, var_cols, cov_cols = _validate_aggr_cols(mean_cols, var_cols, cov_cols)

demean_cols = tuple({*var_cols, *itertools.chain(*cov_cols)})
if len(demean_cols) > 0:
Expand Down Expand Up @@ -363,16 +362,14 @@ def _get_aggregates(


def _validate_aggr_cols(
has_count: bool,
mean_cols: Sequence[str],
var_cols: Sequence[str],
cov_cols: Sequence[tuple[str, str]],
) -> tuple[bool, tuple[str, ...], tuple[str, ...], tuple[tuple[str, str], ...]]:
has_count = has_count or len(var_cols) > 0 or len(cov_cols) > 0
mean_cols = tuple({*mean_cols, *var_cols, *itertools.chain(*cov_cols)})
var_cols = tuple(set(var_cols))
) -> tuple[tuple[str, ...], tuple[str, ...], tuple[tuple[str, str], ...]]:
mean_cols = tuple({*mean_cols})
var_cols = tuple({*var_cols})
cov_cols = tuple({
tea_tasting._utils.sorted_tuple(left, right)
for left, right in cov_cols
})
return has_count, mean_cols, var_cols, cov_cols
return mean_cols, var_cols, cov_cols
24 changes: 12 additions & 12 deletions tests/test_aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def test_aggregates_repr(aggr: tea_tasting.aggr.Aggregates):
assert aggr_repr._var == aggr._var
assert aggr_repr._cov == aggr._cov

def test_aggregates_filter(aggr: tea_tasting.aggr.Aggregates):
filtered_aggr = aggr.filter(
has_count=False,
mean_cols=("x", "x"),
var_cols=("x",),
cov_cols=(),
)
assert filtered_aggr._count is None
assert filtered_aggr._mean == {"x": MEAN["x"]}
assert filtered_aggr._var == {"x": VAR["x"]}
assert filtered_aggr._cov == {}

def test_aggregates_calls(aggr: tea_tasting.aggr.Aggregates):
assert aggr.count() == COUNT
assert aggr.mean("x") == MEAN["x"]
Expand Down Expand Up @@ -79,18 +91,6 @@ def test_aggregates_ratio_cov():
)
assert aggr.ratio_cov("a", "b", "c", "d") == pytest.approx(-0.0146938775510204)

def test_aggregates_filter(aggr: tea_tasting.aggr.Aggregates):
filtered_aggr = aggr.filter(
has_count=False,
mean_cols=("x",),
var_cols=("x",),
cov_cols=(),
)
assert filtered_aggr._count == COUNT
assert filtered_aggr._mean == {"x": MEAN["x"]}
assert filtered_aggr._var == {"x": VAR["x"]}
assert filtered_aggr._cov == {}

def test_aggregates_add(data: Table):
d = data.to_pandas()
aggr = tea_tasting.aggr.Aggregates(
Expand Down

0 comments on commit 9c9cb92

Please sign in to comment.