Skip to content

Commit

Permalink
Correct logic on boundary handling in activate paramter
Browse files Browse the repository at this point in the history
  • Loading branch information
Waschenbacher committed Jan 13, 2025
1 parent ea8b5bc commit bb1fc3d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 88 deletions.
34 changes: 14 additions & 20 deletions baybe/parameters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from baybe.parameters.base import Parameter
from baybe.parameters.numerical import (
NumericalContinuousParameter,
_FixedNumericalContinuousParameter,
)
from baybe.utils.interval import Interval

Expand Down Expand Up @@ -98,7 +97,7 @@ def sort_parameters(parameters: Collection[Parameter]) -> tuple[Parameter, ...]:
def activate_parameter(
parameter: NumericalContinuousParameter,
thresholds: Interval,
) -> NumericalContinuousParameter | _FixedNumericalContinuousParameter:
) -> NumericalContinuousParameter:
"""Activates a given parameter by moving its bounds away from zero.
Important:
Expand Down Expand Up @@ -135,16 +134,21 @@ def activate_parameter(
f"given."
)

# Note that the definition on the boundary (lower/upper threshold) is vague.
# The value on the lower/upper boundary is determined as within inactive_range;
# while an activated parameter may take this boundary value (lower/upper
# threshold). We allow the misuse of boundary in the "in_inactive_range" and it
# is just an utils for checking condition. Ultimately, the "key" threshold
# boundary appears as a bound of the activated parameter and this is compatible
# with the thresholds defined in ContinuousCardinalityConstraint, as long as the
# "key" threshold boundary is not zero. The "key" threshold boundary is always
# non-zero when the thresholds are inferred from the bounds of this parameter.

def in_inactive_range(x: float) -> bool:
"""Return true when x is within the inactive range."""
if thresholds.lower == 0.0:
return thresholds.lower <= x < thresholds.upper
if thresholds.upper == 0.0:
return thresholds.lower < x <= thresholds.upper
return thresholds.lower < x < thresholds.upper

# Note: When both bounds in inactive range. This step must be checked first to catch
# all possible cases when a parameter cannot be activated.
return thresholds.lower <= x <= thresholds.upper

# When both bounds in inactive range.
if in_inactive_range(lower_bound) and in_inactive_range(upper_bound):
raise ValueError(
f"Parameter '{parameter.name}' cannot be set active since its "
Expand All @@ -157,20 +161,10 @@ def in_inactive_range(x: float) -> bool:
if lower_bound < thresholds.lower and in_inactive_range(upper_bound):
return evolve(parameter, bounds=(lower_bound, thresholds.lower))

if lower_bound == thresholds.lower and in_inactive_range(upper_bound):
return _FixedNumericalContinuousParameter(
name=parameter.name, value=thresholds.lower
)

# When the lower bound is in inactive range, move it to the upper threshold of
# the inactive region
if upper_bound > thresholds.upper and in_inactive_range(lower_bound):
return evolve(parameter, bounds=(thresholds.upper, upper_bound))

if upper_bound == thresholds.upper and in_inactive_range(lower_bound):
return _FixedNumericalContinuousParameter(
name=parameter.name, value=thresholds.upper
)

# Both bounds separated from inactive range
return parameter
90 changes: 22 additions & 68 deletions tests/utils/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,100 +22,54 @@ def mirror_interval(interval: Interval) -> Interval:
"expected_bounds",
),
[
# one-side bounds, two-side thresholds
param(
Interval(lower=0.0, upper=1.0),
Interval(lower=-1.0, upper=1.5),
Interval(lower=-1.0, upper=1.0),
Interval(lower=-1.0, upper=1.0),
False,
None,
id="oneside_bounds_in_twoside_thresholds",
id="bounds_on_thresholds",
),
param(
Interval(lower=0.0, upper=1.0),
Interval(lower=-1.0, upper=1.0),
True,
Interval(lower=1.0, upper=1.0),
id="oneside_bounds_in_twoside_thresholds_fixed_value",
),
param(
Interval(lower=0.0, upper=1.0),
Interval(lower=-1.0, upper=0.5),
True,
Interval(lower=0.5, upper=1.0),
id="oneside_bounds_intersected_with_twoside_thresholds",
),
# one-side bounds, one-side thresholds
param(
Interval(lower=0.0, upper=1.0),
Interval(lower=-1.0, upper=0.0),
True,
Interval(lower=0.0, upper=1.0),
id="oneside_bounds_intersected_on_single_point_with_oneside_thresholds",
),
param(
Interval(lower=0.0, upper=1.0),
Interval(lower=0.0, upper=0.5),
True,
Interval(lower=0.5, upper=1.0),
id="oneside_bounds_cover_oneside_thresholds",
),
param(
Interval(lower=0.0, upper=1.0),
Interval(lower=0.0, upper=1.0),
True,
Interval(lower=1.0, upper=1.0),
id="oneside_bounds_match_oneside_thresholds",
),
param(
Interval(lower=0.0, upper=1.0),
Interval(lower=0.0, upper=1.1),
Interval(lower=-1.5, upper=1.5),
False,
None,
id="oneside_bounds_in_oneside_thresholds",
id="bounds_in_thresholds",
),
# Two-side bounds. One-side thresholds do not differ from two-side threshold
# in these cases. Hence, use two-side thresholds.
param(
Interval(lower=-0.5, upper=1.0),
Interval(lower=-1.0, upper=1.1),
Interval(lower=-1.0, upper=1.0),
Interval(lower=-1.5, upper=1.0),
False,
None,
id="twoside_bounds_in_twoside_thresholds",
id="bounds_in_thresholds_single_side_match",
),
param(
Interval(lower=-0.5, upper=1.0),
Interval(lower=-0.5, upper=1.0),
Interval(lower=-1.0, upper=1.0),
Interval(lower=-0.5, upper=0.5),
True,
Interval(lower=-0.5, upper=1.0),
id="twoside_bounds_match_twoside_thresholds",
Interval(lower=-1.0, upper=1.0),
id="thresholds_in_bounds",
),
param(
Interval(lower=-0.6, upper=1.1),
Interval(lower=-1.0, upper=1.0),
Interval(lower=-0.5, upper=1.0),
True,
Interval(lower=-0.6, upper=1.1),
id="twoside_bounds_cover_twoside_thresholds",
),
param(
Interval(lower=-0.6, upper=1.1),
Interval(lower=-1.0, upper=0.5),
True,
Interval(lower=0.5, upper=1.1),
id="twoside_bounds_intersected_with_twoside_thresholds",
Interval(lower=-1.0, upper=-0.5),
id="thresholds_in_bounds_single_side_match",
),
param(
Interval(lower=-0.6, upper=0.5),
Interval(lower=-0.5, upper=1.0),
Interval(lower=-1.0, upper=0.5),
True,
Interval(lower=0.5, upper=0.5),
id="twoside_bounds_partial_in_twoside_thresholds",
Interval(lower=0.5, upper=1.0),
id="bounds_intersected_with_thresholds",
),
param(
Interval(lower=-1.0, upper=0.5),
Interval(lower=-0.6, upper=0.5),
Interval(lower=0.0, upper=1.0),
Interval(lower=-1.0, upper=0.0),
True,
Interval(lower=-1.0, upper=0.5),
id="twoside_bounds_partial_cover_twoside_thresholds",
Interval(lower=0.0, upper=1.0),
id="bounds_intersected_with_thresholds_on_one_point",
),
],
)
Expand Down

0 comments on commit bb1fc3d

Please sign in to comment.