Skip to content

Commit

Permalink
Allow for adhoc computation of multiple CVs at one time (facebook#3427)
Browse files Browse the repository at this point in the history
Summary:

This is the second ux improvement diff for adhoc cross validation computation. This allows for mulitple metrics to be computed at one time for a single adapter.

It does not currently tile or drop down the metic names. This is a nice improvement we'd like to make, but it wasn't immediately clear how to go about this


Also fixed the CI bars per bernie's catch

Differential Revision: D70200056
  • Loading branch information
mgarrard authored and facebook-github-bot committed Mar 1, 2025
1 parent 107393d commit b243feb
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 26 deletions.
81 changes: 56 additions & 25 deletions ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import select_metric
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.exceptions.core import UserInputError
from ax.generation_strategy.generation_strategy import GenerationStrategy
Expand Down Expand Up @@ -106,20 +107,22 @@ def compute(
def _compute_adhoc(
self,
adapter: Adapter,
metric_name: str,
data: Data,
experiment: Experiment | None = None,
folds: int = -1,
untransform: bool = True,
) -> PlotlyAnalysisCard:
metric_name_mapping: dict[str, str] | None = None,
) -> list[PlotlyAnalysisCard]:
"""
Helper method to expose adhoc cross validation plotting. This overrides the
default assumption that the adapter from the generation strategy should be
used. Only for advanced users in a notebook setting.
Args:
adapter: The adapter that will be assessed during cross validation.
metric_name: The name of the metric to plot. Must be provided for adhoc
plotting.
data: The Data that was used to fit the model. Will be used in this
adhoc cross validation call to compute the cross validation for all
metrics in the Data object.
experiment: Experiment associated with this analysis. Used to determine
the priority of the analysis based on the metric importance in the
optimization config.
Expand All @@ -136,17 +139,34 @@ def _compute_adhoc(
regions where outliers have been removed, we have found it to better
reflect the how good the model used for candidate generation actually
is.
metric_name_mapping: Optional mapping from default metric names to more
readable metric names.
"""
return self._construct_plot(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
# trial_index argument is used with generation strategy since this is an
# adhoc plot call, this will be None.
trial_index=None,
experiment=experiment,
)
plots = []
# Get all unique metric names in the data object, CVs will be computed for
# all metrics in the data object
metric_names = list(data.df["metric_name"].unique())
for metric_name in metric_names:
# replace metric name with human readable name if mapping is provided
refined_metric_name = (
metric_name_mapping.get(metric_name, metric_name)
if metric_name_mapping
else metric_name
)
plots.append(
self._construct_plot(
adapter=adapter,
metric_name=metric_name,
folds=folds,
untransform=untransform,
# trial_index argument is used with generation strategy since this
# is an adhoc plot call, this will be None.
trial_index=None,
experiment=experiment,
refined_metric_name=refined_metric_name,
)
)
return plots

def _construct_plot(
self,
Expand All @@ -156,6 +176,7 @@ def _construct_plot(
untransform: bool,
trial_index: int | None,
experiment: Experiment | None = None,
refined_metric_name: str | None = None,
) -> PlotlyAnalysisCard:
"""
Args:
Expand All @@ -181,6 +202,8 @@ def _construct_plot(
experiment: Optional Experiment associated with this analysis. Used to set
the priority of the analysis based on the metric importance in the
optimization config.
refined_metric_name: Optional replacement for raw metric name, useful for
imporving readability of the plot title.
"""
df = _prepare_data(
adapter=adapter,
Expand Down Expand Up @@ -209,8 +232,11 @@ def _construct_plot(
else:
nudge = 0

# If a human readable metric name is provided, use it in the title
metric_title = refined_metric_name if refined_metric_name else metric_name

return self._create_plotly_analysis_card(
title=f"Cross Validation for {metric_name}",
title=f"Cross Validation for {metric_title}",
subtitle=f"Out-of-sample predictions using {k_folds_substring} CV",
level=AnalysisCardLevel.LOW.value + nudge,
df=df,
Expand Down Expand Up @@ -263,15 +289,20 @@ def _prepare_data(
"arm_name": observed.arm_name,
"observed": observed.data.means[observed_i],
"predicted": predicted.means[predicted_i],
# Take the square root of the SEM to get the standard deviation
"observed_sem": observed.data.covariance[observed_i][observed_i] ** 0.5,
"predicted_sem": predicted.covariance[predicted_i][predicted_i] ** 0.5,
# Compute the 95% confidence intervals for plotting purposes
"observed_95_ci": observed.data.covariance[observed_i][observed_i]
** 0.5
* 1.96,
"predicted_95_ci": predicted.covariance[predicted_i][predicted_i] ** 0.5
* 1.96,
}
records.append(record)
return pd.DataFrame.from_records(records)


def _prepare_plot(df: pd.DataFrame) -> go.Figure:
def _prepare_plot(
df: pd.DataFrame,
) -> go.Figure:
# Create a scatter plot using Plotly Graph Objects for more control
fig = go.Figure()
fig.add_trace(
Expand All @@ -284,13 +315,13 @@ def _prepare_plot(df: pd.DataFrame) -> go.Figure:
},
error_x={
"type": "data",
"array": df["observed_sem"],
"array": df["observed_95_ci"],
"visible": True,
"color": "rgba(0, 0, 255, 0.2)", # partially transparent blue
},
error_y={
"type": "data",
"array": df["predicted_sem"],
"array": df["predicted_95_ci"],
"visible": True,
"color": "rgba(0, 0, 255, 0.2)", # partially transparent blue
},
Expand All @@ -313,15 +344,15 @@ def _prepare_plot(df: pd.DataFrame) -> go.Figure:
# this line.
lower_bound = (
min(
(df["observed"] - df["observed_sem"].fillna(0)).min(),
(df["predicted"] - df["predicted_sem"].fillna(0)).min(),
(df["observed"] - df["observed_95_ci"].fillna(0)).min(),
(df["predicted"] - df["predicted_95_ci"].fillna(0)).min(),
)
* 0.999 # tight autozoom
)
upper_bound = (
max(
(df["observed"] + df["observed_sem"].fillna(0)).max(),
(df["predicted"] + df["predicted_sem"].fillna(0)).max(),
(df["observed"] + df["observed_95_ci"].fillna(0)).max(),
(df["predicted"] + df["predicted_95_ci"].fillna(0)).max(),
)
* 1.001 # tight autozoom
)
Expand Down
19 changes: 18 additions & 1 deletion ax/analysis/plotly/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ax.analysis.plotly.cross_validation import CrossValidationPlot
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.modelbridge.registry import Generators
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.common.testutils import TestCase
from ax.utils.testing.mock import mock_botorch_optimize
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_compute(self) -> None:
self.assertEqual(card.category, AnalysisCardCategory.INSIGHT)
self.assertEqual(
{*card.df.columns},
{"arm_name", "observed", "observed_sem", "predicted", "predicted_sem"},
{"arm_name", "observed", "observed_95_ci", "predicted", "predicted_95_ci"},
)
self.assertIsNotNone(card.blob)
self.assertEqual(card.blob_annotation, "plotly")
Expand Down Expand Up @@ -98,3 +99,19 @@ def test_it_can_specify_trial_index_correctly(self) -> None:
arm_name,
card.df["arm_name"].unique(),
)

@mock_botorch_optimize
def test_compute_adhoc(self) -> None:
metric_mapping = {"bar": "spunky"}
data = self.client.experiment.lookup_data()
adapter = Generators.BOTORCH_MODULAR(
experiment=self.client.experiment, data=data
)
analysis = CrossValidationPlot()._compute_adhoc(
adapter=adapter, data=data, metric_name_mapping=metric_mapping
)
self.assertEqual(len(analysis), 1)
card = analysis[0]
self.assertEqual(card.name, "CrossValidationPlot")
# validate that the metric name replacement occured
self.assertEqual(card.title, "Cross Validation for spunky")

0 comments on commit b243feb

Please sign in to comment.