From bb1fc3df862332d8026787f83ee53b7b17f0c191 Mon Sep 17 00:00:00 2001 From: Di Jin Date: Mon, 13 Jan 2025 22:37:36 +0100 Subject: [PATCH] Correct logic on boundary handling in activate paramter --- baybe/parameters/utils.py | 34 ++++++------- tests/utils/test_parameters.py | 90 +++++++++------------------------- 2 files changed, 36 insertions(+), 88 deletions(-) diff --git a/baybe/parameters/utils.py b/baybe/parameters/utils.py index c81b038ae..b6dc18c38 100644 --- a/baybe/parameters/utils.py +++ b/baybe/parameters/utils.py @@ -9,7 +9,6 @@ from baybe.parameters.base import Parameter from baybe.parameters.numerical import ( NumericalContinuousParameter, - _FixedNumericalContinuousParameter, ) from baybe.utils.interval import Interval @@ -98,7 +97,7 @@ def sort_parameters(parameters: Collection[Parameter]) -> tuple[Parameter, ...]: def activate_parameter( parameter: NumericalContinuousParameter, thresholds: Interval, -) -> NumericalContinuousParameter | _FixedNumericalContinuousParameter: +) -> NumericalContinuousParameter: """Activates a given parameter by moving its bounds away from zero. Important: @@ -135,16 +134,21 @@ def activate_parameter( f"given." ) + # Note that the definition on the boundary (lower/upper threshold) is vague. + # The value on the lower/upper boundary is determined as within inactive_range; + # while an activated parameter may take this boundary value (lower/upper + # threshold). We allow the misuse of boundary in the "in_inactive_range" and it + # is just an utils for checking condition. Ultimately, the "key" threshold + # boundary appears as a bound of the activated parameter and this is compatible + # with the thresholds defined in ContinuousCardinalityConstraint, as long as the + # "key" threshold boundary is not zero. The "key" threshold boundary is always + # non-zero when the thresholds are inferred from the bounds of this parameter. + def in_inactive_range(x: float) -> bool: """Return true when x is within the inactive range.""" - if thresholds.lower == 0.0: - return thresholds.lower <= x < thresholds.upper - if thresholds.upper == 0.0: - return thresholds.lower < x <= thresholds.upper - return thresholds.lower < x < thresholds.upper - - # Note: When both bounds in inactive range. This step must be checked first to catch - # all possible cases when a parameter cannot be activated. + return thresholds.lower <= x <= thresholds.upper + + # When both bounds in inactive range. if in_inactive_range(lower_bound) and in_inactive_range(upper_bound): raise ValueError( f"Parameter '{parameter.name}' cannot be set active since its " @@ -157,20 +161,10 @@ def in_inactive_range(x: float) -> bool: if lower_bound < thresholds.lower and in_inactive_range(upper_bound): return evolve(parameter, bounds=(lower_bound, thresholds.lower)) - if lower_bound == thresholds.lower and in_inactive_range(upper_bound): - return _FixedNumericalContinuousParameter( - name=parameter.name, value=thresholds.lower - ) - # When the lower bound is in inactive range, move it to the upper threshold of # the inactive region if upper_bound > thresholds.upper and in_inactive_range(lower_bound): return evolve(parameter, bounds=(thresholds.upper, upper_bound)) - if upper_bound == thresholds.upper and in_inactive_range(lower_bound): - return _FixedNumericalContinuousParameter( - name=parameter.name, value=thresholds.upper - ) - # Both bounds separated from inactive range return parameter diff --git a/tests/utils/test_parameters.py b/tests/utils/test_parameters.py index b038e4e9d..2c4a3af8c 100644 --- a/tests/utils/test_parameters.py +++ b/tests/utils/test_parameters.py @@ -22,100 +22,54 @@ def mirror_interval(interval: Interval) -> Interval: "expected_bounds", ), [ - # one-side bounds, two-side thresholds param( - Interval(lower=0.0, upper=1.0), - Interval(lower=-1.0, upper=1.5), + Interval(lower=-1.0, upper=1.0), + Interval(lower=-1.0, upper=1.0), False, None, - id="oneside_bounds_in_twoside_thresholds", + id="bounds_on_thresholds", ), param( - Interval(lower=0.0, upper=1.0), Interval(lower=-1.0, upper=1.0), - True, - Interval(lower=1.0, upper=1.0), - id="oneside_bounds_in_twoside_thresholds_fixed_value", - ), - param( - Interval(lower=0.0, upper=1.0), - Interval(lower=-1.0, upper=0.5), - True, - Interval(lower=0.5, upper=1.0), - id="oneside_bounds_intersected_with_twoside_thresholds", - ), - # one-side bounds, one-side thresholds - param( - Interval(lower=0.0, upper=1.0), - Interval(lower=-1.0, upper=0.0), - True, - Interval(lower=0.0, upper=1.0), - id="oneside_bounds_intersected_on_single_point_with_oneside_thresholds", - ), - param( - Interval(lower=0.0, upper=1.0), - Interval(lower=0.0, upper=0.5), - True, - Interval(lower=0.5, upper=1.0), - id="oneside_bounds_cover_oneside_thresholds", - ), - param( - Interval(lower=0.0, upper=1.0), - Interval(lower=0.0, upper=1.0), - True, - Interval(lower=1.0, upper=1.0), - id="oneside_bounds_match_oneside_thresholds", - ), - param( - Interval(lower=0.0, upper=1.0), - Interval(lower=0.0, upper=1.1), + Interval(lower=-1.5, upper=1.5), False, None, - id="oneside_bounds_in_oneside_thresholds", + id="bounds_in_thresholds", ), - # Two-side bounds. One-side thresholds do not differ from two-side threshold - # in these cases. Hence, use two-side thresholds. param( - Interval(lower=-0.5, upper=1.0), - Interval(lower=-1.0, upper=1.1), + Interval(lower=-1.0, upper=1.0), + Interval(lower=-1.5, upper=1.0), False, None, - id="twoside_bounds_in_twoside_thresholds", + id="bounds_in_thresholds_single_side_match", ), param( - Interval(lower=-0.5, upper=1.0), - Interval(lower=-0.5, upper=1.0), + Interval(lower=-1.0, upper=1.0), + Interval(lower=-0.5, upper=0.5), True, - Interval(lower=-0.5, upper=1.0), - id="twoside_bounds_match_twoside_thresholds", + Interval(lower=-1.0, upper=1.0), + id="thresholds_in_bounds", ), param( - Interval(lower=-0.6, upper=1.1), + Interval(lower=-1.0, upper=1.0), Interval(lower=-0.5, upper=1.0), True, - Interval(lower=-0.6, upper=1.1), - id="twoside_bounds_cover_twoside_thresholds", - ), - param( - Interval(lower=-0.6, upper=1.1), - Interval(lower=-1.0, upper=0.5), - True, - Interval(lower=0.5, upper=1.1), - id="twoside_bounds_intersected_with_twoside_thresholds", + Interval(lower=-1.0, upper=-0.5), + id="thresholds_in_bounds_single_side_match", ), param( - Interval(lower=-0.6, upper=0.5), + Interval(lower=-0.5, upper=1.0), Interval(lower=-1.0, upper=0.5), True, - Interval(lower=0.5, upper=0.5), - id="twoside_bounds_partial_in_twoside_thresholds", + Interval(lower=0.5, upper=1.0), + id="bounds_intersected_with_thresholds", ), param( - Interval(lower=-1.0, upper=0.5), - Interval(lower=-0.6, upper=0.5), + Interval(lower=0.0, upper=1.0), + Interval(lower=-1.0, upper=0.0), True, - Interval(lower=-1.0, upper=0.5), - id="twoside_bounds_partial_cover_twoside_thresholds", + Interval(lower=0.0, upper=1.0), + id="bounds_intersected_with_thresholds_on_one_point", ), ], )