Skip to content

Commit

Permalink
Simplify Adapter.status_quo_data_by_trial (#3435)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3435

This diff simplifies the implementation of `Adapter.status_quo_data_by_trial`. The helper was prioritizing `status_quo_name`, which always exists if the status quo is set. So, the code block for extracting it based on features was redundant.

It'd be great to extract this directly from the experiment. The main challange is that we currently rely on `Adapter._training_data`, which excludes out of design observations, and that Adapter can have `status_quo_features` that is different than SQ of the experiment (though this is not a critical issue -- typically only trial index differs).

Q: Do we care about excluding out-of-design observations in a world where we expand the model space by default?

Reviewed By: ItsMrLin

Differential Revision: D70336263

fbshipit-source-id: 9836658b816ea46f687e2b26f7695c7db4612b32
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 28, 2025
1 parent 036d2bc commit d488141
Showing 1 changed file with 14 additions and 53 deletions.
67 changes: 14 additions & 53 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-strict

import json
import time
from collections import OrderedDict
from collections.abc import Mapping, MutableMapping, Sequence
Expand Down Expand Up @@ -503,14 +502,20 @@ def _set_status_quo(

@property
def status_quo_data_by_trial(self) -> dict[int, ObservationData] | None:
"""A map of trial index to the status quo observation data of each trial"""
return _get_status_quo_by_trial(
observations=self._training_data,
status_quo_name=self.status_quo_name,
status_quo_features=(
None if self.status_quo is None else self.status_quo.features
),
)
"""A map of trial index to the status quo observation data of each trial.
If status quo does not exist, return None.
"""
# Status quo name will be set if status quo exists. We can just filter by name.
if self.status_quo_name is None:
return None
# Identify status quo data by arm name.
return {
# NOTE: casting to int here, in case the index is a numpy integer.
int(none_throws(obs.features.trial_index)): obs.data
for obs in self._training_data
if obs.arm_name == self.status_quo_name
}

@property
def status_quo(self) -> Observation | None:
Expand Down Expand Up @@ -1208,47 +1213,3 @@ def clamp_observation_features(
)
obsf.parameters[p.name] = p.upper
return observation_features


def _get_status_quo_by_trial(
observations: list[Observation],
status_quo_name: str | None = None,
status_quo_features: ObservationFeatures | None = None,
) -> dict[int, ObservationData] | None:
r"""
Given a status quo observation, return a dictionary of trial index to
the status quo observation data of each trial.
When either `status_quo_name` or `status_quo_features` exists, return the dict;
when both exist, use `status_quo_name`;
when neither exists, return None.
Args:
observations: List of observations.
status_quo_name: Name of the status quo.
status_quo_features: ObservationFeatures for the status quo.
Returns:
A map from trial index to status quo observation data, or None
"""
trial_idx_to_sq_data = None
if status_quo_name is not None:
# identify status quo by arm name
trial_idx_to_sq_data = {
int(none_throws(obs.features.trial_index)): obs.data
for obs in observations
if obs.arm_name == status_quo_name
}
elif status_quo_features is not None:
# identify status quo by (untransformed) feature
status_quo_signature = json.dumps(
status_quo_features.parameters, sort_keys=True
)
trial_idx_to_sq_data = {
int(none_throws(obs.features.trial_index)): obs.data
for obs in observations
if json.dumps(obs.features.parameters, sort_keys=True)
== status_quo_signature
}

return trial_idx_to_sq_data

0 comments on commit d488141

Please sign in to comment.