Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some pre-commit errors #123

Merged
merged 15 commits into from
May 30, 2024
Merged
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ repository = "https://github.com/CAREamics/careamics"
line-length = 88
target-version = "py38"
src = ["src"]
select = [
lint.select = [
"E", # style errors
"W", # style warnings
"F", # flakes
Expand All @@ -86,7 +86,7 @@ select = [
"A001", # flake8-builtins
"RUF", # ruff-specific rules
]
ignore = [
lint.ignore = [
"D100", # Missing docstring in public module
"D107", # Missing docstring in __init__
"D203", # 1 blank line required before class docstring
Expand All @@ -103,13 +103,13 @@ ignore = [
"UP006", # Replace typing.List by list, mandatory for py3.8
"UP007", # Replace Union by |, mandatory for py3.9
]
ignore-init-module-imports = true
lint.ignore-init-module-imports = true
show-fixes = true

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D", "S"]
"setup.py" = ["D"]

Expand Down
13 changes: 10 additions & 3 deletions src/careamics/callbacks/hyperparameters_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Callback saving CAREamics configuration as hyperparameters in the model."""

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback

Expand All @@ -11,13 +13,18 @@ class HyperParametersCallback(Callback):
This allows saving the configuration as dictionnary in the checkpoints, and
loading it subsequently in a CAREamist instance.

Parameters
----------
config : Configuration
CAREamics configuration to be saved as hyperparameter in the model.

Attributes
----------
config : Configuration
CAREamics configuration to be saved as hyperparameter in the model.
"""

def __init__(self, config: Configuration):
def __init__(self, config: Configuration) -> None:
"""
Constructor.

Expand All @@ -28,14 +35,14 @@ def __init__(self, config: Configuration):
"""
self.config = config

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""
Update the hyperparameters of the model with the configuration on train start.

Parameters
----------
trainer : Trainer
PyTorch Lightning trainer.
PyTorch Lightning trainer, unused.
pl_module : LightningModule
PyTorch Lightning module.
"""
Expand Down
41 changes: 37 additions & 4 deletions src/careamics/callbacks/progress_bar_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Progressbar callback."""

import sys
from typing import Dict, Union

Expand All @@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar):
"""Progress bar for training and validation steps."""

def init_train_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for training."""
"""Override this to customize the tqdm bar for training.

Returns
-------
tqdm
A tqdm bar.
"""
bar = tqdm(
desc="Training",
position=(2 * self.process_position),
Expand All @@ -23,7 +31,13 @@ def init_train_tqdm(self) -> tqdm:
return bar

def init_validation_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for validation."""
"""Override this to customize the tqdm bar for validation.

Returns
-------
tqdm
A tqdm bar.
"""
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.train_progress_bar is not None
bar = tqdm(
Expand All @@ -37,7 +51,13 @@ def init_validation_tqdm(self) -> tqdm:
return bar

def init_test_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for testing."""
"""Override this to customize the tqdm bar for testing.

Returns
-------
tqdm
A tqdm bar.
"""
bar = tqdm(
desc="Testing",
position=(2 * self.process_position),
Expand All @@ -52,6 +72,19 @@ def init_test_tqdm(self) -> tqdm:
def get_metrics(
self, trainer: Trainer, pl_module: LightningModule
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
"""Override this to customize the metrics displayed in the progress bar."""
"""Override this to customize the metrics displayed in the progress bar.

Parameters
----------
trainer : Trainer
The trainer object.
pl_module : LightningModule
The LightningModule object, unused.

Returns
-------
dict
A dictionary with the metrics to display in the progress bar.
"""
pbar_metrics = trainer.progress_bar_metrics
return {**pbar_metrics}
8 changes: 5 additions & 3 deletions src/careamics/config/algorithm_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Algorithm configuration."""

from __future__ import annotations

from pprint import pformat
Expand All @@ -17,9 +19,9 @@ class AlgorithmConfig(BaseModel):
training algorithm: which algorithm, loss function, model architecture, optimizer,
and learning rate scheduler to use.

Currently, we only support N2V and custom algorithms. The `n2v` algorithm is only
compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm allows
you to register your own architecture and select it using its name as
Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is
only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm
allows you to register your own architecture and select it using its name as
`name` in the custom pydantic model.

Attributes
Expand Down
7 changes: 7 additions & 0 deletions src/careamics/config/architectures/architecture_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Base model for the various CAREamics architectures."""

from typing import Any, Dict

from pydantic import BaseModel
Expand All @@ -16,6 +18,11 @@ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:
"""
Dump the model as a dictionary, ignoring the architecture keyword.

Parameters
----------
**kwargs : Any
Additional keyword arguments from Pydantic BaseModel model_dump method.

Returns
-------
dict[str, Any]
Expand Down
9 changes: 8 additions & 1 deletion src/careamics/config/architectures/custom_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Custom architecture Pydantic model."""

from __future__ import annotations

from pprint import pformat
Expand Down Expand Up @@ -84,6 +86,11 @@ def custom_model_is_known(cls, value: str) -> str:
value : str
Name of the custom model as registered using the `@register_model`
decorator.

Returns
-------
str
The custom model name.
"""
# delegate error to get_custom_model
model = get_custom_model(value)
Expand Down Expand Up @@ -134,7 +141,7 @@ def model_dump(self, **kwargs: Any) -> Dict[str, Any]:

Parameters
----------
kwargs : Any
**kwargs : Any
Additional keyword arguments from Pydantic BaseModel model_dump method.

Returns
Expand Down
4 changes: 3 additions & 1 deletion src/careamics/config/architectures/register_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Custom model registration utilities."""

from typing import Callable

from torch.nn import Module
Expand Down Expand Up @@ -53,7 +55,7 @@ def add_custom_model(model: Module) -> Module:
Parameters
----------
model : Module
Module class to register
Module class to register.

Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions src/careamics/config/architectures/unet_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""UNet Pydantic model."""

from __future__ import annotations

from typing import Literal
Expand Down
2 changes: 2 additions & 0 deletions src/careamics/config/architectures/vae_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""VAE Pydantic model."""

from typing import Literal

from pydantic import (
Expand Down
18 changes: 3 additions & 15 deletions src/careamics/config/callback_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Checkpoint saving configuration."""
"""Callback Pydantic models."""

from __future__ import annotations

Expand All @@ -13,13 +13,7 @@


class CheckpointModel(BaseModel):
"""_summary_.

Parameters
----------
BaseModel : _type_
_description_
"""
"""Checkpoint saving callback Pydantic model."""

model_config = ConfigDict(
validate_assignment=True,
Expand All @@ -46,13 +40,7 @@ class CheckpointModel(BaseModel):


class EarlyStoppingModel(BaseModel):
"""_summary_.

Parameters
----------
BaseModel : _type_
_description_
"""
"""Early stopping callback Pydantic model."""

model_config = ConfigDict(
validate_assignment=True,
Expand Down
6 changes: 4 additions & 2 deletions src/careamics/config/configuration_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Example of configurations."""

from .algorithm_model import AlgorithmConfig
from .architectures import UNetModel
from .configuration_model import Configuration
Expand All @@ -19,7 +21,7 @@


def full_configuration_example() -> Configuration:
"""Returns a dictionnary representing a full configuration example.
"""Return a dictionnary representing a full configuration example.

Returns
-------
Expand Down Expand Up @@ -56,7 +58,7 @@ def full_configuration_example() -> Configuration:
"name": SupportedTransform.NORMALIZE.value,
},
{
"name": SupportedTransform.NDFLIP.value,
"name": SupportedTransform.XY_FLIP.value,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _create_supervised_configuration(
"name": SupportedTransform.NORMALIZE.value,
},
{
"name": SupportedTransform.NDFLIP.value,
"name": SupportedTransform.XY_FLIP.value,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
Expand Down Expand Up @@ -526,7 +526,7 @@ def create_n2v_configuration(
"name": SupportedTransform.NORMALIZE.value,
},
{
"name": SupportedTransform.NDFLIP.value,
"name": SupportedTransform.XY_FLIP.value,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
Expand Down
14 changes: 7 additions & 7 deletions src/careamics/config/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

from .support import SupportedTransform
from .transformations.n2v_manipulate_model import N2VManipulateModel
from .transformations.nd_flip_model import NDFlipModel
from .transformations.normalize_model import NormalizeModel
from .transformations.xy_flip_model import XYFlipModel
from .transformations.xy_random_rotate90_model import XYRandomRotate90Model
from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2

TRANSFORMS_UNION = Annotated[
Union[
NDFlipModel,
XYFlipModel,
XYRandomRotate90Model,
NormalizeModel,
N2VManipulateModel,
Expand Down Expand Up @@ -70,7 +70,7 @@ class DataConfig(BaseModel):
... "std": 47.2,
... },
... {
... "name": "NDFlip",
... "name": "XYFlip",
... }
... ]
... )
Expand All @@ -97,7 +97,7 @@ class DataConfig(BaseModel):
"name": SupportedTransform.NORMALIZE.value,
},
{
"name": SupportedTransform.NDFLIP.value,
"name": SupportedTransform.XY_FLIP.value,
},
{
"name": SupportedTransform.XY_RANDOM_ROTATE90.value,
Expand Down Expand Up @@ -202,7 +202,7 @@ def validate_prediction_transforms(

if SupportedTransform.N2V_MANIPULATE in transform_list:
# multiple N2V_MANIPULATE
if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1:
if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1:
raise ValueError(
f"Multiple instances of "
f"{SupportedTransform.N2V_MANIPULATE} transforms "
Expand All @@ -211,7 +211,7 @@ def validate_prediction_transforms(

# N2V_MANIPULATE not the last transform
elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE:
index = transform_list.index(SupportedTransform.N2V_MANIPULATE)
index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value)
transform = transforms.pop(index)
transforms.append(transform)

Expand Down Expand Up @@ -250,7 +250,7 @@ def add_std_and_mean_to_normalize(self: Self) -> Self:
Self
Data model with mean and std added to the Normalize transform.
"""
if self.mean is not None or self.std is not None:
if self.mean is not None and self.std is not None:
# search in the transforms for Normalize and update parameters
for transform in self.transforms:
if transform.name == SupportedTransform.NORMALIZE.value:
Expand Down
Loading
Loading