Skip to content

Commit

Permalink
Prevent model predictions issues from erroring out Adapter.gen (#3442)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3442

Model prediction is just a side-effect of `Adapter.gen` call and is not crucial to producing a candidate. We should not error out `gen` due to a model prediction error.

Reviewed By: lena-kashtelyan

Differential Revision: D70425553

fbshipit-source-id: 73122774b0ae428ca2eee54e004640f2cbcc24c3
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Mar 1, 2025
1 parent a66b848 commit b3c32a0
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
3 changes: 2 additions & 1 deletion ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,8 @@ def gen(
best_point_predictions = extract_arm_predictions(
model_predictions=self.predict([best_obsf]), arm_idx=0
)
except NotImplementedError:
except Exception as e:
logger.debug(f"Model predictions failed with error {e}.")
model_predictions = None

if best_obsf is None:
Expand Down
6 changes: 4 additions & 2 deletions ax/models/discrete/thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy.typing as npt
from ax.core.types import TGenMetadata, TParamValue, TParamValueList
from ax.exceptions.constants import TS_MIN_WEIGHT_ERROR, TS_NO_FEASIBLE_ARMS_ERROR
from ax.exceptions.core import UnsupportedError
from ax.exceptions.model import ModelError
from ax.models.discrete_base import DiscreteGenerator
from ax.models.types import TConfig
Expand Down Expand Up @@ -154,8 +155,9 @@ def predict(
for j, x in enumerate(predictX):
# iterate through parameterizations at which to make predictions
if x not in X_to_Y_and_Yvar:
raise ValueError(
"ThompsonSampler does not support out-of-sample prediction."
raise UnsupportedError(
"ThompsonSampler does not support out-of-sample prediction. "
f"(X: {X[j]} - note that this is post-transform application)."
)
f[j, i], cov[j, i, i] = X_to_Y_and_Yvar[
assert_is_instance(x, TParamValue)
Expand Down
5 changes: 4 additions & 1 deletion ax/models/tests/test_eb_thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from unittest.mock import patch

import numpy as np
from ax.exceptions.core import UnsupportedError
from ax.models.discrete.eb_thompson import EmpiricalBayesThompsonSampler
from ax.utils.common.random import set_rng_seed
from ax.utils.common.testutils import TestCase


Expand Down Expand Up @@ -80,6 +82,7 @@ def test_EmpiricalBayesThompsonSamplerGen(self) -> None:
self.assertAlmostEqual(weight, expected_weight, delta=0.1)

def test_EmpiricalBayesThompsonSamplerWarning(self) -> None:
set_rng_seed(0)
generator = EmpiricalBayesThompsonSampler(min_weight=0.0)
generator.fit(
Xs=[x[:-1] for x in self.Xs],
Expand Down Expand Up @@ -132,5 +135,5 @@ def test_EmpiricalBayesThompsonSamplerPredict(self) -> None:
)
)

with self.assertRaises(ValueError):
with self.assertRaisesRegex(UnsupportedError, "out-of-sample"):
generator.predict([[1, 2]])
3 changes: 2 additions & 1 deletion ax/models/tests/test_thompson.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from warnings import catch_warnings

import numpy as np
from ax.exceptions.core import UnsupportedError
from ax.exceptions.model import ModelError
from ax.models.discrete.thompson import ThompsonSampler
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_ThompsonSamplerPredict(self) -> None:
self.assertTrue(np.array_equal(f, np.array([[1], [3]])))
self.assertTrue(np.array_equal(cov, np.ones((2, 1, 1))))

with self.assertRaises(ValueError):
with self.assertRaisesRegex(UnsupportedError, "out-of-sample"):
generator.predict([[1, 2]])

def test_ThompsonSamplerMultiObjectiveWarning(self) -> None:
Expand Down

0 comments on commit b3c32a0

Please sign in to comment.