Skip to content

Commit

Permalink
Extract posterior sampling code from Meridian model container into it…
Browse files Browse the repository at this point in the history
…s own module

PiperOrigin-RevId: 724961367
  • Loading branch information
santoso-wijaya authored and The Meridian Authors committed Feb 10, 2025
1 parent 4040463 commit 929d4fe
Show file tree
Hide file tree
Showing 9 changed files with 3,612 additions and 3,154 deletions.
10 changes: 6 additions & 4 deletions meridian/data/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,9 @@ def sample_coord_to_columns(
)


def sample_input_data_from_dataset(dataset: xr.Dataset, kpi_type: str):
def sample_input_data_from_dataset(
dataset: xr.Dataset, kpi_type: str
) -> input_data.InputData:
"""Generates a sample `InputData` from a full xarray Dataset."""
return input_data.InputData(
kpi=dataset.kpi,
Expand All @@ -1507,7 +1509,7 @@ def sample_input_data_revenue(
n_organic_media_channels: int | None = None,
n_organic_rf_channels: int | None = None,
seed: int = 0,
):
) -> input_data.InputData:
"""Generates sample InputData for `kpi_type='revenue'`."""
dataset = random_dataset(
n_geos=n_geos,
Expand Down Expand Up @@ -1555,7 +1557,7 @@ def sample_input_data_non_revenue_revenue_per_kpi(
n_organic_media_channels: int | None = None,
n_organic_rf_channels: int | None = None,
seed: int = 0,
):
) -> input_data.InputData:
"""Generates sample InputData for `non_revenue` KPI w/ revenue_per_kpi."""
dataset = random_dataset(
n_geos=n_geos,
Expand Down Expand Up @@ -1602,7 +1604,7 @@ def sample_input_data_non_revenue_no_revenue_per_kpi(
n_organic_media_channels: int | None = None,
n_organic_rf_channels: int | None = None,
seed: int = 0,
):
) -> input_data.InputData:
"""Generates sample InputData for `non_revenue` KPI w/o revenue_per_kpi."""
dataset = random_dataset(
n_geos=n_geos,
Expand Down
2 changes: 2 additions & 0 deletions meridian/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from meridian.model import knots
from meridian.model import media
from meridian.model import model
from meridian.model import posterior_sampler
from meridian.model import prior_distribution
from meridian.model import prior_sampler
from meridian.model import spec
from meridian.model import transformers
Loading

0 comments on commit 929d4fe

Please sign in to comment.