Skip to content

Commit

Permalink
Test metrics base classes
Browse files Browse the repository at this point in the history
  • Loading branch information
e10v committed Feb 6, 2024
1 parent abaaa65 commit c6f6587
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ select = [
ignore = ["ANN401", "PGH003", "SLF001", "TRY003"]

[tool.ruff.per-file-ignores]
"tests/*" = ["ANN201", "D100", "D103", "PLR2004", "PT001", "S101"]
"tests/*" = ["ANN201", "D100", "D103", "D104", "PLR2004", "PT001", "S101"]

[tool.ruff.isort]
force-sort-within-sections = true
Expand All @@ -92,7 +92,7 @@ testpaths = ["tests"]
[tool.coverage.run]
source = ["src/tea_tasting"]
[tool.coverage.report]
exclude_lines = ["if TYPE_CHECKING:", "pragma: no cover", "@overload"]
exclude_lines = ["if TYPE_CHECKING:", "pragma: no cover", "@overload", "@abc.abstractmethod"]


[tool.pyright]
Expand Down
Empty file added tests/metrics/__init__.py
Empty file.
147 changes: 147 additions & 0 deletions tests/metrics/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# pyright: reportPrivateUsage=false, reportUnknownMemberType=false

from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import tea_tasting.aggr
import tea_tasting.datasets
import tea_tasting.metrics.base


if TYPE_CHECKING:
from typing import Any, NamedTuple

import ibis.expr.types
import pandas as pd



def test_aggr_cols_or():
aggr_cols0 = tea_tasting.metrics.base.AggrCols(
has_count=False,
mean_cols=("a", "b"),
var_cols=("b", "c"),
cov_cols=(("a", "b"), ("c", "b")),
)

aggr_cols1 = tea_tasting.metrics.base.AggrCols(
has_count=True,
mean_cols=("b", "c"),
var_cols=("c", "d"),
cov_cols=(("b", "c"), ("d", "c")),
)

aggr_cols = aggr_cols0 | aggr_cols1

assert isinstance(aggr_cols, tea_tasting.metrics.base.AggrCols)
assert aggr_cols.has_count is True
assert set(aggr_cols.mean_cols) == {"a", "b", "c"}
assert len(aggr_cols.mean_cols) == 3
assert set(aggr_cols.var_cols) == {"b", "c", "d"}
assert len(aggr_cols.var_cols) == 3
assert set(aggr_cols.cov_cols) == {("a", "b"), ("b", "c"), ("c", "d")}
assert len(aggr_cols.cov_cols) == 3


@pytest.fixture
def data() -> ibis.expr.types.Table:
return tea_tasting.datasets.make_users_data(size=100, seed=42)


@pytest.fixture
def correct_aggrs(
data: ibis.expr.types.Table,
) -> dict[Any, tea_tasting.aggr.Aggregates]:
return tea_tasting.aggr.read_aggregates(
data,
group_col="variant",
has_count=True,
mean_cols=("visits", "orders"),
var_cols=("orders", "revenue"),
cov_cols=(("visits", "revenue"),),
)


@pytest.fixture
def aggr_metric() -> tea_tasting.metrics.base.MetricBaseAggr:
class AggrMetric(tea_tasting.metrics.base.MetricBaseAggr):
def __init__(self: tea_tasting.metrics.base.MetricBaseAggr) -> None:
return None

def analyze(
self: tea_tasting.metrics.base.MetricBaseAggr,
data: pd.DataFrame | ibis.expr.types.Table | dict[ # noqa: ARG002
Any, tea_tasting.aggr.Aggregates],
control: Any, # noqa: ARG002
treatment: Any, # noqa: ARG002
variant_col: str | None = None, # noqa: ARG002
) -> NamedTuple | dict[str, Any]:
return {}

@property
def aggr_cols(
self: tea_tasting.metrics.base.MetricBaseAggr,
) -> tea_tasting.metrics.base.AggrCols:
return tea_tasting.metrics.base.AggrCols(
has_count=True,
mean_cols=("visits", "orders"),
var_cols=("orders", "revenue"),
cov_cols=(("visits", "revenue"),),
)

return AggrMetric()


def test_metric_base_aggr_read_grouped_aggregates_table(
aggr_metric: tea_tasting.metrics.base.MetricBaseAggr,
data: ibis.expr.types.Table,
correct_aggrs: dict[Any, tea_tasting.aggr.Aggregates],
):
aggrs = aggr_metric.read_grouped_aggregates(data, variant_col="variant")
assert aggrs.keys() == correct_aggrs.keys()
for variant in aggrs:
aggr = aggrs[variant]
correct_aggr = correct_aggrs[variant]
assert aggr._count == correct_aggr._count
assert aggr._mean == correct_aggr._mean
assert aggr._var == correct_aggr._var
assert aggr._cov == correct_aggr._cov

def test_metric_base_aggr_read_grouped_aggregates_df(
aggr_metric: tea_tasting.metrics.base.MetricBaseAggr,
data: ibis.expr.types.Table,
correct_aggrs: dict[Any, tea_tasting.aggr.Aggregates],
):
aggrs = aggr_metric.read_grouped_aggregates(data.to_pandas(), variant_col="variant")
assert aggrs.keys() == correct_aggrs.keys()
for variant in aggrs:
aggr = aggrs[variant]
correct_aggr = correct_aggrs[variant]
assert aggr._count == correct_aggr._count
assert aggr._mean == correct_aggr._mean
assert aggr._var == correct_aggr._var
assert aggr._cov == correct_aggr._cov

def test_metric_base_aggr_read_grouped_aggregates_aggrs(
aggr_metric: tea_tasting.metrics.base.MetricBaseAggr,
correct_aggrs: dict[Any, tea_tasting.aggr.Aggregates],
):
aggrs = aggr_metric.read_grouped_aggregates(correct_aggrs)
assert aggrs.keys() == correct_aggrs.keys()
for variant in aggrs:
aggr = aggrs[variant]
correct_aggr = correct_aggrs[variant]
assert aggr._count == correct_aggr._count
assert aggr._mean == correct_aggr._mean
assert aggr._var == correct_aggr._var
assert aggr._cov == correct_aggr._cov

def test_metric_base_aggr_read_grouped_aggregates_raises(
aggr_metric: tea_tasting.metrics.base.MetricBaseAggr,
data: ibis.expr.types.Table,
):
with pytest.raises(ValueError, match="variant_col"):
aggr_metric.read_grouped_aggregates(data) # type: ignore

0 comments on commit c6f6587

Please sign in to comment.