Skip to content

Commit

Permalink
Break up task validation and setting defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Nov 30, 2024
1 parent 18a91ce commit 5e2b7c0
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 61 deletions.
84 changes: 47 additions & 37 deletions amlb/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,30 +229,19 @@ def _benchmark_definition(
)
for task in tasks:
task |= defaults # add missing keys from hard defaults + defaults
Resources._validate_task(task, config_)
Resources._validate_task(task)
Resources._add_task_defaults(task, config_)

Resources._validate_task(defaults, config_, lenient=True)
Resources._add_task_defaults(defaults, config_)
defaults.enabled = False
tasks.append(defaults)
log.debug("Available task definitions:\n%s", tasks)
return tasks, benchmark_name, benchmark_path

@staticmethod
def _validate_task(task: Namespace, config_: Namespace, lenient: bool = False):
if not lenient and task["name"] is None:
raise ValueError(
f"`name` is mandatory but missing in task definition {task}."
)

def _add_task_defaults(task: Namespace, config_: Namespace):
if task["id"] is None:
task["id"] = Resources.generate_task_identifier(task)
if not lenient and task["id"] is None:
raise ValueError(
"task definition must contain an ID or one property "
"among ['openml_task_id', 'dataset'] to create an ID, "
"but task definition is {task}".format(task=str(task))
)

