diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c02d5a2..c261f9cc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 23.12.1 hooks: - id: black name: Black code @@ -64,7 +64,7 @@ repos: - id: yesqa - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.6 + rev: v0.1.11 hooks: - id: ruff args: ["--fix"] diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index f618a71e..047806f7 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -76,9 +76,10 @@ def _validate_args( assert all( isinstance(m, (str, ModelConfig)) for m in model_list ), f"models must be a list of strings or ModelConfigs, but got {model_list}" - assert all( - task == m.task for m in model_list if isinstance(m, ModelConfig) - ), f"task must be the same as the task in ModelConfig, but got {task} and {[m.task for m in model_list if isinstance(m, ModelConfig)]}" + assert all(task == m.task for m in model_list if isinstance(m, ModelConfig)), ( + "task must be the same as the task in ModelConfig," + f" but got {task} and {[m.task for m in model_list if isinstance(m, ModelConfig)]}" + ) if metrics is not None: assert isinstance(metrics, list), f"metrics must be a list of strings or callables, but got {type(metrics)}" assert all( @@ -154,9 +155,9 @@ def model_sweep( trainer_config (Union[TrainerConfig, str]): TrainerConfig object or path to the yaml file. - model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare. This can be one of - the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS`` or a list of ``ModelConfig`` objects. - Defaults to "lite". + model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare. + This can be one of the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS`` + or a list of ``ModelConfig`` objects. Defaults to "lite". metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in ``torchmetrics``. By default, it is