diff --git a/meridian/analysis/analyzer.py b/meridian/analysis/analyzer.py index 03f7749e..50a9772d 100644 --- a/meridian/analysis/analyzer.py +++ b/meridian/analysis/analyzer.py @@ -2065,6 +2065,7 @@ def marginal_roi( new_data: DataTensors | None = None, selected_geos: Sequence[str] | None = None, selected_times: Sequence[str] | None = None, + media_selected_times: Sequence[str] | Sequence[bool] | None = None, aggregate_geos: bool = True, aggregate_times: bool = True, by_reach: bool = True, @@ -2103,6 +2104,9 @@ def marginal_roi( all geos are included. selected_times: Optional. Contains a subset of times to include. By default, all time periods are included. + media_selected_times: Optional list containing either a subset of dates to + include or booleans with length equal to the number of time periods in + `new_data`, if provided. aggregate_geos: If `True`, the expected revenue is summed over all of the regions. aggregate_times: If `True`, the expected revenue is summed over all of @@ -2136,6 +2140,7 @@ def marginal_roi( "use_kpi": use_kpi, "batch_size": batch_size, "include_non_paid_channels": False, + "media_selected_times": media_selected_times, } # TODO: Switch from PerformanceTensors to DataTensors. if new_data is None: @@ -2211,6 +2216,7 @@ def roi( new_data: DataTensors | None = None, selected_geos: Sequence[str] | None = None, selected_times: Sequence[str] | None = None, + media_selected_times: Sequence[str] | Sequence[bool] | None = None, aggregate_geos: bool = True, aggregate_times: bool = True, use_kpi: bool = False, @@ -2252,6 +2258,9 @@ def roi( default, all geos are included. selected_times: Optional list containing a subset of times to include. By default, all time periods are included. + media_selected_times: Optional list containing either a subset of dates to + include or booleans with length equal to the number of time periods in + `new_data`, if provided. aggregate_geos: Boolean. If `True`, the expected revenue is summed over all of the regions. aggregate_times: Boolean. If `True`, the expected revenue is summed over @@ -2282,6 +2291,7 @@ def roi( "use_kpi": use_kpi, "batch_size": batch_size, "include_non_paid_channels": False, + "media_selected_times": media_selected_times, } # TODO: Switch from PerformanceTensors to DataTensors. if new_data is None: @@ -2328,6 +2338,7 @@ def cpik( new_data: DataTensors | None = None, selected_geos: Sequence[str] | None = None, selected_times: Sequence[str] | None = None, + media_selected_times: Sequence[str] | Sequence[bool] | None = None, aggregate_geos: bool = True, aggregate_times: bool = True, batch_size: int = constants.DEFAULT_BATCH_SIZE, @@ -2363,6 +2374,9 @@ def cpik( default, all geos are included. selected_times: Optional list containing a subset of times to include. By default, all time periods are included. + media_selected_times: Optional list containing either a subset of dates to + include or booleans with length equal to the number of time periods in + `new_data`, if provided. aggregate_geos: Boolean. If `True`, the expected KPI is summed over all of the regions. aggregate_times: Boolean. If `True`, the expected KPI is summed over all @@ -2384,6 +2398,7 @@ def cpik( new_data=new_data, selected_geos=selected_geos, selected_times=selected_times, + media_selected_times=media_selected_times, aggregate_geos=aggregate_geos, aggregate_times=aggregate_times, batch_size=batch_size, diff --git a/meridian/analysis/analyzer_test.py b/meridian/analysis/analyzer_test.py index 9dc1e557..eed68edf 100644 --- a/meridian/analysis/analyzer_test.py +++ b/meridian/analysis/analyzer_test.py @@ -663,6 +663,14 @@ def test_incremental_outcome_media_and_rf_new_params(self): atol=1e-3, ) + def test_marginal_roi_media_selected_times_all_false_returns_zero( + self, + ): + no_media_times = self.analyzer_media_and_rf.marginal_roi( + media_selected_times=[False] * _N_MEDIA_TIMES + ) + self.assertAllEqual(no_media_times, tf.zeros_like(no_media_times)) + @parameterized.product( use_posterior=[False, True], aggregate_geos=[False, True], @@ -806,6 +814,14 @@ def test_roi_wrong_rf_spend_raises_exception(self): ), ) + def test_roi_media_selected_times_all_false_returns_zero( + self, + ): + no_media_times = self.analyzer_media_and_rf.roi( + media_selected_times=[False] * _N_MEDIA_TIMES + ) + self.assertAllEqual(no_media_times, tf.zeros_like(no_media_times)) + @parameterized.product( use_posterior=[False, True], aggregate_geos=[False, True], @@ -852,6 +868,14 @@ def test_roi_media_and_rf_default_returns_correct_value(self): ) self.assertAllClose(expected_roi, roi) + def test_cpik_media_selected_times_all_false_returns_zero( + self, + ): + no_media_times = self.analyzer_media_and_rf.cpik( + media_selected_times=[False] * _N_MEDIA_TIMES + ) + self.assertAllEqual(no_media_times, tf.zeros_like(no_media_times)) + @parameterized.product( use_posterior=[False, True], aggregate_geos=[False, True], diff --git a/meridian/analysis/optimizer_test.py b/meridian/analysis/optimizer_test.py index 9356af65..725a5900 100644 --- a/meridian/analysis/optimizer_test.py +++ b/meridian/analysis/optimizer_test.py @@ -3326,6 +3326,7 @@ def test_incremental_outcome_called_correct_optimize( # incremental_outcome() with the following arguments. selected_geos=None, selected_times=None, + media_selected_times=None, aggregate_geos=True, aggregate_times=True, inverse_transform_outcome=True,