diff --git a/baybe/campaign.py b/baybe/campaign.py index bfa82b3be..8648858f0 100644 --- a/baybe/campaign.py +++ b/baybe/campaign.py @@ -3,7 +3,7 @@ from __future__ import annotations import json -from typing import List +from typing import List, Optional import cattrs import numpy as np @@ -68,13 +68,11 @@ class Campaign(SerialMixin): """The number of fits already done.""" # Private - _measurements_exp: pd.DataFrame = field( - factory=pd.DataFrame, eq=eq_dataframe, init=False - ) + _measurements_exp: pd.DataFrame = field(eq=eq_dataframe, init=False) """The experimental representation of the conducted experiments.""" - _cached_recommendation: pd.DataFrame = field( - factory=pd.DataFrame, eq=eq_dataframe, init=False + _cached_recommendation: Optional[pd.DataFrame] = field( + default=None, eq=eq_dataframe, init=False ) """The cached recommendations.""" @@ -82,6 +80,13 @@ class Campaign(SerialMixin): numerical_measurements_must_be_within_tolerance: bool = field(default=None) """Deprecated! Raises an error when used.""" + @_measurements_exp.default + def _default_measurements_exp(self): + """Provide an empty dataframe with the experimental columns as default.""" + return pd.DataFrame( + columns=[p.name for p in self.parameters] + [t.name for t in self.targets] + ) + @numerical_measurements_must_be_within_tolerance.validator def _validate_tolerance_flag(self, _, value) -> None: """Raise a DeprecationError if the tolerance flag is used.""" @@ -110,15 +115,11 @@ def targets(self) -> List[Target]: @property def _measurements_parameters_comp(self) -> pd.DataFrame: """The computational representation of the measured parameters.""" - if len(self._measurements_exp) < 1: - return pd.DataFrame() return self.searchspace.transform(self._measurements_exp) @property def _measurements_targets_comp(self) -> pd.DataFrame: """The computational representation of the measured targets.""" - if len(self._measurements_exp) < 1: - return pd.DataFrame() return self.objective.transform(self._measurements_exp) @classmethod @@ -181,7 +182,7 @@ def add_measurements( TypeError: If the target has non-numeric entries in the provided dataframe. """ # Invalidate recommendation cache first (in case of uncaught exceptions below) - self._cached_recommendation = pd.DataFrame() + self._cached_recommendation = None # Check if all targets have valid values for target in self.targets: @@ -266,7 +267,10 @@ def recommend( # If there are cached recommendations and the batch size of those is equal to # the previously requested one, we just return those - if len(self._cached_recommendation) == batch_size: + if ( + self._cached_recommendation is not None + and len(self._cached_recommendation) == batch_size + ): return self._cached_recommendation # Update recommendation meta data diff --git a/baybe/searchspace/discrete.py b/baybe/searchspace/discrete.py index 2d7a45892..efaa202f6 100644 --- a/baybe/searchspace/discrete.py +++ b/baybe/searchspace/discrete.py @@ -449,7 +449,7 @@ def transform( A dataframe with the parameters in computational representation. """ # If the transformed values are not required, return an empty dataframe - if self.empty_encoding or len(data) < 1: + if self.empty_encoding: comp_rep = pd.DataFrame(index=data.index) return comp_rep diff --git a/baybe/surrogates/gaussian_process.py b/baybe/surrogates/gaussian_process.py index f8835a0e4..fba2bd53a 100644 --- a/baybe/surrogates/gaussian_process.py +++ b/baybe/surrogates/gaussian_process.py @@ -147,8 +147,10 @@ def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> No covar_module=covar_module, likelihood=likelihood, ) - mll = ExactMarginalLogLikelihood(self._model.likelihood, self._model) - # IMPROVE: The step_limit=100 stems from the former (deprecated) - # `fit_gpytorch_torch` function, for which this was the default. Probably, - # one should use a smarter logic here. - fit_gpytorch_mll_torch(mll, step_limit=100) + + if train_x.numel() > 0: + mll = ExactMarginalLogLikelihood(self._model.likelihood, self._model) + # IMPROVE: The step_limit=100 stems from the former (deprecated) + # `fit_gpytorch_torch` function, for which this was the default. Probably, + # one should use a smarter logic here. + fit_gpytorch_mll_torch(mll, step_limit=100) diff --git a/baybe/surrogates/utils.py b/baybe/surrogates/utils.py index d109d9cb4..09067f991 100644 --- a/baybe/surrogates/utils.py +++ b/baybe/surrogates/utils.py @@ -28,12 +28,7 @@ def _prepare_inputs(x: Tensor) -> Tensor: Returns: The prepared input. - - Raises: - ValueError: If the model input is empty. """ - if len(x) == 0: - raise ValueError("The model input must be non-empty.") return x.to(_DTYPE) diff --git a/tests/test_iterations.py b/tests/test_iterations.py index 36243f06e..d8b9fa7de 100644 --- a/tests/test_iterations.py +++ b/tests/test_iterations.py @@ -148,3 +148,9 @@ def test_iter_recommender_hybrid(campaign, n_iterations, batch_size): @pytest.mark.parametrize("strategy", valid_strategies, indirect=True) def test_strategies(campaign, n_iterations, batch_size): run_iterations(campaign, n_iterations, batch_size) + + +@pytest.mark.parametrize("parameter_names", [["Num_disc_1", "Conti_finite3"]]) +def test_without_data(campaign, batch_size): + campaign.strategy = SequentialGreedyRecommender() + campaign.recommend(batch_size)