Skip to content

Commit

Permalink
Fix some pre-commit errors (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored May 30, 2024
2 parents eea3ecd + 10af1b4 commit 99d04c4
Show file tree
Hide file tree
Showing 78 changed files with 878 additions and 1,155 deletions.
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ repository = "https://github.com/CAREamics/careamics"
line-length = 88
target-version = "py38"
src = ["src"]
select = [
lint.select = [
"E", # style errors
"W", # style warnings
"F", # flakes
Expand All @@ -86,7 +86,7 @@ select = [
"A001", # flake8-builtins
"RUF", # ruff-specific rules
]
ignore = [
lint.ignore = [
"D100", # Missing docstring in public module
"D107", # Missing docstring in __init__
"D203", # 1 blank line required before class docstring
Expand All @@ -103,13 +103,13 @@ ignore = [
"UP006", # Replace typing.List by list, mandatory for py3.8
"UP007", # Replace Union by |, mandatory for py3.9
]
ignore-init-module-imports = true
lint.ignore-init-module-imports = true
show-fixes = true

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D", "S"]
"setup.py" = ["D"]

Expand Down
13 changes: 10 additions & 3 deletions src/careamics/callbacks/hyperparameters_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Callback saving CAREamics configuration as hyperparameters in the model."""

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback

Expand All @@ -11,13 +13,18 @@ class HyperParametersCallback(Callback):
This allows saving the configuration as dictionnary in the checkpoints, and
loading it subsequently in a CAREamist instance.
Parameters
----------
config : Configuration
CAREamics configuration to be saved as hyperparameter in the model.
Attributes
----------
config : Configuration
CAREamics configuration to be saved as hyperparameter in the model.
"""

def __init__(self, config: Configuration):
def __init__(self, config: Configuration) -> None:
"""
Constructor.
Expand All @@ -28,14 +35,14 @@ def __init__(self, config: Configuration):
"""
self.config = config

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""
Update the hyperparameters of the model with the configuration on train start.
Parameters
----------
trainer : Trainer
PyTorch Lightning trainer.
PyTorch Lightning trainer, unused.
pl_module : LightningModule
PyTorch Lightning module.
"""
Expand Down
41 changes: 37 additions & 4 deletions src/careamics/callbacks/progress_bar_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Progressbar callback."""

import sys
from typing import Dict, Union

Expand All @@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar):
"""Progress bar for training and validation steps."""

def init_train_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for training."""
"""Override this to customize the tqdm bar for training.
Returns
-------
tqdm
A tqdm bar.
"""
bar = tqdm(
desc="Training",
position=(2 * self.process_position),
Expand All @@ -23,7 +31,13 @@ def init_train_tqdm(self) -> tqdm:
return bar

def init_validation_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for validation."""
"""Override this to customize the tqdm bar for validation.
Returns
-------
tqdm
A tqdm bar.
"""
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.train_progress_bar is not None
bar = tqdm(
Expand All @@ -37,7 +51,13 @@ def init_validation_tqdm(self) -> tqdm:
return bar

def init_test_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for testing."""
"""Override this to customize the tqdm bar for testing.
Returns
-------
tqdm
A tqdm bar.
"""
bar = tqdm(
desc="Testing",
position=(2 * self.process_position),
Expand All @@ -52,6 +72,19 @@ def init_test_tqdm(self) -> tqdm:
def get_metrics(
self, trainer: Trainer, pl_module: LightningModule
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
"""Override this to customize the metrics displayed in the progress bar."""
"""Override this to customize the metrics displayed in the progress bar.
Parameters
----------
trainer : Trainer
The trainer object.
pl_module : LightningModule
The LightningModule object, unused.
Returns
-------
dict
A dictionary with the metrics to display in the progress bar.
"""
pbar_metrics = trainer.progress_bar_metrics
return {**pbar_metrics}
8 changes: 5 additions & 3 deletions src/careamics/config/algorithm_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Algorithm configuration."""

from __future__ import annotations

from pprint import pformat
Expand All @@ -17,9 +19,9 @@ class AlgorithmConfig(BaseModel):
training algorithm: which algorithm, loss function, model architecture, optimizer,
and learning rate scheduler to use.
Currently, we only support N2V and custom algorithms. The `n2v` algorithm is only
compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm allows
you to register your own architecture and select it using its name as
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
allows you to register your own architecture and select it using its name as
`name` in the custom pydantic model.
Attributes
Expand Down
7 changes: 7 additions & 0 deletions src/careamics/config/architectures/architecture_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Base model for the various CAREamics architectures."""

from typing import Any, Dict

from pydantic import BaseModel
Expand All @@ -16,6 +18,11 @@ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
"""
Dump the model as a dictionary, ignoring the architecture keyword.
Parameters
----------
**kwargs : Any
Additional keyword arguments from Pydantic BaseModel model_dump method.
Returns
-------
dict[str, Any]
Expand Down
9 changes: 8 additions & 1 deletion src/careamics/config/architectures/custom_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Custom architecture Pydantic model."""

from __future__ import annotations

from pprint import pformat
Expand Down Expand Up @@ -84,6 +86,11 @@ def custom_model_is_known(cls, value: str) -> str:
value : str
Name of the custom model as registered using the `@register_model`
decorator.
Returns
-------
str
The custom model name.
"""
# delegate error to get_custom_model
model = get_custom_model(value)
Expand Down Expand Up @@ -134,7 +141,7 @@ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
Parameters
----------
kwargs : Any
**kwargs : Any
Additional keyword arguments from Pydantic BaseModel model_dump method.
Returns
Expand Down
4 changes: 3 additions & 1 deletion src/careamics/config/architectures/register_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Custom model registration utilities."""

from typing import Callable

from torch.nn import Module
Expand Down Expand Up @@ -53,7 +55,7 @@ def add_custom_model(model: Module) -> Module:
Parameters
----------
model : Module
Module class to register
Module class to register.
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions src/careamics/config/architectures/unet_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""UNet Pydantic model."""

from __future__ import annotations

from typing import Literal
Expand Down
2 changes: 2 additions & 0 deletions src/careamics/config/architectures/vae_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""VAE Pydantic model."""

from typing import Literal

from pydantic import (
Expand Down
18 changes: 3 additions & 15 deletions src/careamics/config/callback_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Checkpoint saving configuration."""
"""Callback Pydantic models."""

from __future__ import annotations

Expand All @@ -13,13 +13,7 @@


class CheckpointModel(BaseModel):
"""_summary_.
Parameters
----------
BaseModel : _type_
_description_
"""
"""Checkpoint saving callback Pydantic model."""

model_config = ConfigDict(
validate_assignment=True,
Expand All @@ -46,13 +40,7 @@ class CheckpointModel(BaseModel):


class EarlyStoppingModel(BaseModel):
"""_summary_.
Parameters
----------
BaseModel : _type_
_description_
"""
"""Early stopping callback Pydantic model."""

model_config = ConfigDict(
validate_assignment=True,
Expand Down
6 changes: 4 additions & 2 deletions src/careamics/config/configuration_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Example of configurations."""

from .algorithm_model import AlgorithmConfig
from .architectures import UNetModel
from .configuration_model import Configuration
Expand All @@ -19,7 +21,7 @@


def full_configuration_example() -> Configuration:
"""Returns a dictionnary representing a full configuration example.
"""Return a dictionnary representing a full configuration example.
Returns
-------
Expand Down Expand Up @@ -56,7 +58,7 @@ def full_configuration_example() -> Configuration:
"name": SupportedTransform.NORMALIZE.value,
},
{
"name": SupportedTransform.NDFLIP.value,
"name": SupportedTransform.XY_FLIP.value,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _create_supervised_configuration(
"name": SupportedTransform.NORMALIZE.value,
},
{
"name": SupportedTransform.NDFLIP.value,
"name": SupportedTransform.XY_FLIP.value,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
Expand Down Expand Up @@ -526,7 +526,7 @@ def create_n2v_configuration(
"name": SupportedTransform.NORMALIZE.value,
},
{
"name": SupportedTransform.NDFLIP.value,
"name": SupportedTransform.XY_FLIP.value,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
Expand Down
14 changes: 7 additions & 7 deletions src/careamics/config/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

from .support import SupportedTransform
from .transformations.n2v_manipulate_model import N2VManipulateModel
from .transformations.nd_flip_model import NDFlipModel
from .transformations.normalize_model import NormalizeModel
from .transformations.xy_flip_model import XYFlipModel
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2

TRANSFORMS_UNION = Annotated[
Union[
NDFlipModel,
XYFlipModel,
XYRandomRotate90Model,
NormalizeModel,
N2VManipulateModel,
Expand Down Expand Up @@ -70,7 +70,7 @@ class DataConfig(BaseModel):
... "std": 47.2,
... },
... {
... "name": "NDFlip",
... "name": "XYFlip",
... }
... ]
... )
Expand All @@ -97,7 +97,7 @@ class DataConfig(BaseModel):
"name": SupportedTransform.NORMALIZE.value,
},
{
"name": SupportedTransform.NDFLIP.value,
"name": SupportedTransform.XY_FLIP.value,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
Expand Down Expand Up @@ -202,7 +202,7 @@ def validate_prediction_transforms(

if SupportedTransform.N2V_MANIPULATE in transform_list:
# multiple N2V_MANIPULATE
if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
raise ValueError(
f"Multiple instances of "
f"{SupportedTransform.N2V_MANIPULATE} transforms "
Expand All @@ -211,7 +211,7 @@ def validate_prediction_transforms(

# N2V_MANIPULATE not the last transform
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
transform = transforms.pop(index)
transforms.append(transform)

Expand Down Expand Up @@ -250,7 +250,7 @@ def add_std_and_mean_to_normalize(self: Self) -> Self:
Self
Data model with mean and std added to the Normalize transform.
"""
if self.mean is not None or self.std is not None:
if self.mean is not None and self.std is not None:
# search in the transforms for Normalize and update parameters
for transform in self.transforms:
if transform.name == SupportedTransform.NORMALIZE.value:
Expand Down
Loading

0 comments on commit 99d04c4

Please sign in to comment.