Skip to content

Commit

Permalink
Remove status_quo_name input from Adapter (#3431)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3431

Since experiment is always available on the `Adapter` (as of D70103442), we can extract the status quo name directly from the experiment. For now, the `status_quo_features` input (which is used through input constructors) remains. I'd like to transition that off to the `experiment` as well.

Updated the storage code to remove deprecated kwargs while loading the GNodes & GRs from the db.

This is a step in the direction of eliminating redundant inputs within the modeling layer and bringing us closer to relying on the experiment as the source of truth for the inputs that are typically extracted from the experiment.

Reviewed By: ItsMrLin

Differential Revision: D70292112

fbshipit-source-id: 333ec162cf9322bff17dbacbe363539e45400df0
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 28, 2025
1 parent ff1dd62 commit 036d2bc
Show file tree
Hide file tree
Showing 18 changed files with 172 additions and 112 deletions.
1 change: 0 additions & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from collections.abc import Hashable, Iterable, Mapping
from datetime import datetime
from functools import partial, reduce

from typing import Any, cast

import ax.core.observation as observation
Expand Down
2 changes: 0 additions & 2 deletions ax/generation_strategy/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def test_sobol_MBM_strategy(self) -> None:
{
"optimization_config": None,
"status_quo_features": None,
"status_quo_name": None,
"transform_configs": None,
"transforms": Cont_X_trans,
"fit_out_of_design": False,
Expand Down Expand Up @@ -1543,7 +1542,6 @@ def test_gs_with_generation_nodes(self) -> None:
{
"optimization_config": None,
"status_quo_features": None,
"status_quo_name": None,
"transform_configs": None,
"transforms": Cont_X_trans,
"fit_out_of_design": False,
Expand Down
52 changes: 16 additions & 36 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __init__(
data: Data | None = None,
transforms: Sequence[type[Transform]] | None = None,
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
expand_model_space: bool = True,
Expand Down Expand Up @@ -138,11 +137,8 @@ def __init__(
the reverse order.
transform_configs: A dictionary from transform name to the
transform config dictionary.
status_quo_name: Name of the status quo arm. Can only be used if
Data has a single set of ObservationFeatures corresponding to
that arm.
status_quo_features: ObservationFeatures to use as status quo.
Either this or status_quo_name should be specified, not both.
status_quo_features: ObservationFeatures to use as status quo. If None,
the status quo will be extracted from the experiment, if it exists.
optimization_config: An optional ``OptimizationConfig`` defining how to
optimize the model. Defaults to `experiment.optimization_config`.
expand_model_space: If True, expand range parameter bounds in model
Expand Down Expand Up @@ -231,9 +227,7 @@ def __init__(
# Set model status quo.
# NOTE: training data must be set before setting the status quo.
self._set_status_quo(
experiment=experiment,
status_quo_name=status_quo_name,
status_quo_features=status_quo_features,
experiment=experiment, status_quo_features=status_quo_features
)

# Save model, apply terminal transform, and fit.
Expand Down Expand Up @@ -453,38 +447,25 @@ def _set_model_space(self, observations: list[Observation]) -> None:

def _set_status_quo(
self,
experiment: Experiment | None,
status_quo_name: str | None,
experiment: Experiment,
status_quo_features: ObservationFeatures | None,
) -> None:
"""Set model status quo by matching status_quo_name or status_quo_features.
"""Set model status quo by matching status_quo_features or
extracting from the experiment.
First checks for status quo in inputs status_quo_name and
status_quo_features. If neither of these is provided, checks the
experiment for a status quo. If that is set, it is handled by name in
the same way as input status_quo_name.
First checks for status quo in inputs status_quo_features. If not provided,
checks the experiment for a status quo. If either one exists, looks through
the training data for an observation with the same name or features.
Args:
experiment: Experiment that will be checked for status quo.
status_quo_name: Name of status quo arm.
status_quo_features: Features for status quo.
"""
self._status_quo: Observation | None = None
sq_obs = None

if (
status_quo_name is None
and status_quo_features is None
and experiment is not None
and experiment.status_quo is not None
):
if status_quo_features is None and experiment.status_quo is not None:
status_quo_name = experiment.status_quo.name

if status_quo_name is not None:
if status_quo_features is not None:
raise UserInputError(
"Specify either status_quo_name or status_quo_features, not both."
)
sq_obs = [
obs for obs in self._training_data if obs.arm_name == status_quo_name
]
Expand All @@ -495,8 +476,11 @@ def _set_status_quo(
if (obs.features.parameters == status_quo_features.parameters)
and (obs.features.trial_index == status_quo_features.trial_index)
]
status_quo_name = sq_obs[0].arm_name if sq_obs else None
else:
status_quo_name = None

# if status_quo_name or status_quo_features is used for matching status quo
# If a status quo was found in the training data.
if sq_obs is not None:
if len(sq_obs) == 0:
logger.warning(f"Status quo {status_quo_name} not present in data")
Expand All @@ -505,7 +489,7 @@ def _set_status_quo(
# observation features) should be consistent even if we have multiple
# observations of the status quo.
# This is useful for getting status_quo_data_by_trial
self._status_quo_name = sq_obs[0].arm_name
self._status_quo_name = status_quo_name
if len(sq_obs) > 1 and self._fit_only_completed_map_metrics:
# it is expected to have multiple obserations for map data
logger.warning(
Expand All @@ -522,11 +506,7 @@ def status_quo_data_by_trial(self) -> dict[int, ObservationData] | None:
"""A map of trial index to the status quo observation data of each trial"""
return _get_status_quo_by_trial(
observations=self._training_data,
status_quo_name=(
self._status_quo_name
if self.status_quo is None
else self.status_quo.arm_name
),
status_quo_name=self.status_quo_name,
status_quo_features=(
None if self.status_quo is None else self.status_quo.features
),
Expand Down
2 changes: 0 additions & 2 deletions ax/modelbridge/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
data: Data | None = None,
transforms: Sequence[type[Transform]] | None = None,
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
expand_model_space: bool = True,
Expand All @@ -75,7 +74,6 @@ def __init__(
data=data,
transforms=transforms,
transform_configs=transform_configs,
status_quo_name=status_quo_name,
status_quo_features=status_quo_features,
optimization_config=optimization_config,
expand_model_space=expand_model_space,
Expand Down
2 changes: 0 additions & 2 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(
data: Data | None = None,
transforms: Sequence[type[Transform]] | None = None,
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
fit_out_of_design: bool = False,
Expand Down Expand Up @@ -102,7 +101,6 @@ def __init__(
transforms=transforms,
transform_configs=transform_configs,
torch_device=torch_device,
status_quo_name=status_quo_name,
status_quo_features=status_quo_features,
optimization_config=optimization_config,
fit_out_of_design=fit_out_of_design,
Expand Down
3 changes: 0 additions & 3 deletions ax/modelbridge/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
data: Data | None = None,
transforms: Sequence[type[Transform]] | None = None,
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
fit_out_of_design: bool = False,
Expand All @@ -59,7 +58,6 @@ def __init__(
experiment=experiment,
data=data,
transform_configs=transform_configs,
status_quo_name=status_quo_name,
status_quo_features=status_quo_features,
optimization_config=optimization_config,
expand_model_space=False,
Expand Down Expand Up @@ -132,7 +130,6 @@ def _cross_validate(
def _set_status_quo(
self,
experiment: Experiment | None,
status_quo_name: str | None,
status_quo_features: ObservationFeatures | None,
) -> None:
pass
37 changes: 13 additions & 24 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,14 @@ def test_ood_gen(self, _) -> None:
@mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True)
def test_SetStatusQuo(self, _, __) -> None:
exp = get_experiment_for_value()
adapter = Adapter(experiment=exp, model=Generator(), status_quo_name="1_1")
# Specify through the experiment.
exp.status_quo = Arm(parameters={"x": 3.0}, name="1_1")
adapter = Adapter(experiment=exp, model=Generator())
self.assertEqual(adapter.status_quo, get_observation1())
self.assertEqual(adapter.status_quo_name, "1_1")

# Alternatively, we can specify by features
exp = get_experiment_for_value()
adapter = Adapter(
experiment=exp,
model=Generator(),
Expand All @@ -494,20 +497,10 @@ def test_SetStatusQuo(self, _, __) -> None:
self.assertEqual(adapter.status_quo, get_observation1())
self.assertEqual(adapter.status_quo_name, "1_1")

# Errors if features and name both specified
with self.assertRaisesRegex(
UserInputError,
"Specify either status_quo_name or status_quo_features, not both.",
):
adapter = Adapter(
experiment=exp,
model=Generator(),
status_quo_features=get_observation1().features,
status_quo_name="1_1",
)

# Left as None if features or name don't exist
adapter = Adapter(experiment=exp, model=Generator(), status_quo_name="1_0")
# Left as None if features or name don't exist in the data.
exp = get_experiment_for_value()
exp.status_quo = Arm(parameters={"x": 3.0}, name="1_0")
adapter = Adapter(experiment=exp, model=Generator())
self.assertIsNone(adapter.status_quo)
self.assertIsNone(adapter.status_quo_name)
adapter = Adapter(
Expand Down Expand Up @@ -603,11 +596,7 @@ def test_transform_observations(self) -> None:
def test_SetTrainingDataDupFeatures(self, _: Mock, __: Mock) -> None:
# Throws an error if repeated features in observations.
with self.assertRaises(ValueError):
Adapter(
experiment=get_experiment_for_value(),
model=Generator(),
status_quo_name="1_1",
)
Adapter(experiment=get_experiment_for_value(), model=Generator())

def test_UnwrapObservationData(self) -> None:
observation_data = [get_observation1().data, get_observation2().data]
Expand Down Expand Up @@ -952,11 +941,12 @@ def test_fit_only_completed_map_metrics(
) -> None:
# _prepare_observations is called in the constructor and itself calls
# observations_from_data with map_keys_as_parameters=True
experiment = get_experiment_for_value()
experiment.status_quo = Arm(name="1_1", parameters={"x": 3.0})
Adapter(
experiment=get_experiment_for_value(),
experiment=experiment,
model=Generator(),
data=MapData(),
status_quo_name="1_1",
fit_only_completed_map_metrics=False,
)
kwargs = mock_observations_from_data.call_args.kwargs
Expand All @@ -968,9 +958,8 @@ def test_fit_only_completed_map_metrics(
# calling without map data calls observations_from_data with
# map_keys_as_parameters=False even if fit_only_completed_map_metrics is False
Adapter(
experiment=get_experiment_for_value(),
experiment=experiment,
model=Generator(),
status_quo_name="1_1",
fit_only_completed_map_metrics=False,
)
kwargs = mock_observations_from_data.call_args.kwargs
Expand Down
3 changes: 0 additions & 3 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def test_enum_sobol_legacy_GPEI(self) -> None:
{
"transform_configs": None,
"torch_device": None,
"status_quo_name": None,
"status_quo_features": None,
"optimization_config": None,
"transforms": Cont_X_trans + Y_trans,
Expand Down Expand Up @@ -262,7 +261,6 @@ def test_view_defaults(self) -> None:
"optimization_config": None,
"transforms": Cont_X_trans,
"transform_configs": None,
"status_quo_name": None,
"status_quo_features": None,
"fit_out_of_design": False,
"fit_abandoned": False,
Expand All @@ -285,7 +283,6 @@ def test_view_defaults(self) -> None:
"experiment",
"data",
"transform_configs",
"status_quo_name",
"status_quo_features",
"expand_model_space",
"fit_out_of_design",
Expand Down
7 changes: 5 additions & 2 deletions ax/modelbridge/tests/test_transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from unittest import mock

import numpy as np
from ax.core.arm import Arm
from ax.core.experiment import Experiment
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.parameter import ParameterType, RangeParameter
Expand Down Expand Up @@ -68,10 +69,12 @@ def test_derelativize_optimization_config_with_raw_status_quo(self, _) -> None:
]
)
modelbridge = Adapter(
experiment=Experiment(search_space=dummy_search_space),
experiment=Experiment(
search_space=dummy_search_space,
status_quo=Arm(parameters={"x": 1.0, "y": 1.0}, name="1_1"),
),
model=Generator(),
optimization_config=optimization_config,
status_quo_name="1_1",
)
new_opt_config = derelativize_optimization_config_with_raw_status_quo(
optimization_config=optimization_config,
Expand Down
2 changes: 0 additions & 2 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def __init__(
data: Data | None = None,
transforms: Sequence[type[Transform]] | None = None,
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
expand_model_space: bool = True,
Expand Down Expand Up @@ -148,7 +147,6 @@ def __init__(
model=model,
transforms=transforms,
transform_configs=transform_configs,
status_quo_name=status_quo_name,
status_quo_features=status_quo_features,
optimization_config=optimization_config,
expand_model_space=expand_model_space,
Expand Down
19 changes: 13 additions & 6 deletions ax/modelbridge/transforms/tests/test_derelativize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unittest.mock import Mock, patch

import numpy as np
from ax.core.arm import Arm
from ax.core.experiment import Experiment
from ax.core.metric import Metric
from ax.core.objective import Objective
Expand Down Expand Up @@ -119,9 +120,11 @@ def _test_DerelativizeTransform(
]
)
g = Adapter(
experiment=Experiment(search_space=search_space),
experiment=Experiment(
search_space=search_space,
status_quo=Arm(parameters={"x": 1.0, "y": 1.0}, name="1_1"),
),
model=Generator(),
status_quo_name="1_1",
)

# Test with no relative constraints
Expand Down Expand Up @@ -199,9 +202,11 @@ def _test_DerelativizeTransform(
# Test with relative constraint, out-of-design status quo
mock_predict.side_effect = RuntimeError()
g = Adapter(
experiment=Experiment(search_space=search_space),
experiment=Experiment(
search_space=search_space,
status_quo=Arm(parameters={"x": 1.0, "y": 1.0}, name="1_2"),
),
model=Generator(),
status_quo_name="1_2",
)
oc = OptimizationConfig(
objective=objective,
Expand Down Expand Up @@ -246,9 +251,11 @@ def _test_DerelativizeTransform(

# Raises error if predict fails with in-design status quo
g = Adapter(
experiment=Experiment(search_space=search_space),
experiment=Experiment(
search_space=search_space,
status_quo=Arm(parameters={"x": 1.0, "y": 1.0}, name="1_1"),
),
model=Generator(),
status_quo_name="1_1",
)
oc = OptimizationConfig(
objective=objective,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def _refresh_modelbridge(self) -> None:
model=Generator(),
experiment=self.exp,
data=self.exp.lookup_data(),
status_quo_name="status_quo",
)

def test_modelbridge_without_status_quo_name(self) -> None:
Expand Down
Loading

0 comments on commit 036d2bc

Please sign in to comment.