for conf in [
"max_runtime_seconds",
"cores",
Expand All @@ -265,41 +254,62 @@ def _validate_task(task: Namespace, config_: Namespace, lenient: bool = False):
task[conf] = config_.benchmarks.defaults[conf]
log.debug(
"Config `{config}` not set for task {name}, using default `{value}`.".format(
config=conf, name=task.name, value=task[conf]
config=conf, name=task["name"], value=task[conf]
)
)

conf = "ec2_instance_type"
if task[conf] is None:
i_series = config_.aws.ec2.instance_type.series
i_map = config_.aws.ec2.instance_type.map
if str(task.cores) in i_map:
i_size = i_map[str(task.cores)]
elif task.cores > 0:
supported_cores = list(
map(int, Namespace.dict(i_map).keys() - {"default"})
)
supported_cores.sort()
cores = next((c for c in supported_cores if c >= task.cores), "default")
i_size = i_map[str(cores)]
else:
i_size = i_map.default
task[conf] = ".".join([i_series, i_size])
if task["ec2_instance_type"] is None:
task["ec2_instance_type"] = Resources.lookup_ec2_instance_type(
config_, task.cores
)
log.debug(
"Config `{config}` not set for task {name}, using default selection `{value}`.".format(
config=conf, name=task.name, value=task[conf]
config=conf, name=task["name"], value=task["ec2_instance_type"]
)
)

conf = "ec2_volume_type"
if task[conf] is None:
task[conf] = config_.aws.ec2.volume_type
if task["ec2_volume_type"] is None:
task["ec2_volume_type"] = config_.aws.ec2.volume_type
log.debug(
"Config `{config}` not set for task {name}, using default `{value}`.".format(
config=conf, name=task.name, value=task[conf]
config=conf, name=task["name"], value=task["ec2_volume_type"]
)
)

@staticmethod
def _validate_task(task: Namespace) -> None:
"""Raises ValueError if task does not have a name and a way to generate an identifier."""
if task["name"] is None:
raise ValueError(
f"`name` is mandatory but missing in task definition {task}."
)
task_id = Namespace.get(task, "id", Resources.generate_task_identifier(task))
if task_id is None:
raise ValueError(
"task definition must contain an ID or one property "
"among ['openml_task_id', 'dataset'] to create an ID, "
"but task definition is {task}".format(task=str(task))
)

@staticmethod
def lookup_ec2_instance_type(config_: Namespace, cores: int) -> str:
i_series = config_.aws.ec2.instance_type.series
i_map = config_.aws.ec2.instance_type.map
i_size = Resources.lookup_suitable_instance_size(i_map, cores)
return f"{i_series}.{i_size}"

@staticmethod
def lookup_suitable_instance_size(cores_to_size: Namespace, cores: int) -> str:
if str(cores) in cores_to_size:
return cores_to_size[str(cores)]

supported_cores = list(map(int, set(dir(cores_to_size)) - {"default"}))
if cores <= 0 or cores > max(supported_cores):
return cores_to_size.default

cores = next((c for c in sorted(supported_cores) if c >= cores), "default")
return cores_to_size[str(cores)]

@staticmethod
def generate_task_identifier(task: Namespace) -> str | None:
if task["openml_task_id"] is not None:
Expand Down
39 changes: 15 additions & 24 deletions tests/unit/amlb/resources/test_benchmark_definition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from functools import partial

import pytest

from amlb import Resources
Expand Down Expand Up @@ -34,20 +32,13 @@ def amlb_dummy_configuration() -> Namespace:

def test_validate_task_strict_requires_name():
with pytest.raises(ValueError) as excinfo:
Resources._validate_task(
task=Namespace(),
config_=Namespace(),
lenient=False,
)
Resources._validate_task(task=Namespace())
assert "mandatory but missing" in excinfo.value.args[0]


def test_validate_task_strict_requires_id(amlb_dummy_configuration: Namespace):
strict_validate = partial(
Resources._validate_task, config_=amlb_dummy_configuration, lenient=False
)
with pytest.raises(ValueError) as excinfo:
strict_validate(task=Namespace(name="foo"))
Resources._validate_task(task=Namespace(name="foo"))
assert "must contain an ID or one property" in excinfo.value.args[0]


Expand All @@ -61,27 +52,27 @@ def test_validate_task_strict_requires_id(amlb_dummy_configuration: Namespace):
(Namespace(dataset=Namespace(id="bar")), "bar"),
],
)
def test_validate_task_id_formatting(
def test_add_task_defaults_formatting(
properties: Namespace, expected: str, amlb_dummy_configuration: Namespace
):
task = Namespace(name="foo") | properties
Resources._validate_task(task=task, config_=amlb_dummy_configuration)
Resources._add_task_defaults(task=task, config_=amlb_dummy_configuration)
assert task["id"] == expected


def test_validate_task_adds_benchmark_defaults(amlb_dummy_configuration: Namespace):
task = Namespace(name=None)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
def test_add_task_defaults_sets_benchmark_defaults(amlb_dummy_configuration: Namespace):
task = Namespace()
Resources._add_task_defaults(task, amlb_dummy_configuration)

config = Namespace.dict(amlb_dummy_configuration, deep=True)
for setting, default in config["benchmarks"]["defaults"].items():
assert task[setting] == default
assert task["ec2_volume_type"] == amlb_dummy_configuration.aws.ec2.volume_type


def test_validate_task_does_not_overwrite(amlb_dummy_configuration: Namespace):
task = Namespace(name=None, cores=42)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
def test_add_task_defaults_does_not_overwrite(amlb_dummy_configuration: Namespace):
task = Namespace(cores=42)
Resources._add_task_defaults(task, amlb_dummy_configuration)

config = Namespace.dict(amlb_dummy_configuration, deep=True)
assert task.cores == 42
Expand All @@ -90,31 +81,31 @@ def test_validate_task_does_not_overwrite(amlb_dummy_configuration: Namespace):
assert task[setting] == default


def test_validate_task_looks_up_instance_type(amlb_dummy_configuration: Namespace):
def test_add_task_defaults_looks_up_instance_type(amlb_dummy_configuration: Namespace):
instance_type = amlb_dummy_configuration.aws.ec2.instance_type
reverse_size_map = {v: k for k, v in Namespace.dict(instance_type.map).items()}
n_cores_for_small = int(reverse_size_map["small"])

task = Namespace(name="foo", cores=n_cores_for_small)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
Resources._add_task_defaults(task, amlb_dummy_configuration)
assert (
task["ec2_instance_type"] == "m5.small"
), "Should resolve to the instance type with the exact amount of cores"

task = Namespace(name="foo", cores=n_cores_for_small - 1)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
Resources._add_task_defaults(task, amlb_dummy_configuration)
assert (
task["ec2_instance_type"] == "m5.small"
), "If exact amount of cores are not available, should resolve to next biggest"

task = Namespace(name="foo", cores=n_cores_for_small + 1)
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
Resources._add_task_defaults(task, amlb_dummy_configuration)
assert (
task["ec2_instance_type"] == "m5.large"
), "If bigger than largest in map, should revert to default"

task = Namespace(name="foo", ec2_instance_type="bar")
Resources._validate_task(task, amlb_dummy_configuration, lenient=True)
Resources._add_task_defaults(task, amlb_dummy_configuration)
assert (
task["ec2_instance_type"] == "bar"
), "Should not overwrite explicit configuration"

0 comments on commit 5e2b7c0

Please sign in to comment.