diff --git a/meridian/data/test_utils.py b/meridian/data/test_utils.py index 3d3073f9..62f69511 100644 --- a/meridian/data/test_utils.py +++ b/meridian/data/test_utils.py @@ -1480,7 +1480,9 @@ def sample_coord_to_columns( ) -def sample_input_data_from_dataset(dataset: xr.Dataset, kpi_type: str): +def sample_input_data_from_dataset( + dataset: xr.Dataset, kpi_type: str +) -> input_data.InputData: """Generates a sample `InputData` from a full xarray Dataset.""" return input_data.InputData( kpi=dataset.kpi, @@ -1507,7 +1509,7 @@ def sample_input_data_revenue( n_organic_media_channels: int | None = None, n_organic_rf_channels: int | None = None, seed: int = 0, -): +) -> input_data.InputData: """Generates sample InputData for `kpi_type='revenue'`.""" dataset = random_dataset( n_geos=n_geos, @@ -1555,7 +1557,7 @@ def sample_input_data_non_revenue_revenue_per_kpi( n_organic_media_channels: int | None = None, n_organic_rf_channels: int | None = None, seed: int = 0, -): +) -> input_data.InputData: """Generates sample InputData for `non_revenue` KPI w/ revenue_per_kpi.""" dataset = random_dataset( n_geos=n_geos, @@ -1602,7 +1604,7 @@ def sample_input_data_non_revenue_no_revenue_per_kpi( n_organic_media_channels: int | None = None, n_organic_rf_channels: int | None = None, seed: int = 0, -): +) -> input_data.InputData: """Generates sample InputData for `non_revenue` KPI w/o revenue_per_kpi.""" dataset = random_dataset( n_geos=n_geos, diff --git a/meridian/model/__init__.py b/meridian/model/__init__.py index 9fe8261d..eb0550f9 100644 --- a/meridian/model/__init__.py +++ b/meridian/model/__init__.py @@ -18,6 +18,8 @@ from meridian.model import knots from meridian.model import media from meridian.model import model +from meridian.model import posterior_sampler from meridian.model import prior_distribution +from meridian.model import prior_sampler from meridian.model import spec from meridian.model import transformers diff --git a/meridian/model/model.py b/meridian/model/model.py index b5281eda..ca8a4c64 100644 --- a/meridian/model/model.py +++ b/meridian/model/model.py @@ -27,12 +27,13 @@ from meridian.model import adstock_hill from meridian.model import knots from meridian.model import media +from meridian.model import posterior_sampler from meridian.model import prior_distribution +from meridian.model import prior_sampler from meridian.model import spec from meridian.model import transformers import numpy as np import tensorflow as tf -import tensorflow_probability as tfp __all__ = [ @@ -49,12 +50,8 @@ class NotFittedModelError(Exception): """Model has not been fitted.""" -class MCMCSamplingError(Exception): - """The Markov Chain Monte Carlo (MCMC) sampling failed.""" - - -class MCMCOOMError(Exception): - """The Markov Chain Monte Carlo (MCMC) exceeds memory limits.""" +MCMCSamplingError = posterior_sampler.MCMCSamplingError +MCMCOOMError = posterior_sampler.MCMCOOMError def _warn_setting_national_args(**kwargs): @@ -70,43 +67,6 @@ def _warn_setting_national_args(**kwargs): ) -def _get_tau_g( - tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int -) -> tfp.distributions.Distribution: - """Computes `tau_g` from `tau_g_excl_baseline`. - - This function computes `tau_g` by inserting a column of zeros at the - `baseline_geo` position in `tau_g_excl_baseline`. - - Args: - tau_g_excl_baseline: A tensor of shape `[..., n_geos - 1]` for the - user-defined dimensions of the `tau_g` parameter distribution. - baseline_geo_idx: The index of the baseline geo to be set to zero. - - Returns: - A tensor of shape `[..., n_geos]` with the final distribution of the `tau_g` - parameter with zero at position `baseline_geo_idx` and matching - `tau_g_excl_baseline` elsewhere. - """ - rank = len(tau_g_excl_baseline.shape) - shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1 - tau_g = tf.concat( - [ - tau_g_excl_baseline[..., :baseline_geo_idx], - tf.zeros(shape, dtype=tau_g_excl_baseline.dtype), - tau_g_excl_baseline[..., baseline_geo_idx:], - ], - axis=rank - 1, - ) - return tfp.distributions.Deterministic(tau_g, name="tau_g") - - -@tf.function(autograph=False, jit_compile=True) -def _xla_windowed_adaptive_nuts(**kwargs): - """XLA wrapper for windowed_adaptive_nuts.""" - return tfp.experimental.mcmc.windowed_adaptive_nuts(**kwargs) - - class Meridian: """Contains the main functionality for fitting the Meridian MMM model. @@ -452,6 +412,14 @@ def prior_broadcast(self) -> prior_distribution.PriorDistribution: total_spend=agg_total_spend, ) + @functools.cached_property + def prior_sampler(self) -> prior_sampler.PriorSampler: + return prior_sampler.PriorSampler(self) + + @functools.cached_property + def posterior_sampler(self) -> posterior_sampler.PosteriorSampler: + return posterior_sampler.PosteriorSampler(self) + def expand_selected_time_dims( self, start_date: tc.Date | None = None, @@ -720,7 +688,7 @@ def _validate_paid_media_prior_type(self): raise ValueError( f"Custom priors should be set on `{constants.MROI_M}` and" f" `{constants.MROI_RF}` when KPI is non-revenue and revenue per kpi" - f" data is missing." + " data is missing." ) def _validate_geo_invariants(self): @@ -955,143 +923,6 @@ def adstock_hill_rf( return rf_out - def _get_roi_prior_beta_m_value( - self, - alpha_m: tf.Tensor, - beta_gm_dev: tf.Tensor, - ec_m: tf.Tensor, - eta_m: tf.Tensor, - roi_or_mroi_m: tf.Tensor, - slope_m: tf.Tensor, - media_transformed: tf.Tensor, - ) -> tf.Tensor: - """Returns a tensor to be used in `beta_m`.""" - # The `roi_or_mroi_m` parameter represents either ROI or mROI. For reach & - # frequency channels, marginal ROI priors are defined as "mROI by reach", - # which is equivalent to ROI. - media_spend = self.media_tensors.media_spend - media_spend_counterfactual = self.media_tensors.media_spend_counterfactual - media_counterfactual_scaled = self.media_tensors.media_counterfactual_scaled - # If we got here, then we should already have media tensors derived from - # non-None InputData.media data. - assert media_spend is not None - assert media_spend_counterfactual is not None - assert media_counterfactual_scaled is not None - - # Use absolute value here because this difference will be negative for - # marginal ROI priors. - inc_revenue_m = roi_or_mroi_m * tf.reduce_sum( - tf.abs(media_spend - media_spend_counterfactual), - range(media_spend.ndim - 1), - ) - - if ( - self.model_spec.roi_calibration_period is None - and self.model_spec.paid_media_prior_type - == constants.PAID_MEDIA_PRIOR_TYPE_ROI - ): - # We can skip the adstock/hill computation step in this case. - media_counterfactual_transformed = tf.zeros_like(media_transformed) - else: - media_counterfactual_transformed = self.adstock_hill_media( - media=media_counterfactual_scaled, - alpha=alpha_m, - ec=ec_m, - slope=slope_m, - ) - - revenue_per_kpi = self.revenue_per_kpi - if self.input_data.revenue_per_kpi is None: - revenue_per_kpi = tf.ones([self.n_geos, self.n_times], dtype=tf.float32) - # Note: use absolute value here because this difference will be negative for - # marginal ROI priors. - media_contrib_gm = tf.einsum( - "...gtm,g,,gt->...gm", - tf.abs(media_transformed - media_counterfactual_transformed), - self.population, - self.kpi_transformer.population_scaled_stdev, - revenue_per_kpi, - ) - - if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL: - media_contrib_m = tf.einsum("...gm->...m", media_contrib_gm) - random_effect_m = tf.einsum( - "...m,...gm,...gm->...m", eta_m, beta_gm_dev, media_contrib_gm - ) - return (inc_revenue_m - random_effect_m) / media_contrib_m - else: - # For log_normal, beta_m and eta_m are not mean & std. - # The parameterization is beta_gm ~ exp(beta_m + eta_m * N(0, 1)). - random_effect_m = tf.einsum( - "...gm,...gm->...m", - tf.math.exp(beta_gm_dev * eta_m[..., tf.newaxis, :]), - media_contrib_gm, - ) - return tf.math.log(inc_revenue_m) - tf.math.log(random_effect_m) - - def _get_roi_prior_beta_rf_value( - self, - alpha_rf: tf.Tensor, - beta_grf_dev: tf.Tensor, - ec_rf: tf.Tensor, - eta_rf: tf.Tensor, - roi_or_mroi_rf: tf.Tensor, - slope_rf: tf.Tensor, - rf_transformed: tf.Tensor, - ) -> tf.Tensor: - """Returns a tensor to be used in `beta_rf`.""" - rf_spend = self.rf_tensors.rf_spend - rf_spend_counterfactual = self.rf_tensors.rf_spend_counterfactual - reach_counterfactual_scaled = self.rf_tensors.reach_counterfactual_scaled - frequency = self.rf_tensors.frequency - # If we got here, then we should already have RF media tensors derived from - # non-None InputData.reach data. - assert rf_spend is not None - assert rf_spend_counterfactual is not None - assert reach_counterfactual_scaled is not None - assert frequency is not None - - inc_revenue_rf = roi_or_mroi_rf * tf.reduce_sum( - rf_spend - rf_spend_counterfactual, - range(rf_spend.ndim - 1), - ) - if self.model_spec.rf_roi_calibration_period is not None: - rf_counterfactual_transformed = self.adstock_hill_rf( - reach=reach_counterfactual_scaled, - frequency=frequency, - alpha=alpha_rf, - ec=ec_rf, - slope=slope_rf, - ) - else: - rf_counterfactual_transformed = tf.zeros_like(rf_transformed) - revenue_per_kpi = self.revenue_per_kpi - if self.input_data.revenue_per_kpi is None: - revenue_per_kpi = tf.ones([self.n_geos, self.n_times], dtype=tf.float32) - - media_contrib_grf = tf.einsum( - "...gtm,g,,gt->...gm", - rf_transformed - rf_counterfactual_transformed, - self.population, - self.kpi_transformer.population_scaled_stdev, - revenue_per_kpi, - ) - if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL: - media_contrib_rf = tf.einsum("...gm->...m", media_contrib_grf) - random_effect_rf = tf.einsum( - "...m,...gm,...gm->...m", eta_rf, beta_grf_dev, media_contrib_grf - ) - return (inc_revenue_rf - random_effect_rf) / media_contrib_rf - else: - # For log_normal, beta_rf and eta_rf are not mean & std. - # The parameterization is beta_grf ~ exp(beta_rf + eta_rf * N(0, 1)). - random_effect_rf = tf.einsum( - "...gm,...gm->...m", - tf.math.exp(beta_grf_dev * eta_rf[..., tf.newaxis, :]), - media_contrib_grf, - ) - return tf.math.log(inc_revenue_rf) - tf.math.log(random_effect_rf) - def populate_cached_properties(self): """Eagerly activates all cached properties. @@ -1111,301 +942,7 @@ def populate_cached_properties(self): for attr in cached_properties: _ = getattr(self, attr) - def _get_joint_dist_unpinned(self) -> tfp.distributions.Distribution: - """Returns JointDistributionCoroutineAutoBatched function for MCMC.""" - - self.populate_cached_properties() - - # This lists all the derived properties and states of this Meridian object - # that are referenced by the joint distribution coroutine. - # That is, these are the list of captured parameters. - prior_broadcast = self.prior_broadcast - baseline_geo_idx = self.baseline_geo_idx - knot_info = self.knot_info - n_geos = self.n_geos - n_times = self.n_times - n_media_channels = self.n_media_channels - n_rf_channels = self.n_rf_channels - n_organic_media_channels = self.n_organic_media_channels - n_organic_rf_channels = self.n_organic_rf_channels - n_controls = self.n_controls - n_non_media_channels = self.n_non_media_channels - holdout_id = self.holdout_id - media_tensors = self.media_tensors - rf_tensors = self.rf_tensors - organic_media_tensors = self.organic_media_tensors - organic_rf_tensors = self.organic_rf_tensors - controls_scaled = self.controls_scaled - non_media_treatments_scaled = self.non_media_treatments_scaled - media_effects_dist = self.media_effects_dist - adstock_hill_media_fn = self.adstock_hill_media - adstock_hill_rf_fn = self.adstock_hill_rf - get_roi_prior_beta_m_value_fn = self._get_roi_prior_beta_m_value - get_roi_prior_beta_rf_value_fn = self._get_roi_prior_beta_rf_value - - # TODO: Extract this coroutine to be unittestable on its own. - # This MCMC sampling technique is complex enough to have its own abstraction - # and testable API, rather than being embedded as a private method in the - # Meridian class. - @tfp.distributions.JointDistributionCoroutineAutoBatched - def joint_dist_unpinned(): - # Sample directly from prior. - knot_values = yield prior_broadcast.knot_values - gamma_c = yield prior_broadcast.gamma_c - xi_c = yield prior_broadcast.xi_c - sigma = yield prior_broadcast.sigma - - tau_g_excl_baseline = yield tfp.distributions.Sample( - prior_broadcast.tau_g_excl_baseline, - name=constants.TAU_G_EXCL_BASELINE, - ) - tau_g = yield _get_tau_g( - tau_g_excl_baseline=tau_g_excl_baseline, - baseline_geo_idx=baseline_geo_idx, - ) - mu_t = yield tfp.distributions.Deterministic( - tf.einsum( - "k,kt->t", - knot_values, - tf.convert_to_tensor(knot_info.weights), - ), - name=constants.MU_T, - ) - - tau_gt = tau_g[:, tf.newaxis] + mu_t - combined_media_transformed = tf.zeros( - shape=(n_geos, n_times, 0), dtype=tf.float32 - ) - combined_beta = tf.zeros(shape=(n_geos, 0), dtype=tf.float32) - if media_tensors.media is not None: - alpha_m = yield prior_broadcast.alpha_m - ec_m = yield prior_broadcast.ec_m - eta_m = yield prior_broadcast.eta_m - slope_m = yield prior_broadcast.slope_m - beta_gm_dev = yield tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [n_geos, n_media_channels], - name=constants.BETA_GM_DEV, - ) - media_transformed = adstock_hill_media_fn( - media=media_tensors.media_scaled, - alpha=alpha_m, - ec=ec_m, - slope=slope_m, - ) - prior_type = self.model_spec.paid_media_prior_type - if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES: - if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: - roi_or_mroi_m = yield prior_broadcast.roi_m - else: - roi_or_mroi_m = yield prior_broadcast.mroi_m - beta_m_value = get_roi_prior_beta_m_value_fn( - alpha_m, - beta_gm_dev, - ec_m, - eta_m, - roi_or_mroi_m, - slope_m, - media_transformed, - ) - beta_m = yield tfp.distributions.Deterministic( - beta_m_value, name=constants.BETA_M - ) - else: - beta_m = yield prior_broadcast.beta_m - - beta_eta_combined = beta_m + eta_m * beta_gm_dev - beta_gm_value = ( - beta_eta_combined - if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL - else tf.math.exp(beta_eta_combined) - ) - beta_gm = yield tfp.distributions.Deterministic( - beta_gm_value, name=constants.BETA_GM - ) - combined_media_transformed = tf.concat( - [combined_media_transformed, media_transformed], axis=-1 - ) - combined_beta = tf.concat([combined_beta, beta_gm], axis=-1) - - if rf_tensors.reach is not None: - alpha_rf = yield prior_broadcast.alpha_rf - ec_rf = yield prior_broadcast.ec_rf - eta_rf = yield prior_broadcast.eta_rf - slope_rf = yield prior_broadcast.slope_rf - beta_grf_dev = yield tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [n_geos, n_rf_channels], - name=constants.BETA_GRF_DEV, - ) - rf_transformed = adstock_hill_rf_fn( - reach=rf_tensors.reach_scaled, - frequency=rf_tensors.frequency, - alpha=alpha_rf, - ec=ec_rf, - slope=slope_rf, - ) - - prior_type = self.model_spec.paid_media_prior_type - if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES: - if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: - roi_or_mroi_rf = yield prior_broadcast.roi_rf - else: - roi_or_mroi_rf = yield prior_broadcast.mroi_rf - beta_rf_value = get_roi_prior_beta_rf_value_fn( - alpha_rf, - beta_grf_dev, - ec_rf, - eta_rf, - roi_or_mroi_rf, - slope_rf, - rf_transformed, - ) - beta_rf = yield tfp.distributions.Deterministic( - beta_rf_value, - name=constants.BETA_RF, - ) - else: - beta_rf = yield prior_broadcast.beta_rf - - beta_eta_combined = beta_rf + eta_rf * beta_grf_dev - beta_grf_value = ( - beta_eta_combined - if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL - else tf.math.exp(beta_eta_combined) - ) - beta_grf = yield tfp.distributions.Deterministic( - beta_grf_value, name=constants.BETA_GRF - ) - combined_media_transformed = tf.concat( - [combined_media_transformed, rf_transformed], axis=-1 - ) - combined_beta = tf.concat([combined_beta, beta_grf], axis=-1) - - if organic_media_tensors.organic_media is not None: - alpha_om = yield prior_broadcast.alpha_om - ec_om = yield prior_broadcast.ec_om - eta_om = yield prior_broadcast.eta_om - slope_om = yield prior_broadcast.slope_om - beta_gom_dev = yield tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [n_geos, n_organic_media_channels], - name=constants.BETA_GOM_DEV, - ) - organic_media_transformed = adstock_hill_media_fn( - media=organic_media_tensors.organic_media_scaled, - alpha=alpha_om, - ec=ec_om, - slope=slope_om, - ) - beta_om = yield prior_broadcast.beta_om - - beta_eta_combined = beta_om + eta_om * beta_gom_dev - beta_gom_value = ( - beta_eta_combined - if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL - else tf.math.exp(beta_eta_combined) - ) - beta_gom = yield tfp.distributions.Deterministic( - beta_gom_value, name=constants.BETA_GOM - ) - combined_media_transformed = tf.concat( - [combined_media_transformed, organic_media_transformed], axis=-1 - ) - combined_beta = tf.concat([combined_beta, beta_gom], axis=-1) - - if organic_rf_tensors.organic_reach is not None: - alpha_orf = yield prior_broadcast.alpha_orf - ec_orf = yield prior_broadcast.ec_orf - eta_orf = yield prior_broadcast.eta_orf - slope_orf = yield prior_broadcast.slope_orf - beta_gorf_dev = yield tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [n_geos, n_organic_rf_channels], - name=constants.BETA_GORF_DEV, - ) - organic_rf_transformed = adstock_hill_rf_fn( - reach=organic_rf_tensors.organic_reach_scaled, - frequency=organic_rf_tensors.organic_frequency, - alpha=alpha_orf, - ec=ec_orf, - slope=slope_orf, - ) - - beta_orf = yield prior_broadcast.beta_orf - - beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev - beta_gorf_value = ( - beta_eta_combined - if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL - else tf.math.exp(beta_eta_combined) - ) - beta_gorf = yield tfp.distributions.Deterministic( - beta_gorf_value, name=constants.BETA_GORF - ) - combined_media_transformed = tf.concat( - [combined_media_transformed, organic_rf_transformed], axis=-1 - ) - combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1) - - sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos])) - gamma_gc_dev = yield tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [n_geos, n_controls], - name=constants.GAMMA_GC_DEV, - ) - gamma_gc = yield tfp.distributions.Deterministic( - gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC - ) - y_pred_combined_media = ( - tau_gt - + tf.einsum("gtm,gm->gt", combined_media_transformed, combined_beta) - + tf.einsum("gtc,gc->gt", controls_scaled, gamma_gc) - ) - - if self.non_media_treatments is not None: - gamma_n = yield prior_broadcast.gamma_n - xi_n = yield prior_broadcast.xi_n - gamma_gn_dev = yield tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [n_geos, n_non_media_channels], - name=constants.GAMMA_GN_DEV, - ) - gamma_gn = yield tfp.distributions.Deterministic( - gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN - ) - y_pred = y_pred_combined_media + tf.einsum( - "gtn,gn->gt", non_media_treatments_scaled, gamma_gn - ) - else: - y_pred = y_pred_combined_media - - # If there are any holdout observations, the holdout KPI values will - # be replaced with zeros using `experimental_pin`. For these - # observations, we set the posterior mean equal to zero and standard - # deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the - # sampled posterior parameter values. - if holdout_id is not None: - y_pred_holdout = tf.where(holdout_id, 0.0, y_pred) - test_sd = tf.cast(1.0 / np.sqrt(2.0 * np.pi), tf.float32) - sigma_gt_holdout = tf.where(holdout_id, test_sd, sigma_gt) - yield tfp.distributions.Normal( - y_pred_holdout, sigma_gt_holdout, name="y" - ) - else: - yield tfp.distributions.Normal(y_pred, sigma_gt, name="y") - - return joint_dist_unpinned - - def _get_joint_dist(self) -> tfp.distributions.Distribution: - y = ( - tf.where(self.holdout_id, 0.0, self.kpi_scaled) - if self.holdout_id is not None - else self.kpi_scaled - ) - return self._get_joint_dist_unpinned().experimental_pin(y=y) - - def _create_inference_data_coords( + def create_inference_data_coords( self, n_chains: int, n_draws: int ) -> Mapping[str, np.ndarray | Sequence[str]]: """Creates data coordinates for inference data.""" @@ -1449,7 +986,7 @@ def _create_inference_data_coords( constants.ORGANIC_RF_CHANNEL: organic_rf_channel_values, } - def _create_inference_data_dims(self) -> Mapping[str, Sequence[str]]: + def create_inference_data_dims(self) -> Mapping[str, Sequence[str]]: inference_dims = dict(constants.INFERENCE_DIMS) if self.unique_sigma_for_each_geo: inference_dims[constants.SIGMA] = [constants.GEO] @@ -1461,396 +998,6 @@ def _create_inference_data_dims(self) -> Mapping[str, Sequence[str]]: for param, dims in inference_dims.items() } - def _sample_media_priors( - self, - n_draws: int, - seed: int | None = None, - ) -> Mapping[str, tf.Tensor]: - """Draws samples from the prior distributions of the media variables. - - Args: - n_draws: Number of samples drawn from the prior distribution. - seed: Used to set the seed for reproducible results. For more information, - see [PRNGS and seeds] - (https://github.com/tensorflow/probability/blob/main/PRNGS.md). - - Returns: - A mapping of media parameter names to a tensor of shape [n_draws, n_geos, - n_media_channels] or [n_draws, n_media_channels] containing the - samples. - """ - prior = self.prior_broadcast - sample_shape = [1, n_draws] - sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} - media_vars = { - constants.ALPHA_M: prior.alpha_m.sample(**sample_kwargs), - constants.EC_M: prior.ec_m.sample(**sample_kwargs), - constants.ETA_M: prior.eta_m.sample(**sample_kwargs), - constants.SLOPE_M: prior.slope_m.sample(**sample_kwargs), - } - beta_gm_dev = tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [self.n_geos, self.n_media_channels], - name=constants.BETA_GM_DEV, - ).sample(**sample_kwargs) - media_transformed = self.adstock_hill_media( - media=self.media_tensors.media_scaled, - alpha=media_vars[constants.ALPHA_M], - ec=media_vars[constants.EC_M], - slope=media_vars[constants.SLOPE_M], - ) - - prior_type = self.model_spec.paid_media_prior_type - if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: - roi_m = prior.roi_m.sample(**sample_kwargs) - beta_m_value = self._get_roi_prior_beta_m_value( - beta_gm_dev=beta_gm_dev, - media_transformed=media_transformed, - roi_or_mroi_m=roi_m, - **media_vars, - ) - media_vars[constants.ROI_M] = roi_m - media_vars[constants.BETA_M] = tfp.distributions.Deterministic( - beta_m_value, name=constants.BETA_M - ).sample() - elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: - mroi_m = prior.mroi_m.sample(**sample_kwargs) - beta_m_value = self._get_roi_prior_beta_m_value( - beta_gm_dev=beta_gm_dev, - media_transformed=media_transformed, - roi_or_mroi_m=mroi_m, - **media_vars, - ) - media_vars[constants.MROI_M] = mroi_m - media_vars[constants.BETA_M] = tfp.distributions.Deterministic( - beta_m_value, name=constants.BETA_M - ).sample() - else: - media_vars[constants.BETA_M] = prior.beta_m.sample(**sample_kwargs) - - beta_eta_combined = ( - media_vars[constants.BETA_M][..., tf.newaxis, :] - + media_vars[constants.ETA_M][..., tf.newaxis, :] * beta_gm_dev - ) - beta_gm_value = ( - beta_eta_combined - if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL - else tf.math.exp(beta_eta_combined) - ) - media_vars[constants.BETA_GM] = tfp.distributions.Deterministic( - beta_gm_value, name=constants.BETA_GM - ).sample() - - return media_vars - - def _sample_rf_priors( - self, - n_draws: int, - seed: int | None = None, - ) -> Mapping[str, tf.Tensor]: - """Draws samples from the prior distributions of the RF variables. - - Args: - n_draws: Number of samples drawn from the prior distribution. - seed: Used to set the seed for reproducible results. For more information, - see [PRNGS and seeds] - (https://github.com/tensorflow/probability/blob/main/PRNGS.md). - - Returns: - A mapping of RF parameter names to a tensor of shape [n_draws, n_geos, - n_rf_channels] or [n_draws, n_rf_channels] containing the samples. - """ - prior = self.prior_broadcast - sample_shape = [1, n_draws] - sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} - rf_vars = { - constants.ALPHA_RF: prior.alpha_rf.sample(**sample_kwargs), - constants.EC_RF: prior.ec_rf.sample(**sample_kwargs), - constants.ETA_RF: prior.eta_rf.sample(**sample_kwargs), - constants.SLOPE_RF: prior.slope_rf.sample(**sample_kwargs), - } - beta_grf_dev = tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [self.n_geos, self.n_rf_channels], - name=constants.BETA_GRF_DEV, - ).sample(**sample_kwargs) - rf_transformed = self.adstock_hill_rf( - reach=self.rf_tensors.reach_scaled, - frequency=self.rf_tensors.frequency, - alpha=rf_vars[constants.ALPHA_RF], - ec=rf_vars[constants.EC_RF], - slope=rf_vars[constants.SLOPE_RF], - ) - - prior_type = self.model_spec.paid_media_prior_type - if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: - roi_rf = prior.roi_rf.sample(**sample_kwargs) - beta_rf_value = self._get_roi_prior_beta_rf_value( - beta_grf_dev=beta_grf_dev, - rf_transformed=rf_transformed, - roi_or_mroi_rf=roi_rf, - **rf_vars, - ) - rf_vars[constants.ROI_RF] = roi_rf - rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic( - beta_rf_value, - name=constants.BETA_RF, - ).sample() - elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: - mroi_rf = prior.mroi_rf.sample(**sample_kwargs) - beta_rf_value = self._get_roi_prior_beta_rf_value( - beta_grf_dev=beta_grf_dev, - rf_transformed=rf_transformed, - roi_or_mroi_rf=mroi_rf, - **rf_vars, - ) - rf_vars[constants.MROI_RF] = mroi_rf - rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic( - beta_rf_value, - name=constants.BETA_RF, - ).sample() - else: - rf_vars[constants.BETA_RF] = prior.beta_rf.sample(**sample_kwargs) - - beta_eta_combined = ( - rf_vars[constants.BETA_RF][..., tf.newaxis, :] - + rf_vars[constants.ETA_RF][..., tf.newaxis, :] * beta_grf_dev - ) - beta_grf_value = ( - beta_eta_combined - if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL - else tf.math.exp(beta_eta_combined) - ) - rf_vars[constants.BETA_GRF] = tfp.distributions.Deterministic( - beta_grf_value, name=constants.BETA_GRF - ).sample() - - return rf_vars - - def _sample_organic_media_priors( - self, - n_draws: int, - seed: int | None = None, - ) -> Mapping[str, tf.Tensor]: - """Draws samples from the prior distributions of organic media variables. - - Args: - n_draws: Number of samples drawn from the prior distribution. - seed: Used to set the seed for reproducible results. For more information, - see [PRNGS and seeds] - (https://github.com/tensorflow/probability/blob/main/PRNGS.md). - - Returns: - A mapping of organic media parameter names to a tensor of shape [n_draws, - n_geos, n_organic_media_channels] or [n_draws, n_organic_media_channels] - containing the samples. - """ - prior = self.prior_broadcast - sample_shape = [1, n_draws] - sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} - organic_media_vars = { - constants.ALPHA_OM: prior.alpha_om.sample(**sample_kwargs), - constants.EC_OM: prior.ec_om.sample(**sample_kwargs), - constants.ETA_OM: prior.eta_om.sample(**sample_kwargs), - constants.SLOPE_OM: prior.slope_om.sample(**sample_kwargs), - } - beta_gom_dev = tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [self.n_geos, self.n_organic_media_channels], - name=constants.BETA_GOM_DEV, - ).sample(**sample_kwargs) - - organic_media_vars[constants.BETA_OM] = prior.beta_om.sample( - **sample_kwargs - ) - - beta_eta_combined = ( - organic_media_vars[constants.BETA_OM][..., tf.newaxis, :] - + organic_media_vars[constants.ETA_OM][..., tf.newaxis, :] - * beta_gom_dev - ) - beta_gom_value = ( - beta_eta_combined - if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL - else tf.math.exp(beta_eta_combined) - ) - organic_media_vars[constants.BETA_GOM] = tfp.distributions.Deterministic( - beta_gom_value, name=constants.BETA_GOM - ).sample() - - return organic_media_vars - - def _sample_organic_rf_priors( - self, - n_draws: int, - seed: int | None = None, - ) -> Mapping[str, tf.Tensor]: - """Draws samples from the prior distributions of the organic RF variables. - - Args: - n_draws: Number of samples drawn from the prior distribution. - seed: Used to set the seed for reproducible results. For more information, - see [PRNGS and seeds] - (https://github.com/tensorflow/probability/blob/main/PRNGS.md). - - Returns: - A mapping of organic RF parameter names to a tensor of shape [n_draws, - n_geos, n_organic_rf_channels] or [n_draws, n_organic_rf_channels] - containing the samples. - """ - prior = self.prior_broadcast - sample_shape = [1, n_draws] - sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} - organic_rf_vars = { - constants.ALPHA_ORF: prior.alpha_orf.sample(**sample_kwargs), - constants.EC_ORF: prior.ec_orf.sample(**sample_kwargs), - constants.ETA_ORF: prior.eta_orf.sample(**sample_kwargs), - constants.SLOPE_ORF: prior.slope_orf.sample(**sample_kwargs), - } - beta_gorf_dev = tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [self.n_geos, self.n_organic_rf_channels], - name=constants.BETA_GORF_DEV, - ).sample(**sample_kwargs) - - organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(**sample_kwargs) - - beta_eta_combined = ( - organic_rf_vars[constants.BETA_ORF][..., tf.newaxis, :] - + organic_rf_vars[constants.ETA_ORF][..., tf.newaxis, :] * beta_gorf_dev - ) - beta_gorf_value = ( - beta_eta_combined - if self.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL - else tf.math.exp(beta_eta_combined) - ) - organic_rf_vars[constants.BETA_GORF] = tfp.distributions.Deterministic( - beta_gorf_value, name=constants.BETA_GORF - ).sample() - - return organic_rf_vars - - def _sample_non_media_treatments_priors( - self, - n_draws: int, - seed: int | None = None, - ) -> Mapping[str, tf.Tensor]: - """Draws from the prior distributions of the non-media treatment variables. - - Args: - n_draws: Number of samples drawn from the prior distribution. - seed: Used to set the seed for reproducible results. For more information, - see [PRNGS and seeds] - (https://github.com/tensorflow/probability/blob/main/PRNGS.md). - - Returns: - A mapping of non-media treatment parameter names to a tensor of shape - [n_draws, - n_geos, n_non_media_channels] or [n_draws, n_non_media_channels] - containing the samples. - """ - prior = self.prior_broadcast - sample_shape = [1, n_draws] - sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} - non_media_treatments_vars = { - constants.GAMMA_N: prior.gamma_n.sample(**sample_kwargs), - constants.XI_N: prior.xi_n.sample(**sample_kwargs), - } - gamma_gn_dev = tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [self.n_geos, self.n_non_media_channels], - name=constants.GAMMA_GN_DEV, - ).sample(**sample_kwargs) - non_media_treatments_vars[constants.GAMMA_GN] = ( - tfp.distributions.Deterministic( - non_media_treatments_vars[constants.GAMMA_N][..., tf.newaxis, :] - + non_media_treatments_vars[constants.XI_N][..., tf.newaxis, :] - * gamma_gn_dev, - name=constants.GAMMA_GN, - ).sample() - ) - return non_media_treatments_vars - - def _sample_prior_fn( - self, - n_draws: int, - seed: int | None = None, - ) -> Mapping[str, tf.Tensor]: - """Returns a mapping of prior parameters to tensors of the samples.""" - # For stateful sampling, the random seed must be set to ensure that any - # random numbers that are generated are deterministic. - if seed is not None: - tf.keras.utils.set_random_seed(1) - prior = self.prior_broadcast - sample_shape = [1, n_draws] - sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} - - tau_g_excl_baseline = prior.tau_g_excl_baseline.sample(**sample_kwargs) - base_vars = { - constants.KNOT_VALUES: prior.knot_values.sample(**sample_kwargs), - constants.GAMMA_C: prior.gamma_c.sample(**sample_kwargs), - constants.XI_C: prior.xi_c.sample(**sample_kwargs), - constants.SIGMA: prior.sigma.sample(**sample_kwargs), - constants.TAU_G: _get_tau_g( - tau_g_excl_baseline=tau_g_excl_baseline, - baseline_geo_idx=self.baseline_geo_idx, - ).sample(), - } - base_vars[constants.MU_T] = tfp.distributions.Deterministic( - tf.einsum( - "...k,kt->...t", - base_vars[constants.KNOT_VALUES], - tf.convert_to_tensor(self.knot_info.weights), - ), - name=constants.MU_T, - ).sample() - - gamma_gc_dev = tfp.distributions.Sample( - tfp.distributions.Normal(0, 1), - [self.n_geos, self.n_controls], - name=constants.GAMMA_GC_DEV, - ).sample(**sample_kwargs) - base_vars[constants.GAMMA_GC] = tfp.distributions.Deterministic( - base_vars[constants.GAMMA_C][..., tf.newaxis, :] - + base_vars[constants.XI_C][..., tf.newaxis, :] * gamma_gc_dev, - name=constants.GAMMA_GC, - ).sample() - - media_vars = ( - self._sample_media_priors(n_draws, seed) - if self.media_tensors.media is not None - else {} - ) - rf_vars = ( - self._sample_rf_priors(n_draws, seed) - if self.rf_tensors.reach is not None - else {} - ) - organic_media_vars = ( - self._sample_organic_media_priors(n_draws, seed) - if self.organic_media_tensors.organic_media is not None - else {} - ) - organic_rf_vars = ( - self._sample_organic_rf_priors(n_draws, seed) - if self.organic_rf_tensors.organic_reach is not None - else {} - ) - non_media_treatments_vars = ( - self._sample_non_media_treatments_priors(n_draws, seed) - if self.non_media_treatments_scaled is not None - else {} - ) - - return ( - base_vars - | media_vars - | rf_vars - | organic_media_vars - | organic_rf_vars - | non_media_treatments_vars - ) - def sample_prior(self, n_draws: int, seed: int | None = None): """Draws samples from the prior distributions. @@ -1860,13 +1007,7 @@ def sample_prior(self, n_draws: int, seed: int | None = None): see [PRNGS and seeds] (https://github.com/tensorflow/probability/blob/main/PRNGS.md). """ - prior_draws = self._sample_prior_fn(n_draws, seed=seed) - # Create Arviz InferenceData for prior draws. - prior_coords = self._create_inference_data_coords(1, n_draws) - prior_dims = self._create_inference_data_dims() - prior_inference_data = az.convert_to_inference_data( - prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR - ) + prior_inference_data = self.prior_sampler(n_draws, seed) self.inference_data.extend(prior_inference_data, join="right") def sample_posterior( @@ -1943,112 +1084,20 @@ def sample_posterior( [ResourceExhaustedError when running Meridian.sample_posterior] (https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error). """ - seed = tfp.random.sanitize_seed(seed) if seed else None - n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains - total_chains = np.sum(n_chains_list) - - states = [] - traces = [] - for n_chains_batch in n_chains_list: - try: - mcmc = _xla_windowed_adaptive_nuts( - n_draws=n_burnin + n_keep, - joint_dist=self._get_joint_dist(), - n_chains=n_chains_batch, - num_adaptation_steps=n_adapt, - current_state=current_state, - init_step_size=init_step_size, - dual_averaging_kwargs=dual_averaging_kwargs, - max_tree_depth=max_tree_depth, - max_energy_diff=max_energy_diff, - unrolled_leapfrog_steps=unrolled_leapfrog_steps, - parallel_iterations=parallel_iterations, - seed=seed, - **pins, - ) - except tf.errors.ResourceExhaustedError as error: - raise MCMCOOMError( - "ERROR: Out of memory. Try reducing `n_keep` or pass a list of" - " integers as `n_chains` to sample chains serially (see" - " https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error)" - ) from error - states.append(mcmc.all_states._asdict()) - traces.append(mcmc.trace) - - mcmc_states = { - k: tf.einsum( - "ij...->ji...", - tf.concat([state[k] for state in states], axis=1)[n_burnin:, ...], - ) - for k in states[0].keys() - if k not in constants.UNSAVED_PARAMETERS - } - # Create Arviz InferenceData for posterior draws. - posterior_coords = self._create_inference_data_coords(total_chains, n_keep) - posterior_dims = self._create_inference_data_dims() - infdata_posterior = az.convert_to_inference_data( - mcmc_states, coords=posterior_coords, dims=posterior_dims - ) - - # Save trace metrics in InferenceData. - mcmc_trace = {} - for k in traces[0].keys(): - if k not in constants.IGNORED_TRACE_METRICS: - mcmc_trace[k] = tf.concat( - [ - tf.broadcast_to( - tf.transpose(trace[k][n_burnin:, ...]), - [n_chains_list[i], n_keep], - ) - for i, trace in enumerate(traces) - ], - axis=0, - ) - - trace_coords = { - constants.CHAIN: np.arange(total_chains), - constants.DRAW: np.arange(n_keep), - } - trace_dims = { - k: [constants.CHAIN, constants.DRAW] for k in mcmc_trace.keys() - } - infdata_trace = az.convert_to_inference_data( - mcmc_trace, coords=trace_coords, dims=trace_dims, group="trace" - ) - - # Create Arviz InferenceData for divergent transitions and other sampling - # statistics. Note that InferenceData has a different naming convention - # than Tensorflow, and only certain variables are recongnized. - # https://arviz-devs.github.io/arviz/schema/schema.html#sample-stats - # The list of values returned by windowed_adaptive_nuts() is the following: - # 'step_size', 'tune', 'target_log_prob', 'diverging', 'accept_ratio', - # 'variance_scaling', 'n_steps', 'is_accepted'. - - sample_stats = { - constants.SAMPLE_STATS_METRICS[k]: v - for k, v in mcmc_trace.items() - if k in constants.SAMPLE_STATS_METRICS - } - sample_stats_dims = { - constants.SAMPLE_STATS_METRICS[k]: v - for k, v in trace_dims.items() - if k in constants.SAMPLE_STATS_METRICS - } - # Tensorflow does not include a "draw" dimension on step size metric if same - # step size is used for all chains. Step size must be broadcast to the - # correct shape. - sample_stats[constants.STEP_SIZE] = tf.broadcast_to( - sample_stats[constants.STEP_SIZE], [total_chains, n_keep] - ) - sample_stats_dims[constants.STEP_SIZE] = [constants.CHAIN, constants.DRAW] - infdata_sample_stats = az.convert_to_inference_data( - sample_stats, - coords=trace_coords, - dims=sample_stats_dims, - group="sample_stats", - ) - posterior_inference_data = az.concat( - infdata_posterior, infdata_trace, infdata_sample_stats + posterior_inference_data = self.posterior_sampler( + n_chains, + n_adapt, + n_burnin, + n_keep, + current_state, + init_step_size, + dual_averaging_kwargs, + max_tree_depth, + max_energy_diff, + unrolled_leapfrog_steps, + parallel_iterations, + seed, + **pins, ) self.inference_data.extend(posterior_inference_data, join="right") diff --git a/meridian/model/model_test.py b/meridian/model/model_test.py index cb9de4af..5a3e87f7 100644 --- a/meridian/model/model_test.py +++ b/meridian/model/model_test.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for model.py.""" - -import collections from collections.abc import Collection, Mapping, Sequence import os from unittest import mock @@ -30,6 +27,7 @@ from meridian.model import adstock_hill from meridian.model import knots as knots_module from meridian.model import model +from meridian.model import model_test_data from meridian.model import prior_distribution from meridian.model import spec import numpy as np @@ -38,267 +36,17 @@ import xarray as xr -def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor: - """Converts a DataArray to a tf.Tensor with the correct MCMC format. - - This function converts a DataArray to tf.Tensor, swaps first two dimensions - and adds the burnin part. This is needed to properly mock the - _xla_windowed_adaptive_nuts() function output in the sample_posterior - tests. - - Args: - array: The array to be converted. - n_burnin: The number of extra draws to be padded with as the 'burnin' part. - - Returns: - A tensor in the same format as returned by the _xla_windowed_adaptive_nuts() - function. - """ - tensor = tf.convert_to_tensor(array) - perm = [1, 0] + [i for i in range(2, len(tensor.shape))] - transposed_tensor = tf.transpose(tensor, perm=perm) - - # Add the "burnin" part to the mocked output of _xla_windowed_adaptive_nuts - # to make sure sample_posterior returns the correct "keep" part. - if array.dtype == bool: - pad_value = False - else: - pad_value = 0.0 if array.dtype.kind == "f" else 0 - - burnin = tf.fill([n_burnin] + transposed_tensor.shape[1:], pad_value) - return tf.concat( - [burnin, transposed_tensor], - axis=0, - ) - - -class ModelTest(tf.test.TestCase, parameterized.TestCase): - # TODO: Update the sample data to span over 1 or 2 year(s). - _TEST_DIR = os.path.join(os.path.dirname(__file__), "test_data") - _TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH = os.path.join( - _TEST_DIR, - "sample_prior_media_and_rf.nc", - ) - _TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH = os.path.join( - _TEST_DIR, - "sample_prior_media_only.nc", - ) - _TEST_SAMPLE_PRIOR_RF_ONLY_PATH = os.path.join( - _TEST_DIR, - "sample_prior_rf_only.nc", - ) - _TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH = os.path.join( - _TEST_DIR, - "sample_posterior_media_and_rf.nc", - ) - _TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH = os.path.join( - _TEST_DIR, - "sample_posterior_media_only.nc", - ) - _TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH = os.path.join( - _TEST_DIR, - "sample_posterior_rf_only.nc", - ) - _TEST_SAMPLE_TRACE_PATH = os.path.join( - _TEST_DIR, - "sample_trace.nc", - ) +class ModelTest( + tf.test.TestCase, + parameterized.TestCase, + model_test_data.WithInputDataSamples, +): - # Data dimensions for sample input. - _N_CHAINS = 2 - _N_ADAPT = 2 - _N_BURNIN = 5 - _N_KEEP = 10 - _N_DRAWS = 10 - _N_GEOS = 5 - _N_GEOS_NATIONAL = 1 - _N_TIMES = 200 - _N_TIMES_SHORT = 49 - _N_MEDIA_TIMES = 203 - _N_MEDIA_TIMES_SHORT = 52 - _N_MEDIA_CHANNELS = 3 - _N_RF_CHANNELS = 2 - _N_CONTROLS = 2 - _ROI_CALIBRATION_PERIOD = tf.cast( - tf.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)), - dtype=tf.bool, - ) - _RF_ROI_CALIBRATION_PERIOD = tf.cast( - tf.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)), - dtype=tf.bool, - ) + IDS = model_test_data.WithInputDataSamples def setUp(self): super().setUp() - self.input_data_non_revenue_no_revenue_per_kpi = ( - test_utils.sample_input_data_non_revenue_no_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES, - n_media_times=self._N_MEDIA_TIMES, - n_controls=self._N_CONTROLS, - n_media_channels=self._N_MEDIA_CHANNELS, - seed=0, - ) - ) - self.input_data_with_media_only = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES, - n_media_times=self._N_MEDIA_TIMES, - n_controls=self._N_CONTROLS, - n_media_channels=self._N_MEDIA_CHANNELS, - seed=0, - ) - ) - self.input_data_with_rf_only = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES, - n_media_times=self._N_MEDIA_TIMES, - n_controls=self._N_CONTROLS, - n_rf_channels=self._N_RF_CHANNELS, - seed=0, - ) - ) - self.input_data_with_media_and_rf = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES, - n_media_times=self._N_MEDIA_TIMES, - n_controls=self._N_CONTROLS, - n_media_channels=self._N_MEDIA_CHANNELS, - n_rf_channels=self._N_RF_CHANNELS, - seed=0, - ) - ) - self.short_input_data_with_media_only = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES_SHORT, - n_media_times=self._N_MEDIA_TIMES_SHORT, - n_controls=self._N_CONTROLS, - n_media_channels=self._N_MEDIA_CHANNELS, - seed=0, - ) - ) - self.short_input_data_with_rf_only = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES_SHORT, - n_media_times=self._N_MEDIA_TIMES_SHORT, - n_controls=self._N_CONTROLS, - n_rf_channels=self._N_RF_CHANNELS, - seed=0, - ) - ) - self.short_input_data_with_media_and_rf = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES_SHORT, - n_media_times=self._N_MEDIA_TIMES_SHORT, - n_controls=self._N_CONTROLS, - n_media_channels=self._N_MEDIA_CHANNELS, - n_rf_channels=self._N_RF_CHANNELS, - seed=0, - ) - ) - self.national_input_data_media_only = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS_NATIONAL, - n_times=self._N_TIMES, - n_media_times=self._N_MEDIA_TIMES, - n_controls=self._N_CONTROLS, - n_media_channels=self._N_MEDIA_CHANNELS, - seed=0, - ) - ) - self.national_input_data_media_and_rf = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS_NATIONAL, - n_times=self._N_TIMES, - n_media_times=self._N_MEDIA_TIMES, - n_controls=self._N_CONTROLS, - n_media_channels=self._N_MEDIA_CHANNELS, - n_rf_channels=self._N_RF_CHANNELS, - seed=0, - ) - ) - - test_prior_media_and_rf = xr.open_dataset( - self._TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH - ) - test_prior_media_only = xr.open_dataset( - self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH - ) - test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH) - self.test_dist_media_and_rf = collections.OrderedDict({ - param: tf.convert_to_tensor(test_prior_media_and_rf[param]) - for param in constants.COMMON_PARAMETER_NAMES - + constants.MEDIA_PARAMETER_NAMES - + constants.RF_PARAMETER_NAMES - }) - self.test_dist_media_only = collections.OrderedDict({ - param: tf.convert_to_tensor(test_prior_media_only[param]) - for param in constants.COMMON_PARAMETER_NAMES - + constants.MEDIA_PARAMETER_NAMES - }) - self.test_dist_rf_only = collections.OrderedDict({ - param: tf.convert_to_tensor(test_prior_rf_only[param]) - for param in constants.COMMON_PARAMETER_NAMES - + constants.RF_PARAMETER_NAMES - }) - - test_posterior_media_and_rf = xr.open_dataset( - self._TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH - ) - test_posterior_media_only = xr.open_dataset( - self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH - ) - test_posterior_rf_only = xr.open_dataset( - self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH - ) - posterior_params_to_tensors_media_and_rf = { - param: _convert_with_swap( - test_posterior_media_and_rf[param], n_burnin=self._N_BURNIN - ) - for param in constants.COMMON_PARAMETER_NAMES - + constants.MEDIA_PARAMETER_NAMES - + constants.RF_PARAMETER_NAMES - } - posterior_params_to_tensors_media_only = { - param: _convert_with_swap( - test_posterior_media_only[param], n_burnin=self._N_BURNIN - ) - for param in constants.COMMON_PARAMETER_NAMES - + constants.MEDIA_PARAMETER_NAMES - } - posterior_params_to_tensors_rf_only = { - param: _convert_with_swap( - test_posterior_rf_only[param], n_burnin=self._N_BURNIN - ) - for param in constants.COMMON_PARAMETER_NAMES - + constants.RF_PARAMETER_NAMES - } - self.test_posterior_states_media_and_rf = collections.namedtuple( - "StructTuple", - constants.COMMON_PARAMETER_NAMES - + constants.MEDIA_PARAMETER_NAMES - + constants.RF_PARAMETER_NAMES, - )(**posterior_params_to_tensors_media_and_rf) - self.test_posterior_states_media_only = collections.namedtuple( - "StructTuple", - constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES, - )(**posterior_params_to_tensors_media_only) - self.test_posterior_states_rf_only = collections.namedtuple( - "StructTuple", - constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES, - )(**posterior_params_to_tensors_rf_only) - - test_trace = xr.open_dataset(self._TEST_SAMPLE_TRACE_PATH) - self.test_trace = { - param: _convert_with_swap(test_trace[param], n_burnin=self._N_BURNIN) - for param in test_trace.data_vars - } + model_test_data.WithInputDataSamples.setUp(self) @parameterized.named_parameters( dict( @@ -654,7 +402,7 @@ def test_init_national_args_with_model_spec_warnings(self): ), constants.MROI_RF: tfp.distributions.LogNormal( 0.2, 0.8, name=constants.MROI_RF - ) + ), }, ignored_priors="mroi_m, mroi_rf", paid_media_prior_type=constants.PAID_MEDIA_PRIOR_TYPE_ROI, @@ -670,7 +418,7 @@ def test_init_national_args_with_model_spec_warnings(self): ), constants.ROI_M: tfp.distributions.LogNormal( 0.2, 0.1, name=constants.ROI_M - ) + ), }, ignored_priors="beta_m, beta_rf, roi_m", paid_media_prior_type=constants.PAID_MEDIA_PRIOR_TYPE_MROI, @@ -740,19 +488,20 @@ def test_base_national_properties(self): dict( testcase_name="media_only", data=test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_media_channels=_N_MEDIA_CHANNELS + n_media_channels=IDS._N_MEDIA_CHANNELS ), ), dict( testcase_name="rf_only", data=test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_rf_channels=_N_RF_CHANNELS + n_rf_channels=IDS._N_RF_CHANNELS ), ), dict( testcase_name="rf_and_media", data=test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_media_channels=_N_MEDIA_CHANNELS, n_rf_channels=_N_RF_CHANNELS + n_media_channels=IDS._N_MEDIA_CHANNELS, + n_rf_channels=IDS._N_RF_CHANNELS, ), ), ) @@ -824,25 +573,25 @@ def test_input_data_tensor_properties(self, data): @parameterized.named_parameters( dict( testcase_name="geo_normal", - n_geos=_N_GEOS, + n_geos=IDS._N_GEOS, media_effects_dist=constants.MEDIA_EFFECTS_NORMAL, expected_media_effects_dist=constants.MEDIA_EFFECTS_NORMAL, ), dict( testcase_name="geo_log_normal", - n_geos=_N_GEOS, + n_geos=IDS._N_GEOS, media_effects_dist=constants.MEDIA_EFFECTS_LOG_NORMAL, expected_media_effects_dist=constants.MEDIA_EFFECTS_LOG_NORMAL, ), dict( testcase_name="national_normal", - n_geos=_N_GEOS_NATIONAL, + n_geos=IDS._N_GEOS_NATIONAL, media_effects_dist=constants.MEDIA_EFFECTS_NORMAL, expected_media_effects_dist=constants.MEDIA_EFFECTS_NORMAL, ), dict( testcase_name="national_log_normal", - n_geos=_N_GEOS_NATIONAL, + n_geos=IDS._N_GEOS_NATIONAL, media_effects_dist=constants.MEDIA_EFFECTS_LOG_NORMAL, expected_media_effects_dist=constants.MEDIA_EFFECTS_NORMAL, ), @@ -861,25 +610,25 @@ def test_media_effects_dist_property( @parameterized.named_parameters( dict( testcase_name="geo_unique_sigma_for_each_geo_true", - n_geos=_N_GEOS, + n_geos=IDS._N_GEOS, unique_sigma_for_each_geo=True, expected_unique_sigma_for_each_geo=True, ), dict( testcase_name="geo_unique_sigma_for_each_geo_false", - n_geos=_N_GEOS, + n_geos=IDS._N_GEOS, unique_sigma_for_each_geo=False, expected_unique_sigma_for_each_geo=False, ), dict( testcase_name="national_unique_sigma_for_each_geo_true", - n_geos=_N_GEOS_NATIONAL, + n_geos=IDS._N_GEOS_NATIONAL, unique_sigma_for_each_geo=True, expected_unique_sigma_for_each_geo=False, ), dict( testcase_name="national_unique_sigma_for_each_geo_false", - n_geos=_N_GEOS_NATIONAL, + n_geos=IDS._N_GEOS_NATIONAL, unique_sigma_for_each_geo=False, expected_unique_sigma_for_each_geo=False, ), @@ -1366,1894 +1115,113 @@ def test_adstock_hill_rf( mocks_called_names = [mc[0] for mc in manager.mock_calls] self.assertEqual(mocks_called_names, expected_called_names) - def test_get_joint_dist_zeros(self): - model_spec = spec.ModelSpec( - prior=prior_distribution.PriorDistribution( - knot_values=tfp.distributions.Deterministic(0), - tau_g_excl_baseline=tfp.distributions.Deterministic(0), - beta_m=tfp.distributions.Deterministic(0), - beta_rf=tfp.distributions.Deterministic(0), - eta_m=tfp.distributions.Deterministic(0), - eta_rf=tfp.distributions.Deterministic(0), - gamma_c=tfp.distributions.Deterministic(0), - xi_c=tfp.distributions.Deterministic(0), - alpha_m=tfp.distributions.Deterministic(0), - alpha_rf=tfp.distributions.Deterministic(0), - ec_m=tfp.distributions.Deterministic(0), - ec_rf=tfp.distributions.Deterministic(0), - slope_m=tfp.distributions.Deterministic(0), - slope_rf=tfp.distributions.Deterministic(0), - sigma=tfp.distributions.Deterministic(0), - roi_m=tfp.distributions.Deterministic(0), - roi_rf=tfp.distributions.Deterministic(0), - ) - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - sample = meridian._get_joint_dist_unpinned().sample(self._N_DRAWS) - self.assertAllEqual( - sample.y, - tf.zeros(shape=(self._N_DRAWS, self._N_GEOS, self._N_TIMES_SHORT)), - ) - - @parameterized.product( - paid_media_prior_type=[ - constants.PAID_MEDIA_PRIOR_TYPE_ROI, - constants.PAID_MEDIA_PRIOR_TYPE_MROI, - constants.PAID_MEDIA_PRIOR_TYPE_COEFFICIENT, - ], - media_effects_dist=[ - constants.MEDIA_EFFECTS_NORMAL, - constants.MEDIA_EFFECTS_LOG_NORMAL, - ], - ) - def test_get_joint_dist_with_log_prob_media_only( - self, - paid_media_prior_type: str, - media_effects_dist: str, - ): - model_spec = spec.ModelSpec( - paid_media_prior_type=paid_media_prior_type, - media_effects_dist=media_effects_dist, - ) - meridian = model.Meridian( - model_spec=model_spec, - input_data=self.short_input_data_with_media_only, - ) - - # Take a single draw of all parameters from the prior distribution. - par_structtuple = meridian._get_joint_dist_unpinned().sample(1) - par = par_structtuple._asdict() - - # Note that "y" is a draw from the prior predictive (transformed) outcome - # distribution. We drop it because "y" is already "pinned" in - # meridian._get_joint_dist() and is not actually a parameter. - del par["y"] + def test_save_and_load_works(self): + # The create_tempdir() method below internally uses command line flag + # (--test_tmpdir) and such flags are not marked as parsed by default + # when running with pytest. Marking as parsed directly here to make the + # pytest run pass. + flags.FLAGS.mark_as_parsed() + file_path = os.path.join(self.create_tempdir().full_path, "joblib") + mmm = model.Meridian(input_data=self.input_data_with_media_and_rf) + model.save_mmm(mmm, str(file_path)) + self.assertTrue(os.path.exists(file_path)) + new_mmm = model.load_mmm(file_path) + for attr in dir(mmm): + if isinstance(getattr(mmm, attr), (int, bool)): + with self.subTest(name=attr): + self.assertEqual(getattr(mmm, attr), getattr(new_mmm, attr)) + elif isinstance(getattr(mmm, attr), tf.Tensor): + with self.subTest(name=attr): + self.assertAllClose(getattr(mmm, attr), getattr(new_mmm, attr)) - # Note that the actual (transformed) outcome data is "pinned" as "y". - log_prob_parts_structtuple = meridian._get_joint_dist().log_prob_parts(par) - log_prob_parts = { - k: v._asdict() for k, v in log_prob_parts_structtuple._asdict().items() - } + def test_load_error(self): + with self.assertRaisesWithLiteralMatch( + FileNotFoundError, "No such file or directory: this/path/does/not/exist" + ): + model.load_mmm("this/path/does/not/exist") - derived_params = [ - constants.BETA_GM, - constants.GAMMA_GC, - constants.MU_T, - constants.TAU_G, - ] - prior_distribution_params = [ - constants.KNOT_VALUES, - constants.ETA_M, - constants.GAMMA_C, - constants.XI_C, - constants.ALPHA_M, - constants.EC_M, - constants.SLOPE_M, - constants.SIGMA, - ] - if paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: - derived_params.append(constants.BETA_M) - prior_distribution_params.append(constants.ROI_M) - elif paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: - derived_params.append(constants.BETA_M) - prior_distribution_params.append(constants.MROI_M) - else: - prior_distribution_params.append(constants.BETA_M) +class NonPaidModelTest( + tf.test.TestCase, + parameterized.TestCase, + model_test_data.WithInputDataSamples, +): - # Parameters that are derived from other parameters via Deterministic() - # should have zero contribution to log_prob. - for parname in derived_params: - self.assertAllEqual(log_prob_parts["unpinned"][parname][0], 0) + IDS = model_test_data.WithInputDataSamples - prior_distribution_logprobs = {} - for parname in prior_distribution_params: - prior_distribution_logprobs[parname] = tf.reduce_sum( - getattr(meridian.prior_broadcast, parname).log_prob(par[parname]) - ) - self.assertAllClose( - prior_distribution_logprobs[parname], - log_prob_parts["unpinned"][parname][0], - ) + def setUp(self): + super().setUp() + model_test_data.WithInputDataSamples.setUp(self) - coef_params = [ - constants.BETA_GM_DEV, - constants.GAMMA_GC_DEV, - ] - coef_logprobs = {} - for parname in coef_params: - coef_logprobs[parname] = tf.reduce_sum( - tfp.distributions.Normal(0, 1).log_prob(par[parname]) - ) - self.assertAllClose( - coef_logprobs[parname], log_prob_parts["unpinned"][parname][0] - ) - transformed_media = meridian.adstock_hill_media( - media=meridian.media_tensors.media_scaled, - alpha=par[constants.ALPHA_M], - ec=par[constants.EC_M], - slope=par[constants.SLOPE_M], - )[0, :, :, :] - beta_m = par[constants.BETA_GM][0, :, :] - y_means = ( - par[constants.TAU_G][0, :, None] - + par[constants.MU_T][0, None, :] - + tf.einsum("gtm,gm->gt", transformed_media, beta_m) - + tf.einsum( - "gtc,gc->gt", - meridian.controls_scaled, - par[constants.GAMMA_GC][0, :, :], - ) - ) - y_means_logprob = tf.reduce_sum( - tfp.distributions.Normal(y_means, par[constants.SIGMA]).log_prob( - meridian.kpi_scaled - ) + def test_init_with_wrong_non_media_population_scaling_id_shape_fails(self): + model_spec = spec.ModelSpec( + non_media_population_scaling_id=np.ones((7), dtype=bool) ) - self.assertAllClose(y_means_logprob, log_prob_parts["pinned"]["y"][0]) + with self.assertRaisesWithLiteralMatch( + ValueError, + "The shape of `non_media_population_scaling_id` (7,) is different from" + " `(n_non_media_channels,) = (2,)`.", + ): + model.Meridian( + input_data=self.input_data_non_media_and_organic, + model_spec=model_spec, + ) - tau_g_logprob = tf.reduce_sum( - getattr( - meridian.prior_broadcast, constants.TAU_G_EXCL_BASELINE - ).log_prob(par[constants.TAU_G_EXCL_BASELINE]) - ) - self.assertAllClose( - tau_g_logprob, - log_prob_parts["unpinned"][constants.TAU_G_EXCL_BASELINE][0], - ) + def test_base_geo_properties(self): + meridian = model.Meridian(input_data=self.input_data_non_media_and_organic) + self.assertEqual(meridian.n_geos, self._N_GEOS) + self.assertEqual(meridian.n_controls, self._N_CONTROLS) + self.assertEqual(meridian.n_non_media_channels, self._N_NON_MEDIA_CHANNELS) + self.assertEqual(meridian.n_times, self._N_TIMES) + self.assertEqual(meridian.n_media_times, self._N_MEDIA_TIMES) + self.assertFalse(meridian.is_national) + self.assertIsNotNone(meridian.prior_broadcast) + self.assertIsNotNone(meridian.inference_data) + self.assertNotIn(constants.PRIOR, meridian.inference_data.attrs) + self.assertNotIn(constants.POSTERIOR, meridian.inference_data.attrs) - posterior_unnormalized_logprob = ( - sum(prior_distribution_logprobs.values()) - + sum(coef_logprobs.values()) - + y_means_logprob - + tau_g_logprob - ) - self.assertAllClose( - posterior_unnormalized_logprob, - meridian._get_joint_dist().log_prob(par)[0], + def test_base_national_properties(self): + meridian = model.Meridian( + input_data=self.national_input_data_non_media_and_organic ) - - @parameterized.product( - paid_media_prior_type=[ - constants.PAID_MEDIA_PRIOR_TYPE_ROI, - constants.PAID_MEDIA_PRIOR_TYPE_MROI, - constants.PAID_MEDIA_PRIOR_TYPE_COEFFICIENT, - ], - media_effects_dist=[ - constants.MEDIA_EFFECTS_NORMAL, - constants.MEDIA_EFFECTS_LOG_NORMAL, - ], - ) - def test_get_joint_dist_with_log_prob_rf_only( - self, - paid_media_prior_type: str, - media_effects_dist: str, - ): - model_spec = spec.ModelSpec( - paid_media_prior_type=paid_media_prior_type, - media_effects_dist=media_effects_dist, - ) - meridian = model.Meridian( - model_spec=model_spec, - input_data=self.short_input_data_with_rf_only, - ) - - # Take a single draw of all parameters from the prior distribution. - par_structtuple = meridian._get_joint_dist_unpinned().sample(1) - par = par_structtuple._asdict() - - # Note that "y" is a draw from the prior predictive (transformed) outcome - # distribution. We drop it because "y" is already "pinned" in - # meridian._get_joint_dist() and is not actually a parameter. - del par["y"] - - # Note that the actual (transformed) outcome data is "pinned" as "y". - log_prob_parts_structtuple = meridian._get_joint_dist().log_prob_parts(par) - log_prob_parts = { - k: v._asdict() for k, v in log_prob_parts_structtuple._asdict().items() - } - - derived_params = [ - constants.BETA_GRF, - constants.GAMMA_GC, - constants.MU_T, - constants.TAU_G, - ] - prior_distribution_params = [ - constants.KNOT_VALUES, - constants.ETA_RF, - constants.GAMMA_C, - constants.XI_C, - constants.ALPHA_RF, - constants.EC_RF, - constants.SLOPE_RF, - constants.SIGMA, - ] - - if paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: - derived_params.append(constants.BETA_RF) - prior_distribution_params.append(constants.ROI_RF) - elif paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: - derived_params.append(constants.BETA_RF) - prior_distribution_params.append(constants.MROI_RF) - else: - prior_distribution_params.append(constants.BETA_RF) - - # Parameters that are derived from other parameters via Deterministic() - # should have zero contribution to log_prob. - for parname in derived_params: - self.assertAllEqual(log_prob_parts["unpinned"][parname][0], 0) - - prior_distribution_logprobs = {} - for parname in prior_distribution_params: - prior_distribution_logprobs[parname] = tf.reduce_sum( - getattr(meridian.prior_broadcast, parname).log_prob(par[parname]) - ) - self.assertAllClose( - prior_distribution_logprobs[parname], - log_prob_parts["unpinned"][parname][0], - ) - - coef_params = [ - constants.BETA_GRF_DEV, - constants.GAMMA_GC_DEV, - ] - coef_logprobs = {} - for parname in coef_params: - coef_logprobs[parname] = tf.reduce_sum( - tfp.distributions.Normal(0, 1).log_prob(par[parname]) - ) - self.assertAllClose( - coef_logprobs[parname], log_prob_parts["unpinned"][parname][0] - ) - transformed_reach = meridian.adstock_hill_rf( - reach=meridian.rf_tensors.reach_scaled, - frequency=meridian.rf_tensors.frequency, - alpha=par[constants.ALPHA_RF], - ec=par[constants.EC_RF], - slope=par[constants.SLOPE_RF], - )[0, :, :, :] - beta_rf = par[constants.BETA_GRF][0, :, :] - y_means = ( - par[constants.TAU_G][0, :, None] - + par[constants.MU_T][0, None, :] - + tf.einsum("gtm,gm->gt", transformed_reach, beta_rf) - + tf.einsum( - "gtc,gc->gt", - meridian.controls_scaled, - par[constants.GAMMA_GC][0, :, :], - ) - ) - y_means_logprob = tf.reduce_sum( - tfp.distributions.Normal(y_means, par[constants.SIGMA]).log_prob( - meridian.kpi_scaled - ) - ) - self.assertAllClose(y_means_logprob, log_prob_parts["pinned"]["y"][0]) - - tau_g_logprob = tf.reduce_sum( - getattr( - meridian.prior_broadcast, constants.TAU_G_EXCL_BASELINE - ).log_prob(par[constants.TAU_G_EXCL_BASELINE]) - ) - self.assertAllClose( - tau_g_logprob, - log_prob_parts["unpinned"][constants.TAU_G_EXCL_BASELINE][0], - ) - - posterior_unnormalized_logprob = ( - sum(prior_distribution_logprobs.values()) - + sum(coef_logprobs.values()) - + y_means_logprob - + tau_g_logprob - ) - self.assertAllClose( - posterior_unnormalized_logprob, - meridian._get_joint_dist().log_prob(par)[0], - ) - - # TODO: Add test for holdout_id. - @parameterized.product( - paid_media_prior_type=[ - constants.PAID_MEDIA_PRIOR_TYPE_ROI, - constants.PAID_MEDIA_PRIOR_TYPE_MROI, - constants.PAID_MEDIA_PRIOR_TYPE_COEFFICIENT, - ], - media_effects_dist=[ - constants.MEDIA_EFFECTS_NORMAL, - constants.MEDIA_EFFECTS_LOG_NORMAL, - ], - ) - def test_get_joint_dist_with_log_prob_media_and_rf( - self, - paid_media_prior_type: str, - media_effects_dist: str, - ): - model_spec = spec.ModelSpec( - paid_media_prior_type=paid_media_prior_type, - media_effects_dist=media_effects_dist, - ) - meridian = model.Meridian( - model_spec=model_spec, - input_data=self.short_input_data_with_media_and_rf, - ) - - # Take a single draw of all parameters from the prior distribution. - par_structtuple = meridian._get_joint_dist_unpinned().sample(1) - par = par_structtuple._asdict() - - # Note that "y" is a draw from the prior predictive (transformed) outcome - # distribution. We drop it because "y" is already "pinned" in - # meridian._get_joint_dist() and is not actually a parameter. - del par["y"] - - # Note that the actual (transformed) outcome data is "pinned" as "y". - log_prob_parts_structtuple = meridian._get_joint_dist().log_prob_parts(par) - log_prob_parts = { - k: v._asdict() for k, v in log_prob_parts_structtuple._asdict().items() - } - - derived_params = [ - constants.BETA_GM, - constants.BETA_GRF, - constants.GAMMA_GC, - constants.MU_T, - constants.TAU_G, - ] - prior_distribution_params = [ - constants.KNOT_VALUES, - constants.ETA_M, - constants.ETA_RF, - constants.GAMMA_C, - constants.XI_C, - constants.ALPHA_M, - constants.ALPHA_RF, - constants.EC_M, - constants.EC_RF, - constants.SLOPE_M, - constants.SLOPE_RF, - constants.SIGMA, - ] - - if paid_media_prior_type in constants.PAID_MEDIA_PRIOR_TYPE_ROI: - derived_params.append(constants.BETA_M) - derived_params.append(constants.BETA_RF) - prior_distribution_params.append(constants.ROI_M) - prior_distribution_params.append(constants.ROI_RF) - elif paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: - derived_params.append(constants.BETA_M) - derived_params.append(constants.BETA_RF) - prior_distribution_params.append(constants.MROI_M) - prior_distribution_params.append(constants.MROI_RF) - else: - prior_distribution_params.append(constants.BETA_M) - prior_distribution_params.append(constants.BETA_RF) - - # Parameters that are derived from other parameters via Deterministic() - # should have zero contribution to log_prob. - for parname in derived_params: - self.assertAllEqual(log_prob_parts["unpinned"][parname][0], 0) - - prior_distribution_logprobs = {} - for parname in prior_distribution_params: - prior_distribution_logprobs[parname] = tf.reduce_sum( - getattr(meridian.prior_broadcast, parname).log_prob(par[parname]) - ) - self.assertAllClose( - prior_distribution_logprobs[parname], - log_prob_parts["unpinned"][parname][0], - ) - - coef_params = [ - constants.BETA_GM_DEV, - constants.BETA_GRF_DEV, - constants.GAMMA_GC_DEV, - ] - coef_logprobs = {} - for parname in coef_params: - coef_logprobs[parname] = tf.reduce_sum( - tfp.distributions.Normal(0, 1).log_prob(par[parname]) - ) - self.assertAllClose( - coef_logprobs[parname], log_prob_parts["unpinned"][parname][0] - ) - transformed_media = meridian.adstock_hill_media( - media=meridian.media_tensors.media_scaled, - alpha=par[constants.ALPHA_M], - ec=par[constants.EC_M], - slope=par[constants.SLOPE_M], - )[0, :, :, :] - transformed_reach = meridian.adstock_hill_rf( - reach=meridian.rf_tensors.reach_scaled, - frequency=meridian.rf_tensors.frequency, - alpha=par[constants.ALPHA_RF], - ec=par[constants.EC_RF], - slope=par[constants.SLOPE_RF], - )[0, :, :, :] - combined_transformed_media = tf.concat( - [transformed_media, transformed_reach], axis=-1 - ) - - combined_beta = tf.concat( - [par[constants.BETA_GM][0, :, :], par[constants.BETA_GRF][0, :, :]], - axis=-1, - ) - y_means = ( - par[constants.TAU_G][0, :, None] - + par[constants.MU_T][0, None, :] - + tf.einsum("gtm,gm->gt", combined_transformed_media, combined_beta) - + tf.einsum( - "gtc,gc->gt", - meridian.controls_scaled, - par[constants.GAMMA_GC][0, :, :], - ) - ) - y_means_logprob = tf.reduce_sum( - tfp.distributions.Normal(y_means, par[constants.SIGMA]).log_prob( - meridian.kpi_scaled - ) - ) - self.assertAllClose(y_means_logprob, log_prob_parts["pinned"]["y"][0]) - - tau_g_logprob = tf.reduce_sum( - getattr( - meridian.prior_broadcast, constants.TAU_G_EXCL_BASELINE - ).log_prob(par[constants.TAU_G_EXCL_BASELINE]) - ) - self.assertAllClose( - tau_g_logprob, - log_prob_parts["unpinned"][constants.TAU_G_EXCL_BASELINE][0], - ) - - posterior_unnormalized_logprob = ( - sum(prior_distribution_logprobs.values()) - + sum(coef_logprobs.values()) - + y_means_logprob - + tau_g_logprob - ) - self.assertAllClose( - posterior_unnormalized_logprob, - meridian._get_joint_dist().log_prob(par)[0], - ) - - def test_sample_prior_seed_same_seed(self): - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - meridian.sample_prior(n_draws=self._N_DRAWS, seed=1) - meridian2 = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - meridian2.sample_prior(n_draws=self._N_DRAWS, seed=1) - self.assertEqual( - meridian.inference_data.prior, meridian2.inference_data.prior - ) - - def test_sample_prior_different_seed(self): - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - meridian.sample_prior(n_draws=self._N_DRAWS, seed=1) - meridian2 = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - meridian2.sample_prior(n_draws=self._N_DRAWS, seed=2) - - self.assertNotEqual( - meridian.inference_data.prior, meridian2.inference_data.prior - ) - - def test_sample_prior_media_and_rf_returns_correct_shape(self): - self.enter_context( - mock.patch.object( - model.Meridian, - "_sample_prior_fn", - autospec=True, - return_value=self.test_dist_media_and_rf, - ) - ) - - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - meridian.sample_prior(n_draws=self._N_DRAWS) - knots_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) - control_shape = (1, self._N_DRAWS, self._N_CONTROLS) - media_channel_shape = (1, self._N_DRAWS, self._N_MEDIA_CHANNELS) - rf_channel_shape = (1, self._N_DRAWS, self._N_RF_CHANNELS) - sigma_shape = ( - (1, self._N_DRAWS, self._N_GEOS) - if meridian.unique_sigma_for_each_geo - else (1, self._N_DRAWS, 1) - ) - geo_shape = (1, self._N_DRAWS, self._N_GEOS) - time_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) - geo_control_shape = geo_shape + (self._N_CONTROLS,) - geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) - geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) - - media_parameters = list(constants.MEDIA_PARAMETER_NAMES) - media_parameters.remove(constants.BETA_GM) - rf_parameters = list(constants.RF_PARAMETER_NAMES) - rf_parameters.remove(constants.BETA_GRF) - - prior = meridian.inference_data.prior - shape_to_params = { - knots_shape: [ - getattr(prior, attr) for attr in constants.KNOTS_PARAMETERS - ], - media_channel_shape: [ - getattr(prior, attr) for attr in media_parameters - ], - rf_channel_shape: [ - getattr(prior, attr) for attr in rf_parameters - ], - control_shape: [ - getattr(prior, attr) for attr in constants.CONTROL_PARAMETERS - ], - sigma_shape: [ - getattr(prior, attr) for attr in constants.SIGMA_PARAMETERS - ], - geo_shape: [getattr(prior, attr) for attr in constants.GEO_PARAMETERS], - time_shape: [ - getattr(prior, attr) for attr in constants.TIME_PARAMETERS - ], - geo_control_shape: [ - getattr(prior, attr) for attr in constants.GEO_CONTROL_PARAMETERS - ], - geo_media_channel_shape: [ - getattr(prior, attr) for attr in constants.GEO_MEDIA_PARAMETERS - ], - geo_rf_channel_shape: [ - getattr(prior, attr) for attr in constants.GEO_RF_PARAMETERS - ], - } - for shape, params in shape_to_params.items(): - for param in params: - self.assertEqual(param.shape, shape) - - def test_sample_prior_media_only_returns_correct_shape(self): - self.enter_context( - mock.patch.object( - model.Meridian, - "_sample_prior_fn", - autospec=True, - return_value=self.test_dist_media_only, - ) - ) - - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_media_only, - model_spec=model_spec, - ) - meridian.sample_prior(n_draws=self._N_DRAWS) - knots_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) - control_shape = (1, self._N_DRAWS, self._N_CONTROLS) - media_channel_shape = (1, self._N_DRAWS, self._N_MEDIA_CHANNELS) - sigma_shape = ( - (1, self._N_DRAWS, self._N_GEOS) - if meridian.unique_sigma_for_each_geo - else (1, self._N_DRAWS, 1) - ) - geo_shape = (1, self._N_DRAWS, self._N_GEOS) - time_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) - geo_control_shape = geo_shape + (self._N_CONTROLS,) - geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) - - media_parameters = list(constants.MEDIA_PARAMETER_NAMES) - media_parameters.remove(constants.BETA_GM) - - prior = meridian.inference_data.prior - shape_to_params = { - knots_shape: [ - getattr(prior, attr) for attr in constants.KNOTS_PARAMETERS - ], - media_channel_shape: [ - getattr(prior, attr) for attr in media_parameters - ], - control_shape: [ - getattr(prior, attr) for attr in constants.CONTROL_PARAMETERS - ], - sigma_shape: [ - getattr(prior, attr) for attr in constants.SIGMA_PARAMETERS - ], - geo_shape: [getattr(prior, attr) for attr in constants.GEO_PARAMETERS], - time_shape: [ - getattr(prior, attr) for attr in constants.TIME_PARAMETERS - ], - geo_control_shape: [ - getattr(prior, attr) for attr in constants.GEO_CONTROL_PARAMETERS - ], - geo_media_channel_shape: [ - getattr(prior, attr) for attr in constants.GEO_MEDIA_PARAMETERS - ], - } - for shape, params in shape_to_params.items(): - for param in params: - self.assertEqual(param.shape, shape) - - def test_sample_prior_rf_only_returns_correct_shape(self): - self.enter_context( - mock.patch.object( - model.Meridian, - "_sample_prior_fn", - autospec=True, - return_value=self.test_dist_rf_only, - ) - ) - - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_rf_only, - model_spec=model_spec, - ) - meridian.sample_prior(n_draws=self._N_DRAWS) - knots_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) - control_shape = (1, self._N_DRAWS, self._N_CONTROLS) - rf_channel_shape = (1, self._N_DRAWS, self._N_RF_CHANNELS) - sigma_shape = ( - (1, self._N_DRAWS, self._N_GEOS) - if meridian.unique_sigma_for_each_geo - else (1, self._N_DRAWS, 1) - ) - geo_shape = (1, self._N_DRAWS, self._N_GEOS) - time_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) - geo_control_shape = geo_shape + (self._N_CONTROLS,) - geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) - - prior = meridian.inference_data.prior - shape_to_params = { - knots_shape: [ - getattr(prior, attr) for attr in constants.KNOTS_PARAMETERS - ], - rf_channel_shape: [ - getattr(prior, attr) for attr in constants.RF_PARAMETER_NAMES - ], - control_shape: [ - getattr(prior, attr) for attr in constants.CONTROL_PARAMETERS - ], - sigma_shape: [ - getattr(prior, attr) for attr in constants.SIGMA_PARAMETERS - ], - geo_shape: [getattr(prior, attr) for attr in constants.GEO_PARAMETERS], - time_shape: [ - getattr(prior, attr) for attr in constants.TIME_PARAMETERS - ], - geo_control_shape: [ - getattr(prior, attr) for attr in constants.GEO_CONTROL_PARAMETERS - ], - geo_rf_channel_shape: [ - getattr(prior, attr) for attr in constants.GEO_RF_PARAMETERS - ], - } - for shape, params in shape_to_params.items(): - for param in params: - self.assertEqual(param.shape, shape) - - def test_sample_posterior_media_and_rf_returns_correct_shape(self): - mock_sample_posterior = self.enter_context( - mock.patch.object( - model, - "_xla_windowed_adaptive_nuts", - autospec=True, - return_value=collections.namedtuple( - "StatesAndTrace", ["all_states", "trace"] - )( - all_states=self.test_posterior_states_media_and_rf, - trace=self.test_trace, - ), - ) - ) - model_spec = spec.ModelSpec( - roi_calibration_period=self._ROI_CALIBRATION_PERIOD, - rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - - meridian.sample_posterior( - n_chains=self._N_CHAINS, - n_adapt=self._N_ADAPT, - n_burnin=self._N_BURNIN, - n_keep=self._N_KEEP, - ) - mock_sample_posterior.assert_called_with( - n_draws=self._N_BURNIN + self._N_KEEP, - joint_dist=mock.ANY, - n_chains=self._N_CHAINS, - num_adaptation_steps=self._N_ADAPT, - current_state=None, - init_step_size=None, - dual_averaging_kwargs=None, - max_tree_depth=10, - max_energy_diff=500.0, - unrolled_leapfrog_steps=1, - parallel_iterations=10, - seed=None, - ) - knots_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) - control_shape = (self._N_CHAINS, self._N_KEEP, self._N_CONTROLS) - media_channel_shape = (self._N_CHAINS, self._N_KEEP, self._N_MEDIA_CHANNELS) - rf_channel_shape = (self._N_CHAINS, self._N_KEEP, self._N_RF_CHANNELS) - sigma_shape = ( - (self._N_CHAINS, self._N_KEEP, self._N_GEOS) - if meridian.unique_sigma_for_each_geo - else (self._N_CHAINS, self._N_KEEP, 1) - ) - geo_shape = (self._N_CHAINS, self._N_KEEP, self._N_GEOS) - time_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) - geo_control_shape = geo_shape + (self._N_CONTROLS,) - geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) - geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) - - media_parameters = list(constants.MEDIA_PARAMETER_NAMES) - media_parameters.remove(constants.BETA_GM) - rf_parameters = list(constants.RF_PARAMETER_NAMES) - rf_parameters.remove(constants.BETA_GRF) - - posterior = meridian.inference_data.posterior - shape_to_params = { - knots_shape: [ - getattr(posterior, attr) for attr in constants.KNOTS_PARAMETERS - ], - control_shape: [ - getattr(posterior, attr) for attr in constants.CONTROL_PARAMETERS - ], - media_channel_shape: [ - getattr(posterior, attr) for attr in media_parameters - ], - rf_channel_shape: [ - getattr(posterior, attr) for attr in rf_parameters - ], - sigma_shape: [ - getattr(posterior, attr) for attr in constants.SIGMA_PARAMETERS - ], - geo_shape: [ - getattr(posterior, attr) for attr in constants.GEO_PARAMETERS - ], - time_shape: [ - getattr(posterior, attr) for attr in constants.TIME_PARAMETERS - ], - geo_control_shape: [ - getattr(posterior, attr) - for attr in constants.GEO_CONTROL_PARAMETERS - ], - geo_media_channel_shape: [ - getattr(posterior, attr) for attr in constants.GEO_MEDIA_PARAMETERS - ], - geo_rf_channel_shape: [ - getattr(posterior, attr) for attr in constants.GEO_RF_PARAMETERS - ], - } - for shape, params in shape_to_params.items(): - for param in params: - self.assertEqual(param.shape, shape) - - for attr in [ - constants.STEP_SIZE, - constants.TARGET_LOG_PROBABILITY_ARVIZ, - constants.DIVERGING, - constants.N_STEPS, - ]: - self.assertEqual( - getattr(meridian.inference_data.sample_stats, attr).shape, - ( - self._N_CHAINS, - self._N_KEEP, - ), - ) - for attr in [ - constants.STEP_SIZE, - constants.TUNE, - constants.TARGET_LOG_PROBABILITY_TF, - constants.DIVERGING, - constants.ACCEPT_RATIO, - constants.N_STEPS, - ]: - self.assertEqual( - getattr(meridian.inference_data.trace, attr).shape, - ( - self._N_CHAINS, - self._N_KEEP, - ), - ) - - def test_sample_posterior_media_only_returns_correct_shape(self): - mock_sample_posterior = self.enter_context( - mock.patch.object( - model, - "_xla_windowed_adaptive_nuts", - autospec=True, - return_value=collections.namedtuple( - "StatesAndTrace", ["all_states", "trace"] - )( - all_states=self.test_posterior_states_media_only, - trace=self.test_trace, - ), - ) - ) - model_spec = spec.ModelSpec( - roi_calibration_period=self._ROI_CALIBRATION_PERIOD, - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_media_only, - model_spec=model_spec, - ) - - meridian.sample_posterior( - n_chains=self._N_CHAINS, - n_adapt=self._N_ADAPT, - n_burnin=self._N_BURNIN, - n_keep=self._N_KEEP, - ) - mock_sample_posterior.assert_called_with( - n_draws=self._N_BURNIN + self._N_KEEP, - joint_dist=mock.ANY, - n_chains=self._N_CHAINS, - num_adaptation_steps=self._N_ADAPT, - current_state=None, - init_step_size=None, - dual_averaging_kwargs=None, - max_tree_depth=10, - max_energy_diff=500.0, - unrolled_leapfrog_steps=1, - parallel_iterations=10, - seed=None, - ) - knots_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) - control_shape = (self._N_CHAINS, self._N_KEEP, self._N_CONTROLS) - media_channel_shape = (self._N_CHAINS, self._N_KEEP, self._N_MEDIA_CHANNELS) - sigma_shape = ( - (self._N_CHAINS, self._N_KEEP, self._N_GEOS) - if meridian.unique_sigma_for_each_geo - else (self._N_CHAINS, self._N_KEEP, 1) - ) - geo_shape = (self._N_CHAINS, self._N_KEEP, self._N_GEOS) - time_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) - geo_control_shape = geo_shape + (self._N_CONTROLS,) - geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) - - media_parameters = list(constants.MEDIA_PARAMETER_NAMES) - media_parameters.remove(constants.BETA_GM) - - posterior = meridian.inference_data.posterior - shape_to_params = { - knots_shape: [ - getattr(posterior, attr) for attr in constants.KNOTS_PARAMETERS - ], - control_shape: [ - getattr(posterior, attr) for attr in constants.CONTROL_PARAMETERS - ], - media_channel_shape: [ - getattr(posterior, attr) for attr in media_parameters - ], - sigma_shape: [ - getattr(posterior, attr) for attr in constants.SIGMA_PARAMETERS - ], - geo_shape: [ - getattr(posterior, attr) for attr in constants.GEO_PARAMETERS - ], - time_shape: [ - getattr(posterior, attr) for attr in constants.TIME_PARAMETERS - ], - geo_control_shape: [ - getattr(posterior, attr) - for attr in constants.GEO_CONTROL_PARAMETERS - ], - geo_media_channel_shape: [ - getattr(posterior, attr) for attr in constants.GEO_MEDIA_PARAMETERS - ], - } - for shape, params in shape_to_params.items(): - for param in params: - self.assertEqual(param.shape, shape) - - for attr in [ - constants.STEP_SIZE, - constants.TARGET_LOG_PROBABILITY_ARVIZ, - constants.DIVERGING, - constants.N_STEPS, - ]: - self.assertEqual( - getattr(meridian.inference_data.sample_stats, attr).shape, - ( - self._N_CHAINS, - self._N_KEEP, - ), - ) - for attr in [ - constants.STEP_SIZE, - constants.TUNE, - constants.TARGET_LOG_PROBABILITY_TF, - constants.DIVERGING, - constants.ACCEPT_RATIO, - constants.N_STEPS, - ]: - self.assertEqual( - getattr(meridian.inference_data.trace, attr).shape, - ( - self._N_CHAINS, - self._N_KEEP, - ), - ) - - def test_sample_posterior_rf_only_returns_correct_shape(self): - mock_sample_posterior = self.enter_context( - mock.patch.object( - model, - "_xla_windowed_adaptive_nuts", - autospec=True, - return_value=collections.namedtuple( - "StatesAndTrace", ["all_states", "trace"] - )( - all_states=self.test_posterior_states_rf_only, - trace=self.test_trace, - ), - ) - ) - model_spec = spec.ModelSpec( - rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_rf_only, - model_spec=model_spec, - ) - - meridian.sample_posterior( - n_chains=self._N_CHAINS, - n_adapt=self._N_ADAPT, - n_burnin=self._N_BURNIN, - n_keep=self._N_KEEP, - ) - mock_sample_posterior.assert_called_with( - n_draws=self._N_BURNIN + self._N_KEEP, - joint_dist=mock.ANY, - n_chains=self._N_CHAINS, - num_adaptation_steps=self._N_ADAPT, - current_state=None, - init_step_size=None, - dual_averaging_kwargs=None, - max_tree_depth=10, - max_energy_diff=500.0, - unrolled_leapfrog_steps=1, - parallel_iterations=10, - seed=None, - ) - knots_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) - control_shape = (self._N_CHAINS, self._N_KEEP, self._N_CONTROLS) - rf_channel_shape = (self._N_CHAINS, self._N_KEEP, self._N_RF_CHANNELS) - sigma_shape = ( - (self._N_CHAINS, self._N_KEEP, self._N_GEOS) - if meridian.unique_sigma_for_each_geo - else (self._N_CHAINS, self._N_KEEP, 1) - ) - geo_shape = (self._N_CHAINS, self._N_KEEP, self._N_GEOS) - time_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) - geo_control_shape = geo_shape + (self._N_CONTROLS,) - geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) - - rf_parameters = list(constants.RF_PARAMETER_NAMES) - rf_parameters.remove(constants.BETA_GRF) - - posterior = meridian.inference_data.posterior - shape_to_params = { - knots_shape: [ - getattr(posterior, attr) for attr in constants.KNOTS_PARAMETERS - ], - control_shape: [ - getattr(posterior, attr) for attr in constants.CONTROL_PARAMETERS - ], - rf_channel_shape: [ - getattr(posterior, attr) for attr in rf_parameters - ], - sigma_shape: [ - getattr(posterior, attr) for attr in constants.SIGMA_PARAMETERS - ], - geo_shape: [ - getattr(posterior, attr) for attr in constants.GEO_PARAMETERS - ], - time_shape: [ - getattr(posterior, attr) for attr in constants.TIME_PARAMETERS - ], - geo_control_shape: [ - getattr(posterior, attr) - for attr in constants.GEO_CONTROL_PARAMETERS - ], - geo_rf_channel_shape: [ - getattr(posterior, attr) for attr in constants.GEO_RF_PARAMETERS - ], - } - for shape, params in shape_to_params.items(): - for param in params: - self.assertEqual(param.shape, shape) - - for attr in [ - constants.STEP_SIZE, - constants.TARGET_LOG_PROBABILITY_ARVIZ, - constants.DIVERGING, - constants.N_STEPS, - ]: - self.assertEqual( - getattr(meridian.inference_data.sample_stats, attr).shape, - ( - self._N_CHAINS, - self._N_KEEP, - ), - ) - for attr in [ - constants.STEP_SIZE, - constants.TUNE, - constants.TARGET_LOG_PROBABILITY_TF, - constants.DIVERGING, - constants.ACCEPT_RATIO, - constants.N_STEPS, - ]: - self.assertEqual( - getattr(meridian.inference_data.trace, attr).shape, - ( - self._N_CHAINS, - self._N_KEEP, - ), - ) - - def test_sample_posterior_media_and_rf_sequential_returns_correct_shape(self): - mock_sample_posterior = self.enter_context( - mock.patch.object( - model, - "_xla_windowed_adaptive_nuts", - autospec=True, - return_value=collections.namedtuple( - "StatesAndTrace", ["all_states", "trace"] - )( - all_states=self.test_posterior_states_media_and_rf, - trace=self.test_trace, - ), - ) - ) - model_spec = spec.ModelSpec( - roi_calibration_period=self._ROI_CALIBRATION_PERIOD, - rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - - meridian.sample_posterior( - n_chains=[self._N_CHAINS, self._N_CHAINS], - n_adapt=self._N_ADAPT, - n_burnin=self._N_BURNIN, - n_keep=self._N_KEEP, - ) - mock_sample_posterior.assert_called_with( - n_draws=self._N_BURNIN + self._N_KEEP, - joint_dist=mock.ANY, - n_chains=self._N_CHAINS, - num_adaptation_steps=self._N_ADAPT, - current_state=None, - init_step_size=None, - dual_averaging_kwargs=None, - max_tree_depth=10, - max_energy_diff=500.0, - unrolled_leapfrog_steps=1, - parallel_iterations=10, - seed=None, - ) - n_total_chains = self._N_CHAINS * 2 - knots_shape = (n_total_chains, self._N_KEEP, self._N_TIMES_SHORT) - control_shape = (n_total_chains, self._N_KEEP, self._N_CONTROLS) - media_channel_shape = (n_total_chains, self._N_KEEP, self._N_MEDIA_CHANNELS) - rf_channel_shape = (n_total_chains, self._N_KEEP, self._N_RF_CHANNELS) - sigma_shape = ( - (n_total_chains, self._N_KEEP, self._N_GEOS) - if meridian.unique_sigma_for_each_geo - else (n_total_chains, self._N_KEEP, 1) - ) - geo_shape = (n_total_chains, self._N_KEEP, self._N_GEOS) - time_shape = (n_total_chains, self._N_KEEP, self._N_TIMES_SHORT) - geo_control_shape = geo_shape + (self._N_CONTROLS,) - geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) - geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) - - media_parameters = list(constants.MEDIA_PARAMETER_NAMES) - media_parameters.remove(constants.BETA_GM) - rf_parameters = list(constants.RF_PARAMETER_NAMES) - rf_parameters.remove(constants.BETA_GRF) - - posterior = meridian.inference_data.posterior - shape_to_params = { - knots_shape: [ - getattr(posterior, attr) for attr in constants.KNOTS_PARAMETERS - ], - control_shape: [ - getattr(posterior, attr) for attr in constants.CONTROL_PARAMETERS - ], - media_channel_shape: [ - getattr(posterior, attr) for attr in media_parameters - ], - rf_channel_shape: [ - getattr(posterior, attr) for attr in rf_parameters - ], - sigma_shape: [ - getattr(posterior, attr) for attr in constants.SIGMA_PARAMETERS - ], - geo_shape: [ - getattr(posterior, attr) for attr in constants.GEO_PARAMETERS - ], - time_shape: [ - getattr(posterior, attr) for attr in constants.TIME_PARAMETERS - ], - geo_control_shape: [ - getattr(posterior, attr) - for attr in constants.GEO_CONTROL_PARAMETERS - ], - geo_media_channel_shape: [ - getattr(posterior, attr) for attr in constants.GEO_MEDIA_PARAMETERS - ], - geo_rf_channel_shape: [ - getattr(posterior, attr) for attr in constants.GEO_RF_PARAMETERS - ], - } - for shape, params in shape_to_params.items(): - for param in params: - self.assertEqual(param.shape, shape) - - for attr in [ - constants.STEP_SIZE, - constants.TARGET_LOG_PROBABILITY_ARVIZ, - constants.DIVERGING, - constants.N_STEPS, - ]: - self.assertEqual( - getattr(meridian.inference_data.sample_stats, attr).shape, - ( - n_total_chains, - self._N_KEEP, - ), - ) - for attr in [ - constants.STEP_SIZE, - constants.TUNE, - constants.TARGET_LOG_PROBABILITY_TF, - constants.DIVERGING, - constants.ACCEPT_RATIO, - constants.N_STEPS, - ]: - self.assertEqual( - getattr(meridian.inference_data.trace, attr).shape, - ( - n_total_chains, - self._N_KEEP, - ), - ) - - def test_sample_posterior_raises_oom_error_when_limits_exceeded(self): - self.enter_context( - mock.patch.object( - model, - "_xla_windowed_adaptive_nuts", - autospec=True, - side_effect=tf.errors.ResourceExhaustedError( - None, None, "Resource exhausted" - ), - ) - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=spec.ModelSpec(), - ) - - with self.assertRaises(model.MCMCOOMError): - meridian.sample_posterior( - n_chains=self._N_CHAINS, - n_adapt=self._N_ADAPT, - n_burnin=self._N_BURNIN, - n_keep=self._N_KEEP, - ) - - def test_save_and_load_works(self): - # The create_tempdir() method below internally uses command line flag - # (--test_tmpdir) and such flags are not marked as parsed by default - # when running with pytest. Marking as parsed directly here to make the - # pytest run pass. - flags.FLAGS.mark_as_parsed() - file_path = os.path.join(self.create_tempdir().full_path, "joblib") - mmm = model.Meridian(input_data=self.input_data_with_media_and_rf) - model.save_mmm(mmm, str(file_path)) - self.assertTrue(os.path.exists(file_path)) - new_mmm = model.load_mmm(file_path) - for attr in dir(mmm): - if isinstance(getattr(mmm, attr), (int, bool)): - with self.subTest(name=attr): - self.assertEqual(getattr(mmm, attr), getattr(new_mmm, attr)) - elif isinstance(getattr(mmm, attr), tf.Tensor): - with self.subTest(name=attr): - self.assertAllClose(getattr(mmm, attr), getattr(new_mmm, attr)) - - def test_load_error(self): - with self.assertRaisesWithLiteralMatch( - FileNotFoundError, "No such file or directory: this/path/does/not/exist" - ): - model.load_mmm("this/path/does/not/exist") - - def test_injected_sample_prior_media_and_rf_returns_correct_shape(self): - """Checks validation passes with correct shapes.""" - self.enter_context( - mock.patch.object( - model.Meridian, - "_sample_prior_fn", - autospec=True, - return_value=self.test_dist_media_and_rf, - ) - ) - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - meridian.sample_prior(n_draws=self._N_DRAWS) - inference_data = meridian.inference_data - - meridian_with_inference_data = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - inference_data=inference_data, - ) - - self.assertEqual( - meridian_with_inference_data.inference_data, inference_data - ) - - def test_injected_sample_posterior_media_and_rf_returns_correct_shape(self): - """Checks validation passes with correct shapes.""" - self.enter_context( - mock.patch.object( - model, - "_xla_windowed_adaptive_nuts", - autospec=True, - return_value=collections.namedtuple( - "StatesAndTrace", ["all_states", "trace"] - )( - all_states=self.test_posterior_states_media_and_rf, - trace=self.test_trace, - ), - ) - ) - model_spec = spec.ModelSpec( - roi_calibration_period=self._ROI_CALIBRATION_PERIOD, - rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - - meridian.sample_posterior( - n_chains=self._N_CHAINS, - n_adapt=self._N_ADAPT, - n_burnin=self._N_BURNIN, - n_keep=self._N_KEEP, - ) - inference_data = meridian.inference_data - meridian_with_inference_data = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - inference_data=inference_data, - ) - - self.assertEqual( - meridian_with_inference_data.inference_data, inference_data - ) - - def test_injected_sample_prior_media_only_returns_correct_shape(self): - """Checks validation passes with correct shapes.""" - self.enter_context( - mock.patch.object( - model.Meridian, - "_sample_prior_fn", - autospec=True, - return_value=self.test_dist_media_only, - ) - ) - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_media_only, - model_spec=model_spec, - ) - meridian.sample_prior(n_draws=self._N_DRAWS) - inference_data = meridian.inference_data - - meridian_with_inference_data = model.Meridian( - input_data=self.short_input_data_with_media_only, - model_spec=model_spec, - inference_data=inference_data, - ) - - self.assertEqual( - meridian_with_inference_data.inference_data, inference_data - ) - - def test_injected_sample_posterior_media_only_returns_correct_shape(self): - """Checks validation passes with correct shapes.""" - self.enter_context( - mock.patch.object( - model, - "_xla_windowed_adaptive_nuts", - autospec=True, - return_value=collections.namedtuple( - "StatesAndTrace", ["all_states", "trace"] - )( - all_states=self.test_posterior_states_media_only, - trace=self.test_trace, - ), - ) - ) - model_spec = spec.ModelSpec( - roi_calibration_period=self._ROI_CALIBRATION_PERIOD, - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_media_only, - model_spec=model_spec, - ) - - meridian.sample_posterior( - n_chains=self._N_CHAINS, - n_adapt=self._N_ADAPT, - n_burnin=self._N_BURNIN, - n_keep=self._N_KEEP, - ) - inference_data = meridian.inference_data - meridian_with_inference_data = model.Meridian( - input_data=self.short_input_data_with_media_only, - model_spec=model_spec, - inference_data=inference_data, - ) - - self.assertEqual( - meridian_with_inference_data.inference_data, inference_data - ) - - def test_injected_sample_prior_rf_only_returns_correct_shape(self): - """Checks validation passes with correct shapes.""" - self.enter_context( - mock.patch.object( - model.Meridian, - "_sample_prior_fn", - autospec=True, - return_value=self.test_dist_rf_only, - ) - ) - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_rf_only, - model_spec=model_spec, - ) - meridian.sample_prior(n_draws=self._N_DRAWS) - inference_data = meridian.inference_data - - meridian_with_inference_data = model.Meridian( - input_data=self.short_input_data_with_rf_only, - model_spec=model_spec, - inference_data=inference_data, - ) - - self.assertEqual( - meridian_with_inference_data.inference_data, inference_data - ) - - @parameterized.named_parameters( - dict( - testcase_name="control_variables", - coord=constants.CONTROL_VARIABLE, - mismatched_priors={ - constants.GAMMA_C: (1, _N_DRAWS, _N_CONTROLS + 1), - constants.GAMMA_GC: (1, _N_DRAWS, _N_GEOS, _N_CONTROLS + 1), - constants.XI_C: (1, _N_DRAWS, _N_CONTROLS + 1), - }, - mismatched_coord_size=_N_CONTROLS + 1, - expected_coord_size=_N_CONTROLS, - ), - dict( - testcase_name="geos", - coord=constants.GEO, - mismatched_priors={ - constants.BETA_GM: (1, _N_DRAWS, _N_GEOS + 1, _N_MEDIA_CHANNELS), - constants.BETA_GRF: ( - 1, - _N_DRAWS, - _N_GEOS + 1, - _N_RF_CHANNELS, - ), - constants.GAMMA_GC: ( - 1, - _N_DRAWS, - _N_GEOS + 1, - _N_CONTROLS, - ), - constants.TAU_G: (1, _N_DRAWS, _N_GEOS + 1), - }, - mismatched_coord_size=_N_GEOS + 1, - expected_coord_size=_N_GEOS, - ), - dict( - testcase_name="knots", - coord=constants.KNOTS, - mismatched_priors={ - constants.KNOT_VALUES: ( - 1, - _N_DRAWS, - _N_TIMES_SHORT + 1, - ), - }, - mismatched_coord_size=_N_TIMES_SHORT + 1, - expected_coord_size=_N_TIMES_SHORT, - ), - dict( - testcase_name="times", - coord=constants.TIME, - mismatched_priors={ - constants.MU_T: (1, _N_DRAWS, _N_TIMES_SHORT + 1), - }, - mismatched_coord_size=_N_TIMES_SHORT + 1, - expected_coord_size=_N_TIMES_SHORT, - ), - dict( - testcase_name="sigma_dims", - coord=constants.SIGMA_DIM, - mismatched_priors={ - constants.SIGMA: (1, _N_DRAWS, _N_GEOS_NATIONAL + 1), - }, - mismatched_coord_size=_N_GEOS_NATIONAL + 1, - expected_coord_size=_N_GEOS_NATIONAL, - ), - dict( - testcase_name="media_channels", - coord=constants.MEDIA_CHANNEL, - mismatched_priors={ - constants.ALPHA_M: (1, _N_DRAWS, _N_MEDIA_CHANNELS + 1), - constants.BETA_GM: ( - 1, - _N_DRAWS, - _N_GEOS, - _N_MEDIA_CHANNELS + 1, - ), - constants.BETA_M: (1, _N_DRAWS, _N_MEDIA_CHANNELS + 1), - constants.EC_M: (1, _N_DRAWS, _N_MEDIA_CHANNELS + 1), - constants.ETA_M: (1, _N_DRAWS, _N_MEDIA_CHANNELS + 1), - constants.ROI_M: (1, _N_DRAWS, _N_MEDIA_CHANNELS + 1), - constants.SLOPE_M: (1, _N_DRAWS, _N_MEDIA_CHANNELS + 1), - }, - mismatched_coord_size=_N_MEDIA_CHANNELS + 1, - expected_coord_size=_N_MEDIA_CHANNELS, - ), - dict( - testcase_name="rf_channels", - coord=constants.RF_CHANNEL, - mismatched_priors={ - constants.ALPHA_RF: (1, _N_DRAWS, _N_RF_CHANNELS + 1), - constants.BETA_GRF: ( - 1, - _N_DRAWS, - _N_GEOS, - _N_RF_CHANNELS + 1, - ), - constants.BETA_RF: (1, _N_DRAWS, _N_RF_CHANNELS + 1), - constants.EC_RF: (1, _N_DRAWS, _N_RF_CHANNELS + 1), - constants.ETA_RF: (1, _N_DRAWS, _N_RF_CHANNELS + 1), - constants.ROI_RF: (1, _N_DRAWS, _N_RF_CHANNELS + 1), - constants.SLOPE_RF: (1, _N_DRAWS, _N_RF_CHANNELS + 1), - }, - mismatched_coord_size=_N_RF_CHANNELS + 1, - expected_coord_size=_N_RF_CHANNELS, - ), - ) - def test_validate_injected_inference_data_prior_incorrect_coordinates( - self, coord, mismatched_priors, mismatched_coord_size, expected_coord_size - ): - """Checks prior validation fails with incorrect coordinates.""" - model_spec = spec.ModelSpec() - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - prior_samples = meridian._sample_prior_fn(self._N_DRAWS) - prior_coords = meridian._create_inference_data_coords(1, self._N_DRAWS) - prior_dims = meridian._create_inference_data_dims() - - prior_samples = dict(prior_samples) - for param in mismatched_priors: - prior_samples[param] = tf.zeros(mismatched_priors[param]) - prior_coords = dict(prior_coords) - prior_coords[coord] = np.arange(mismatched_coord_size) - - inference_data = az.convert_to_inference_data( - prior_samples, - coords=prior_coords, - dims=prior_dims, - group=constants.PRIOR, - ) - - with self.assertRaisesRegex( - ValueError, - f"Injected inference data {constants.PRIOR} has incorrect coordinate" - f" '{coord}': expected {expected_coord_size}, got" - f" {mismatched_coord_size}", - ): - _ = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - inference_data=inference_data, - ) - - @parameterized.named_parameters( - dict( - testcase_name="control_variables", - coord=constants.CONTROL_VARIABLE, - mismatched_posteriors={ - constants.GAMMA_C: (_N_CHAINS, _N_KEEP, _N_CONTROLS + 1), - constants.GAMMA_GC: ( - _N_CHAINS, - _N_KEEP, - _N_GEOS, - _N_CONTROLS + 1, - ), - constants.XI_C: (_N_CHAINS, _N_KEEP, _N_CONTROLS + 1), - }, - mismatched_coord_size=_N_CONTROLS + 1, - expected_coord_size=_N_CONTROLS, - ), - dict( - testcase_name="geos", - coord=constants.GEO, - mismatched_posteriors={ - constants.BETA_GM: ( - _N_CHAINS, - _N_KEEP, - _N_GEOS + 1, - _N_MEDIA_CHANNELS, - ), - constants.BETA_GRF: ( - _N_CHAINS, - _N_KEEP, - _N_GEOS + 1, - _N_RF_CHANNELS, - ), - constants.GAMMA_GC: ( - _N_CHAINS, - _N_KEEP, - _N_GEOS + 1, - _N_CONTROLS, - ), - constants.TAU_G: (_N_CHAINS, _N_KEEP, _N_GEOS + 1), - }, - mismatched_coord_size=_N_GEOS + 1, - expected_coord_size=_N_GEOS, - ), - dict( - testcase_name="knots", - coord=constants.KNOTS, - mismatched_posteriors={ - constants.KNOT_VALUES: ( - _N_CHAINS, - _N_KEEP, - _N_TIMES_SHORT + 1, - ), - }, - mismatched_coord_size=_N_TIMES_SHORT + 1, - expected_coord_size=_N_TIMES_SHORT, - ), - dict( - testcase_name="times", - coord=constants.TIME, - mismatched_posteriors={ - constants.MU_T: (_N_CHAINS, _N_KEEP, _N_TIMES_SHORT + 1), - }, - mismatched_coord_size=_N_TIMES_SHORT + 1, - expected_coord_size=_N_TIMES_SHORT, - ), - dict( - testcase_name="sigma_dims", - coord=constants.SIGMA_DIM, - mismatched_posteriors={ - constants.SIGMA: (_N_CHAINS, _N_KEEP, _N_GEOS_NATIONAL + 1), - }, - mismatched_coord_size=_N_GEOS_NATIONAL + 1, - expected_coord_size=_N_GEOS_NATIONAL, - ), - dict( - testcase_name="media_channels", - coord=constants.MEDIA_CHANNEL, - mismatched_posteriors={ - constants.ALPHA_M: (_N_CHAINS, _N_KEEP, _N_MEDIA_CHANNELS + 1), - constants.BETA_GM: ( - _N_CHAINS, - _N_KEEP, - _N_GEOS, - _N_MEDIA_CHANNELS + 1, - ), - constants.BETA_M: (_N_CHAINS, _N_KEEP, _N_MEDIA_CHANNELS + 1), - constants.EC_M: (_N_CHAINS, _N_KEEP, _N_MEDIA_CHANNELS + 1), - constants.ETA_M: (_N_CHAINS, _N_KEEP, _N_MEDIA_CHANNELS + 1), - constants.ROI_M: (_N_CHAINS, _N_KEEP, _N_MEDIA_CHANNELS + 1), - constants.SLOPE_M: (_N_CHAINS, _N_KEEP, _N_MEDIA_CHANNELS + 1), - }, - mismatched_coord_size=_N_MEDIA_CHANNELS + 1, - expected_coord_size=_N_MEDIA_CHANNELS, - ), - dict( - testcase_name="rf_channels", - coord=constants.RF_CHANNEL, - mismatched_posteriors={ - constants.ALPHA_RF: (_N_CHAINS, _N_KEEP, _N_RF_CHANNELS + 1), - constants.BETA_GRF: ( - _N_CHAINS, - _N_KEEP, - _N_GEOS, - _N_RF_CHANNELS + 1, - ), - constants.BETA_RF: (_N_CHAINS, _N_KEEP, _N_RF_CHANNELS + 1), - constants.EC_RF: (_N_CHAINS, _N_KEEP, _N_RF_CHANNELS + 1), - constants.ETA_RF: (_N_CHAINS, _N_KEEP, _N_RF_CHANNELS + 1), - constants.ROI_RF: (_N_CHAINS, _N_KEEP, _N_RF_CHANNELS + 1), - constants.SLOPE_RF: (_N_CHAINS, _N_KEEP, _N_RF_CHANNELS + 1), - }, - mismatched_coord_size=_N_RF_CHANNELS + 1, - expected_coord_size=_N_RF_CHANNELS, - ), - ) - def test_validate_injected_inference_data_posterior_incorrect_coordinates( - self, - coord, - mismatched_posteriors, - mismatched_coord_size, - expected_coord_size, - ): - """Checks posterior validation fails with incorrect coordinates.""" - self.enter_context( - mock.patch.object( - model, - "_xla_windowed_adaptive_nuts", - autospec=True, - return_value=collections.namedtuple( - "StatesAndTrace", ["all_states", "trace"] - )( - all_states=self.test_posterior_states_media_and_rf, - trace=self.test_trace, - ), - ) - ) - model_spec = spec.ModelSpec( - roi_calibration_period=self._ROI_CALIBRATION_PERIOD, - rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, - ) - meridian = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - ) - - meridian.sample_posterior( - n_chains=self._N_CHAINS, - n_adapt=self._N_ADAPT, - n_burnin=self._N_BURNIN, - n_keep=self._N_KEEP, - ) - - posterior_coords = meridian._create_inference_data_coords( - self._N_CHAINS, self._N_KEEP - ) - posterior_dims = meridian._create_inference_data_dims() - posterior_samples = dict(meridian.inference_data.posterior) - for posterior in mismatched_posteriors: - posterior_samples[posterior] = tf.zeros(mismatched_posteriors[posterior]) - - posterior_coords = dict(posterior_coords) - posterior_coords[coord] = np.arange(mismatched_coord_size) - - inference_data = az.convert_to_inference_data( - posterior_samples, - coords=posterior_coords, - dims=posterior_dims, - group=constants.POSTERIOR, - ) - - with self.assertRaisesRegex( - ValueError, - f"Injected inference data {constants.POSTERIOR} has incorrect" - f" coordinate '{coord}': expected" - f" {expected_coord_size}, got {mismatched_coord_size}", - ): - _ = model.Meridian( - input_data=self.short_input_data_with_media_and_rf, - model_spec=model_spec, - inference_data=inference_data, - ) - - -class NonPaidModelTest(tf.test.TestCase, parameterized.TestCase): - - # Data dimensions for sample input. - _N_CHAINS = 2 - _N_ADAPT = 2 - _N_BURNIN = 5 - _N_KEEP = 10 - _N_DRAWS = 10 - _N_GEOS = 5 - _N_GEOS_NATIONAL = 1 - _N_TIMES = 200 - _N_TIMES_SHORT = 49 - _N_MEDIA_TIMES = 203 - _N_MEDIA_TIMES_SHORT = 52 - _N_MEDIA_CHANNELS = 3 - _N_RF_CHANNELS = 2 - _N_ORGANIC_MEDIA_CHANNELS = 4 - _N_ORGANIC_RF_CHANNELS = 1 - _N_CONTROLS = 2 - _N_NON_MEDIA_CHANNELS = 2 - - def setUp(self): - super().setUp() - - self.national_input_data_non_media_and_organic = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS_NATIONAL, - n_times=self._N_TIMES, - n_media_times=self._N_MEDIA_TIMES, - n_controls=self._N_CONTROLS, - n_non_media_channels=self._N_NON_MEDIA_CHANNELS, - n_media_channels=self._N_MEDIA_CHANNELS, - n_rf_channels=self._N_RF_CHANNELS, - n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS, - n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS, - seed=0, - ) - ) - - self.input_data_non_media_and_organic = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES, - n_media_times=self._N_MEDIA_TIMES, - n_controls=self._N_CONTROLS, - n_non_media_channels=self._N_NON_MEDIA_CHANNELS, - n_media_channels=self._N_MEDIA_CHANNELS, - n_rf_channels=self._N_RF_CHANNELS, - n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS, - n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS, - seed=0, - ) - ) - self.short_input_data_non_media_and_organic = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES_SHORT, - n_media_times=self._N_MEDIA_TIMES_SHORT, - n_controls=self._N_CONTROLS, - n_non_media_channels=self._N_NON_MEDIA_CHANNELS, - n_media_channels=self._N_MEDIA_CHANNELS, - n_rf_channels=self._N_RF_CHANNELS, - n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS, - n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS, - seed=0, - ) - ) - self.short_input_data_non_media = ( - test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_geos=self._N_GEOS, - n_times=self._N_TIMES_SHORT, - n_media_times=self._N_MEDIA_TIMES_SHORT, - n_controls=self._N_CONTROLS, - n_non_media_channels=self._N_NON_MEDIA_CHANNELS, - n_media_channels=self._N_MEDIA_CHANNELS, - n_rf_channels=self._N_RF_CHANNELS, - n_organic_media_channels=0, - n_organic_rf_channels=0, - seed=0, - ) - ) - - def test_init_with_wrong_non_media_population_scaling_id_shape_fails(self): - model_spec = spec.ModelSpec( - non_media_population_scaling_id=np.ones((7), dtype=bool) - ) - with self.assertRaisesWithLiteralMatch( - ValueError, - "The shape of `non_media_population_scaling_id` (7,) is different from" - " `(n_non_media_channels,) = (2,)`.", - ): - model.Meridian( - input_data=self.input_data_non_media_and_organic, - model_spec=model_spec, - ) - - def test_base_geo_properties(self): - meridian = model.Meridian(input_data=self.input_data_non_media_and_organic) - self.assertEqual(meridian.n_geos, self._N_GEOS) - self.assertEqual(meridian.n_controls, self._N_CONTROLS) - self.assertEqual(meridian.n_non_media_channels, self._N_NON_MEDIA_CHANNELS) - self.assertEqual(meridian.n_times, self._N_TIMES) - self.assertEqual(meridian.n_media_times, self._N_MEDIA_TIMES) - self.assertFalse(meridian.is_national) - self.assertIsNotNone(meridian.prior_broadcast) - self.assertIsNotNone(meridian.inference_data) - self.assertNotIn(constants.PRIOR, meridian.inference_data.attrs) - self.assertNotIn(constants.POSTERIOR, meridian.inference_data.attrs) - - def test_base_national_properties(self): - meridian = model.Meridian( - input_data=self.national_input_data_non_media_and_organic - ) - self.assertEqual(meridian.n_geos, self._N_GEOS_NATIONAL) - self.assertEqual(meridian.n_controls, self._N_CONTROLS) - self.assertEqual(meridian.n_non_media_channels, self._N_NON_MEDIA_CHANNELS) - self.assertEqual(meridian.n_times, self._N_TIMES) - self.assertEqual(meridian.n_media_times, self._N_MEDIA_TIMES) - self.assertTrue(meridian.is_national) - self.assertIsNotNone(meridian.prior_broadcast) - self.assertIsNotNone(meridian.inference_data) - self.assertNotIn(constants.PRIOR, meridian.inference_data.attrs) - self.assertNotIn(constants.POSTERIOR, meridian.inference_data.attrs) + self.assertEqual(meridian.n_geos, self._N_GEOS_NATIONAL) + self.assertEqual(meridian.n_controls, self._N_CONTROLS) + self.assertEqual(meridian.n_non_media_channels, self._N_NON_MEDIA_CHANNELS) + self.assertEqual(meridian.n_times, self._N_TIMES) + self.assertEqual(meridian.n_media_times, self._N_MEDIA_TIMES) + self.assertTrue(meridian.is_national) + self.assertIsNotNone(meridian.prior_broadcast) + self.assertIsNotNone(meridian.inference_data) + self.assertNotIn(constants.PRIOR, meridian.inference_data.attrs) + self.assertNotIn(constants.POSTERIOR, meridian.inference_data.attrs) @parameterized.named_parameters( dict( testcase_name="media_non_media_and_organic", data=test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_media_channels=_N_MEDIA_CHANNELS, - n_non_media_channels=_N_NON_MEDIA_CHANNELS, - n_organic_media_channels=_N_ORGANIC_MEDIA_CHANNELS, - n_organic_rf_channels=_N_ORGANIC_RF_CHANNELS, + n_media_channels=IDS._N_MEDIA_CHANNELS, + n_non_media_channels=IDS._N_NON_MEDIA_CHANNELS, + n_organic_media_channels=IDS._N_ORGANIC_MEDIA_CHANNELS, + n_organic_rf_channels=IDS._N_ORGANIC_RF_CHANNELS, ), ), dict( testcase_name="rf_non_media_and_organic", data=test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_rf_channels=_N_RF_CHANNELS, - n_non_media_channels=_N_NON_MEDIA_CHANNELS, - n_organic_media_channels=_N_ORGANIC_MEDIA_CHANNELS, - n_organic_rf_channels=_N_ORGANIC_RF_CHANNELS, + n_rf_channels=IDS._N_RF_CHANNELS, + n_non_media_channels=IDS._N_NON_MEDIA_CHANNELS, + n_organic_media_channels=IDS._N_ORGANIC_MEDIA_CHANNELS, + n_organic_rf_channels=IDS._N_ORGANIC_RF_CHANNELS, ), ), dict( testcase_name="media_rf_non_media_and_organic", data=test_utils.sample_input_data_non_revenue_revenue_per_kpi( - n_media_channels=_N_MEDIA_CHANNELS, - n_rf_channels=_N_RF_CHANNELS, - n_non_media_channels=_N_NON_MEDIA_CHANNELS, - n_organic_media_channels=_N_ORGANIC_MEDIA_CHANNELS, - n_organic_rf_channels=_N_ORGANIC_RF_CHANNELS, + n_media_channels=IDS._N_MEDIA_CHANNELS, + n_rf_channels=IDS._N_RF_CHANNELS, + n_non_media_channels=IDS._N_NON_MEDIA_CHANNELS, + n_organic_media_channels=IDS._N_ORGANIC_MEDIA_CHANNELS, + n_organic_rf_channels=IDS._N_ORGANIC_RF_CHANNELS, ), ), ) @@ -3550,7 +1518,9 @@ def test_get_joint_dist_zeros(self): input_data=self.short_input_data_non_media, model_spec=model_spec, ) - sample = meridian._get_joint_dist_unpinned().sample(self._N_DRAWS) + sample = meridian.posterior_sampler._get_joint_dist_unpinned().sample( + self._N_DRAWS + ) self.assertAllEqual( sample.y, tf.zeros(shape=(self._N_DRAWS, self._N_GEOS, self._N_TIMES_SHORT)), @@ -3572,7 +1542,7 @@ def test_get_joint_dist_with_log_prob_non_media( ): model_spec = spec.ModelSpec( paid_media_prior_type=paid_media_prior_type, - media_effects_dist=media_effects_dist + media_effects_dist=media_effects_dist, ) meridian = model.Meridian( model_spec=model_spec, @@ -3580,7 +1550,9 @@ def test_get_joint_dist_with_log_prob_non_media( ) # Take a single draw of all parameters from the prior distribution. - par_structtuple = meridian._get_joint_dist_unpinned().sample(1) + par_structtuple = ( + meridian.posterior_sampler._get_joint_dist_unpinned().sample(1) + ) par = par_structtuple._asdict() # Note that "y" is a draw from the prior predictive (transformed) outcome @@ -3589,7 +1561,9 @@ def test_get_joint_dist_with_log_prob_non_media( del par["y"] # Note that the actual (transformed) outcome data is "pinned" as "y". - log_prob_parts_structtuple = meridian._get_joint_dist().log_prob_parts(par) + log_prob_parts_structtuple = ( + meridian.posterior_sampler._get_joint_dist().log_prob_parts(par) + ) log_prob_parts = { k: v._asdict() for k, v in log_prob_parts_structtuple._asdict().items() } @@ -3760,7 +1734,7 @@ def test_get_joint_dist_with_log_prob_non_media( ) self.assertAllClose( posterior_unnormalized_logprob, - meridian._get_joint_dist().log_prob(par)[0], + meridian.posterior_sampler._get_joint_dist().log_prob(par)[0], rtol=1e-3, ) @@ -3771,10 +1745,10 @@ def test_inference_data_non_paid_correct_dims(self): model_spec=model_spec, ) n_draws = 7 - prior_draws = mmm._sample_prior_fn(n_draws, seed=1) + prior_draws = mmm.prior_sampler._sample_prior(n_draws, seed=1) # Create Arviz InferenceData for prior draws. - prior_coords = mmm._create_inference_data_coords(1, n_draws) - prior_dims = mmm._create_inference_data_dims() + prior_coords = mmm.create_inference_data_coords(1, n_draws) + prior_dims = mmm.create_inference_data_dims() for param, tensor in prior_draws.items(): self.assertIn(param, prior_dims) @@ -3794,9 +1768,9 @@ def test_validate_injected_inference_data_correct_shapes(self): ) n_chains = 1 n_draws = 10 - prior_samples = meridian._sample_prior_fn(n_draws) - prior_coords = meridian._create_inference_data_coords(n_chains, n_draws) - prior_dims = meridian._create_inference_data_dims() + prior_samples = meridian.prior_sampler._sample_prior(n_draws) + prior_coords = meridian.create_inference_data_coords(n_chains, n_draws) + prior_dims = meridian.create_inference_data_dims() inference_data = az.convert_to_inference_data( prior_samples, coords=prior_coords, @@ -3822,15 +1796,19 @@ def test_validate_injected_inference_data_correct_shapes(self): mismatched_priors={ constants.GAMMA_GN: ( 1, - _N_DRAWS, - _N_GEOS, - _N_NON_MEDIA_CHANNELS + 1, + IDS._N_DRAWS, + IDS._N_GEOS, + IDS._N_NON_MEDIA_CHANNELS + 1, ), - constants.GAMMA_N: (1, _N_DRAWS, _N_NON_MEDIA_CHANNELS + 1), - constants.XI_N: (1, _N_DRAWS, _N_NON_MEDIA_CHANNELS + 1), + constants.GAMMA_N: ( + 1, + IDS._N_DRAWS, + IDS._N_NON_MEDIA_CHANNELS + 1, + ), + constants.XI_N: (1, IDS._N_DRAWS, IDS._N_NON_MEDIA_CHANNELS + 1), }, - mismatched_coord_size=_N_NON_MEDIA_CHANNELS + 1, - expected_coord_size=_N_NON_MEDIA_CHANNELS, + mismatched_coord_size=IDS._N_NON_MEDIA_CHANNELS + 1, + expected_coord_size=IDS._N_NON_MEDIA_CHANNELS, ), dict( testcase_name="organic_rf_channels", @@ -3838,26 +1816,38 @@ def test_validate_injected_inference_data_correct_shapes(self): mismatched_priors={ constants.ALPHA_ORF: ( 1, - _N_DRAWS, - _N_ORGANIC_RF_CHANNELS + 1, + IDS._N_DRAWS, + IDS._N_ORGANIC_RF_CHANNELS + 1, ), constants.BETA_ORF: ( 1, - _N_DRAWS, - _N_ORGANIC_RF_CHANNELS + 1, + IDS._N_DRAWS, + IDS._N_ORGANIC_RF_CHANNELS + 1, ), constants.BETA_GORF: ( 1, - _N_DRAWS, - _N_GEOS, - _N_ORGANIC_RF_CHANNELS + 1, + IDS._N_DRAWS, + IDS._N_GEOS, + IDS._N_ORGANIC_RF_CHANNELS + 1, + ), + constants.EC_ORF: ( + 1, + IDS._N_DRAWS, + IDS._N_ORGANIC_RF_CHANNELS + 1, + ), + constants.ETA_ORF: ( + 1, + IDS._N_DRAWS, + IDS._N_ORGANIC_RF_CHANNELS + 1, + ), + constants.SLOPE_ORF: ( + 1, + IDS._N_DRAWS, + IDS._N_ORGANIC_RF_CHANNELS + 1, ), - constants.EC_ORF: (1, _N_DRAWS, _N_ORGANIC_RF_CHANNELS + 1), - constants.ETA_ORF: (1, _N_DRAWS, _N_ORGANIC_RF_CHANNELS + 1), - constants.SLOPE_ORF: (1, _N_DRAWS, _N_ORGANIC_RF_CHANNELS + 1), }, - mismatched_coord_size=_N_ORGANIC_RF_CHANNELS + 1, - expected_coord_size=_N_ORGANIC_RF_CHANNELS, + mismatched_coord_size=IDS._N_ORGANIC_RF_CHANNELS + 1, + expected_coord_size=IDS._N_ORGANIC_RF_CHANNELS, ), ) def test_validate_injected_inference_data_prior_incorrect_coordinates( @@ -3869,9 +1859,9 @@ def test_validate_injected_inference_data_prior_incorrect_coordinates( input_data=self.input_data_non_media_and_organic, model_spec=model_spec, ) - prior_samples = meridian._sample_prior_fn(self._N_DRAWS) - prior_coords = meridian._create_inference_data_coords(1, self._N_DRAWS) - prior_dims = meridian._create_inference_data_dims() + prior_samples = meridian.prior_sampler._sample_prior(self._N_DRAWS) + prior_coords = meridian.create_inference_data_coords(1, self._N_DRAWS) + prior_dims = meridian.create_inference_data_dims() prior_samples = dict(prior_samples) for param in mismatched_priors: diff --git a/meridian/model/model_test_data.py b/meridian/model/model_test_data.py new file mode 100644 index 00000000..d5106af4 --- /dev/null +++ b/meridian/model/model_test_data.py @@ -0,0 +1,350 @@ +# Copyright 2024 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared test data samples.""" + +import collections +import os + +from meridian import constants +from meridian.data import test_utils +import tensorflow as tf +import xarray as xr + + +def _convert_with_swap(array: xr.DataArray, n_burnin: int) -> tf.Tensor: + """Converts a DataArray to a tf.Tensor with the correct MCMC format. + + This function converts a DataArray to tf.Tensor, swaps first two dimensions + and adds the burnin part. This is needed to properly mock the + _xla_windowed_adaptive_nuts() function output in the sample_posterior + tests. + + Args: + array: The array to be converted. + n_burnin: The number of extra draws to be padded with as the 'burnin' part. + + Returns: + A tensor in the same format as returned by the _xla_windowed_adaptive_nuts() + function. + """ + tensor = tf.convert_to_tensor(array) + perm = [1, 0] + [i for i in range(2, len(tensor.shape))] + transposed_tensor = tf.transpose(tensor, perm=perm) + + # Add the "burnin" part to the mocked output of _xla_windowed_adaptive_nuts + # to make sure sample_posterior returns the correct "keep" part. + if array.dtype == bool: + pad_value = False + else: + pad_value = 0.0 if array.dtype.kind == "f" else 0 + + burnin = tf.fill([n_burnin] + transposed_tensor.shape[1:], pad_value) + return tf.concat( + [burnin, transposed_tensor], + axis=0, + ) + + +class WithInputDataSamples: + """A mixin to inject test data samples to a unit test class.""" + + # TODO: Update the sample data to span over 1 or 2 year(s). + _TEST_DIR = os.path.join(os.path.dirname(__file__), "test_data") + _TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH = os.path.join( + _TEST_DIR, + "sample_prior_media_and_rf.nc", + ) + _TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH = os.path.join( + _TEST_DIR, + "sample_prior_media_only.nc", + ) + _TEST_SAMPLE_PRIOR_RF_ONLY_PATH = os.path.join( + _TEST_DIR, + "sample_prior_rf_only.nc", + ) + _TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH = os.path.join( + _TEST_DIR, + "sample_posterior_media_and_rf.nc", + ) + _TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH = os.path.join( + _TEST_DIR, + "sample_posterior_media_only.nc", + ) + _TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH = os.path.join( + _TEST_DIR, + "sample_posterior_rf_only.nc", + ) + _TEST_SAMPLE_TRACE_PATH = os.path.join( + _TEST_DIR, + "sample_trace.nc", + ) + + # Data dimensions for sample input. + _N_CHAINS = 2 + _N_ADAPT = 2 + _N_BURNIN = 5 + _N_KEEP = 10 + _N_DRAWS = 10 + _N_GEOS = 5 + _N_GEOS_NATIONAL = 1 + _N_TIMES = 200 + _N_TIMES_SHORT = 49 + _N_MEDIA_TIMES = 203 + _N_MEDIA_TIMES_SHORT = 52 + _N_MEDIA_CHANNELS = 3 + _N_RF_CHANNELS = 2 + _N_CONTROLS = 2 + _ROI_CALIBRATION_PERIOD = tf.cast( + tf.ones((_N_MEDIA_TIMES_SHORT, _N_MEDIA_CHANNELS)), + dtype=tf.bool, + ) + _RF_ROI_CALIBRATION_PERIOD = tf.cast( + tf.ones((_N_MEDIA_TIMES_SHORT, _N_RF_CHANNELS)), + dtype=tf.bool, + ) + _N_ORGANIC_MEDIA_CHANNELS = 4 + _N_ORGANIC_RF_CHANNELS = 1 + _N_NON_MEDIA_CHANNELS = 2 + + def setUp(self): + self.input_data_non_revenue_no_revenue_per_kpi = ( + test_utils.sample_input_data_non_revenue_no_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES, + n_media_times=self._N_MEDIA_TIMES, + n_controls=self._N_CONTROLS, + n_media_channels=self._N_MEDIA_CHANNELS, + seed=0, + ) + ) + self.input_data_with_media_only = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES, + n_media_times=self._N_MEDIA_TIMES, + n_controls=self._N_CONTROLS, + n_media_channels=self._N_MEDIA_CHANNELS, + seed=0, + ) + ) + self.input_data_with_rf_only = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES, + n_media_times=self._N_MEDIA_TIMES, + n_controls=self._N_CONTROLS, + n_rf_channels=self._N_RF_CHANNELS, + seed=0, + ) + ) + self.input_data_with_media_and_rf = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES, + n_media_times=self._N_MEDIA_TIMES, + n_controls=self._N_CONTROLS, + n_media_channels=self._N_MEDIA_CHANNELS, + n_rf_channels=self._N_RF_CHANNELS, + seed=0, + ) + ) + self.short_input_data_with_media_only = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES_SHORT, + n_media_times=self._N_MEDIA_TIMES_SHORT, + n_controls=self._N_CONTROLS, + n_media_channels=self._N_MEDIA_CHANNELS, + seed=0, + ) + ) + self.short_input_data_with_rf_only = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES_SHORT, + n_media_times=self._N_MEDIA_TIMES_SHORT, + n_controls=self._N_CONTROLS, + n_rf_channels=self._N_RF_CHANNELS, + seed=0, + ) + ) + self.short_input_data_with_media_and_rf = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES_SHORT, + n_media_times=self._N_MEDIA_TIMES_SHORT, + n_controls=self._N_CONTROLS, + n_media_channels=self._N_MEDIA_CHANNELS, + n_rf_channels=self._N_RF_CHANNELS, + seed=0, + ) + ) + self.national_input_data_media_only = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS_NATIONAL, + n_times=self._N_TIMES, + n_media_times=self._N_MEDIA_TIMES, + n_controls=self._N_CONTROLS, + n_media_channels=self._N_MEDIA_CHANNELS, + seed=0, + ) + ) + self.national_input_data_media_and_rf = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS_NATIONAL, + n_times=self._N_TIMES, + n_media_times=self._N_MEDIA_TIMES, + n_controls=self._N_CONTROLS, + n_media_channels=self._N_MEDIA_CHANNELS, + n_rf_channels=self._N_RF_CHANNELS, + seed=0, + ) + ) + + test_prior_media_and_rf = xr.open_dataset( + self._TEST_SAMPLE_PRIOR_MEDIA_AND_RF_PATH + ) + test_prior_media_only = xr.open_dataset( + self._TEST_SAMPLE_PRIOR_MEDIA_ONLY_PATH + ) + test_prior_rf_only = xr.open_dataset(self._TEST_SAMPLE_PRIOR_RF_ONLY_PATH) + self.test_dist_media_and_rf = collections.OrderedDict({ + param: tf.convert_to_tensor(test_prior_media_and_rf[param]) + for param in constants.COMMON_PARAMETER_NAMES + + constants.MEDIA_PARAMETER_NAMES + + constants.RF_PARAMETER_NAMES + }) + self.test_dist_media_only = collections.OrderedDict({ + param: tf.convert_to_tensor(test_prior_media_only[param]) + for param in constants.COMMON_PARAMETER_NAMES + + constants.MEDIA_PARAMETER_NAMES + }) + self.test_dist_rf_only = collections.OrderedDict({ + param: tf.convert_to_tensor(test_prior_rf_only[param]) + for param in constants.COMMON_PARAMETER_NAMES + + constants.RF_PARAMETER_NAMES + }) + + test_posterior_media_and_rf = xr.open_dataset( + self._TEST_SAMPLE_POSTERIOR_MEDIA_AND_RF_PATH + ) + test_posterior_media_only = xr.open_dataset( + self._TEST_SAMPLE_POSTERIOR_MEDIA_ONLY_PATH + ) + test_posterior_rf_only = xr.open_dataset( + self._TEST_SAMPLE_POSTERIOR_RF_ONLY_PATH + ) + posterior_params_to_tensors_media_and_rf = { + param: _convert_with_swap( + test_posterior_media_and_rf[param], n_burnin=self._N_BURNIN + ) + for param in constants.COMMON_PARAMETER_NAMES + + constants.MEDIA_PARAMETER_NAMES + + constants.RF_PARAMETER_NAMES + } + posterior_params_to_tensors_media_only = { + param: _convert_with_swap( + test_posterior_media_only[param], n_burnin=self._N_BURNIN + ) + for param in constants.COMMON_PARAMETER_NAMES + + constants.MEDIA_PARAMETER_NAMES + } + posterior_params_to_tensors_rf_only = { + param: _convert_with_swap( + test_posterior_rf_only[param], n_burnin=self._N_BURNIN + ) + for param in constants.COMMON_PARAMETER_NAMES + + constants.RF_PARAMETER_NAMES + } + self.test_posterior_states_media_and_rf = collections.namedtuple( + "StructTuple", + constants.COMMON_PARAMETER_NAMES + + constants.MEDIA_PARAMETER_NAMES + + constants.RF_PARAMETER_NAMES, + )(**posterior_params_to_tensors_media_and_rf) + self.test_posterior_states_media_only = collections.namedtuple( + "StructTuple", + constants.COMMON_PARAMETER_NAMES + constants.MEDIA_PARAMETER_NAMES, + )(**posterior_params_to_tensors_media_only) + self.test_posterior_states_rf_only = collections.namedtuple( + "StructTuple", + constants.COMMON_PARAMETER_NAMES + constants.RF_PARAMETER_NAMES, + )(**posterior_params_to_tensors_rf_only) + + test_trace = xr.open_dataset(self._TEST_SAMPLE_TRACE_PATH) + self.test_trace = { + param: _convert_with_swap(test_trace[param], n_burnin=self._N_BURNIN) + for param in test_trace.data_vars + } + + # The following are input data samples with non-paid channels. + + self.national_input_data_non_media_and_organic = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS_NATIONAL, + n_times=self._N_TIMES, + n_media_times=self._N_MEDIA_TIMES, + n_controls=self._N_CONTROLS, + n_non_media_channels=self._N_NON_MEDIA_CHANNELS, + n_media_channels=self._N_MEDIA_CHANNELS, + n_rf_channels=self._N_RF_CHANNELS, + n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS, + n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS, + seed=0, + ) + ) + + self.input_data_non_media_and_organic = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES, + n_media_times=self._N_MEDIA_TIMES, + n_controls=self._N_CONTROLS, + n_non_media_channels=self._N_NON_MEDIA_CHANNELS, + n_media_channels=self._N_MEDIA_CHANNELS, + n_rf_channels=self._N_RF_CHANNELS, + n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS, + n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS, + seed=0, + ) + ) + self.short_input_data_non_media_and_organic = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES_SHORT, + n_media_times=self._N_MEDIA_TIMES_SHORT, + n_controls=self._N_CONTROLS, + n_non_media_channels=self._N_NON_MEDIA_CHANNELS, + n_media_channels=self._N_MEDIA_CHANNELS, + n_rf_channels=self._N_RF_CHANNELS, + n_organic_media_channels=self._N_ORGANIC_MEDIA_CHANNELS, + n_organic_rf_channels=self._N_ORGANIC_RF_CHANNELS, + seed=0, + ) + ) + self.short_input_data_non_media = ( + test_utils.sample_input_data_non_revenue_revenue_per_kpi( + n_geos=self._N_GEOS, + n_times=self._N_TIMES_SHORT, + n_media_times=self._N_MEDIA_TIMES_SHORT, + n_controls=self._N_CONTROLS, + n_non_media_channels=self._N_NON_MEDIA_CHANNELS, + n_media_channels=self._N_MEDIA_CHANNELS, + n_rf_channels=self._N_RF_CHANNELS, + n_organic_media_channels=0, + n_organic_rf_channels=0, + seed=0, + ) + ) diff --git a/meridian/model/posterior_sampler.py b/meridian/model/posterior_sampler.py new file mode 100644 index 00000000..b47997b9 --- /dev/null +++ b/meridian/model/posterior_sampler.py @@ -0,0 +1,560 @@ +# Copyright 2024 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for MCMC sampling of posterior distributions in a Meridian model.""" + +from collections.abc import Mapping, Sequence + +import arviz as az +from meridian import constants +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp + + +__all__ = [ + "MCMCSamplingError", + "MCMCOOMError", + "PosteriorSampler", +] + + +class MCMCSamplingError(Exception): + """The Markov Chain Monte Carlo (MCMC) sampling failed.""" + + +class MCMCOOMError(Exception): + """The Markov Chain Monte Carlo (MCMC) exceeds memory limits.""" + + +def _get_tau_g( + tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int +) -> tfp.distributions.Distribution: + """Computes `tau_g` from `tau_g_excl_baseline`. + + This function computes `tau_g` by inserting a column of zeros at the + `baseline_geo` position in `tau_g_excl_baseline`. + + Args: + tau_g_excl_baseline: A tensor of shape `[..., n_geos - 1]` for the + user-defined dimensions of the `tau_g` parameter distribution. + baseline_geo_idx: The index of the baseline geo to be set to zero. + + Returns: + A tensor of shape `[..., n_geos]` with the final distribution of the `tau_g` + parameter with zero at position `baseline_geo_idx` and matching + `tau_g_excl_baseline` elsewhere. + """ + rank = len(tau_g_excl_baseline.shape) + shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1 + tau_g = tf.concat( + [ + tau_g_excl_baseline[..., :baseline_geo_idx], + tf.zeros(shape, dtype=tau_g_excl_baseline.dtype), + tau_g_excl_baseline[..., baseline_geo_idx:], + ], + axis=rank - 1, + ) + return tfp.distributions.Deterministic(tau_g, name="tau_g") + + +@tf.function(autograph=False, jit_compile=True) +def _xla_windowed_adaptive_nuts(**kwargs): + """XLA wrapper for windowed_adaptive_nuts.""" + return tfp.experimental.mcmc.windowed_adaptive_nuts(**kwargs) + + +class PosteriorSampler: + """Posterior sampler for distributions in a Meridian model.""" + + def __init__(self, meridian): # meridian: model.Meridian + self._meridian = meridian + + def _get_joint_dist_unpinned(self) -> tfp.distributions.Distribution: + """Returns a `JointDistributionCoroutineAutoBatched` function for MCMC.""" + mmm = self._meridian + mmm.populate_cached_properties() + + # This lists all the derived properties and states of this Meridian object + # that are referenced by the joint distribution coroutine. + # That is, these are the list of captured parameters. + prior_broadcast = mmm.prior_broadcast + baseline_geo_idx = mmm.baseline_geo_idx + knot_info = mmm.knot_info + n_geos = mmm.n_geos + n_times = mmm.n_times + n_media_channels = mmm.n_media_channels + n_rf_channels = mmm.n_rf_channels + n_organic_media_channels = mmm.n_organic_media_channels + n_organic_rf_channels = mmm.n_organic_rf_channels + n_controls = mmm.n_controls + n_non_media_channels = mmm.n_non_media_channels + holdout_id = mmm.holdout_id + media_tensors = mmm.media_tensors + rf_tensors = mmm.rf_tensors + organic_media_tensors = mmm.organic_media_tensors + organic_rf_tensors = mmm.organic_rf_tensors + controls_scaled = mmm.controls_scaled + non_media_treatments_scaled = mmm.non_media_treatments_scaled + media_effects_dist = mmm.media_effects_dist + adstock_hill_media_fn = mmm.adstock_hill_media + adstock_hill_rf_fn = mmm.adstock_hill_rf + get_roi_prior_beta_m_value_fn = mmm.prior_sampler.get_roi_prior_beta_m_value + get_roi_prior_beta_rf_value_fn = ( + mmm.prior_sampler.get_roi_prior_beta_rf_value + ) + + @tfp.distributions.JointDistributionCoroutineAutoBatched + def joint_dist_unpinned(): + # Sample directly from prior. + knot_values = yield prior_broadcast.knot_values + gamma_c = yield prior_broadcast.gamma_c + xi_c = yield prior_broadcast.xi_c + sigma = yield prior_broadcast.sigma + + tau_g_excl_baseline = yield tfp.distributions.Sample( + prior_broadcast.tau_g_excl_baseline, + name=constants.TAU_G_EXCL_BASELINE, + ) + tau_g = yield _get_tau_g( + tau_g_excl_baseline=tau_g_excl_baseline, + baseline_geo_idx=baseline_geo_idx, + ) + mu_t = yield tfp.distributions.Deterministic( + tf.einsum( + "k,kt->t", + knot_values, + tf.convert_to_tensor(knot_info.weights), + ), + name=constants.MU_T, + ) + + tau_gt = tau_g[:, tf.newaxis] + mu_t + combined_media_transformed = tf.zeros( + shape=(n_geos, n_times, 0), dtype=tf.float32 + ) + combined_beta = tf.zeros(shape=(n_geos, 0), dtype=tf.float32) + if media_tensors.media is not None: + alpha_m = yield prior_broadcast.alpha_m + ec_m = yield prior_broadcast.ec_m + eta_m = yield prior_broadcast.eta_m + slope_m = yield prior_broadcast.slope_m + beta_gm_dev = yield tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [n_geos, n_media_channels], + name=constants.BETA_GM_DEV, + ) + media_transformed = adstock_hill_media_fn( + media=media_tensors.media_scaled, + alpha=alpha_m, + ec=ec_m, + slope=slope_m, + ) + prior_type = mmm.model_spec.paid_media_prior_type + if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES: + if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: + roi_or_mroi_m = yield prior_broadcast.roi_m + else: + roi_or_mroi_m = yield prior_broadcast.mroi_m + beta_m_value = get_roi_prior_beta_m_value_fn( + alpha_m, + beta_gm_dev, + ec_m, + eta_m, + roi_or_mroi_m, + slope_m, + media_transformed, + ) + beta_m = yield tfp.distributions.Deterministic( + beta_m_value, name=constants.BETA_M + ) + else: + beta_m = yield prior_broadcast.beta_m + + beta_eta_combined = beta_m + eta_m * beta_gm_dev + beta_gm_value = ( + beta_eta_combined + if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL + else tf.math.exp(beta_eta_combined) + ) + beta_gm = yield tfp.distributions.Deterministic( + beta_gm_value, name=constants.BETA_GM + ) + combined_media_transformed = tf.concat( + [combined_media_transformed, media_transformed], axis=-1 + ) + combined_beta = tf.concat([combined_beta, beta_gm], axis=-1) + + if rf_tensors.reach is not None: + alpha_rf = yield prior_broadcast.alpha_rf + ec_rf = yield prior_broadcast.ec_rf + eta_rf = yield prior_broadcast.eta_rf + slope_rf = yield prior_broadcast.slope_rf + beta_grf_dev = yield tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [n_geos, n_rf_channels], + name=constants.BETA_GRF_DEV, + ) + rf_transformed = adstock_hill_rf_fn( + reach=rf_tensors.reach_scaled, + frequency=rf_tensors.frequency, + alpha=alpha_rf, + ec=ec_rf, + slope=slope_rf, + ) + + prior_type = mmm.model_spec.paid_media_prior_type + if prior_type in constants.PAID_MEDIA_ROI_PRIOR_TYPES: + if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: + roi_or_mroi_rf = yield prior_broadcast.roi_rf + else: + roi_or_mroi_rf = yield prior_broadcast.mroi_rf + beta_rf_value = get_roi_prior_beta_rf_value_fn( + alpha_rf, + beta_grf_dev, + ec_rf, + eta_rf, + roi_or_mroi_rf, + slope_rf, + rf_transformed, + ) + beta_rf = yield tfp.distributions.Deterministic( + beta_rf_value, + name=constants.BETA_RF, + ) + else: + beta_rf = yield prior_broadcast.beta_rf + + beta_eta_combined = beta_rf + eta_rf * beta_grf_dev + beta_grf_value = ( + beta_eta_combined + if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL + else tf.math.exp(beta_eta_combined) + ) + beta_grf = yield tfp.distributions.Deterministic( + beta_grf_value, name=constants.BETA_GRF + ) + combined_media_transformed = tf.concat( + [combined_media_transformed, rf_transformed], axis=-1 + ) + combined_beta = tf.concat([combined_beta, beta_grf], axis=-1) + + if organic_media_tensors.organic_media is not None: + alpha_om = yield prior_broadcast.alpha_om + ec_om = yield prior_broadcast.ec_om + eta_om = yield prior_broadcast.eta_om + slope_om = yield prior_broadcast.slope_om + beta_gom_dev = yield tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [n_geos, n_organic_media_channels], + name=constants.BETA_GOM_DEV, + ) + organic_media_transformed = adstock_hill_media_fn( + media=organic_media_tensors.organic_media_scaled, + alpha=alpha_om, + ec=ec_om, + slope=slope_om, + ) + beta_om = yield prior_broadcast.beta_om + + beta_eta_combined = beta_om + eta_om * beta_gom_dev + beta_gom_value = ( + beta_eta_combined + if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL + else tf.math.exp(beta_eta_combined) + ) + beta_gom = yield tfp.distributions.Deterministic( + beta_gom_value, name=constants.BETA_GOM + ) + combined_media_transformed = tf.concat( + [combined_media_transformed, organic_media_transformed], axis=-1 + ) + combined_beta = tf.concat([combined_beta, beta_gom], axis=-1) + + if organic_rf_tensors.organic_reach is not None: + alpha_orf = yield prior_broadcast.alpha_orf + ec_orf = yield prior_broadcast.ec_orf + eta_orf = yield prior_broadcast.eta_orf + slope_orf = yield prior_broadcast.slope_orf + beta_gorf_dev = yield tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [n_geos, n_organic_rf_channels], + name=constants.BETA_GORF_DEV, + ) + organic_rf_transformed = adstock_hill_rf_fn( + reach=organic_rf_tensors.organic_reach_scaled, + frequency=organic_rf_tensors.organic_frequency, + alpha=alpha_orf, + ec=ec_orf, + slope=slope_orf, + ) + + beta_orf = yield prior_broadcast.beta_orf + + beta_eta_combined = beta_orf + eta_orf * beta_gorf_dev + beta_gorf_value = ( + beta_eta_combined + if media_effects_dist == constants.MEDIA_EFFECTS_NORMAL + else tf.math.exp(beta_eta_combined) + ) + beta_gorf = yield tfp.distributions.Deterministic( + beta_gorf_value, name=constants.BETA_GORF + ) + combined_media_transformed = tf.concat( + [combined_media_transformed, organic_rf_transformed], axis=-1 + ) + combined_beta = tf.concat([combined_beta, beta_gorf], axis=-1) + + sigma_gt = tf.transpose(tf.broadcast_to(sigma, [n_times, n_geos])) + gamma_gc_dev = yield tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [n_geos, n_controls], + name=constants.GAMMA_GC_DEV, + ) + gamma_gc = yield tfp.distributions.Deterministic( + gamma_c + xi_c * gamma_gc_dev, name=constants.GAMMA_GC + ) + y_pred_combined_media = ( + tau_gt + + tf.einsum("gtm,gm->gt", combined_media_transformed, combined_beta) + + tf.einsum("gtc,gc->gt", controls_scaled, gamma_gc) + ) + + if mmm.non_media_treatments is not None: + gamma_n = yield prior_broadcast.gamma_n + xi_n = yield prior_broadcast.xi_n + gamma_gn_dev = yield tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [n_geos, n_non_media_channels], + name=constants.GAMMA_GN_DEV, + ) + gamma_gn = yield tfp.distributions.Deterministic( + gamma_n + xi_n * gamma_gn_dev, name=constants.GAMMA_GN + ) + y_pred = y_pred_combined_media + tf.einsum( + "gtn,gn->gt", non_media_treatments_scaled, gamma_gn + ) + else: + y_pred = y_pred_combined_media + + # If there are any holdout observations, the holdout KPI values will + # be replaced with zeros using `experimental_pin`. For these + # observations, we set the posterior mean equal to zero and standard + # deviation to `1/sqrt(2pi)`, so the log-density is 0 regardless of the + # sampled posterior parameter values. + if holdout_id is not None: + y_pred_holdout = tf.where(holdout_id, 0.0, y_pred) + test_sd = tf.cast(1.0 / np.sqrt(2.0 * np.pi), tf.float32) + sigma_gt_holdout = tf.where(holdout_id, test_sd, sigma_gt) + yield tfp.distributions.Normal( + y_pred_holdout, sigma_gt_holdout, name="y" + ) + else: + yield tfp.distributions.Normal(y_pred, sigma_gt, name="y") + + return joint_dist_unpinned + + def _get_joint_dist(self) -> tfp.distributions.Distribution: + mmm = self._meridian + y = ( + tf.where(mmm.holdout_id, 0.0, mmm.kpi_scaled) + if mmm.holdout_id is not None + else mmm.kpi_scaled + ) + return self._get_joint_dist_unpinned().experimental_pin(y=y) + + def __call__( + self, + n_chains: Sequence[int] | int, + n_adapt: int, + n_burnin: int, + n_keep: int, + current_state: Mapping[str, tf.Tensor] | None = None, + init_step_size: int | None = None, + dual_averaging_kwargs: Mapping[str, int] | None = None, + max_tree_depth: int = 10, + max_energy_diff: float = 500.0, + unrolled_leapfrog_steps: int = 1, + parallel_iterations: int = 10, + seed: Sequence[int] | None = None, + **pins, + ) -> az.InferenceData: + """Runs Markov Chain Monte Carlo (MCMC) sampling of posterior distributions. + + For more information about the arguments, see [`windowed_adaptive_nuts`] + (https://www.tensorflow.org/probability/api_docs/python/tfp/experimental/mcmc/windowed_adaptive_nuts). + + Args: + n_chains: Number of MCMC chains. Given a sequence of integers, + `windowed_adaptive_nuts` will be called once for each element. The + `n_chains` argument of each `windowed_adaptive_nuts` call will be equal + to the respective integer element. Using a list of integers, one can + split the chains of a `windowed_adaptive_nuts` call into multiple calls + with fewer chains per call. This can reduce memory usage. This might + require an increased number of adaptation steps for convergence, as the + optimization is occurring across fewer chains per sampling call. + n_adapt: Number of adaptation draws per chain. + n_burnin: Number of burn-in draws per chain. Burn-in draws occur after + adaptation draws and before the kept draws. + n_keep: Integer number of draws per chain to keep for inference. + current_state: Optional structure of tensors at which to initialize + sampling. Use the same shape and structure as + `model.experimental_pin(**pins).sample(n_chains)`. + init_step_size: Optional integer determining where to initialize the step + size for the leapfrog integrator. The structure must broadcast with + `current_state`. For example, if the initial state is: ``` { 'a': + tf.zeros(n_chains), 'b': tf.zeros([n_chains, n_features]), } ``` then + any of `1.`, `{'a': 1., 'b': 1.}`, or `{'a': tf.ones(n_chains), 'b': + tf.ones([n_chains, n_features])}` will work. Defaults to the dimension + of the log density to the ¼ power. + dual_averaging_kwargs: Optional dict keyword arguments to pass to + `tfp.mcmc.DualAveragingStepSizeAdaptation`. By default, a + `target_accept_prob` of `0.85` is set, acceptance probabilities across + chains are reduced using a harmonic mean, and the class defaults are + used otherwise. + max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The + maximum number of leapfrog steps is bounded by `2**max_tree_depth`, for + example, the number of nodes in a binary tree `max_tree_depth` nodes + deep. The default setting of `10` takes up to 1024 leapfrog steps. + max_energy_diff: Scalar threshold of energy differences at each leapfrog, + divergence samples are defined as leapfrog steps that exceed this + threshold. Default is `1000`. + unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree + expansion step. Applies a direct linear multiplier to the maximum + trajectory length implied by `max_tree_depth`. Defaults is `1`. + parallel_iterations: Number of iterations allowed to run in parallel. Must + be a positive integer. For more information, see `tf.while_loop`. + seed: Used to set the seed for reproducible results. For more information, + see [PRNGS and seeds] + (https://github.com/tensorflow/probability/blob/main/PRNGS.md). + **pins: These are used to condition the provided joint distribution, and + are passed directly to `joint_dist.experimental_pin(**pins)`. + + Returns: + An Arviz `InferenceData` object containing posterior samples only. + + Throws: + MCMCOOMError: If the model is out of memory. Try reducing `n_keep` or pass + a list of integers as `n_chains` to sample chains serially. For more + information, see + [ResourceExhaustedError when running Meridian.sample_posterior] + (https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error). + """ + seed = tfp.random.sanitize_seed(seed) if seed else None + n_chains_list = [n_chains] if isinstance(n_chains, int) else n_chains + total_chains = np.sum(n_chains_list) + + states = [] + traces = [] + for n_chains_batch in n_chains_list: + try: + mcmc = _xla_windowed_adaptive_nuts( + n_draws=n_burnin + n_keep, + joint_dist=self._get_joint_dist(), + n_chains=n_chains_batch, + num_adaptation_steps=n_adapt, + current_state=current_state, + init_step_size=init_step_size, + dual_averaging_kwargs=dual_averaging_kwargs, + max_tree_depth=max_tree_depth, + max_energy_diff=max_energy_diff, + unrolled_leapfrog_steps=unrolled_leapfrog_steps, + parallel_iterations=parallel_iterations, + seed=seed, + **pins, + ) + except tf.errors.ResourceExhaustedError as error: + raise MCMCOOMError( + "ERROR: Out of memory. Try reducing `n_keep` or pass a list of" + " integers as `n_chains` to sample chains serially (see" + " https://developers.google.com/meridian/docs/advanced-modeling/model-debugging#gpu-oom-error)" + ) from error + states.append(mcmc.all_states._asdict()) + traces.append(mcmc.trace) + + mcmc_states = { + k: tf.einsum( + "ij...->ji...", + tf.concat([state[k] for state in states], axis=1)[n_burnin:, ...], + ) + for k in states[0].keys() + if k not in constants.UNSAVED_PARAMETERS + } + # Create Arviz InferenceData for posterior draws. + posterior_coords = self._meridian.create_inference_data_coords( + total_chains, n_keep + ) + posterior_dims = self._meridian.create_inference_data_dims() + infdata_posterior = az.convert_to_inference_data( + mcmc_states, coords=posterior_coords, dims=posterior_dims + ) + + # Save trace metrics in InferenceData. + mcmc_trace = {} + for k in traces[0].keys(): + if k not in constants.IGNORED_TRACE_METRICS: + mcmc_trace[k] = tf.concat( + [ + tf.broadcast_to( + tf.transpose(trace[k][n_burnin:, ...]), + [n_chains_list[i], n_keep], + ) + for i, trace in enumerate(traces) + ], + axis=0, + ) + + trace_coords = { + constants.CHAIN: np.arange(total_chains), + constants.DRAW: np.arange(n_keep), + } + trace_dims = { + k: [constants.CHAIN, constants.DRAW] for k in mcmc_trace.keys() + } + infdata_trace = az.convert_to_inference_data( + mcmc_trace, coords=trace_coords, dims=trace_dims, group="trace" + ) + + # Create Arviz InferenceData for divergent transitions and other sampling + # statistics. Note that InferenceData has a different naming convention + # than Tensorflow, and only certain variables are recongnized. + # https://arviz-devs.github.io/arviz/schema/schema.html#sample-stats + # The list of values returned by windowed_adaptive_nuts() is the following: + # 'step_size', 'tune', 'target_log_prob', 'diverging', 'accept_ratio', + # 'variance_scaling', 'n_steps', 'is_accepted'. + + sample_stats = { + constants.SAMPLE_STATS_METRICS[k]: v + for k, v in mcmc_trace.items() + if k in constants.SAMPLE_STATS_METRICS + } + sample_stats_dims = { + constants.SAMPLE_STATS_METRICS[k]: v + for k, v in trace_dims.items() + if k in constants.SAMPLE_STATS_METRICS + } + # Tensorflow does not include a "draw" dimension on step size metric if same + # step size is used for all chains. Step size must be broadcast to the + # correct shape. + sample_stats[constants.STEP_SIZE] = tf.broadcast_to( + sample_stats[constants.STEP_SIZE], [total_chains, n_keep] + ) + sample_stats_dims[constants.STEP_SIZE] = [constants.CHAIN, constants.DRAW] + infdata_sample_stats = az.convert_to_inference_data( + sample_stats, + coords=trace_coords, + dims=sample_stats_dims, + group="sample_stats", + ) + return az.concat(infdata_posterior, infdata_trace, infdata_sample_stats) diff --git a/meridian/model/posterior_sampler_test.py b/meridian/model/posterior_sampler_test.py new file mode 100644 index 00000000..892fac81 --- /dev/null +++ b/meridian/model/posterior_sampler_test.py @@ -0,0 +1,1374 @@ +# Copyright 2024 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import arviz as az +from meridian import constants +from meridian.model import model +from meridian.model import model_test_data +from meridian.model import posterior_sampler +from meridian.model import prior_distribution +from meridian.model import spec +import numpy as np +import tensorflow as tf +import tensorflow_probability as tfp + + +class PosteriorSamplerTest( + tf.test.TestCase, + parameterized.TestCase, + model_test_data.WithInputDataSamples, +): + + IDS = model_test_data.WithInputDataSamples + + def setUp(self): + super().setUp() + model_test_data.WithInputDataSamples.setUp(self) + + def test_get_joint_dist_zeros(self): + model_spec = spec.ModelSpec( + prior=prior_distribution.PriorDistribution( + knot_values=tfp.distributions.Deterministic(0), + tau_g_excl_baseline=tfp.distributions.Deterministic(0), + beta_m=tfp.distributions.Deterministic(0), + beta_rf=tfp.distributions.Deterministic(0), + eta_m=tfp.distributions.Deterministic(0), + eta_rf=tfp.distributions.Deterministic(0), + gamma_c=tfp.distributions.Deterministic(0), + xi_c=tfp.distributions.Deterministic(0), + alpha_m=tfp.distributions.Deterministic(0), + alpha_rf=tfp.distributions.Deterministic(0), + ec_m=tfp.distributions.Deterministic(0), + ec_rf=tfp.distributions.Deterministic(0), + slope_m=tfp.distributions.Deterministic(0), + slope_rf=tfp.distributions.Deterministic(0), + sigma=tfp.distributions.Deterministic(0), + roi_m=tfp.distributions.Deterministic(0), + roi_rf=tfp.distributions.Deterministic(0), + ) + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + sample = meridian.posterior_sampler._get_joint_dist_unpinned().sample( + self._N_DRAWS + ) + self.assertAllEqual( + sample.y, + tf.zeros(shape=(self._N_DRAWS, self._N_GEOS, self._N_TIMES_SHORT)), + ) + + @parameterized.product( + paid_media_prior_type=[ + constants.PAID_MEDIA_PRIOR_TYPE_ROI, + constants.PAID_MEDIA_PRIOR_TYPE_MROI, + constants.PAID_MEDIA_PRIOR_TYPE_COEFFICIENT, + ], + media_effects_dist=[ + constants.MEDIA_EFFECTS_NORMAL, + constants.MEDIA_EFFECTS_LOG_NORMAL, + ], + ) + def test_get_joint_dist_with_log_prob_media_only( + self, + paid_media_prior_type: str, + media_effects_dist: str, + ): + model_spec = spec.ModelSpec( + paid_media_prior_type=paid_media_prior_type, + media_effects_dist=media_effects_dist, + ) + meridian = model.Meridian( + model_spec=model_spec, + input_data=self.short_input_data_with_media_only, + ) + + # Take a single draw of all parameters from the prior distribution. + par_structtuple = ( + meridian.posterior_sampler._get_joint_dist_unpinned().sample(1) + ) + par = par_structtuple._asdict() + + # Note that "y" is a draw from the prior predictive (transformed) outcome + # distribution. We drop it because "y" is already "pinned" in + # meridian._get_joint_dist() and is not actually a parameter. + del par["y"] + + # Note that the actual (transformed) outcome data is "pinned" as "y". + log_prob_parts_structtuple = ( + meridian.posterior_sampler._get_joint_dist().log_prob_parts(par) + ) + log_prob_parts = { + k: v._asdict() for k, v in log_prob_parts_structtuple._asdict().items() + } + + derived_params = [ + constants.BETA_GM, + constants.GAMMA_GC, + constants.MU_T, + constants.TAU_G, + ] + prior_distribution_params = [ + constants.KNOT_VALUES, + constants.ETA_M, + constants.GAMMA_C, + constants.XI_C, + constants.ALPHA_M, + constants.EC_M, + constants.SLOPE_M, + constants.SIGMA, + ] + + if paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: + derived_params.append(constants.BETA_M) + prior_distribution_params.append(constants.ROI_M) + elif paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: + derived_params.append(constants.BETA_M) + prior_distribution_params.append(constants.MROI_M) + else: + prior_distribution_params.append(constants.BETA_M) + + # Parameters that are derived from other parameters via Deterministic() + # should have zero contribution to log_prob. + for parname in derived_params: + self.assertAllEqual(log_prob_parts["unpinned"][parname][0], 0) + + prior_distribution_logprobs = {} + for parname in prior_distribution_params: + prior_distribution_logprobs[parname] = tf.reduce_sum( + getattr(meridian.prior_broadcast, parname).log_prob(par[parname]) + ) + self.assertAllClose( + prior_distribution_logprobs[parname], + log_prob_parts["unpinned"][parname][0], + ) + + coef_params = [ + constants.BETA_GM_DEV, + constants.GAMMA_GC_DEV, + ] + coef_logprobs = {} + for parname in coef_params: + coef_logprobs[parname] = tf.reduce_sum( + tfp.distributions.Normal(0, 1).log_prob(par[parname]) + ) + self.assertAllClose( + coef_logprobs[parname], log_prob_parts["unpinned"][parname][0] + ) + transformed_media = meridian.adstock_hill_media( + media=meridian.media_tensors.media_scaled, + alpha=par[constants.ALPHA_M], + ec=par[constants.EC_M], + slope=par[constants.SLOPE_M], + )[0, :, :, :] + beta_m = par[constants.BETA_GM][0, :, :] + y_means = ( + par[constants.TAU_G][0, :, None] + + par[constants.MU_T][0, None, :] + + tf.einsum("gtm,gm->gt", transformed_media, beta_m) + + tf.einsum( + "gtc,gc->gt", + meridian.controls_scaled, + par[constants.GAMMA_GC][0, :, :], + ) + ) + y_means_logprob = tf.reduce_sum( + tfp.distributions.Normal(y_means, par[constants.SIGMA]).log_prob( + meridian.kpi_scaled + ) + ) + self.assertAllClose(y_means_logprob, log_prob_parts["pinned"]["y"][0]) + + tau_g_logprob = tf.reduce_sum( + getattr( + meridian.prior_broadcast, constants.TAU_G_EXCL_BASELINE + ).log_prob(par[constants.TAU_G_EXCL_BASELINE]) + ) + self.assertAllClose( + tau_g_logprob, + log_prob_parts["unpinned"][constants.TAU_G_EXCL_BASELINE][0], + ) + + posterior_unnormalized_logprob = ( + sum(prior_distribution_logprobs.values()) + + sum(coef_logprobs.values()) + + y_means_logprob + + tau_g_logprob + ) + self.assertAllClose( + posterior_unnormalized_logprob, + meridian.posterior_sampler._get_joint_dist().log_prob(par)[0], + ) + + @parameterized.product( + paid_media_prior_type=[ + constants.PAID_MEDIA_PRIOR_TYPE_ROI, + constants.PAID_MEDIA_PRIOR_TYPE_MROI, + constants.PAID_MEDIA_PRIOR_TYPE_COEFFICIENT, + ], + media_effects_dist=[ + constants.MEDIA_EFFECTS_NORMAL, + constants.MEDIA_EFFECTS_LOG_NORMAL, + ], + ) + def test_get_joint_dist_with_log_prob_rf_only( + self, + paid_media_prior_type: str, + media_effects_dist: str, + ): + model_spec = spec.ModelSpec( + paid_media_prior_type=paid_media_prior_type, + media_effects_dist=media_effects_dist, + ) + meridian = model.Meridian( + model_spec=model_spec, + input_data=self.short_input_data_with_rf_only, + ) + + # Take a single draw of all parameters from the prior distribution. + par_structtuple = ( + meridian.posterior_sampler._get_joint_dist_unpinned().sample(1) + ) + par = par_structtuple._asdict() + + # Note that "y" is a draw from the prior predictive (transformed) outcome + # distribution. We drop it because "y" is already "pinned" in + # meridian._get_joint_dist() and is not actually a parameter. + del par["y"] + + # Note that the actual (transformed) outcome data is "pinned" as "y". + log_prob_parts_structtuple = ( + meridian.posterior_sampler._get_joint_dist().log_prob_parts(par) + ) + log_prob_parts = { + k: v._asdict() for k, v in log_prob_parts_structtuple._asdict().items() + } + + derived_params = [ + constants.BETA_GRF, + constants.GAMMA_GC, + constants.MU_T, + constants.TAU_G, + ] + prior_distribution_params = [ + constants.KNOT_VALUES, + constants.ETA_RF, + constants.GAMMA_C, + constants.XI_C, + constants.ALPHA_RF, + constants.EC_RF, + constants.SLOPE_RF, + constants.SIGMA, + ] + + if paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: + derived_params.append(constants.BETA_RF) + prior_distribution_params.append(constants.ROI_RF) + elif paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: + derived_params.append(constants.BETA_RF) + prior_distribution_params.append(constants.MROI_RF) + else: + prior_distribution_params.append(constants.BETA_RF) + + # Parameters that are derived from other parameters via Deterministic() + # should have zero contribution to log_prob. + for parname in derived_params: + self.assertAllEqual(log_prob_parts["unpinned"][parname][0], 0) + + prior_distribution_logprobs = {} + for parname in prior_distribution_params: + prior_distribution_logprobs[parname] = tf.reduce_sum( + getattr(meridian.prior_broadcast, parname).log_prob(par[parname]) + ) + self.assertAllClose( + prior_distribution_logprobs[parname], + log_prob_parts["unpinned"][parname][0], + ) + + coef_params = [ + constants.BETA_GRF_DEV, + constants.GAMMA_GC_DEV, + ] + coef_logprobs = {} + for parname in coef_params: + coef_logprobs[parname] = tf.reduce_sum( + tfp.distributions.Normal(0, 1).log_prob(par[parname]) + ) + self.assertAllClose( + coef_logprobs[parname], log_prob_parts["unpinned"][parname][0] + ) + transformed_reach = meridian.adstock_hill_rf( + reach=meridian.rf_tensors.reach_scaled, + frequency=meridian.rf_tensors.frequency, + alpha=par[constants.ALPHA_RF], + ec=par[constants.EC_RF], + slope=par[constants.SLOPE_RF], + )[0, :, :, :] + beta_rf = par[constants.BETA_GRF][0, :, :] + y_means = ( + par[constants.TAU_G][0, :, None] + + par[constants.MU_T][0, None, :] + + tf.einsum("gtm,gm->gt", transformed_reach, beta_rf) + + tf.einsum( + "gtc,gc->gt", + meridian.controls_scaled, + par[constants.GAMMA_GC][0, :, :], + ) + ) + y_means_logprob = tf.reduce_sum( + tfp.distributions.Normal(y_means, par[constants.SIGMA]).log_prob( + meridian.kpi_scaled + ) + ) + self.assertAllClose(y_means_logprob, log_prob_parts["pinned"]["y"][0]) + + tau_g_logprob = tf.reduce_sum( + getattr( + meridian.prior_broadcast, constants.TAU_G_EXCL_BASELINE + ).log_prob(par[constants.TAU_G_EXCL_BASELINE]) + ) + self.assertAllClose( + tau_g_logprob, + log_prob_parts["unpinned"][constants.TAU_G_EXCL_BASELINE][0], + ) + + posterior_unnormalized_logprob = ( + sum(prior_distribution_logprobs.values()) + + sum(coef_logprobs.values()) + + y_means_logprob + + tau_g_logprob + ) + self.assertAllClose( + posterior_unnormalized_logprob, + meridian.posterior_sampler._get_joint_dist().log_prob(par)[0], + ) + + # TODO: Add test for holdout_id. + @parameterized.product( + paid_media_prior_type=[ + constants.PAID_MEDIA_PRIOR_TYPE_ROI, + constants.PAID_MEDIA_PRIOR_TYPE_MROI, + constants.PAID_MEDIA_PRIOR_TYPE_COEFFICIENT, + ], + media_effects_dist=[ + constants.MEDIA_EFFECTS_NORMAL, + constants.MEDIA_EFFECTS_LOG_NORMAL, + ], + ) + def test_get_joint_dist_with_log_prob_media_and_rf( + self, + paid_media_prior_type: str, + media_effects_dist: str, + ): + model_spec = spec.ModelSpec( + paid_media_prior_type=paid_media_prior_type, + media_effects_dist=media_effects_dist, + ) + meridian = model.Meridian( + model_spec=model_spec, + input_data=self.short_input_data_with_media_and_rf, + ) + + # Take a single draw of all parameters from the prior distribution. + par_structtuple = ( + meridian.posterior_sampler._get_joint_dist_unpinned().sample(1) + ) + par = par_structtuple._asdict() + + # Note that "y" is a draw from the prior predictive (transformed) outcome + # distribution. We drop it because "y" is already "pinned" in + # meridian._get_joint_dist() and is not actually a parameter. + del par["y"] + + # Note that the actual (transformed) outcome data is "pinned" as "y". + log_prob_parts_structtuple = ( + meridian.posterior_sampler._get_joint_dist().log_prob_parts(par) + ) + log_prob_parts = { + k: v._asdict() for k, v in log_prob_parts_structtuple._asdict().items() + } + + derived_params = [ + constants.BETA_GM, + constants.BETA_GRF, + constants.GAMMA_GC, + constants.MU_T, + constants.TAU_G, + ] + prior_distribution_params = [ + constants.KNOT_VALUES, + constants.ETA_M, + constants.ETA_RF, + constants.GAMMA_C, + constants.XI_C, + constants.ALPHA_M, + constants.ALPHA_RF, + constants.EC_M, + constants.EC_RF, + constants.SLOPE_M, + constants.SLOPE_RF, + constants.SIGMA, + ] + + if paid_media_prior_type in constants.PAID_MEDIA_PRIOR_TYPE_ROI: + derived_params.append(constants.BETA_M) + derived_params.append(constants.BETA_RF) + prior_distribution_params.append(constants.ROI_M) + prior_distribution_params.append(constants.ROI_RF) + elif paid_media_prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: + derived_params.append(constants.BETA_M) + derived_params.append(constants.BETA_RF) + prior_distribution_params.append(constants.MROI_M) + prior_distribution_params.append(constants.MROI_RF) + else: + prior_distribution_params.append(constants.BETA_M) + prior_distribution_params.append(constants.BETA_RF) + + # Parameters that are derived from other parameters via Deterministic() + # should have zero contribution to log_prob. + for parname in derived_params: + self.assertAllEqual(log_prob_parts["unpinned"][parname][0], 0) + + prior_distribution_logprobs = {} + for parname in prior_distribution_params: + prior_distribution_logprobs[parname] = tf.reduce_sum( + getattr(meridian.prior_broadcast, parname).log_prob(par[parname]) + ) + self.assertAllClose( + prior_distribution_logprobs[parname], + log_prob_parts["unpinned"][parname][0], + ) + + coef_params = [ + constants.BETA_GM_DEV, + constants.BETA_GRF_DEV, + constants.GAMMA_GC_DEV, + ] + coef_logprobs = {} + for parname in coef_params: + coef_logprobs[parname] = tf.reduce_sum( + tfp.distributions.Normal(0, 1).log_prob(par[parname]) + ) + self.assertAllClose( + coef_logprobs[parname], log_prob_parts["unpinned"][parname][0] + ) + transformed_media = meridian.adstock_hill_media( + media=meridian.media_tensors.media_scaled, + alpha=par[constants.ALPHA_M], + ec=par[constants.EC_M], + slope=par[constants.SLOPE_M], + )[0, :, :, :] + transformed_reach = meridian.adstock_hill_rf( + reach=meridian.rf_tensors.reach_scaled, + frequency=meridian.rf_tensors.frequency, + alpha=par[constants.ALPHA_RF], + ec=par[constants.EC_RF], + slope=par[constants.SLOPE_RF], + )[0, :, :, :] + combined_transformed_media = tf.concat( + [transformed_media, transformed_reach], axis=-1 + ) + + combined_beta = tf.concat( + [par[constants.BETA_GM][0, :, :], par[constants.BETA_GRF][0, :, :]], + axis=-1, + ) + y_means = ( + par[constants.TAU_G][0, :, None] + + par[constants.MU_T][0, None, :] + + tf.einsum("gtm,gm->gt", combined_transformed_media, combined_beta) + + tf.einsum( + "gtc,gc->gt", + meridian.controls_scaled, + par[constants.GAMMA_GC][0, :, :], + ) + ) + y_means_logprob = tf.reduce_sum( + tfp.distributions.Normal(y_means, par[constants.SIGMA]).log_prob( + meridian.kpi_scaled + ) + ) + self.assertAllClose(y_means_logprob, log_prob_parts["pinned"]["y"][0]) + + tau_g_logprob = tf.reduce_sum( + getattr( + meridian.prior_broadcast, constants.TAU_G_EXCL_BASELINE + ).log_prob(par[constants.TAU_G_EXCL_BASELINE]) + ) + self.assertAllClose( + tau_g_logprob, + log_prob_parts["unpinned"][constants.TAU_G_EXCL_BASELINE][0], + ) + + posterior_unnormalized_logprob = ( + sum(prior_distribution_logprobs.values()) + + sum(coef_logprobs.values()) + + y_means_logprob + + tau_g_logprob + ) + self.assertAllClose( + posterior_unnormalized_logprob, + meridian.posterior_sampler._get_joint_dist().log_prob(par)[0], + ) + + def test_sample_posterior_media_and_rf_returns_correct_shape(self): + mock_sample_posterior = self.enter_context( + mock.patch.object( + posterior_sampler, + "_xla_windowed_adaptive_nuts", + autospec=True, + return_value=collections.namedtuple( + "StatesAndTrace", ["all_states", "trace"] + )( + all_states=self.test_posterior_states_media_and_rf, + trace=self.test_trace, + ), + ) + ) + model_spec = spec.ModelSpec( + roi_calibration_period=self._ROI_CALIBRATION_PERIOD, + rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + + meridian.sample_posterior( + n_chains=self._N_CHAINS, + n_adapt=self._N_ADAPT, + n_burnin=self._N_BURNIN, + n_keep=self._N_KEEP, + ) + mock_sample_posterior.assert_called_with( + n_draws=self._N_BURNIN + self._N_KEEP, + joint_dist=mock.ANY, + n_chains=self._N_CHAINS, + num_adaptation_steps=self._N_ADAPT, + current_state=None, + init_step_size=None, + dual_averaging_kwargs=None, + max_tree_depth=10, + max_energy_diff=500.0, + unrolled_leapfrog_steps=1, + parallel_iterations=10, + seed=None, + ) + knots_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) + control_shape = (self._N_CHAINS, self._N_KEEP, self._N_CONTROLS) + media_channel_shape = (self._N_CHAINS, self._N_KEEP, self._N_MEDIA_CHANNELS) + rf_channel_shape = (self._N_CHAINS, self._N_KEEP, self._N_RF_CHANNELS) + sigma_shape = ( + (self._N_CHAINS, self._N_KEEP, self._N_GEOS) + if meridian.unique_sigma_for_each_geo + else (self._N_CHAINS, self._N_KEEP, 1) + ) + geo_shape = (self._N_CHAINS, self._N_KEEP, self._N_GEOS) + time_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) + geo_control_shape = geo_shape + (self._N_CONTROLS,) + geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) + geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) + + media_parameters = list(constants.MEDIA_PARAMETER_NAMES) + media_parameters.remove(constants.BETA_GM) + rf_parameters = list(constants.RF_PARAMETER_NAMES) + rf_parameters.remove(constants.BETA_GRF) + + posterior = meridian.inference_data.posterior + shape_to_params = { + knots_shape: [ + getattr(posterior, attr) for attr in constants.KNOTS_PARAMETERS + ], + control_shape: [ + getattr(posterior, attr) for attr in constants.CONTROL_PARAMETERS + ], + media_channel_shape: [ + getattr(posterior, attr) for attr in media_parameters + ], + rf_channel_shape: [getattr(posterior, attr) for attr in rf_parameters], + sigma_shape: [ + getattr(posterior, attr) for attr in constants.SIGMA_PARAMETERS + ], + geo_shape: [ + getattr(posterior, attr) for attr in constants.GEO_PARAMETERS + ], + time_shape: [ + getattr(posterior, attr) for attr in constants.TIME_PARAMETERS + ], + geo_control_shape: [ + getattr(posterior, attr) + for attr in constants.GEO_CONTROL_PARAMETERS + ], + geo_media_channel_shape: [ + getattr(posterior, attr) for attr in constants.GEO_MEDIA_PARAMETERS + ], + geo_rf_channel_shape: [ + getattr(posterior, attr) for attr in constants.GEO_RF_PARAMETERS + ], + } + for shape, params in shape_to_params.items(): + for param in params: + self.assertEqual(param.shape, shape) + + for attr in [ + constants.STEP_SIZE, + constants.TARGET_LOG_PROBABILITY_ARVIZ, + constants.DIVERGING, + constants.N_STEPS, + ]: + self.assertEqual( + getattr(meridian.inference_data.sample_stats, attr).shape, + ( + self._N_CHAINS, + self._N_KEEP, + ), + ) + for attr in [ + constants.STEP_SIZE, + constants.TUNE, + constants.TARGET_LOG_PROBABILITY_TF, + constants.DIVERGING, + constants.ACCEPT_RATIO, + constants.N_STEPS, + ]: + self.assertEqual( + getattr(meridian.inference_data.trace, attr).shape, + ( + self._N_CHAINS, + self._N_KEEP, + ), + ) + + def test_sample_posterior_media_only_returns_correct_shape(self): + mock_sample_posterior = self.enter_context( + mock.patch.object( + posterior_sampler, + "_xla_windowed_adaptive_nuts", + autospec=True, + return_value=collections.namedtuple( + "StatesAndTrace", ["all_states", "trace"] + )( + all_states=self.test_posterior_states_media_only, + trace=self.test_trace, + ), + ) + ) + model_spec = spec.ModelSpec( + roi_calibration_period=self._ROI_CALIBRATION_PERIOD, + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_media_only, + model_spec=model_spec, + ) + + meridian.sample_posterior( + n_chains=self._N_CHAINS, + n_adapt=self._N_ADAPT, + n_burnin=self._N_BURNIN, + n_keep=self._N_KEEP, + ) + mock_sample_posterior.assert_called_with( + n_draws=self._N_BURNIN + self._N_KEEP, + joint_dist=mock.ANY, + n_chains=self._N_CHAINS, + num_adaptation_steps=self._N_ADAPT, + current_state=None, + init_step_size=None, + dual_averaging_kwargs=None, + max_tree_depth=10, + max_energy_diff=500.0, + unrolled_leapfrog_steps=1, + parallel_iterations=10, + seed=None, + ) + knots_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) + control_shape = (self._N_CHAINS, self._N_KEEP, self._N_CONTROLS) + media_channel_shape = (self._N_CHAINS, self._N_KEEP, self._N_MEDIA_CHANNELS) + sigma_shape = ( + (self._N_CHAINS, self._N_KEEP, self._N_GEOS) + if meridian.unique_sigma_for_each_geo + else (self._N_CHAINS, self._N_KEEP, 1) + ) + geo_shape = (self._N_CHAINS, self._N_KEEP, self._N_GEOS) + time_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) + geo_control_shape = geo_shape + (self._N_CONTROLS,) + geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) + + media_parameters = list(constants.MEDIA_PARAMETER_NAMES) + media_parameters.remove(constants.BETA_GM) + + posterior = meridian.inference_data.posterior + shape_to_params = { + knots_shape: [ + getattr(posterior, attr) for attr in constants.KNOTS_PARAMETERS + ], + control_shape: [ + getattr(posterior, attr) for attr in constants.CONTROL_PARAMETERS + ], + media_channel_shape: [ + getattr(posterior, attr) for attr in media_parameters + ], + sigma_shape: [ + getattr(posterior, attr) for attr in constants.SIGMA_PARAMETERS + ], + geo_shape: [ + getattr(posterior, attr) for attr in constants.GEO_PARAMETERS + ], + time_shape: [ + getattr(posterior, attr) for attr in constants.TIME_PARAMETERS + ], + geo_control_shape: [ + getattr(posterior, attr) + for attr in constants.GEO_CONTROL_PARAMETERS + ], + geo_media_channel_shape: [ + getattr(posterior, attr) for attr in constants.GEO_MEDIA_PARAMETERS + ], + } + for shape, params in shape_to_params.items(): + for param in params: + self.assertEqual(param.shape, shape) + + for attr in [ + constants.STEP_SIZE, + constants.TARGET_LOG_PROBABILITY_ARVIZ, + constants.DIVERGING, + constants.N_STEPS, + ]: + self.assertEqual( + getattr(meridian.inference_data.sample_stats, attr).shape, + ( + self._N_CHAINS, + self._N_KEEP, + ), + ) + for attr in [ + constants.STEP_SIZE, + constants.TUNE, + constants.TARGET_LOG_PROBABILITY_TF, + constants.DIVERGING, + constants.ACCEPT_RATIO, + constants.N_STEPS, + ]: + self.assertEqual( + getattr(meridian.inference_data.trace, attr).shape, + ( + self._N_CHAINS, + self._N_KEEP, + ), + ) + + def test_sample_posterior_rf_only_returns_correct_shape(self): + mock_sample_posterior = self.enter_context( + mock.patch.object( + posterior_sampler, + "_xla_windowed_adaptive_nuts", + autospec=True, + return_value=collections.namedtuple( + "StatesAndTrace", ["all_states", "trace"] + )( + all_states=self.test_posterior_states_rf_only, + trace=self.test_trace, + ), + ) + ) + model_spec = spec.ModelSpec( + rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_rf_only, + model_spec=model_spec, + ) + + meridian.sample_posterior( + n_chains=self._N_CHAINS, + n_adapt=self._N_ADAPT, + n_burnin=self._N_BURNIN, + n_keep=self._N_KEEP, + ) + mock_sample_posterior.assert_called_with( + n_draws=self._N_BURNIN + self._N_KEEP, + joint_dist=mock.ANY, + n_chains=self._N_CHAINS, + num_adaptation_steps=self._N_ADAPT, + current_state=None, + init_step_size=None, + dual_averaging_kwargs=None, + max_tree_depth=10, + max_energy_diff=500.0, + unrolled_leapfrog_steps=1, + parallel_iterations=10, + seed=None, + ) + knots_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) + control_shape = (self._N_CHAINS, self._N_KEEP, self._N_CONTROLS) + rf_channel_shape = (self._N_CHAINS, self._N_KEEP, self._N_RF_CHANNELS) + sigma_shape = ( + (self._N_CHAINS, self._N_KEEP, self._N_GEOS) + if meridian.unique_sigma_for_each_geo + else (self._N_CHAINS, self._N_KEEP, 1) + ) + geo_shape = (self._N_CHAINS, self._N_KEEP, self._N_GEOS) + time_shape = (self._N_CHAINS, self._N_KEEP, self._N_TIMES_SHORT) + geo_control_shape = geo_shape + (self._N_CONTROLS,) + geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) + + rf_parameters = list(constants.RF_PARAMETER_NAMES) + rf_parameters.remove(constants.BETA_GRF) + + posterior = meridian.inference_data.posterior + shape_to_params = { + knots_shape: [ + getattr(posterior, attr) for attr in constants.KNOTS_PARAMETERS + ], + control_shape: [ + getattr(posterior, attr) for attr in constants.CONTROL_PARAMETERS + ], + rf_channel_shape: [getattr(posterior, attr) for attr in rf_parameters], + sigma_shape: [ + getattr(posterior, attr) for attr in constants.SIGMA_PARAMETERS + ], + geo_shape: [ + getattr(posterior, attr) for attr in constants.GEO_PARAMETERS + ], + time_shape: [ + getattr(posterior, attr) for attr in constants.TIME_PARAMETERS + ], + geo_control_shape: [ + getattr(posterior, attr) + for attr in constants.GEO_CONTROL_PARAMETERS + ], + geo_rf_channel_shape: [ + getattr(posterior, attr) for attr in constants.GEO_RF_PARAMETERS + ], + } + for shape, params in shape_to_params.items(): + for param in params: + self.assertEqual(param.shape, shape) + + for attr in [ + constants.STEP_SIZE, + constants.TARGET_LOG_PROBABILITY_ARVIZ, + constants.DIVERGING, + constants.N_STEPS, + ]: + self.assertEqual( + getattr(meridian.inference_data.sample_stats, attr).shape, + ( + self._N_CHAINS, + self._N_KEEP, + ), + ) + for attr in [ + constants.STEP_SIZE, + constants.TUNE, + constants.TARGET_LOG_PROBABILITY_TF, + constants.DIVERGING, + constants.ACCEPT_RATIO, + constants.N_STEPS, + ]: + self.assertEqual( + getattr(meridian.inference_data.trace, attr).shape, + ( + self._N_CHAINS, + self._N_KEEP, + ), + ) + + def test_sample_posterior_media_and_rf_sequential_returns_correct_shape(self): + mock_sample_posterior = self.enter_context( + mock.patch.object( + posterior_sampler, + "_xla_windowed_adaptive_nuts", + autospec=True, + return_value=collections.namedtuple( + "StatesAndTrace", ["all_states", "trace"] + )( + all_states=self.test_posterior_states_media_and_rf, + trace=self.test_trace, + ), + ) + ) + model_spec = spec.ModelSpec( + roi_calibration_period=self._ROI_CALIBRATION_PERIOD, + rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + + meridian.sample_posterior( + n_chains=[self._N_CHAINS, self._N_CHAINS], + n_adapt=self._N_ADAPT, + n_burnin=self._N_BURNIN, + n_keep=self._N_KEEP, + ) + mock_sample_posterior.assert_called_with( + n_draws=self._N_BURNIN + self._N_KEEP, + joint_dist=mock.ANY, + n_chains=self._N_CHAINS, + num_adaptation_steps=self._N_ADAPT, + current_state=None, + init_step_size=None, + dual_averaging_kwargs=None, + max_tree_depth=10, + max_energy_diff=500.0, + unrolled_leapfrog_steps=1, + parallel_iterations=10, + seed=None, + ) + n_total_chains = self._N_CHAINS * 2 + knots_shape = (n_total_chains, self._N_KEEP, self._N_TIMES_SHORT) + control_shape = (n_total_chains, self._N_KEEP, self._N_CONTROLS) + media_channel_shape = (n_total_chains, self._N_KEEP, self._N_MEDIA_CHANNELS) + rf_channel_shape = (n_total_chains, self._N_KEEP, self._N_RF_CHANNELS) + sigma_shape = ( + (n_total_chains, self._N_KEEP, self._N_GEOS) + if meridian.unique_sigma_for_each_geo + else (n_total_chains, self._N_KEEP, 1) + ) + geo_shape = (n_total_chains, self._N_KEEP, self._N_GEOS) + time_shape = (n_total_chains, self._N_KEEP, self._N_TIMES_SHORT) + geo_control_shape = geo_shape + (self._N_CONTROLS,) + geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) + geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) + + media_parameters = list(constants.MEDIA_PARAMETER_NAMES) + media_parameters.remove(constants.BETA_GM) + rf_parameters = list(constants.RF_PARAMETER_NAMES) + rf_parameters.remove(constants.BETA_GRF) + + posterior = meridian.inference_data.posterior + shape_to_params = { + knots_shape: [ + getattr(posterior, attr) for attr in constants.KNOTS_PARAMETERS + ], + control_shape: [ + getattr(posterior, attr) for attr in constants.CONTROL_PARAMETERS + ], + media_channel_shape: [ + getattr(posterior, attr) for attr in media_parameters + ], + rf_channel_shape: [getattr(posterior, attr) for attr in rf_parameters], + sigma_shape: [ + getattr(posterior, attr) for attr in constants.SIGMA_PARAMETERS + ], + geo_shape: [ + getattr(posterior, attr) for attr in constants.GEO_PARAMETERS + ], + time_shape: [ + getattr(posterior, attr) for attr in constants.TIME_PARAMETERS + ], + geo_control_shape: [ + getattr(posterior, attr) + for attr in constants.GEO_CONTROL_PARAMETERS + ], + geo_media_channel_shape: [ + getattr(posterior, attr) for attr in constants.GEO_MEDIA_PARAMETERS + ], + geo_rf_channel_shape: [ + getattr(posterior, attr) for attr in constants.GEO_RF_PARAMETERS + ], + } + for shape, params in shape_to_params.items(): + for param in params: + self.assertEqual(param.shape, shape) + + for attr in [ + constants.STEP_SIZE, + constants.TARGET_LOG_PROBABILITY_ARVIZ, + constants.DIVERGING, + constants.N_STEPS, + ]: + self.assertEqual( + getattr(meridian.inference_data.sample_stats, attr).shape, + ( + n_total_chains, + self._N_KEEP, + ), + ) + for attr in [ + constants.STEP_SIZE, + constants.TUNE, + constants.TARGET_LOG_PROBABILITY_TF, + constants.DIVERGING, + constants.ACCEPT_RATIO, + constants.N_STEPS, + ]: + self.assertEqual( + getattr(meridian.inference_data.trace, attr).shape, + ( + n_total_chains, + self._N_KEEP, + ), + ) + + def test_sample_posterior_raises_oom_error_when_limits_exceeded(self): + self.enter_context( + mock.patch.object( + posterior_sampler, + "_xla_windowed_adaptive_nuts", + autospec=True, + side_effect=tf.errors.ResourceExhaustedError( + None, None, "Resource exhausted" + ), + ) + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=spec.ModelSpec(), + ) + + with self.assertRaises(model.MCMCOOMError): + meridian.sample_posterior( + n_chains=self._N_CHAINS, + n_adapt=self._N_ADAPT, + n_burnin=self._N_BURNIN, + n_keep=self._N_KEEP, + ) + + def test_injected_sample_posterior_media_and_rf_returns_correct_shape(self): + """Checks validation passes with correct shapes.""" + self.enter_context( + mock.patch.object( + posterior_sampler, + "_xla_windowed_adaptive_nuts", + autospec=True, + return_value=collections.namedtuple( + "StatesAndTrace", ["all_states", "trace"] + )( + all_states=self.test_posterior_states_media_and_rf, + trace=self.test_trace, + ), + ) + ) + model_spec = spec.ModelSpec( + roi_calibration_period=self._ROI_CALIBRATION_PERIOD, + rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + + meridian.sample_posterior( + n_chains=self._N_CHAINS, + n_adapt=self._N_ADAPT, + n_burnin=self._N_BURNIN, + n_keep=self._N_KEEP, + ) + inference_data = meridian.inference_data + meridian_with_inference_data = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + inference_data=inference_data, + ) + + self.assertEqual( + meridian_with_inference_data.inference_data, inference_data + ) + + def test_injected_sample_posterior_media_only_returns_correct_shape(self): + """Checks validation passes with correct shapes.""" + self.enter_context( + mock.patch.object( + posterior_sampler, + "_xla_windowed_adaptive_nuts", + autospec=True, + return_value=collections.namedtuple( + "StatesAndTrace", ["all_states", "trace"] + )( + all_states=self.test_posterior_states_media_only, + trace=self.test_trace, + ), + ) + ) + model_spec = spec.ModelSpec( + roi_calibration_period=self._ROI_CALIBRATION_PERIOD, + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_media_only, + model_spec=model_spec, + ) + + meridian.sample_posterior( + n_chains=self._N_CHAINS, + n_adapt=self._N_ADAPT, + n_burnin=self._N_BURNIN, + n_keep=self._N_KEEP, + ) + inference_data = meridian.inference_data + meridian_with_inference_data = model.Meridian( + input_data=self.short_input_data_with_media_only, + model_spec=model_spec, + inference_data=inference_data, + ) + + self.assertEqual( + meridian_with_inference_data.inference_data, inference_data + ) + + @parameterized.named_parameters( + dict( + testcase_name="control_variables", + coord=constants.CONTROL_VARIABLE, + mismatched_posteriors={ + constants.GAMMA_C: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_CONTROLS + 1, + ), + constants.GAMMA_GC: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_GEOS, + IDS._N_CONTROLS + 1, + ), + constants.XI_C: (IDS._N_CHAINS, IDS._N_KEEP, IDS._N_CONTROLS + 1), + }, + mismatched_coord_size=IDS._N_CONTROLS + 1, + expected_coord_size=IDS._N_CONTROLS, + ), + dict( + testcase_name="geos", + coord=constants.GEO, + mismatched_posteriors={ + constants.BETA_GM: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_GEOS + 1, + IDS._N_MEDIA_CHANNELS, + ), + constants.BETA_GRF: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_GEOS + 1, + IDS._N_RF_CHANNELS, + ), + constants.GAMMA_GC: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_GEOS + 1, + IDS._N_CONTROLS, + ), + constants.TAU_G: (IDS._N_CHAINS, IDS._N_KEEP, IDS._N_GEOS + 1), + }, + mismatched_coord_size=IDS._N_GEOS + 1, + expected_coord_size=IDS._N_GEOS, + ), + dict( + testcase_name="knots", + coord=constants.KNOTS, + mismatched_posteriors={ + constants.KNOT_VALUES: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_TIMES_SHORT + 1, + ), + }, + mismatched_coord_size=IDS._N_TIMES_SHORT + 1, + expected_coord_size=IDS._N_TIMES_SHORT, + ), + dict( + testcase_name="times", + coord=constants.TIME, + mismatched_posteriors={ + constants.MU_T: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_TIMES_SHORT + 1, + ), + }, + mismatched_coord_size=IDS._N_TIMES_SHORT + 1, + expected_coord_size=IDS._N_TIMES_SHORT, + ), + dict( + testcase_name="sigma_dims", + coord=constants.SIGMA_DIM, + mismatched_posteriors={ + constants.SIGMA: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_GEOS_NATIONAL + 1, + ), + }, + mismatched_coord_size=IDS._N_GEOS_NATIONAL + 1, + expected_coord_size=IDS._N_GEOS_NATIONAL, + ), + dict( + testcase_name="media_channels", + coord=constants.MEDIA_CHANNEL, + mismatched_posteriors={ + constants.ALPHA_M: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_MEDIA_CHANNELS + 1, + ), + constants.BETA_GM: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_GEOS, + IDS._N_MEDIA_CHANNELS + 1, + ), + constants.BETA_M: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_MEDIA_CHANNELS + 1, + ), + constants.EC_M: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_MEDIA_CHANNELS + 1, + ), + constants.ETA_M: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_MEDIA_CHANNELS + 1, + ), + constants.ROI_M: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_MEDIA_CHANNELS + 1, + ), + constants.SLOPE_M: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_MEDIA_CHANNELS + 1, + ), + }, + mismatched_coord_size=IDS._N_MEDIA_CHANNELS + 1, + expected_coord_size=IDS._N_MEDIA_CHANNELS, + ), + dict( + testcase_name="rf_channels", + coord=constants.RF_CHANNEL, + mismatched_posteriors={ + constants.ALPHA_RF: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_RF_CHANNELS + 1, + ), + constants.BETA_GRF: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_GEOS, + IDS._N_RF_CHANNELS + 1, + ), + constants.BETA_RF: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_RF_CHANNELS + 1, + ), + constants.EC_RF: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_RF_CHANNELS + 1, + ), + constants.ETA_RF: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_RF_CHANNELS + 1, + ), + constants.ROI_RF: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_RF_CHANNELS + 1, + ), + constants.SLOPE_RF: ( + IDS._N_CHAINS, + IDS._N_KEEP, + IDS._N_RF_CHANNELS + 1, + ), + }, + mismatched_coord_size=IDS._N_RF_CHANNELS + 1, + expected_coord_size=IDS._N_RF_CHANNELS, + ), + ) + def test_validate_injected_inference_data_posterior_incorrect_coordinates( + self, + coord, + mismatched_posteriors, + mismatched_coord_size, + expected_coord_size, + ): + """Checks posterior validation fails with incorrect coordinates.""" + self.enter_context( + mock.patch.object( + posterior_sampler, + "_xla_windowed_adaptive_nuts", + autospec=True, + return_value=collections.namedtuple( + "StatesAndTrace", ["all_states", "trace"] + )( + all_states=self.test_posterior_states_media_and_rf, + trace=self.test_trace, + ), + ) + ) + model_spec = spec.ModelSpec( + roi_calibration_period=self._ROI_CALIBRATION_PERIOD, + rf_roi_calibration_period=self._RF_ROI_CALIBRATION_PERIOD, + ) + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + + meridian.sample_posterior( + n_chains=self._N_CHAINS, + n_adapt=self._N_ADAPT, + n_burnin=self._N_BURNIN, + n_keep=self._N_KEEP, + ) + + posterior_coords = meridian.create_inference_data_coords( + self._N_CHAINS, self._N_KEEP + ) + posterior_dims = meridian.create_inference_data_dims() + posterior_samples = dict(meridian.inference_data.posterior) + for posterior in mismatched_posteriors: + posterior_samples[posterior] = tf.zeros(mismatched_posteriors[posterior]) + + posterior_coords = dict(posterior_coords) + posterior_coords[coord] = np.arange(mismatched_coord_size) + + inference_data = az.convert_to_inference_data( + posterior_samples, + coords=posterior_coords, + dims=posterior_dims, + group=constants.POSTERIOR, + ) + + with self.assertRaisesRegex( + ValueError, + f"Injected inference data {constants.POSTERIOR} has incorrect" + f" coordinate '{coord}': expected" + f" {expected_coord_size}, got {mismatched_coord_size}", + ): + _ = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + inference_data=inference_data, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/meridian/model/prior_sampler.py b/meridian/model/prior_sampler.py new file mode 100644 index 00000000..4c8d352b --- /dev/null +++ b/meridian/model/prior_sampler.py @@ -0,0 +1,629 @@ +# Copyright 2024 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module for sampling prior distributions in a Meridian model.""" + +from collections.abc import Mapping + +import arviz as az +from meridian import constants +import tensorflow as tf +import tensorflow_probability as tfp + + +__all__ = [ + "PriorSampler", +] + + +def _get_tau_g( + tau_g_excl_baseline: tf.Tensor, baseline_geo_idx: int +) -> tfp.distributions.Distribution: + """Computes `tau_g` from `tau_g_excl_baseline`. + + This function computes `tau_g` by inserting a column of zeros at the + `baseline_geo` position in `tau_g_excl_baseline`. + + Args: + tau_g_excl_baseline: A tensor of shape `[..., n_geos - 1]` for the + user-defined dimensions of the `tau_g` parameter distribution. + baseline_geo_idx: The index of the baseline geo to be set to zero. + + Returns: + A tensor of shape `[..., n_geos]` with the final distribution of the `tau_g` + parameter with zero at position `baseline_geo_idx` and matching + `tau_g_excl_baseline` elsewhere. + """ + rank = len(tau_g_excl_baseline.shape) + shape = tau_g_excl_baseline.shape[:-1] + [1] if rank != 1 else 1 + tau_g = tf.concat( + [ + tau_g_excl_baseline[..., :baseline_geo_idx], + tf.zeros(shape, dtype=tau_g_excl_baseline.dtype), + tau_g_excl_baseline[..., baseline_geo_idx:], + ], + axis=rank - 1, + ) + return tfp.distributions.Deterministic(tau_g, name="tau_g") + + +class PriorSampler: + """Samples from prior distributions in a Meridian model.""" + + def __init__(self, meridian): # meridian: model.Meridian + self._meridian = meridian + + def get_roi_prior_beta_m_value( + self, + alpha_m: tf.Tensor, + beta_gm_dev: tf.Tensor, + ec_m: tf.Tensor, + eta_m: tf.Tensor, + roi_or_mroi_m: tf.Tensor, + slope_m: tf.Tensor, + media_transformed: tf.Tensor, + ) -> tf.Tensor: + """Returns a tensor to be used in `beta_m`.""" + mmm = self._meridian + + # The `roi_or_mroi_m` parameter represents either ROI or mROI. For reach & + # frequency channels, marginal ROI priors are defined as "mROI by reach", + # which is equivalent to ROI. + media_spend = mmm.media_tensors.media_spend + media_spend_counterfactual = mmm.media_tensors.media_spend_counterfactual + media_counterfactual_scaled = mmm.media_tensors.media_counterfactual_scaled + # If we got here, then we should already have media tensors derived from + # non-None InputData.media data. + assert media_spend is not None + assert media_spend_counterfactual is not None + assert media_counterfactual_scaled is not None + + # Use absolute value here because this difference will be negative for + # marginal ROI priors. + inc_revenue_m = roi_or_mroi_m * tf.reduce_sum( + tf.abs(media_spend - media_spend_counterfactual), + range(media_spend.ndim - 1), + ) + + if ( + mmm.model_spec.roi_calibration_period is None + and mmm.model_spec.paid_media_prior_type + == constants.PAID_MEDIA_PRIOR_TYPE_ROI + ): + # We can skip the adstock/hill computation step in this case. + media_counterfactual_transformed = tf.zeros_like(media_transformed) + else: + media_counterfactual_transformed = mmm.adstock_hill_media( + media=media_counterfactual_scaled, + alpha=alpha_m, + ec=ec_m, + slope=slope_m, + ) + + revenue_per_kpi = mmm.revenue_per_kpi + if mmm.input_data.revenue_per_kpi is None: + revenue_per_kpi = tf.ones([mmm.n_geos, mmm.n_times], dtype=tf.float32) + # Note: use absolute value here because this difference will be negative for + # marginal ROI priors. + media_contrib_gm = tf.einsum( + "...gtm,g,,gt->...gm", + tf.abs(media_transformed - media_counterfactual_transformed), + mmm.population, + mmm.kpi_transformer.population_scaled_stdev, + revenue_per_kpi, + ) + + if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL: + media_contrib_m = tf.einsum("...gm->...m", media_contrib_gm) + random_effect_m = tf.einsum( + "...m,...gm,...gm->...m", eta_m, beta_gm_dev, media_contrib_gm + ) + return (inc_revenue_m - random_effect_m) / media_contrib_m + else: + # For log_normal, beta_m and eta_m are not mean & std. + # The parameterization is beta_gm ~ exp(beta_m + eta_m * N(0, 1)). + random_effect_m = tf.einsum( + "...gm,...gm->...m", + tf.math.exp(beta_gm_dev * eta_m[..., tf.newaxis, :]), + media_contrib_gm, + ) + return tf.math.log(inc_revenue_m) - tf.math.log(random_effect_m) + + def get_roi_prior_beta_rf_value( + self, + alpha_rf: tf.Tensor, + beta_grf_dev: tf.Tensor, + ec_rf: tf.Tensor, + eta_rf: tf.Tensor, + roi_or_mroi_rf: tf.Tensor, + slope_rf: tf.Tensor, + rf_transformed: tf.Tensor, + ) -> tf.Tensor: + """Returns a tensor to be used in `beta_rf`.""" + mmm = self._meridian + + rf_spend = mmm.rf_tensors.rf_spend + rf_spend_counterfactual = mmm.rf_tensors.rf_spend_counterfactual + reach_counterfactual_scaled = mmm.rf_tensors.reach_counterfactual_scaled + frequency = mmm.rf_tensors.frequency + # If we got here, then we should already have RF media tensors derived from + # non-None InputData.reach data. + assert rf_spend is not None + assert rf_spend_counterfactual is not None + assert reach_counterfactual_scaled is not None + assert frequency is not None + + inc_revenue_rf = roi_or_mroi_rf * tf.reduce_sum( + rf_spend - rf_spend_counterfactual, + range(rf_spend.ndim - 1), + ) + if mmm.model_spec.rf_roi_calibration_period is not None: + rf_counterfactual_transformed = mmm.adstock_hill_rf( + reach=reach_counterfactual_scaled, + frequency=frequency, + alpha=alpha_rf, + ec=ec_rf, + slope=slope_rf, + ) + else: + rf_counterfactual_transformed = tf.zeros_like(rf_transformed) + revenue_per_kpi = mmm.revenue_per_kpi + if mmm.input_data.revenue_per_kpi is None: + revenue_per_kpi = tf.ones([mmm.n_geos, mmm.n_times], dtype=tf.float32) + + media_contrib_grf = tf.einsum( + "...gtm,g,,gt->...gm", + rf_transformed - rf_counterfactual_transformed, + mmm.population, + mmm.kpi_transformer.population_scaled_stdev, + revenue_per_kpi, + ) + if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL: + media_contrib_rf = tf.einsum("...gm->...m", media_contrib_grf) + random_effect_rf = tf.einsum( + "...m,...gm,...gm->...m", eta_rf, beta_grf_dev, media_contrib_grf + ) + return (inc_revenue_rf - random_effect_rf) / media_contrib_rf + else: + # For log_normal, beta_rf and eta_rf are not mean & std. + # The parameterization is beta_grf ~ exp(beta_rf + eta_rf * N(0, 1)). + random_effect_rf = tf.einsum( + "...gm,...gm->...m", + tf.math.exp(beta_grf_dev * eta_rf[..., tf.newaxis, :]), + media_contrib_grf, + ) + return tf.math.log(inc_revenue_rf) - tf.math.log(random_effect_rf) + + def _sample_media_priors( + self, + n_draws: int, + seed: int | None = None, + ) -> Mapping[str, tf.Tensor]: + """Draws samples from the prior distributions of the media variables. + + Args: + n_draws: Number of samples drawn from the prior distribution. + seed: Used to set the seed for reproducible results. For more information, + see [PRNGS and seeds] + (https://github.com/tensorflow/probability/blob/main/PRNGS.md). + + Returns: + A mapping of media parameter names to a tensor of shape `[n_draws, n_geos, + n_media_channels]` or `[n_draws, n_media_channels]` containing the + samples. + """ + mmm = self._meridian + + prior = mmm.prior_broadcast + sample_shape = [1, n_draws] + sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} + media_vars = { + constants.ALPHA_M: prior.alpha_m.sample(**sample_kwargs), + constants.EC_M: prior.ec_m.sample(**sample_kwargs), + constants.ETA_M: prior.eta_m.sample(**sample_kwargs), + constants.SLOPE_M: prior.slope_m.sample(**sample_kwargs), + } + beta_gm_dev = tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [mmm.n_geos, mmm.n_media_channels], + name=constants.BETA_GM_DEV, + ).sample(**sample_kwargs) + media_transformed = mmm.adstock_hill_media( + media=mmm.media_tensors.media_scaled, + alpha=media_vars[constants.ALPHA_M], + ec=media_vars[constants.EC_M], + slope=media_vars[constants.SLOPE_M], + ) + + prior_type = mmm.model_spec.paid_media_prior_type + if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: + roi_m = prior.roi_m.sample(**sample_kwargs) + beta_m_value = self.get_roi_prior_beta_m_value( + beta_gm_dev=beta_gm_dev, + media_transformed=media_transformed, + roi_or_mroi_m=roi_m, + **media_vars, + ) + media_vars[constants.ROI_M] = roi_m + media_vars[constants.BETA_M] = tfp.distributions.Deterministic( + beta_m_value, name=constants.BETA_M + ).sample() + elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: + mroi_m = prior.mroi_m.sample(**sample_kwargs) + beta_m_value = self.get_roi_prior_beta_m_value( + beta_gm_dev=beta_gm_dev, + media_transformed=media_transformed, + roi_or_mroi_m=mroi_m, + **media_vars, + ) + media_vars[constants.MROI_M] = mroi_m + media_vars[constants.BETA_M] = tfp.distributions.Deterministic( + beta_m_value, name=constants.BETA_M + ).sample() + else: + media_vars[constants.BETA_M] = prior.beta_m.sample(**sample_kwargs) + + beta_eta_combined = ( + media_vars[constants.BETA_M][..., tf.newaxis, :] + + media_vars[constants.ETA_M][..., tf.newaxis, :] * beta_gm_dev + ) + beta_gm_value = ( + beta_eta_combined + if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL + else tf.math.exp(beta_eta_combined) + ) + media_vars[constants.BETA_GM] = tfp.distributions.Deterministic( + beta_gm_value, name=constants.BETA_GM + ).sample() + + return media_vars + + def _sample_rf_priors( + self, + n_draws: int, + seed: int | None = None, + ) -> Mapping[str, tf.Tensor]: + """Draws samples from the prior distributions of the RF variables. + + Args: + n_draws: Number of samples drawn from the prior distribution. + seed: Used to set the seed for reproducible results. For more information, + see [PRNGS and seeds] + (https://github.com/tensorflow/probability/blob/main/PRNGS.md). + + Returns: + A mapping of RF parameter names to a tensor of shape `[n_draws, n_geos, + n_rf_channels]` or `[n_draws, n_rf_channels]` containing the samples. + """ + mmm = self._meridian + + prior = mmm.prior_broadcast + sample_shape = [1, n_draws] + sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} + rf_vars = { + constants.ALPHA_RF: prior.alpha_rf.sample(**sample_kwargs), + constants.EC_RF: prior.ec_rf.sample(**sample_kwargs), + constants.ETA_RF: prior.eta_rf.sample(**sample_kwargs), + constants.SLOPE_RF: prior.slope_rf.sample(**sample_kwargs), + } + beta_grf_dev = tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [mmm.n_geos, mmm.n_rf_channels], + name=constants.BETA_GRF_DEV, + ).sample(**sample_kwargs) + rf_transformed = mmm.adstock_hill_rf( + reach=mmm.rf_tensors.reach_scaled, + frequency=mmm.rf_tensors.frequency, + alpha=rf_vars[constants.ALPHA_RF], + ec=rf_vars[constants.EC_RF], + slope=rf_vars[constants.SLOPE_RF], + ) + + prior_type = mmm.model_spec.paid_media_prior_type + if prior_type == constants.PAID_MEDIA_PRIOR_TYPE_ROI: + roi_rf = prior.roi_rf.sample(**sample_kwargs) + beta_rf_value = self.get_roi_prior_beta_rf_value( + beta_grf_dev=beta_grf_dev, + rf_transformed=rf_transformed, + roi_or_mroi_rf=roi_rf, + **rf_vars, + ) + rf_vars[constants.ROI_RF] = roi_rf + rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic( + beta_rf_value, + name=constants.BETA_RF, + ).sample() + elif prior_type == constants.PAID_MEDIA_PRIOR_TYPE_MROI: + mroi_rf = prior.mroi_rf.sample(**sample_kwargs) + beta_rf_value = self.get_roi_prior_beta_rf_value( + beta_grf_dev=beta_grf_dev, + rf_transformed=rf_transformed, + roi_or_mroi_rf=mroi_rf, + **rf_vars, + ) + rf_vars[constants.MROI_RF] = mroi_rf + rf_vars[constants.BETA_RF] = tfp.distributions.Deterministic( + beta_rf_value, + name=constants.BETA_RF, + ).sample() + else: + rf_vars[constants.BETA_RF] = prior.beta_rf.sample(**sample_kwargs) + + beta_eta_combined = ( + rf_vars[constants.BETA_RF][..., tf.newaxis, :] + + rf_vars[constants.ETA_RF][..., tf.newaxis, :] * beta_grf_dev + ) + beta_grf_value = ( + beta_eta_combined + if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL + else tf.math.exp(beta_eta_combined) + ) + rf_vars[constants.BETA_GRF] = tfp.distributions.Deterministic( + beta_grf_value, name=constants.BETA_GRF + ).sample() + + return rf_vars + + def _sample_organic_media_priors( + self, + n_draws: int, + seed: int | None = None, + ) -> Mapping[str, tf.Tensor]: + """Draws samples from the prior distributions of organic media variables. + + Args: + n_draws: Number of samples drawn from the prior distribution. + seed: Used to set the seed for reproducible results. For more information, + see [PRNGS and seeds] + (https://github.com/tensorflow/probability/blob/main/PRNGS.md). + + Returns: + A mapping of organic media parameter names to a tensor of shape [n_draws, + n_geos, n_organic_media_channels] or [n_draws, n_organic_media_channels] + containing the samples. + """ + mmm = self._meridian + + prior = mmm.prior_broadcast + sample_shape = [1, n_draws] + sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} + organic_media_vars = { + constants.ALPHA_OM: prior.alpha_om.sample(**sample_kwargs), + constants.EC_OM: prior.ec_om.sample(**sample_kwargs), + constants.ETA_OM: prior.eta_om.sample(**sample_kwargs), + constants.SLOPE_OM: prior.slope_om.sample(**sample_kwargs), + } + beta_gom_dev = tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [mmm.n_geos, mmm.n_organic_media_channels], + name=constants.BETA_GOM_DEV, + ).sample(**sample_kwargs) + + organic_media_vars[constants.BETA_OM] = prior.beta_om.sample( + **sample_kwargs + ) + + beta_eta_combined = ( + organic_media_vars[constants.BETA_OM][..., tf.newaxis, :] + + organic_media_vars[constants.ETA_OM][..., tf.newaxis, :] + * beta_gom_dev + ) + beta_gom_value = ( + beta_eta_combined + if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL + else tf.math.exp(beta_eta_combined) + ) + organic_media_vars[constants.BETA_GOM] = tfp.distributions.Deterministic( + beta_gom_value, name=constants.BETA_GOM + ).sample() + + return organic_media_vars + + def _sample_organic_rf_priors( + self, + n_draws: int, + seed: int | None = None, + ) -> Mapping[str, tf.Tensor]: + """Draws samples from the prior distributions of the organic RF variables. + + Args: + n_draws: Number of samples drawn from the prior distribution. + seed: Used to set the seed for reproducible results. For more information, + see [PRNGS and seeds] + (https://github.com/tensorflow/probability/blob/main/PRNGS.md). + + Returns: + A mapping of organic RF parameter names to a tensor of shape [n_draws, + n_geos, n_organic_rf_channels] or [n_draws, n_organic_rf_channels] + containing the samples. + """ + mmm = self._meridian + + prior = mmm.prior_broadcast + sample_shape = [1, n_draws] + sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} + organic_rf_vars = { + constants.ALPHA_ORF: prior.alpha_orf.sample(**sample_kwargs), + constants.EC_ORF: prior.ec_orf.sample(**sample_kwargs), + constants.ETA_ORF: prior.eta_orf.sample(**sample_kwargs), + constants.SLOPE_ORF: prior.slope_orf.sample(**sample_kwargs), + } + beta_gorf_dev = tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [mmm.n_geos, mmm.n_organic_rf_channels], + name=constants.BETA_GORF_DEV, + ).sample(**sample_kwargs) + + organic_rf_vars[constants.BETA_ORF] = prior.beta_orf.sample(**sample_kwargs) + + beta_eta_combined = ( + organic_rf_vars[constants.BETA_ORF][..., tf.newaxis, :] + + organic_rf_vars[constants.ETA_ORF][..., tf.newaxis, :] * beta_gorf_dev + ) + beta_gorf_value = ( + beta_eta_combined + if mmm.media_effects_dist == constants.MEDIA_EFFECTS_NORMAL + else tf.math.exp(beta_eta_combined) + ) + organic_rf_vars[constants.BETA_GORF] = tfp.distributions.Deterministic( + beta_gorf_value, name=constants.BETA_GORF + ).sample() + + return organic_rf_vars + + def _sample_non_media_treatments_priors( + self, + n_draws: int, + seed: int | None = None, + ) -> Mapping[str, tf.Tensor]: + """Draws from the prior distributions of the non-media treatment variables. + + Args: + n_draws: Number of samples drawn from the prior distribution. + seed: Used to set the seed for reproducible results. For more information, + see [PRNGS and seeds] + (https://github.com/tensorflow/probability/blob/main/PRNGS.md). + + Returns: + A mapping of non-media treatment parameter names to a tensor of shape + [n_draws, + n_geos, n_non_media_channels] or [n_draws, n_non_media_channels] + containing the samples. + """ + mmm = self._meridian + + prior = mmm.prior_broadcast + sample_shape = [1, n_draws] + sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} + non_media_treatments_vars = { + constants.GAMMA_N: prior.gamma_n.sample(**sample_kwargs), + constants.XI_N: prior.xi_n.sample(**sample_kwargs), + } + gamma_gn_dev = tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [mmm.n_geos, mmm.n_non_media_channels], + name=constants.GAMMA_GN_DEV, + ).sample(**sample_kwargs) + non_media_treatments_vars[constants.GAMMA_GN] = ( + tfp.distributions.Deterministic( + non_media_treatments_vars[constants.GAMMA_N][..., tf.newaxis, :] + + non_media_treatments_vars[constants.XI_N][..., tf.newaxis, :] + * gamma_gn_dev, + name=constants.GAMMA_GN, + ).sample() + ) + return non_media_treatments_vars + + def _sample_prior( + self, + n_draws: int, + seed: int | None = None, + ) -> Mapping[str, tf.Tensor]: + """Returns a mapping of prior parameters to tensors of the samples.""" + mmm = self._meridian + + # For stateful sampling, the random seed must be set to ensure that any + # random numbers that are generated are deterministic. + if seed is not None: + tf.keras.utils.set_random_seed(1) + + prior = mmm.prior_broadcast + sample_shape = [1, n_draws] + sample_kwargs = {constants.SAMPLE_SHAPE: sample_shape, constants.SEED: seed} + + tau_g_excl_baseline = prior.tau_g_excl_baseline.sample(**sample_kwargs) + base_vars = { + constants.KNOT_VALUES: prior.knot_values.sample(**sample_kwargs), + constants.GAMMA_C: prior.gamma_c.sample(**sample_kwargs), + constants.XI_C: prior.xi_c.sample(**sample_kwargs), + constants.SIGMA: prior.sigma.sample(**sample_kwargs), + constants.TAU_G: _get_tau_g( + tau_g_excl_baseline=tau_g_excl_baseline, + baseline_geo_idx=mmm.baseline_geo_idx, + ).sample(), + } + base_vars[constants.MU_T] = tfp.distributions.Deterministic( + tf.einsum( + "...k,kt->...t", + base_vars[constants.KNOT_VALUES], + tf.convert_to_tensor(mmm.knot_info.weights), + ), + name=constants.MU_T, + ).sample() + + gamma_gc_dev = tfp.distributions.Sample( + tfp.distributions.Normal(0, 1), + [mmm.n_geos, mmm.n_controls], + name=constants.GAMMA_GC_DEV, + ).sample(**sample_kwargs) + base_vars[constants.GAMMA_GC] = tfp.distributions.Deterministic( + base_vars[constants.GAMMA_C][..., tf.newaxis, :] + + base_vars[constants.XI_C][..., tf.newaxis, :] * gamma_gc_dev, + name=constants.GAMMA_GC, + ).sample() + + media_vars = ( + self._sample_media_priors(n_draws, seed) + if mmm.media_tensors.media is not None + else {} + ) + rf_vars = ( + self._sample_rf_priors(n_draws, seed) + if mmm.rf_tensors.reach is not None + else {} + ) + organic_media_vars = ( + self._sample_organic_media_priors(n_draws, seed) + if mmm.organic_media_tensors.organic_media is not None + else {} + ) + organic_rf_vars = ( + self._sample_organic_rf_priors(n_draws, seed) + if mmm.organic_rf_tensors.organic_reach is not None + else {} + ) + non_media_treatments_vars = ( + self._sample_non_media_treatments_priors(n_draws, seed) + if mmm.non_media_treatments_scaled is not None + else {} + ) + + return ( + base_vars + | media_vars + | rf_vars + | organic_media_vars + | organic_rf_vars + | non_media_treatments_vars + ) + + def __call__(self, n_draws: int, seed: int | None = None) -> az.InferenceData: + """Draws samples from prior distributions. + + Returns: + An Arviz `InferenceData` object containing prior samples only. + + Args: + n_draws: Number of samples drawn from the prior distribution. + seed: Used to set the seed for reproducible results. For more information, + see [PRNGS and seeds] + (https://github.com/tensorflow/probability/blob/main/PRNGS.md). + """ + prior_draws = self._sample_prior(n_draws, seed=seed) + # Create Arviz InferenceData for prior draws. + prior_coords = self._meridian.create_inference_data_coords(1, n_draws) + prior_dims = self._meridian.create_inference_data_dims() + return az.convert_to_inference_data( + prior_draws, coords=prior_coords, dims=prior_dims, group=constants.PRIOR + ) diff --git a/meridian/model/prior_sampler_test.py b/meridian/model/prior_sampler_test.py new file mode 100644 index 00000000..fe1d7054 --- /dev/null +++ b/meridian/model/prior_sampler_test.py @@ -0,0 +1,502 @@ +# Copyright 2024 The Meridian Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import arviz as az +from meridian import constants +from meridian.model import model +from meridian.model import model_test_data +from meridian.model import prior_sampler +from meridian.model import spec +import numpy as np +import tensorflow as tf + + +class PriorSamplerTest( + tf.test.TestCase, + parameterized.TestCase, + model_test_data.WithInputDataSamples, +): + + IDS = model_test_data.WithInputDataSamples + + def setUp(self): + super().setUp() + model_test_data.WithInputDataSamples.setUp(self) + + def test_sample_prior_seed_same_seed(self): + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + meridian.sample_prior(n_draws=self._N_DRAWS, seed=1) + meridian2 = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + meridian2.sample_prior(n_draws=self._N_DRAWS, seed=1) + self.assertEqual( + meridian.inference_data.prior, meridian2.inference_data.prior + ) + + def test_sample_prior_different_seed(self): + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + meridian.sample_prior(n_draws=self._N_DRAWS, seed=1) + meridian2 = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + meridian2.sample_prior(n_draws=self._N_DRAWS, seed=2) + + self.assertNotEqual( + meridian.inference_data.prior, meridian2.inference_data.prior + ) + + def test_sample_prior_media_and_rf_returns_correct_shape(self): + self.enter_context( + mock.patch.object( + prior_sampler.PriorSampler, + "_sample_prior", + autospec=True, + return_value=self.test_dist_media_and_rf, + ) + ) + + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + meridian.sample_prior(n_draws=self._N_DRAWS) + knots_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) + control_shape = (1, self._N_DRAWS, self._N_CONTROLS) + media_channel_shape = (1, self._N_DRAWS, self._N_MEDIA_CHANNELS) + rf_channel_shape = (1, self._N_DRAWS, self._N_RF_CHANNELS) + sigma_shape = ( + (1, self._N_DRAWS, self._N_GEOS) + if meridian.unique_sigma_for_each_geo + else (1, self._N_DRAWS, 1) + ) + geo_shape = (1, self._N_DRAWS, self._N_GEOS) + time_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) + geo_control_shape = geo_shape + (self._N_CONTROLS,) + geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) + geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) + + media_parameters = list(constants.MEDIA_PARAMETER_NAMES) + media_parameters.remove(constants.BETA_GM) + rf_parameters = list(constants.RF_PARAMETER_NAMES) + rf_parameters.remove(constants.BETA_GRF) + + prior = meridian.inference_data.prior + shape_to_params = { + knots_shape: [ + getattr(prior, attr) for attr in constants.KNOTS_PARAMETERS + ], + media_channel_shape: [ + getattr(prior, attr) for attr in media_parameters + ], + rf_channel_shape: [getattr(prior, attr) for attr in rf_parameters], + control_shape: [ + getattr(prior, attr) for attr in constants.CONTROL_PARAMETERS + ], + sigma_shape: [ + getattr(prior, attr) for attr in constants.SIGMA_PARAMETERS + ], + geo_shape: [getattr(prior, attr) for attr in constants.GEO_PARAMETERS], + time_shape: [ + getattr(prior, attr) for attr in constants.TIME_PARAMETERS + ], + geo_control_shape: [ + getattr(prior, attr) for attr in constants.GEO_CONTROL_PARAMETERS + ], + geo_media_channel_shape: [ + getattr(prior, attr) for attr in constants.GEO_MEDIA_PARAMETERS + ], + geo_rf_channel_shape: [ + getattr(prior, attr) for attr in constants.GEO_RF_PARAMETERS + ], + } + for shape, params in shape_to_params.items(): + for param in params: + self.assertEqual(param.shape, shape) + + def test_sample_prior_media_only_returns_correct_shape(self): + self.enter_context( + mock.patch.object( + prior_sampler.PriorSampler, + "_sample_prior", + autospec=True, + return_value=self.test_dist_media_only, + ) + ) + + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_media_only, + model_spec=model_spec, + ) + meridian.sample_prior(n_draws=self._N_DRAWS) + knots_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) + control_shape = (1, self._N_DRAWS, self._N_CONTROLS) + media_channel_shape = (1, self._N_DRAWS, self._N_MEDIA_CHANNELS) + sigma_shape = ( + (1, self._N_DRAWS, self._N_GEOS) + if meridian.unique_sigma_for_each_geo + else (1, self._N_DRAWS, 1) + ) + geo_shape = (1, self._N_DRAWS, self._N_GEOS) + time_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) + geo_control_shape = geo_shape + (self._N_CONTROLS,) + geo_media_channel_shape = geo_shape + (self._N_MEDIA_CHANNELS,) + + media_parameters = list(constants.MEDIA_PARAMETER_NAMES) + media_parameters.remove(constants.BETA_GM) + + prior = meridian.inference_data.prior + shape_to_params = { + knots_shape: [ + getattr(prior, attr) for attr in constants.KNOTS_PARAMETERS + ], + media_channel_shape: [ + getattr(prior, attr) for attr in media_parameters + ], + control_shape: [ + getattr(prior, attr) for attr in constants.CONTROL_PARAMETERS + ], + sigma_shape: [ + getattr(prior, attr) for attr in constants.SIGMA_PARAMETERS + ], + geo_shape: [getattr(prior, attr) for attr in constants.GEO_PARAMETERS], + time_shape: [ + getattr(prior, attr) for attr in constants.TIME_PARAMETERS + ], + geo_control_shape: [ + getattr(prior, attr) for attr in constants.GEO_CONTROL_PARAMETERS + ], + geo_media_channel_shape: [ + getattr(prior, attr) for attr in constants.GEO_MEDIA_PARAMETERS + ], + } + for shape, params in shape_to_params.items(): + for param in params: + self.assertEqual(param.shape, shape) + + def test_sample_prior_rf_only_returns_correct_shape(self): + self.enter_context( + mock.patch.object( + prior_sampler.PriorSampler, + "_sample_prior", + autospec=True, + return_value=self.test_dist_rf_only, + ) + ) + + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_rf_only, + model_spec=model_spec, + ) + meridian.sample_prior(n_draws=self._N_DRAWS) + knots_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) + control_shape = (1, self._N_DRAWS, self._N_CONTROLS) + rf_channel_shape = (1, self._N_DRAWS, self._N_RF_CHANNELS) + sigma_shape = ( + (1, self._N_DRAWS, self._N_GEOS) + if meridian.unique_sigma_for_each_geo + else (1, self._N_DRAWS, 1) + ) + geo_shape = (1, self._N_DRAWS, self._N_GEOS) + time_shape = (1, self._N_DRAWS, self._N_TIMES_SHORT) + geo_control_shape = geo_shape + (self._N_CONTROLS,) + geo_rf_channel_shape = geo_shape + (self._N_RF_CHANNELS,) + + prior = meridian.inference_data.prior + shape_to_params = { + knots_shape: [ + getattr(prior, attr) for attr in constants.KNOTS_PARAMETERS + ], + rf_channel_shape: [ + getattr(prior, attr) for attr in constants.RF_PARAMETER_NAMES + ], + control_shape: [ + getattr(prior, attr) for attr in constants.CONTROL_PARAMETERS + ], + sigma_shape: [ + getattr(prior, attr) for attr in constants.SIGMA_PARAMETERS + ], + geo_shape: [getattr(prior, attr) for attr in constants.GEO_PARAMETERS], + time_shape: [ + getattr(prior, attr) for attr in constants.TIME_PARAMETERS + ], + geo_control_shape: [ + getattr(prior, attr) for attr in constants.GEO_CONTROL_PARAMETERS + ], + geo_rf_channel_shape: [ + getattr(prior, attr) for attr in constants.GEO_RF_PARAMETERS + ], + } + for shape, params in shape_to_params.items(): + for param in params: + self.assertEqual(param.shape, shape) + + def test_injected_sample_prior_media_and_rf_returns_correct_shape(self): + """Checks validation passes with correct shapes.""" + self.enter_context( + mock.patch.object( + prior_sampler.PriorSampler, + "_sample_prior", + autospec=True, + return_value=self.test_dist_media_and_rf, + ) + ) + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + meridian.sample_prior(n_draws=self._N_DRAWS) + inference_data = meridian.inference_data + + meridian_with_inference_data = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + inference_data=inference_data, + ) + + self.assertEqual( + meridian_with_inference_data.inference_data, inference_data + ) + + def test_injected_sample_prior_media_only_returns_correct_shape(self): + """Checks validation passes with correct shapes.""" + self.enter_context( + mock.patch.object( + prior_sampler.PriorSampler, + "_sample_prior", + autospec=True, + return_value=self.test_dist_media_only, + ) + ) + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_media_only, + model_spec=model_spec, + ) + meridian.sample_prior(n_draws=self._N_DRAWS) + inference_data = meridian.inference_data + + meridian_with_inference_data = model.Meridian( + input_data=self.short_input_data_with_media_only, + model_spec=model_spec, + inference_data=inference_data, + ) + + self.assertEqual( + meridian_with_inference_data.inference_data, inference_data + ) + + def test_injected_sample_prior_rf_only_returns_correct_shape(self): + """Checks validation passes with correct shapes.""" + self.enter_context( + mock.patch.object( + prior_sampler.PriorSampler, + "_sample_prior", + autospec=True, + return_value=self.test_dist_rf_only, + ) + ) + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_rf_only, + model_spec=model_spec, + ) + meridian.sample_prior(n_draws=self._N_DRAWS) + inference_data = meridian.inference_data + + meridian_with_inference_data = model.Meridian( + input_data=self.short_input_data_with_rf_only, + model_spec=model_spec, + inference_data=inference_data, + ) + + self.assertEqual( + meridian_with_inference_data.inference_data, inference_data + ) + + @parameterized.named_parameters( + dict( + testcase_name="control_variables", + coord=constants.CONTROL_VARIABLE, + mismatched_priors={ + constants.GAMMA_C: (1, IDS._N_DRAWS, IDS._N_CONTROLS + 1), + constants.GAMMA_GC: ( + 1, + IDS._N_DRAWS, + IDS._N_GEOS, + IDS._N_CONTROLS + 1, + ), + constants.XI_C: (1, IDS._N_DRAWS, IDS._N_CONTROLS + 1), + }, + mismatched_coord_size=IDS._N_CONTROLS + 1, + expected_coord_size=IDS._N_CONTROLS, + ), + dict( + testcase_name="geos", + coord=constants.GEO, + mismatched_priors={ + constants.BETA_GM: ( + 1, + IDS._N_DRAWS, + IDS._N_GEOS + 1, + IDS._N_MEDIA_CHANNELS, + ), + constants.BETA_GRF: ( + 1, + IDS._N_DRAWS, + IDS._N_GEOS + 1, + IDS._N_RF_CHANNELS, + ), + constants.GAMMA_GC: ( + 1, + IDS._N_DRAWS, + IDS._N_GEOS + 1, + IDS._N_CONTROLS, + ), + constants.TAU_G: (1, IDS._N_DRAWS, IDS._N_GEOS + 1), + }, + mismatched_coord_size=IDS._N_GEOS + 1, + expected_coord_size=IDS._N_GEOS, + ), + dict( + testcase_name="knots", + coord=constants.KNOTS, + mismatched_priors={ + constants.KNOT_VALUES: ( + 1, + IDS._N_DRAWS, + IDS._N_TIMES_SHORT + 1, + ), + }, + mismatched_coord_size=IDS._N_TIMES_SHORT + 1, + expected_coord_size=IDS._N_TIMES_SHORT, + ), + dict( + testcase_name="times", + coord=constants.TIME, + mismatched_priors={ + constants.MU_T: (1, IDS._N_DRAWS, IDS._N_TIMES_SHORT + 1), + }, + mismatched_coord_size=IDS._N_TIMES_SHORT + 1, + expected_coord_size=IDS._N_TIMES_SHORT, + ), + dict( + testcase_name="sigma_dims", + coord=constants.SIGMA_DIM, + mismatched_priors={ + constants.SIGMA: (1, IDS._N_DRAWS, IDS._N_GEOS_NATIONAL + 1), + }, + mismatched_coord_size=IDS._N_GEOS_NATIONAL + 1, + expected_coord_size=IDS._N_GEOS_NATIONAL, + ), + dict( + testcase_name="media_channels", + coord=constants.MEDIA_CHANNEL, + mismatched_priors={ + constants.ALPHA_M: (1, IDS._N_DRAWS, IDS._N_MEDIA_CHANNELS + 1), + constants.BETA_GM: ( + 1, + IDS._N_DRAWS, + IDS._N_GEOS, + IDS._N_MEDIA_CHANNELS + 1, + ), + constants.BETA_M: (1, IDS._N_DRAWS, IDS._N_MEDIA_CHANNELS + 1), + constants.EC_M: (1, IDS._N_DRAWS, IDS._N_MEDIA_CHANNELS + 1), + constants.ETA_M: (1, IDS._N_DRAWS, IDS._N_MEDIA_CHANNELS + 1), + constants.ROI_M: (1, IDS._N_DRAWS, IDS._N_MEDIA_CHANNELS + 1), + constants.SLOPE_M: (1, IDS._N_DRAWS, IDS._N_MEDIA_CHANNELS + 1), + }, + mismatched_coord_size=IDS._N_MEDIA_CHANNELS + 1, + expected_coord_size=IDS._N_MEDIA_CHANNELS, + ), + dict( + testcase_name="rf_channels", + coord=constants.RF_CHANNEL, + mismatched_priors={ + constants.ALPHA_RF: (1, IDS._N_DRAWS, IDS._N_RF_CHANNELS + 1), + constants.BETA_GRF: ( + 1, + IDS._N_DRAWS, + IDS._N_GEOS, + IDS._N_RF_CHANNELS + 1, + ), + constants.BETA_RF: (1, IDS._N_DRAWS, IDS._N_RF_CHANNELS + 1), + constants.EC_RF: (1, IDS._N_DRAWS, IDS._N_RF_CHANNELS + 1), + constants.ETA_RF: (1, IDS._N_DRAWS, IDS._N_RF_CHANNELS + 1), + constants.ROI_RF: (1, IDS._N_DRAWS, IDS._N_RF_CHANNELS + 1), + constants.SLOPE_RF: (1, IDS._N_DRAWS, IDS._N_RF_CHANNELS + 1), + }, + mismatched_coord_size=IDS._N_RF_CHANNELS + 1, + expected_coord_size=IDS._N_RF_CHANNELS, + ), + ) + def test_validate_injected_inference_data_prior_incorrect_coordinates( + self, coord, mismatched_priors, mismatched_coord_size, expected_coord_size + ): + """Checks prior validation fails with incorrect coordinates.""" + model_spec = spec.ModelSpec() + meridian = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + ) + prior_samples = meridian.prior_sampler._sample_prior(self._N_DRAWS) + prior_coords = meridian.create_inference_data_coords(1, self._N_DRAWS) + prior_dims = meridian.create_inference_data_dims() + + prior_samples = dict(prior_samples) + for param in mismatched_priors: + prior_samples[param] = tf.zeros(mismatched_priors[param]) + prior_coords = dict(prior_coords) + prior_coords[coord] = np.arange(mismatched_coord_size) + + inference_data = az.convert_to_inference_data( + prior_samples, + coords=prior_coords, + dims=prior_dims, + group=constants.PRIOR, + ) + + with self.assertRaisesRegex( + ValueError, + f"Injected inference data {constants.PRIOR} has incorrect coordinate" + f" '{coord}': expected {expected_coord_size}, got" + f" {mismatched_coord_size}", + ): + _ = model.Meridian( + input_data=self.short_input_data_with_media_and_rf, + model_spec=model_spec, + inference_data=inference_data, + ) + + +if __name__ == "__main__": + absltest.main()