From 2fcb0c8cede292578036d726aeab902e4ebd09d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Jan 2024 21:32:28 +0000 Subject: [PATCH 1/3] [pre-commit.ci] pre-commit suggestions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/psf/black: 23.11.0 → 23.12.1](https://github.com/psf/black/compare/23.11.0...23.12.1) - [github.com/astral-sh/ruff-pre-commit: v0.1.6 → v0.1.11](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.6...v0.1.11) - [github.com/pre-commit/mirrors-prettier: v3.1.0 → v4.0.0-alpha.8](https://github.com/pre-commit/mirrors-prettier/compare/v3.1.0...v4.0.0-alpha.8) --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c02d5a2..edbb3db1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -40,7 +40,7 @@ repos: args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 23.12.1 hooks: - id: black name: Black code @@ -64,13 +64,13 @@ repos: - id: yesqa - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.6 + rev: v0.1.11 hooks: - id: ruff args: ["--fix"] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.1.0 + rev: v4.0.0-alpha.8 hooks: - id: prettier # https://prettier.io/docs/en/options.html#print-width From f7a326315cef9685fa244d7c2e324c357f8d5323 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 12 Jan 2024 17:38:28 +0100 Subject: [PATCH 2/3] Apply suggestions from code review --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index edbb3db1..c261f9cc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -70,7 +70,7 @@ repos: args: ["--fix"] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 + rev: v3.1.0 hooks: - id: prettier # https://prettier.io/docs/en/options.html#print-width From e7c554b63d209bb0e6ea44cabd14e3d86a42c0db Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 12 Jan 2024 17:46:43 +0100 Subject: [PATCH 3/3] long lines --- src/pytorch_tabular/tabular_model_sweep.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/pytorch_tabular/tabular_model_sweep.py b/src/pytorch_tabular/tabular_model_sweep.py index f618a71e..047806f7 100644 --- a/src/pytorch_tabular/tabular_model_sweep.py +++ b/src/pytorch_tabular/tabular_model_sweep.py @@ -76,9 +76,10 @@ def _validate_args( assert all( isinstance(m, (str, ModelConfig)) for m in model_list ), f"models must be a list of strings or ModelConfigs, but got {model_list}" - assert all( - task == m.task for m in model_list if isinstance(m, ModelConfig) - ), f"task must be the same as the task in ModelConfig, but got {task} and {[m.task for m in model_list if isinstance(m, ModelConfig)]}" + assert all(task == m.task for m in model_list if isinstance(m, ModelConfig)), ( + "task must be the same as the task in ModelConfig," + f" but got {task} and {[m.task for m in model_list if isinstance(m, ModelConfig)]}" + ) if metrics is not None: assert isinstance(metrics, list), f"metrics must be a list of strings or callables, but got {type(metrics)}" assert all( @@ -154,9 +155,9 @@ def model_sweep( trainer_config (Union[TrainerConfig, str]): TrainerConfig object or path to the yaml file. - model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare. This can be one of - the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS`` or a list of ``ModelConfig`` objects. - Defaults to "lite". + model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare. + This can be one of the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS`` + or a list of ``ModelConfig`` objects. Defaults to "lite". metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in ``torchmetrics``. By default, it is