From a25e4b14c8145a20d4912d7c6abe301941c733bc Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 24 May 2024 19:38:55 +0200 Subject: [PATCH 01/13] (chore): update ruff lint namespace --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bdddf8a95..578abb4bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ repository = "https://github.com/CAREamics/careamics" line-length = 88 target-version = "py38" src = ["src"] -select = [ +lint.select = [ "E", # style errors "W", # style warnings "F", # flakes @@ -87,7 +87,7 @@ select = [ "A001", # flake8-builtins "RUF", # ruff-specific rules ] -ignore = [ +lint.ignore = [ "D100", # Missing docstring in public module "D107", # Missing docstring in __init__ "D203", # 1 blank line required before class docstring @@ -104,13 +104,13 @@ ignore = [ "UP006", # Replace typing.List by list, mandatory for py3.8 "UP007", # Replace Union by |, mandatory for py3.9 ] -ignore-init-module-imports = true +lint.ignore-init-module-imports = true show-fixes = true -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "numpy" -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "tests/*.py" = ["D", "S"] "setup.py" = ["D"] From e68e292f847439eaa1a537465a62e019948d0623 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 24 May 2024 19:56:43 +0200 Subject: [PATCH 02/13] (chore): fix some ruff errors --- src/careamics/conftest.py | 12 ++++++++++++ src/careamics/losses/__init__.py | 5 ++--- src/careamics/losses/losses.py | 3 --- src/careamics/models/unet.py | 3 +-- src/careamics/utils/receptive_field.py | 22 ++++++++++++++-------- 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/careamics/conftest.py b/src/careamics/conftest.py index a192658ca..815bbd188 100644 --- a/src/careamics/conftest.py +++ b/src/careamics/conftest.py @@ -14,6 +14,18 @@ @pytest.fixture(scope="module") def my_path(tmpdir_factory: TempPathFactory) -> Path: + """Fixture used in doctest to create a temporary directory. + + Parameters + ---------- + tmpdir_factory : TempPathFactory + Temporary path factory from pytest. + + Returns + ------- + Path + Temporary directory path. + """ return tmpdir_factory.mktemp("my_path") diff --git a/src/careamics/losses/__init__.py b/src/careamics/losses/__init__.py index 07a8972e0..11ce03389 100644 --- a/src/careamics/losses/__init__.py +++ b/src/careamics/losses/__init__.py @@ -1,6 +1,5 @@ """Losses module.""" -from .loss_factory import loss_factory +__all__ = ["loss_factory"] -# from .noise_model_factory import noise_model_factory as noise_model_factory -# from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel +from .loss_factory import loss_factory diff --git a/src/careamics/losses/losses.py b/src/careamics/losses/losses.py index b69937424..c6c3234ee 100644 --- a/src/careamics/losses/losses.py +++ b/src/careamics/losses/losses.py @@ -5,9 +5,6 @@ """ import torch - -# TODO if we are only using the DiceLoss, can we just implement it? -# from segmentation_models_pytorch.losses import DiceLoss from torch.nn import L1Loss, MSELoss diff --git a/src/careamics/models/unet.py b/src/careamics/models/unet.py index 9167756b9..2a9198d8d 100644 --- a/src/careamics/models/unet.py +++ b/src/careamics/models/unet.py @@ -253,8 +253,7 @@ def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor: """ Splits the tensors `A` and `B` into equally sized groups along the channel axis (axis=1); then concatenates the groups in alternating - order along the channel axis, starting with the first group from tensor - A. + order along the channel axis, starting with the first group from tensor A. Parameters ---------- diff --git a/src/careamics/utils/receptive_field.py b/src/careamics/utils/receptive_field.py index 05de04bd7..913b75650 100644 --- a/src/careamics/utils/receptive_field.py +++ b/src/careamics/utils/receptive_field.py @@ -1,5 +1,6 @@ """Receptive field calculation for computing the tile overlap.""" +# TODO better docstring and function names # Adapted from: https://github.com/frgfm/torch-scan import math @@ -21,14 +22,19 @@ def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]: """Estimate the spatial receptive field of the module. - Args: - module (torch.nn.Module): PyTorch module - inp (torch.Tensor): input to the module - out (torch.Tensor): output of the module - Returns: - receptive field - effective stride - effective padding + Parameters + ---------- + module : Module + Module to estimate the receptive field. + inp : Tensor + Input tensor. + out : Tensor + Output tensor. + + Returns + ------- + Tuple[float, float, float] + Receptive field, effective stride and padding. """ if isinstance( module, From 3b61129b56c06e5433fb173988aed29244db1bc0 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Mon, 27 May 2024 13:06:52 +0200 Subject: [PATCH 03/13] Ruff passing --- src/careamics/models/unet.py | 11 +- src/careamics/utils/receptive_field.py | 182 ++++++++++++------------- 2 files changed, 98 insertions(+), 95 deletions(-) diff --git a/src/careamics/models/unet.py b/src/careamics/models/unet.py index 2a9198d8d..11f29eb32 100644 --- a/src/careamics/models/unet.py +++ b/src/careamics/models/unet.py @@ -250,15 +250,18 @@ def forward(self, *features: torch.Tensor) -> torch.Tensor: @staticmethod def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor: - """ - Splits the tensors `A` and `B` into equally sized groups along the - channel axis (axis=1); then concatenates the groups in alternating - order along the channel axis, starting with the first group from tensor A. + """Interleave two tensors. + + Splits the tensors `A` and `B` into equally sized groups along the channel + axis (axis=1); then concatenates the groups in alternating order along the + channel axis, starting with the first group from tensor A. Parameters ---------- A: torch.Tensor + First tensor. B: torch.Tensor + Second tensor. groups: int The number of groups. diff --git a/src/careamics/utils/receptive_field.py b/src/careamics/utils/receptive_field.py index 913b75650..fc4fca0d7 100644 --- a/src/careamics/utils/receptive_field.py +++ b/src/careamics/utils/receptive_field.py @@ -3,106 +3,106 @@ # TODO better docstring and function names # Adapted from: https://github.com/frgfm/torch-scan -import math -import warnings -from typing import Tuple, Union +# import math +# import warnings +# from typing import Tuple, Union -from torch import Tensor, nn -from torch.nn import Module -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd -from torch.nn.modules.pooling import ( - _AdaptiveAvgPoolNd, - _AdaptiveMaxPoolNd, - _AvgPoolNd, - _MaxPoolNd, -) +# from torch import Tensor, nn +# from torch.nn import Module +# from torch.nn.modules.batchnorm import _BatchNorm +# from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd +# from torch.nn.modules.pooling import ( +# _AdaptiveAvgPoolNd, +# _AdaptiveMaxPoolNd, +# _AvgPoolNd, +# _MaxPoolNd, +# ) -def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]: - """Estimate the spatial receptive field of the module. +# def module_rf(module: Module, inp: Tensor, out: Tensor) -> Tuple[float, float, float]: +# """Estimate the spatial receptive field of the module. - Parameters - ---------- - module : Module - Module to estimate the receptive field. - inp : Tensor - Input tensor. - out : Tensor - Output tensor. +# Parameters +# ---------- +# module : Module +# Module to estimate the receptive field. +# inp : Tensor +# Input tensor. +# out : Tensor +# Output tensor. - Returns - ------- - Tuple[float, float, float] - Receptive field, effective stride and padding. - """ - if isinstance( - module, - ( - nn.Identity, - nn.Flatten, - nn.ReLU, - nn.ELU, - nn.LeakyReLU, - nn.ReLU6, - nn.Tanh, - nn.Sigmoid, - _BatchNorm, - nn.Dropout, - nn.Linear, - ), - ): - return 1.0, 1.0, 0.0 - elif isinstance(module, _ConvTransposeNd): - return rf_convtransposend(module, inp, out) - elif isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)): - return rf_aggregnd(module, inp, out) - elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)): - return rf_adaptive_poolnd(module, inp, out) - else: - warnings.warn( - f"Module type not supported: {module.__class__.__name__}", stacklevel=1 - ) - return 1.0, 1.0, 0.0 +# Returns +# ------- +# Tuple[float, float, float] +# Receptive field, effective stride and padding. +# """ +# if isinstance( +# module, +# ( +# nn.Identity, +# nn.Flatten, +# nn.ReLU, +# nn.ELU, +# nn.LeakyReLU, +# nn.ReLU6, +# nn.Tanh, +# nn.Sigmoid, +# _BatchNorm, +# nn.Dropout, +# nn.Linear, +# ), +# ): +# return 1.0, 1.0, 0.0 +# elif isinstance(module, _ConvTransposeNd): +# return rf_convtransposend(module, inp, out) +# elif isinstance(module, (_ConvNd, _MaxPoolNd, _AvgPoolNd)): +# return rf_aggregnd(module, inp, out) +# elif isinstance(module, (_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd)): +# return rf_adaptive_poolnd(module, inp, out) +# else: +# warnings.warn( +# f"Module type not supported: {module.__class__.__name__}", stacklevel=1 +# ) +# return 1.0, 1.0, 0.0 -def rf_convtransposend( - module: _ConvTransposeNd, _: Tensor, __: Tensor -) -> Tuple[float, float, float]: - k = ( - module.kernel_size[0] - if isinstance(module.kernel_size, tuple) - else module.kernel_size - ) - s = module.stride[0] if isinstance(module.stride, tuple) else module.stride - return -k, 1.0 / s, 0.0 +# def rf_convtransposend( +# module: _ConvTransposeNd, _: Tensor, __: Tensor +# ) -> Tuple[float, float, float]: +# k = ( +# module.kernel_size[0] +# if isinstance(module.kernel_size, tuple) +# else module.kernel_size +# ) +# s = module.stride[0] if isinstance(module.stride, tuple) else module.stride +# return -k, 1.0 / s, 0.0 -def rf_aggregnd( - module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tensor, __: Tensor -) -> Tuple[float, float, float]: - k = ( - module.kernel_size[0] - if isinstance(module.kernel_size, tuple) - else module.kernel_size - ) - if hasattr(module, "dilation"): - d = ( - module.dilation[0] - if isinstance(module.dilation, tuple) - else module.dilation - ) - k = d * (k - 1) + 1 - s = module.stride[0] if isinstance(module.stride, tuple) else module.stride - p = module.padding[0] if isinstance(module.padding, tuple) else module.padding - return k, s, p # type: ignore[return-value] +# def rf_aggregnd( +# module: Union[_ConvNd, _MaxPoolNd, _AvgPoolNd], _: Tensor, __: Tensor +# ) -> Tuple[float, float, float]: +# k = ( +# module.kernel_size[0] +# if isinstance(module.kernel_size, tuple) +# else module.kernel_size +# ) +# if hasattr(module, "dilation"): +# d = ( +# module.dilation[0] +# if isinstance(module.dilation, tuple) +# else module.dilation +# ) +# k = d * (k - 1) + 1 +# s = module.stride[0] if isinstance(module.stride, tuple) else module.stride +# p = module.padding[0] if isinstance(module.padding, tuple) else module.padding +# return k, s, p # type: ignore[return-value] -def rf_adaptive_poolnd( - _: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor -) -> Tuple[int, int, float]: - stride = math.ceil(inp.shape[-1] / out.shape[-1]) - kernel_size = stride - padding = (inp.shape[-1] - kernel_size * stride) / 2 +# def rf_adaptive_poolnd( +# _: Union[_AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd], inp: Tensor, out: Tensor +# ) -> Tuple[int, int, float]: +# stride = math.ceil(inp.shape[-1] / out.shape[-1]) +# kernel_size = stride +# padding = (inp.shape[-1] - kernel_size * stride) / 2 - return kernel_size, stride, padding +# return kernel_size, stride, padding From 1b46cb6afc464acf96d85e337969a5423ffb5591 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Mon, 27 May 2024 18:05:11 +0200 Subject: [PATCH 04/13] Fix some mypy errors --- src/careamics/config/data_model.py | 6 +- src/careamics/config/noise_models.py | 324 ++--- .../dataset/patching/tiled_patching.py | 2 +- src/careamics/lightning_datamodule.py | 44 +- src/careamics/losses/noise_model_factory.py | 60 +- src/careamics/losses/noise_models.py | 1048 ++++++++--------- src/careamics/transforms/compose.py | 6 +- .../transforms/pixel_manipulation.py | 4 +- src/careamics/transforms/transform.py | 30 +- 9 files changed, 759 insertions(+), 765 deletions(-) diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index f1aacb7bd..862bc3c92 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -202,7 +202,7 @@ def validate_prediction_transforms( if SupportedTransform.N2V_MANIPULATE in transform_list: # multiple N2V_MANIPULATE - if transform_list.count(SupportedTransform.N2V_MANIPULATE) > 1: + if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1: raise ValueError( f"Multiple instances of " f"{SupportedTransform.N2V_MANIPULATE} transforms " @@ -211,7 +211,7 @@ def validate_prediction_transforms( # N2V_MANIPULATE not the last transform elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE: - index = transform_list.index(SupportedTransform.N2V_MANIPULATE) + index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value) transform = transforms.pop(index) transforms.append(transform) @@ -250,7 +250,7 @@ def add_std_and_mean_to_normalize(self: Self) -> Self: Self Data model with mean and std added to the Normalize transform. """ - if self.mean is not None or self.std is not None: + if self.mean is not None and self.std is not None: # search in the transforms for Normalize and update parameters for transform in self.transforms: if transform.name == SupportedTransform.NORMALIZE.value: diff --git a/src/careamics/config/noise_models.py b/src/careamics/config/noise_models.py index 6dd01fa49..2bdae9388 100644 --- a/src/careamics/config/noise_models.py +++ b/src/careamics/config/noise_models.py @@ -1,162 +1,162 @@ -from __future__ import annotations - -from enum import Enum -from typing import Dict, Union - -from pydantic import BaseModel, ConfigDict, Field, field_validator - - -class NoiseModelType(str, Enum): - """ - Available noise models. - - Currently supported noise models: - - - hist: Histogram noise model. - - gmm: Gaussian mixture model noise model.F - """ - - NONE = "none" - HIST = "hist" - GMM = "gmm" - - # TODO add validator decorator - @classmethod - def validate_noise_model_type( - cls, noise_model: Union[str, NoiseModel], parameters: dict - ) -> None: - """_summary_. - - Parameters - ---------- - noise_model : Union[str, NoiseModel] - _description_ - parameters : dict - _description_ - - Returns - ------- - BaseModel - _description_ - """ - if noise_model == NoiseModelType.HIST.value: - HistogramNoiseModel(**parameters) - return HistogramNoiseModel().model_dump() if not parameters else parameters - - elif noise_model == NoiseModelType.GMM.value: - GaussianMixtureNoiseModel(**parameters) - return ( - GaussianMixtureNoiseModel().model_dump() - if not parameters - else parameters - ) - - -class NoiseModel(BaseModel): - """_summary_. - - Parameters - ---------- - BaseModel : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - - Raises - ------ - ValueError - _description_ - """ - - model_config = ConfigDict( - use_enum_values=True, - protected_namespaces=(), # allows to use model_* as a field name - validate_assignment=True, - ) - - model_type: NoiseModelType - parameters: Dict = Field(default_factory=dict, validate_default=True) - - @field_validator("parameters") - @classmethod - def validate_parameters(cls, data, values) -> Dict: - """_summary_. - - Parameters - ---------- - parameters : Dict - _description_ - - Returns - ------- - Dict - _description_ - """ - if values.data["model_type"] not in [NoiseModelType.GMM, NoiseModelType.HIST]: - raise ValueError( - f"Incorrect noise model {values.data['model_type']}." - f"Please refer to the documentation" # TODO add link to documentation - ) - - parameters = NoiseModelType.validate_noise_model_type( - values.data["model_type"], data - ) - return parameters - - -class HistogramNoiseModel(BaseModel): - """ - Histogram noise model. - - Attributes - ---------- - min_value : float - Minimum value in the input. - max_value : float - Maximum value in the input. - bins : int - Number of bins of the histogram. - """ - - min_value: float = Field(default=350.0, ge=0.0, le=65535.0) - max_value: float = Field(default=6500.0, ge=0.0, le=65535.0) - bins: int = Field(default=256, ge=1) - - -class GaussianMixtureNoiseModel(BaseModel): - """ - Gaussian mixture model noise model. - - Attributes - ---------- - min_signal : float - Minimum signal intensity expected in the image. - max_signal : float - Maximum signal intensity expected in the image. - weight : array - A [3*n_gaussian, n_coeff] sized array containing the values of the weights - describing the noise model. - Each gaussian contributes three parameters (mean, standard deviation and weight), - hence the number of rows in `weight` are 3*n_gaussian. - If `weight = None`, the weight array is initialized using the `min_signal` and - `max_signal` parameters. - n_gaussian: int - Number of gaussians. - n_coeff: int - Number of coefficients to describe the functional relationship between gaussian - parameters and the signal. - 2 implies a linear relationship, 3 implies a quadratic relationship and so on. - device: device - GPU device. - min_sigma: int - """ - - num_components: int = Field(default=3, ge=1) - min_value: float = Field(default=350.0, ge=0.0, le=65535.0) - max_value: float = Field(default=6500.0, ge=0.0, le=65535.0) - n_gaussian: int = Field(default=3, ge=1) - n_coeff: int = Field(default=2, ge=1) - min_sigma: int = Field(default=50, ge=1) +# from __future__ import annotations + +# from enum import Enum +# from typing import Dict, Union + +# from pydantic import BaseModel, ConfigDict, Field, field_validator + + +# class NoiseModelType(str, Enum): +# """ +# Available noise models. + +# Currently supported noise models: + +# - hist: Histogram noise model. +# - gmm: Gaussian mixture model noise model.F +# """ + +# NONE = "none" +# HIST = "hist" +# GMM = "gmm" + +# # TODO add validator decorator +# @classmethod +# def validate_noise_model_type( +# cls, noise_model: Union[str, NoiseModel], parameters: dict +# ) -> None: +# """_summary_. + +# Parameters +# ---------- +# noise_model : Union[str, NoiseModel] +# _description_ +# parameters : dict +# _description_ + +# Returns +# ------- +# BaseModel +# _description_ +# """ +# if noise_model == NoiseModelType.HIST.value: +# HistogramNoiseModel(**parameters) +# return HistogramNoiseModel().model_dump() if not parameters else parameters + +# elif noise_model == NoiseModelType.GMM.value: +# GaussianMixtureNoiseModel(**parameters) +# return ( +# GaussianMixtureNoiseModel().model_dump() +# if not parameters +# else parameters +# ) + + +# class NoiseModel(BaseModel): +# """_summary_. + +# Parameters +# ---------- +# BaseModel : _type_ +# _description_ + +# Returns +# ------- +# _type_ +# _description_ + +# Raises +# ------ +# ValueError +# _description_ +# """ + +# model_config = ConfigDict( +# use_enum_values=True, +# protected_namespaces=(), # allows to use model_* as a field name +# validate_assignment=True, +# ) + +# model_type: NoiseModelType +# parameters: Dict = Field(default_factory=dict, validate_default=True) + +# @field_validator("parameters") +# @classmethod +# def validate_parameters(cls, data, values) -> Dict: +# """_summary_. + +# Parameters +# ---------- +# parameters : Dict +# _description_ + +# Returns +# ------- +# Dict +# _description_ +# """ +# if values.data["model_type"] not in [NoiseModelType.GMM, NoiseModelType.HIST]: +# raise ValueError( +# f"Incorrect noise model {values.data['model_type']}." +# f"Please refer to the documentation" # TODO add link to documentation +# ) + +# parameters = NoiseModelType.validate_noise_model_type( +# values.data["model_type"], data +# ) +# return parameters + + +# class HistogramNoiseModel(BaseModel): +# """ +# Histogram noise model. + +# Attributes +# ---------- +# min_value : float +# Minimum value in the input. +# max_value : float +# Maximum value in the input. +# bins : int +# Number of bins of the histogram. +# """ + +# min_value: float = Field(default=350.0, ge=0.0, le=65535.0) +# max_value: float = Field(default=6500.0, ge=0.0, le=65535.0) +# bins: int = Field(default=256, ge=1) + + +# class GaussianMixtureNoiseModel(BaseModel): +# """ +# Gaussian mixture model noise model. + +# Attributes +# ---------- +# min_signal : float +# Minimum signal intensity expected in the image. +# max_signal : float +# Maximum signal intensity expected in the image. +# weight : array +# A [3*n_gaussian, n_coeff] sized array containing the values of the weights +# describing the noise model. +# Each gaussian contributes three parameters (mean, standard deviation and weight), +# hence the number of rows in `weight` are 3*n_gaussian. +# If `weight = None`, the weight array is initialized using the `min_signal` and +# `max_signal` parameters. +# n_gaussian: int +# Number of gaussians. +# n_coeff: int +# Number of coefficients to describe the functional relationship between gaussian +# parameters and the signal. +# 2 implies a linear relationship, 3 implies a quadratic relationship and so on. +# device: device +# GPU device. +# min_sigma: int +# """ + +# num_components: int = Field(default=3, ge=1) +# min_value: float = Field(default=350.0, ge=0.0, le=65535.0) +# max_value: float = Field(default=6500.0, ge=0.0, le=65535.0) +# n_gaussian: int = Field(default=3, ge=1) +# n_coeff: int = Field(default=2, ge=1) +# min_sigma: int = Field(default=50, ge=1) diff --git a/src/careamics/dataset/patching/tiled_patching.py b/src/careamics/dataset/patching/tiled_patching.py index 04489fb86..ddd97c1b0 100644 --- a/src/careamics/dataset/patching/tiled_patching.py +++ b/src/careamics/dataset/patching/tiled_patching.py @@ -8,7 +8,7 @@ def _compute_crop_and_stitch_coords_1d( axis_size: int, tile_size: int, overlap: int -) -> Tuple[List[Tuple[int, ...]], ...]: +) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]], List[Tuple[int, int]]]: """ Compute the coordinates of each tile along an axis, given the overlap. diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py index b9685c28f..8e5008c71 100644 --- a/src/careamics/lightning_datamodule.py +++ b/src/careamics/lightning_datamodule.py @@ -95,13 +95,13 @@ class CAREamicsTrainData(L.LightningDataModule): Batch size. use_in_memory : bool Whether to use in memory dataset if possible. - train_data : Union[Path, str, np.ndarray] + train_data : Union[Path, np.ndarray] Training data. - val_data : Optional[Union[Path, str, np.ndarray]] + val_data : Optional[Union[Path, np.ndarray]] Validation data. - train_data_target : Optional[Union[Path, str, np.ndarray]] + train_data_target : Optional[Union[Path, np.ndarray]] Training target data. - val_data_target : Optional[Union[Path, str, np.ndarray]] + val_data_target : Optional[Union[Path, np.ndarray]] Validation target data. val_percentage : float Percentage of the training data to use for validation, if no validation data is @@ -217,17 +217,33 @@ def __init__( ) # configuration - self.data_config = data_config - self.data_type = data_config.data_type - self.batch_size = data_config.batch_size - self.use_in_memory = use_in_memory + self.data_config: DataConfig = data_config + self.data_type: str = data_config.data_type + self.batch_size: int = data_config.batch_size + self.use_in_memory: bool = use_in_memory + + # data: make data Path or np.ndarray, use type annotations for mypy + self.train_data: Union[Path, np.ndarray] = ( + Path(train_data) if isinstance(train_data, str) else train_data + ) + + self.val_data: Union[Path, np.ndarray] = ( + Path(val_data) if isinstance(val_data, str) else val_data + ) - # data - self.train_data = train_data - self.val_data = val_data + self.train_data_target: Union[Path, np.ndarray] = ( + Path(train_data_target) + if isinstance(train_data_target, str) + else train_data_target + ) + + self.val_data_target: Union[Path, np.ndarray] = ( + Path(val_data_target) + if isinstance(val_data_target, str) + else val_data_target + ) - self.train_data_target = train_data_target - self.val_data_target = val_data_target + # validation split self.val_percentage = val_percentage self.val_minimum_split = val_minimum_split @@ -241,7 +257,7 @@ def __init__( elif data_config.data_type != SupportedData.ARRAY: self.read_source_func = get_read_func(data_config.data_type) - self.extension_filter = extension_filter + self.extension_filter: str = extension_filter # Pytorch dataloader parameters self.dataloader_params = ( diff --git a/src/careamics/losses/noise_model_factory.py b/src/careamics/losses/noise_model_factory.py index fdab1182c..56173e4b3 100644 --- a/src/careamics/losses/noise_model_factory.py +++ b/src/careamics/losses/noise_model_factory.py @@ -1,40 +1,40 @@ -from typing import Type, Union +# from typing import Type, Union -from ..config.noise_models import NoiseModel, NoiseModelType -from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel +# from ..config.noise_models import NoiseModel, NoiseModelType +# from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel -def noise_model_factory( - noise_config: NoiseModel, -) -> Type[Union[HistogramNoiseModel, GaussianMixtureNoiseModel, None]]: - """Create loss model based on Configuration. +# def noise_model_factory( +# noise_config: NoiseModel, +# ) -> Type[Union[HistogramNoiseModel, GaussianMixtureNoiseModel, None]]: +# """Create loss model based on Configuration. - Parameters - ---------- - config : Configuration - Configuration. +# Parameters +# ---------- +# config : Configuration +# Configuration. - Returns - ------- - Noise model +# Returns +# ------- +# Noise model - Raises - ------ - NotImplementedError - If the noise model is unknown. - """ - noise_model_type = noise_config.model_type if noise_config else None +# Raises +# ------ +# NotImplementedError +# If the noise model is unknown. +# """ +# noise_model_type = noise_config.model_type if noise_config else None - if noise_model_type == NoiseModelType.HIST: - return HistogramNoiseModel +# if noise_model_type == NoiseModelType.HIST: +# return HistogramNoiseModel - elif noise_model_type == NoiseModelType.GMM: - return GaussianMixtureNoiseModel +# elif noise_model_type == NoiseModelType.GMM: +# return GaussianMixtureNoiseModel - elif noise_model_type is None: - return None +# elif noise_model_type is None: +# return None - else: - raise NotImplementedError( - f"Noise model {noise_model_type} is not yet supported." - ) +# else: +# raise NotImplementedError( +# f"Noise model {noise_model_type} is not yet supported." +# ) diff --git a/src/careamics/losses/noise_models.py b/src/careamics/losses/noise_models.py index 5f4fc8ef2..f43906a93 100644 --- a/src/careamics/losses/noise_models.py +++ b/src/careamics/losses/noise_models.py @@ -1,524 +1,524 @@ -from abc import ABC, abstractmethod - -import numpy as np -import torch - -from ..utils.logging import get_logger - -logger = get_logger(__name__) - - -# TODO here "Model" clashes a bit with the naming convention of the Pydantic Models -class NoiseModel(ABC): - """Base class for noise models.""" - - @abstractmethod - def instantiate(self): - """Instantiate the noise model. - - Method that should produce ready to use noise model. - """ - pass - - @abstractmethod - def likelihood(self, observations, signals): - """Function that returns the likelihood of observations given the signals.""" - pass - - -class HistogramNoiseModel(NoiseModel): - """Creates a NoiseModel object. - - Parameters - ---------- - histogram: numpy array - A histogram as create by the 'createHistogram(...)' method. - device: - The device your NoiseModel lives on, e.g. your GPU. - """ - - def __init__(self, **kwargs): - pass - - def instantiate(self, bins, min_value, max_value, observation, signal): - """Creates a nD histogram from 'observation' and 'signal'. - - Parameters - ---------- - bins: int - The number of bins in all dimensions. The total number of bins is - 'bins' ** number_of_dimensions. - min_value: float - the lower bound of the lowest bin. - max_value: float - the highest bound of the highest bin. - observation: np.array - A stack of noisy images. The number has to be divisible by the number of - images in signal. N subsequent images in observation belong to one image - in the signal. - signal: np.array - A stack of clean images. - - Returns - ------- - histogram: numpy array - A 3D array: - 'histogram[0,...]' holds the normalized nD counts. - Each row sums to 1, describing p(x_i|s_i). - 'histogram[1,...]' holds the lower boundaries of each bin in y. - 'histogram[2,...]' holds the upper boundaries of each bin in y. - The values for x can be obtained by transposing 'histogram[1,...]' - and 'histogram[2,...]'. - """ - img_factor = int(observation.shape[0] / signal.shape[0]) - histogram = np.zeros((3, bins, bins)) - value_range = [min_value, max_value] - - for i in range(observation.shape[0]): - observation_i = observation[i].copy().ravel() - - signal_i = (signal[i // img_factor].copy()).ravel() - - histogram_i = np.histogramdd( - (signal_i, observation_i), bins=bins, range=[value_range, value_range] - ) - # Adding a constant for numerical stability - histogram[0] = histogram[0] + histogram_i[0] + 1e-30 - - for i in range(bins): - # Exclude empty rows from normalization - if np.sum(histogram[0, i, :]) > 1e-20: - # Normalize each non-empty row - histogram[0, i, :] /= np.sum(histogram[0, i, :]) - - for i in range(bins): - # The lower boundaries of each bin in y are stored in dimension 1 - histogram[1, :, i] = histogram_i[1][:-1] - # The upper boundaries of each bin in y are stored in dimension 2 - histogram[2, :, i] = histogram_i[1][1:] - # The accordent numbers for x are just transposed. - - return histogram - - def likelihood(self, observed, signal): - """Calculate the likelihood using a histogram based noise model. - - For every pixel in a tensor, calculate (x_i|s_i). To ensure differentiability - in the direction of s_i, we linearly interpolate in this direction. - - Parameters - ---------- - observed: torch.Tensor - tensor holding your observed intesities x_i. - - signal: torch.Tensor - tensor holding hypotheses for the clean signal at every pixel s_i^k. - - Returns - ------- - Torch.tensor containing the observation likelihoods according to the - noise model. - """ - observed_float = self.get_index_observed_float(observed) - observed_long = observed_float.floor().long() - signal_float = self.get_index_signal_float(signal) - signal_long = signal_float.floor().long() - fact = signal_float - signal_long.float() - - # Finally we are looking ud the values and interpolate - return self.fullHist[signal_long, observed_long] * (1.0 - fact) + self.fullHist[ - torch.clamp((signal_long + 1).long(), 0, self.bins.long()), observed_long - ] * (fact) - - def get_index_observed_float(self, x: float): - """_summary_. - - Parameters - ---------- - x : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - """ - return torch.clamp( - self.bins * (x - self.minv) / (self.maxv - self.minv), - min=0.0, - max=self.bins - 1 - 1e-3, - ) - - def get_index_signal_float(self, x): - """_summary_. - - Parameters - ---------- - x : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - """ - return torch.clamp( - self.bins * (x - self.minv) / (self.maxv - self.minv), - min=0.0, - max=self.bins - 1 - 1e-3, - ) - - -# TODO refactor this into Pydantic model -class GaussianMixtureNoiseModel(NoiseModel): - """Describes a noise model parameterized as a mixture of gaussians. - - If you would like to initialize a new object from scratch, then set `params` = None - and specify the other parameters as keyword arguments. If you are instead loading - a model, use only `params`. - - Parameters - ---------- - **kwargs: keyworded, variable-length argument dictionary. - Arguments include: - min_signal : float - Minimum signal intensity expected in the image. - max_signal : float - Maximum signal intensity expected in the image. - weight : array - A [3*n_gaussian, n_coeff] sized array containing the values of the weights - describing the noise model. - Each gaussian contributes three parameters (mean, standard deviation and weight), - hence the number of rows in `weight` are 3*n_gaussian. - If `weight = None`, the weight array is initialized using the `min_signal` and - `max_signal` parameters. - n_gaussian: int - Number of gaussians. - n_coeff: int - Number of coefficients to describe the functional relationship between gaussian - parameters and the signal. - 2 implies a linear relationship, 3 implies a quadratic relationship and so on. - device: device - GPU device. - min_sigma: int - All values of sigma (`standard deviation`) below min_sigma are clamped to become - equal to min_sigma. - params: dictionary - Use `params` if one wishes to load a model with trained weights. - While initializing a new object of the class `GaussianMixtureNoiseModel` from - scratch, set this to `None`. - """ - - def __init__(self, **kwargs): - if kwargs.get("params") is None: - weight = kwargs.get("weight") - n_gaussian = kwargs.get("n_gaussian") - n_coeff = kwargs.get("n_coeff") - min_signal = kwargs.get("min_signal") - max_signal = kwargs.get("max_signal") - self.device = kwargs.get("device") - self.path = kwargs.get("path") - self.min_sigma = kwargs.get("min_sigma") - if weight is None: - weight = np.random.randn(n_gaussian * 3, n_coeff) - weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal) - weight = ( - torch.from_numpy(weight.astype(np.float32)).float().to(self.device) - ) - weight.requires_grad = True - self.n_gaussian = weight.shape[0] // 3 - self.n_coeff = weight.shape[1] - self.weight = weight - self.min_signal = torch.Tensor([min_signal]).to(self.device) - self.max_signal = torch.Tensor([max_signal]).to(self.device) - self.tol = torch.Tensor([1e-10]).to(self.device) - else: - params = kwargs.get("params") - self.device = kwargs.get("device") - - self.min_signal = torch.Tensor(params["min_signal"]).to(self.device) - self.max_signal = torch.Tensor(params["max_signal"]).to(self.device) - - self.weight = torch.Tensor(params["trained_weight"]).to(self.device) - self.min_sigma = np.ndarray.item(params["min_sigma"]) - self.n_gaussian = self.weight.shape[0] // 3 - self.n_coeff = self.weight.shape[1] - self.tol = torch.Tensor([1e-10]).to(self.device) - self.min_signal = torch.Tensor([self.min_signal]).to(self.device) - self.max_signal = torch.Tensor([self.max_signal]).to(self.device) - - def fast_shuffle(self, series, num): - """. - - Parameters - ---------- - series : _type_ - _description_ - num : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - """ - length = series.shape[0] - for _i in range(num): - series = series[np.random.permutation(length), :] - return series - - def polynomial_regressor(self, weightParams, signals): - """Combines weight_parameters and signals to perform regression. - - Parameters - ---------- - weightParams : torch.cuda.FloatTensor - Corresponds to specific rows of the `self.weight' - - signals : torch.cuda.FloatTensor - Signals - - Returns - ------- - value : torch.cuda.FloatTensor - Corresponds to either of mean, standard deviation or weight, evaluated at - `signals` - """ - value = 0 - for i in range(weightParams.shape[0]): - value += weightParams[i] * ( - ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i - ) - return value - - def normal_density(self, x, m_=0.0, std_=None): - """Evaluates the normal probability density. - - Parameters - ---------- - x: torch.cuda.FloatTensor - Observations - m_: torch.cuda.FloatTensor - Mean - std_: torch.cuda.FloatTensor - Standard-deviation - - Returns - ------- - tmp: torch.cuda.FloatTensor - Normal probability density of `x` given `m_` and `std_` - - """ - tmp = -((x - m_) ** 2) - tmp = tmp / (2.0 * std_ * std_) - tmp = torch.exp(tmp) - tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_) - return tmp - - def likelihood(self, observations, signals): - """Evaluates the likelihood of observations. - - Given the signals and the corresponding gaussian parameters evaluates the - likelihood of observations. - - Parameters - ---------- - observations : torch.cuda.FloatTensor - Noisy observations - signals : torch.cuda.FloatTensor - Underlying signals - - Returns - ------- - value :p + self.tol - Likelihood of observations given the signals and the GMM noise model - - """ - gaussianParameters = self.getGaussianParameters(signals) - p = 0 - for gaussian in range(self.n_gaussian): - p += ( - self.normalDens( - observations, - gaussianParameters[gaussian], - gaussianParameters[self.n_gaussian + gaussian], - ) - * gaussianParameters[2 * self.n_gaussian + gaussian] - ) - return p + self.tol - - def get_gaussian_parameters(self, signals): - """Returns the noise model for given signals. - - Parameters - ---------- - signals : torch.cuda.FloatTensor - Underlying signals - - Returns - ------- - noiseModel: list of torch.cuda.FloatTensor - Contains a list of `mu`, `sigma` and `alpha` for the `signals` - - """ - noiseModel = [] - mu = [] - sigma = [] - alpha = [] - kernels = self.weight.shape[0] // 3 - for num in range(kernels): - mu.append(self.polynomialRegressor(self.weight[num, :], signals)) - - sigmaTemp = self.polynomialRegressor( - torch.exp(self.weight[kernels + num, :]), signals - ) - sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma) - sigma.append(torch.sqrt(sigmaTemp)) - alpha.append( - torch.exp( - self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) - + self.tol - ) - ) - - sum_alpha = 0 - for al in range(kernels): - sum_alpha = alpha[al] + sum_alpha - for ker in range(kernels): - alpha[ker] = alpha[ker] / sum_alpha - - sum_means = 0 - for ker in range(kernels): - sum_means = alpha[ker] * mu[ker] + sum_means - - for ker in range(kernels): - mu[ker] = mu[ker] - sum_means + signals - - for i in range(kernels): - noiseModel.append(mu[i]) - for j in range(kernels): - noiseModel.append(sigma[j]) - for k in range(kernels): - noiseModel.append(alpha[k]) - - return noiseModel - - def get_signal_observation_pairs(self, signal, observation, lowerClip, upperClip): - """Returns the Signal-Observation pixel intensities as a two-column array. - - Parameters - ---------- - signal : numpy array - Clean Signal Data - observation: numpy array - Noisy observation Data - lowerClip: float - Lower percentile bound for clipping. - upperClip: float - Upper percentile bound for clipping. - - Returns - ------- - noiseModel: list of torch floats - Contains a list of `mu`, `sigma` and `alpha` for the `signals` - - """ - lb = np.percentile(signal, lowerClip) - ub = np.percentile(signal, upperClip) - stepsize = observation[0].size - n_observations = observation.shape[0] - n_signals = signal.shape[0] - sig_obs_pairs = np.zeros((n_observations * stepsize, 2)) - - for i in range(n_observations): - j = i // (n_observations // n_signals) - sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel() - sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel() - sig_obs_pairs = sig_obs_pairs[ - (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub) - ] - return self.fast_shuffle(sig_obs_pairs, 2) - - def train( - self, - signal, - observation, - learning_rate=1e-1, - batchSize=250000, - n_epochs=2000, - name="GMMNoiseModel.npz", - lowerClip=0, - upperClip=100, - ): - """Training to learn the noise model from signal - observation pairs. - - Parameters - ---------- - signal: numpy array - Clean Signal Data - observation: numpy array - Noisy Observation Data - learning_rate: float - Learning rate. Default = 1e-1. - batchSize: int - Nini-batch size. Default = 250000. - n_epochs: int - Number of epochs. Default = 2000. - name: string - Model name. Default is `GMMNoiseModel`. This model after being trained is - saved at the location `path`. - - lowerClip : int - Lower percentile for clipping. Default is 0. - upperClip : int - Upper percentile for clipping. Default is 100. - - - """ - sig_obs_pairs = self.getSignalObservationPairs( - signal, observation, lowerClip, upperClip - ) - counter = 0 - optimizer = torch.optim.Adam([self.weight], lr=learning_rate) - for t in range(n_epochs): - jointLoss = 0 - if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]: - counter = 0 - sig_obs_pairs = self.fast_shuffle(sig_obs_pairs, 1) - - batch_vectors = sig_obs_pairs[ - counter * batchSize : (counter + 1) * batchSize, : - ] - observations = batch_vectors[:, 1].astype(np.float32) - signals = batch_vectors[:, 0].astype(np.float32) - observations = ( - torch.from_numpy(observations.astype(np.float32)) - .float() - .to(self.device) - ) - signals = torch.from_numpy(signals).float().to(self.device) - p = self.likelihood(observations, signals) - loss = torch.mean(-torch.log(p)) - jointLoss = jointLoss + loss - - if t % 100 == 0: - print(t, jointLoss.item()) - - if t % (int(n_epochs * 0.5)) == 0: - trained_weight = self.weight.cpu().detach().numpy() - min_signal = self.min_signal.cpu().detach().numpy() - max_signal = self.max_signal.cpu().detach().numpy() - np.savez( - self.path + name, - trained_weight=trained_weight, - min_signal=min_signal, - max_signal=max_signal, - min_sigma=self.min_sigma, - ) - - optimizer.zero_grad() - jointLoss.backward() - optimizer.step() - counter += 1 - - logger.info(f"The trained parameters {name} is saved at location: " + self.path) +# from abc import ABC, abstractmethod + +# import numpy as np +# import torch + +# from ..utils.logging import get_logger + +# logger = get_logger(__name__) + + +# # TODO here "Model" clashes a bit with the naming convention of the Pydantic Models +# class NoiseModel(ABC): +# """Base class for noise models.""" + +# @abstractmethod +# def instantiate(self): +# """Instantiate the noise model. + +# Method that should produce ready to use noise model. +# """ +# pass + +# @abstractmethod +# def likelihood(self, observations, signals): +# """Function that returns the likelihood of observations given the signals.""" +# pass + + +# class HistogramNoiseModel(NoiseModel): +# """Creates a NoiseModel object. + +# Parameters +# ---------- +# histogram: numpy array +# A histogram as create by the 'createHistogram(...)' method. +# device: +# The device your NoiseModel lives on, e.g. your GPU. +# """ + +# def __init__(self, **kwargs): +# pass + +# def instantiate(self, bins, min_value, max_value, observation, signal): +# """Creates a nD histogram from 'observation' and 'signal'. + +# Parameters +# ---------- +# bins: int +# The number of bins in all dimensions. The total number of bins is +# 'bins' ** number_of_dimensions. +# min_value: float +# the lower bound of the lowest bin. +# max_value: float +# the highest bound of the highest bin. +# observation: np.array +# A stack of noisy images. The number has to be divisible by the number of +# images in signal. N subsequent images in observation belong to one image +# in the signal. +# signal: np.array +# A stack of clean images. + +# Returns +# ------- +# histogram: numpy array +# A 3D array: +# 'histogram[0,...]' holds the normalized nD counts. +# Each row sums to 1, describing p(x_i|s_i). +# 'histogram[1,...]' holds the lower boundaries of each bin in y. +# 'histogram[2,...]' holds the upper boundaries of each bin in y. +# The values for x can be obtained by transposing 'histogram[1,...]' +# and 'histogram[2,...]'. +# """ +# img_factor = int(observation.shape[0] / signal.shape[0]) +# histogram = np.zeros((3, bins, bins)) +# value_range = [min_value, max_value] + +# for i in range(observation.shape[0]): +# observation_i = observation[i].copy().ravel() + +# signal_i = (signal[i // img_factor].copy()).ravel() + +# histogram_i = np.histogramdd( +# (signal_i, observation_i), bins=bins, range=[value_range, value_range] +# ) +# # Adding a constant for numerical stability +# histogram[0] = histogram[0] + histogram_i[0] + 1e-30 + +# for i in range(bins): +# # Exclude empty rows from normalization +# if np.sum(histogram[0, i, :]) > 1e-20: +# # Normalize each non-empty row +# histogram[0, i, :] /= np.sum(histogram[0, i, :]) + +# for i in range(bins): +# # The lower boundaries of each bin in y are stored in dimension 1 +# histogram[1, :, i] = histogram_i[1][:-1] +# # The upper boundaries of each bin in y are stored in dimension 2 +# histogram[2, :, i] = histogram_i[1][1:] +# # The accordent numbers for x are just transposed. + +# return histogram + +# def likelihood(self, observed, signal): +# """Calculate the likelihood using a histogram based noise model. + +# For every pixel in a tensor, calculate (x_i|s_i). To ensure differentiability +# in the direction of s_i, we linearly interpolate in this direction. + +# Parameters +# ---------- +# observed: torch.Tensor +# tensor holding your observed intesities x_i. + +# signal: torch.Tensor +# tensor holding hypotheses for the clean signal at every pixel s_i^k. + +# Returns +# ------- +# Torch.tensor containing the observation likelihoods according to the +# noise model. +# """ +# observed_float = self.get_index_observed_float(observed) +# observed_long = observed_float.floor().long() +# signal_float = self.get_index_signal_float(signal) +# signal_long = signal_float.floor().long() +# fact = signal_float - signal_long.float() + +# # Finally we are looking ud the values and interpolate +# return self.fullHist[signal_long, observed_long] * (1.0 - fact) + self.fullHist[ +# torch.clamp((signal_long + 1).long(), 0, self.bins.long()), observed_long +# ] * (fact) + +# def get_index_observed_float(self, x: float): +# """_summary_. + +# Parameters +# ---------- +# x : _type_ +# _description_ + +# Returns +# ------- +# _type_ +# _description_ +# """ +# return torch.clamp( +# self.bins * (x - self.minv) / (self.maxv - self.minv), +# min=0.0, +# max=self.bins - 1 - 1e-3, +# ) + +# def get_index_signal_float(self, x): +# """_summary_. + +# Parameters +# ---------- +# x : _type_ +# _description_ + +# Returns +# ------- +# _type_ +# _description_ +# """ +# return torch.clamp( +# self.bins * (x - self.minv) / (self.maxv - self.minv), +# min=0.0, +# max=self.bins - 1 - 1e-3, +# ) + + +# # TODO refactor this into Pydantic model +# class GaussianMixtureNoiseModel(NoiseModel): +# """Describes a noise model parameterized as a mixture of gaussians. + +# If you would like to initialize a new object from scratch, then set `params` = None +# and specify the other parameters as keyword arguments. If you are instead loading +# a model, use only `params`. + +# Parameters +# ---------- +# **kwargs: keyworded, variable-length argument dictionary. +# Arguments include: +# min_signal : float +# Minimum signal intensity expected in the image. +# max_signal : float +# Maximum signal intensity expected in the image. +# weight : array +# A [3*n_gaussian, n_coeff] sized array containing the values of the weights +# describing the noise model. +# Each gaussian contributes three parameters (mean, standard deviation and weight), +# hence the number of rows in `weight` are 3*n_gaussian. +# If `weight = None`, the weight array is initialized using the `min_signal` and +# `max_signal` parameters. +# n_gaussian: int +# Number of gaussians. +# n_coeff: int +# Number of coefficients to describe the functional relationship between gaussian +# parameters and the signal. +# 2 implies a linear relationship, 3 implies a quadratic relationship and so on. +# device: device +# GPU device. +# min_sigma: int +# All values of sigma (`standard deviation`) below min_sigma are clamped to become +# equal to min_sigma. +# params: dictionary +# Use `params` if one wishes to load a model with trained weights. +# While initializing a new object of the class `GaussianMixtureNoiseModel` from +# scratch, set this to `None`. +# """ + +# def __init__(self, **kwargs): +# if kwargs.get("params") is None: +# weight = kwargs.get("weight") +# n_gaussian = kwargs.get("n_gaussian") +# n_coeff = kwargs.get("n_coeff") +# min_signal = kwargs.get("min_signal") +# max_signal = kwargs.get("max_signal") +# self.device = kwargs.get("device") +# self.path = kwargs.get("path") +# self.min_sigma = kwargs.get("min_sigma") +# if weight is None: +# weight = np.random.randn(n_gaussian * 3, n_coeff) +# weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal) +# weight = ( +# torch.from_numpy(weight.astype(np.float32)).float().to(self.device) +# ) +# weight.requires_grad = True +# self.n_gaussian = weight.shape[0] // 3 +# self.n_coeff = weight.shape[1] +# self.weight = weight +# self.min_signal = torch.Tensor([min_signal]).to(self.device) +# self.max_signal = torch.Tensor([max_signal]).to(self.device) +# self.tol = torch.Tensor([1e-10]).to(self.device) +# else: +# params = kwargs.get("params") +# self.device = kwargs.get("device") + +# self.min_signal = torch.Tensor(params["min_signal"]).to(self.device) +# self.max_signal = torch.Tensor(params["max_signal"]).to(self.device) + +# self.weight = torch.Tensor(params["trained_weight"]).to(self.device) +# self.min_sigma = np.ndarray.item(params["min_sigma"]) +# self.n_gaussian = self.weight.shape[0] // 3 +# self.n_coeff = self.weight.shape[1] +# self.tol = torch.Tensor([1e-10]).to(self.device) +# self.min_signal = torch.Tensor([self.min_signal]).to(self.device) +# self.max_signal = torch.Tensor([self.max_signal]).to(self.device) + +# def fast_shuffle(self, series, num): +# """. + +# Parameters +# ---------- +# series : _type_ +# _description_ +# num : _type_ +# _description_ + +# Returns +# ------- +# _type_ +# _description_ +# """ +# length = series.shape[0] +# for _i in range(num): +# series = series[np.random.permutation(length), :] +# return series + +# def polynomial_regressor(self, weightParams, signals): +# """Combines weight_parameters and signals to perform regression. + +# Parameters +# ---------- +# weightParams : torch.cuda.FloatTensor +# Corresponds to specific rows of the `self.weight' + +# signals : torch.cuda.FloatTensor +# Signals + +# Returns +# ------- +# value : torch.cuda.FloatTensor +# Corresponds to either of mean, standard deviation or weight, evaluated at +# `signals` +# """ +# value = 0 +# for i in range(weightParams.shape[0]): +# value += weightParams[i] * ( +# ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i +# ) +# return value + +# def normal_density(self, x, m_=0.0, std_=None): +# """Evaluates the normal probability density. + +# Parameters +# ---------- +# x: torch.cuda.FloatTensor +# Observations +# m_: torch.cuda.FloatTensor +# Mean +# std_: torch.cuda.FloatTensor +# Standard-deviation + +# Returns +# ------- +# tmp: torch.cuda.FloatTensor +# Normal probability density of `x` given `m_` and `std_` + +# """ +# tmp = -((x - m_) ** 2) +# tmp = tmp / (2.0 * std_ * std_) +# tmp = torch.exp(tmp) +# tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_) +# return tmp + +# def likelihood(self, observations, signals): +# """Evaluates the likelihood of observations. + +# Given the signals and the corresponding gaussian parameters evaluates the +# likelihood of observations. + +# Parameters +# ---------- +# observations : torch.cuda.FloatTensor +# Noisy observations +# signals : torch.cuda.FloatTensor +# Underlying signals + +# Returns +# ------- +# value :p + self.tol +# Likelihood of observations given the signals and the GMM noise model + +# """ +# gaussianParameters = self.getGaussianParameters(signals) +# p = 0 +# for gaussian in range(self.n_gaussian): +# p += ( +# self.normalDens( +# observations, +# gaussianParameters[gaussian], +# gaussianParameters[self.n_gaussian + gaussian], +# ) +# * gaussianParameters[2 * self.n_gaussian + gaussian] +# ) +# return p + self.tol + +# def get_gaussian_parameters(self, signals): +# """Returns the noise model for given signals. + +# Parameters +# ---------- +# signals : torch.cuda.FloatTensor +# Underlying signals + +# Returns +# ------- +# noiseModel: list of torch.cuda.FloatTensor +# Contains a list of `mu`, `sigma` and `alpha` for the `signals` + +# """ +# noiseModel = [] +# mu = [] +# sigma = [] +# alpha = [] +# kernels = self.weight.shape[0] // 3 +# for num in range(kernels): +# mu.append(self.polynomialRegressor(self.weight[num, :], signals)) + +# sigmaTemp = self.polynomialRegressor( +# torch.exp(self.weight[kernels + num, :]), signals +# ) +# sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma) +# sigma.append(torch.sqrt(sigmaTemp)) +# alpha.append( +# torch.exp( +# self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) +# + self.tol +# ) +# ) + +# sum_alpha = 0 +# for al in range(kernels): +# sum_alpha = alpha[al] + sum_alpha +# for ker in range(kernels): +# alpha[ker] = alpha[ker] / sum_alpha + +# sum_means = 0 +# for ker in range(kernels): +# sum_means = alpha[ker] * mu[ker] + sum_means + +# for ker in range(kernels): +# mu[ker] = mu[ker] - sum_means + signals + +# for i in range(kernels): +# noiseModel.append(mu[i]) +# for j in range(kernels): +# noiseModel.append(sigma[j]) +# for k in range(kernels): +# noiseModel.append(alpha[k]) + +# return noiseModel + +# def get_signal_observation_pairs(self, signal, observation, lowerClip, upperClip): +# """Returns the Signal-Observation pixel intensities as a two-column array. + +# Parameters +# ---------- +# signal : numpy array +# Clean Signal Data +# observation: numpy array +# Noisy observation Data +# lowerClip: float +# Lower percentile bound for clipping. +# upperClip: float +# Upper percentile bound for clipping. + +# Returns +# ------- +# noiseModel: list of torch floats +# Contains a list of `mu`, `sigma` and `alpha` for the `signals` + +# """ +# lb = np.percentile(signal, lowerClip) +# ub = np.percentile(signal, upperClip) +# stepsize = observation[0].size +# n_observations = observation.shape[0] +# n_signals = signal.shape[0] +# sig_obs_pairs = np.zeros((n_observations * stepsize, 2)) + +# for i in range(n_observations): +# j = i // (n_observations // n_signals) +# sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel() +# sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel() +# sig_obs_pairs = sig_obs_pairs[ +# (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub) +# ] +# return self.fast_shuffle(sig_obs_pairs, 2) + +# def train( +# self, +# signal, +# observation, +# learning_rate=1e-1, +# batchSize=250000, +# n_epochs=2000, +# name="GMMNoiseModel.npz", +# lowerClip=0, +# upperClip=100, +# ): +# """Training to learn the noise model from signal - observation pairs. + +# Parameters +# ---------- +# signal: numpy array +# Clean Signal Data +# observation: numpy array +# Noisy Observation Data +# learning_rate: float +# Learning rate. Default = 1e-1. +# batchSize: int +# Nini-batch size. Default = 250000. +# n_epochs: int +# Number of epochs. Default = 2000. +# name: string +# Model name. Default is `GMMNoiseModel`. This model after being trained is +# saved at the location `path`. + +# lowerClip : int +# Lower percentile for clipping. Default is 0. +# upperClip : int +# Upper percentile for clipping. Default is 100. + + +# """ +# sig_obs_pairs = self.getSignalObservationPairs( +# signal, observation, lowerClip, upperClip +# ) +# counter = 0 +# optimizer = torch.optim.Adam([self.weight], lr=learning_rate) +# for t in range(n_epochs): +# jointLoss = 0 +# if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]: +# counter = 0 +# sig_obs_pairs = self.fast_shuffle(sig_obs_pairs, 1) + +# batch_vectors = sig_obs_pairs[ +# counter * batchSize : (counter + 1) * batchSize, : +# ] +# observations = batch_vectors[:, 1].astype(np.float32) +# signals = batch_vectors[:, 0].astype(np.float32) +# observations = ( +# torch.from_numpy(observations.astype(np.float32)) +# .float() +# .to(self.device) +# ) +# signals = torch.from_numpy(signals).float().to(self.device) +# p = self.likelihood(observations, signals) +# loss = torch.mean(-torch.log(p)) +# jointLoss = jointLoss + loss + +# if t % 100 == 0: +# print(t, jointLoss.item()) + +# if t % (int(n_epochs * 0.5)) == 0: +# trained_weight = self.weight.cpu().detach().numpy() +# min_signal = self.min_signal.cpu().detach().numpy() +# max_signal = self.max_signal.cpu().detach().numpy() +# np.savez( +# self.path + name, +# trained_weight=trained_weight, +# min_signal=min_signal, +# max_signal=max_signal, +# min_sigma=self.min_sigma, +# ) + +# optimizer.zero_grad() +# jointLoss.backward() +# optimizer.step() +# counter += 1 + +# logger.info(f"The trained parameters {name} is saved at location: " + self.path) diff --git a/src/careamics/transforms/compose.py b/src/careamics/transforms/compose.py index 2993ea340..30a34ba2e 100644 --- a/src/careamics/transforms/compose.py +++ b/src/careamics/transforms/compose.py @@ -1,6 +1,6 @@ """A class chaining transforms together.""" -from typing import Callable, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import numpy as np @@ -20,7 +20,7 @@ } -def get_all_transforms() -> dict: +def get_all_transforms() -> Dict[str, type]: """Return all the transforms accepted by CAREamics. Returns @@ -68,7 +68,7 @@ def _chain_transforms(self, transforms: List[Transform]) -> Callable: def _chain( patch: np.ndarray, target: Optional[np.ndarray] - ) -> Tuple[np.ndarray, ...]: + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: params = (patch, target) for t in transforms: diff --git a/src/careamics/transforms/pixel_manipulation.py b/src/careamics/transforms/pixel_manipulation.py index 4cbfd9253..280a974f4 100644 --- a/src/careamics/transforms/pixel_manipulation.py +++ b/src/careamics/transforms/pixel_manipulation.py @@ -5,7 +5,7 @@ masked pixels. """ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import numpy as np @@ -98,7 +98,7 @@ def _odd_jitter_func(step: float, rng: np.random.Generator) -> np.ndarray: def _get_stratified_coords( - mask_pixel_perc: float, shape: Union[Tuple[int, int], Tuple[int, int, int]] + mask_pixel_perc: float, shape: Tuple[int, ...] ) -> np.ndarray: """ Generate coordinates of the pixels to mask. diff --git a/src/careamics/transforms/transform.py b/src/careamics/transforms/transform.py index 7d1c180ba..9798355df 100644 --- a/src/careamics/transforms/transform.py +++ b/src/careamics/transforms/transform.py @@ -1,33 +1,11 @@ """A general parent class for transforms.""" -from typing import Optional, Tuple - -import numpy as np +from typing import Any class Transform: """A general parent class for transforms.""" - def __call__( - self, patch: np.ndarray, target: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, ...]: - """Apply the transform to the input data. - - Parameters - ---------- - patch : np.ndarray - The input data to transform. - target : Optional[np.ndarray], optional - The target data to transform, by default None - - Returns - ------- - Tuple[np.ndarray, ...] - The output of the transformations. - - Raises - ------ - NotImplementedError - This method should be implemented in the child class. - """ - raise NotImplementedError + def __call__(self, *args: Any, **kwwargs: Any) -> Any: + """Apply the transform.""" + pass From da94cb416b42dd1b177e2fd41ae99b36e8ee35e2 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 28 May 2024 13:20:22 +0200 Subject: [PATCH 05/13] Fix one error --- src/careamics/dataset/in_memory_dataset.py | 48 +++++++++---------- src/careamics/dataset/iterable_dataset.py | 8 +++- src/careamics/lightning_datamodule.py | 8 ++-- .../lightning_prediction_datamodule.py | 5 +- tests/dataset/test_in_memory_dataset.py | 6 +-- 5 files changed, 40 insertions(+), 35 deletions(-) diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index fb862f347..7c244a157 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -33,7 +33,7 @@ def __init__( self, data_config: DataConfig, inputs: Union[np.ndarray, List[Path]], - data_target: Optional[Union[np.ndarray, List[Path]]] = None, + input_target: Optional[Union[np.ndarray, List[Path]]] = None, read_source_func: Callable = read_tiff, **kwargs: Any, ) -> None: @@ -44,7 +44,7 @@ def __init__( """ self.data_config = data_config self.inputs = inputs - self.data_target = data_target + self.input_targets = input_target self.axes = self.data_config.axes self.patch_size = self.data_config.patch_size @@ -52,11 +52,11 @@ def __init__( self.read_source_func = read_source_func # Generate patches - supervised = self.data_target is not None - patches = self._prepare_patches(supervised) + supervised = self.input_targets is not None + patch_data = self._prepare_patches(supervised) # Add results to members - self.data, self.data_targets, computed_mean, computed_std = patches + self.patches, self.patch_targets, computed_mean, computed_std = patch_data if not self.data_config.mean or not self.data_config.std: self.mean, self.std = computed_mean, computed_std @@ -91,18 +91,18 @@ def _prepare_patches( """ if supervised: if isinstance(self.inputs, np.ndarray) and isinstance( - self.data_target, np.ndarray + self.input_targets, np.ndarray ): return prepare_patches_supervised_array( self.inputs, self.axes, - self.data_target, + self.input_targets, self.patch_size, ) - elif isinstance(self.inputs, list) and isinstance(self.data_target, list): + elif isinstance(self.inputs, list) and isinstance(self.input_targets, list): return prepare_patches_supervised( self.inputs, - self.data_target, + self.input_targets, self.axes, self.patch_size, self.read_source_func, @@ -111,7 +111,7 @@ def _prepare_patches( raise ValueError( f"Data and target must be of the same type, either both numpy " f"arrays or both lists of paths, got {type(self.inputs)} (data) " - f"and {type(self.data_target)} (target)." + f"and {type(self.input_targets)} (target)." ) else: if isinstance(self.inputs, np.ndarray): @@ -137,9 +137,9 @@ def __len__(self) -> int: int Length of the dataset. """ - return len(self.data) + return len(self.patches) - def __getitem__(self, index: int) -> Tuple[np.ndarray]: + def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]: """ Return the patch corresponding to the provided index. @@ -158,12 +158,12 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray]: ValueError If dataset mean and std are not set. """ - patch = self.data[index] + patch = self.patches[index] # if there is a target - if self.data_target is not None: + if self.patch_targets: # get target - target = self.data_targets[index] + target = self.patch_targets[index] return self.patch_transform(patch=patch, target=target) @@ -223,25 +223,25 @@ def split_dataset( indices = np.random.choice(total_patches, n_patches, replace=False) # extract patches - val_patches = self.data[indices] + val_patches = self.patches[indices] # remove patches from self.patch - self.data = np.delete(self.data, indices, axis=0) + self.patches = np.delete(self.patches, indices, axis=0) # same for targets - if self.data_targets is not None: - val_targets = self.data_targets[indices] - self.data_targets = np.delete(self.data_targets, indices, axis=0) + if self.patch_targets is not None: + val_targets = self.patch_targets[indices] + self.patch_targets = np.delete(self.patch_targets, indices, axis=0) # clone the dataset dataset = copy.deepcopy(self) # reassign patches - dataset.data = val_patches + dataset.patches = val_patches # reassign targets - if self.data_targets is not None: - dataset.data_targets = val_targets + if self.patch_targets is not None: + dataset.patch_targets = val_targets return dataset @@ -310,7 +310,7 @@ def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: # reshape array reshaped_sample = reshape_array(self.input_array, self.axes) - if self.tiling: + if self.tiling and self.tile_size is not None and self.tile_overlap is not None: # generate patches, which returns a generator patch_generator = extract_tiles( arr=reshaped_sample, diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 465e53e1a..b5e48dc01 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -147,7 +147,7 @@ def _iterate_over_files( def __iter__( self, - ) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]: + ) -> Generator[Tuple[np.ndarray, ...], None, None]: """ Iterate over data source and yield single patch. @@ -339,7 +339,11 @@ def __iter__( # reshape array reshaped_sample = reshape_array(sample, self.axes) - if self.tile: + if ( + self.tile + and self.tile_size is not None + and self.tile_overlap is not None + ): # generate patches, return a generator patch_gen = extract_tiles( arr=reshaped_sample, diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py index 8e5008c71..695ded065 100644 --- a/src/careamics/lightning_datamodule.py +++ b/src/careamics/lightning_datamodule.py @@ -329,7 +329,7 @@ def setup(self, *args: Any, **kwargs: Any) -> None: self.train_dataset: DatasetType = InMemoryDataset( data_config=self.data_config, inputs=self.train_data, - data_target=self.train_data_target, + input_target=self.train_data_target, ) # validation dataset @@ -338,7 +338,7 @@ def setup(self, *args: Any, **kwargs: Any) -> None: self.val_dataset: DatasetType = InMemoryDataset( data_config=self.data_config, inputs=self.val_data, - data_target=self.val_data_target, + input_target=self.val_data_target, ) else: # extract validation from the training patches @@ -357,7 +357,7 @@ def setup(self, *args: Any, **kwargs: Any) -> None: self.train_dataset = InMemoryDataset( data_config=self.data_config, inputs=self.train_files, - data_target=( + input_target=( self.train_target_files if self.train_data_target else None ), read_source_func=self.read_source_func, @@ -368,7 +368,7 @@ def setup(self, *args: Any, **kwargs: Any) -> None: self.val_dataset = InMemoryDataset( data_config=self.data_config, inputs=self.val_files, - data_target=( + input_target=( self.val_target_files if self.val_data_target else None ), read_source_func=self.read_source_func, diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py index 5a7cf70b6..88f416a2b 100644 --- a/src/careamics/lightning_prediction_datamodule.py +++ b/src/careamics/lightning_prediction_datamodule.py @@ -1,7 +1,7 @@ """Prediction Lightning data modules.""" from pathlib import Path -from typing import Any, Callable, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import pytorch_lightning as L @@ -363,7 +363,7 @@ def __init__( """ if dataloader_params is None: dataloader_params = {} - prediction_dict = { + prediction_dict: Dict[str, Any] = { "data_type": data_type, "tile_size": tile_size, "tile_overlap": tile_overlap, @@ -372,6 +372,7 @@ def __init__( "std": std, "tta": tta_transforms, "batch_size": batch_size, + "transforms": [], } # if transforms are passed (otherwise it will use the default ones) diff --git a/tests/dataset/test_in_memory_dataset.py b/tests/dataset/test_in_memory_dataset.py index 9d78d9713..bdaecedc2 100644 --- a/tests/dataset/test_in_memory_dataset.py +++ b/tests/dataset/test_in_memory_dataset.py @@ -27,7 +27,7 @@ def test_number_of_patches(ordered_array): ) # check number of patches - assert len(dataset) == dataset.data.shape[0] + assert len(dataset) == dataset.patches.shape[0] def test_compute_mean_std_transform(ordered_array): @@ -69,7 +69,7 @@ def test_extracting_val_array(ordered_array, percentage): assert len(dataset) == total_n_patches - n_patches # check that none of the validation patch values are in the original dataset - assert np.in1d(valset.data, dataset.data).sum() == 0 + assert np.in1d(valset.patches, dataset.patches).sum() == 0 @pytest.mark.parametrize("percentage", [0.1, 0.6]) @@ -109,4 +109,4 @@ def test_extracting_val_files(tmp_path, ordered_array, percentage): assert len(dataset) == total_n_patches - n_patches # check that none of the validation patch values are in the original dataset - assert np.in1d(valset.data, dataset.data).sum() == 0 + assert np.in1d(valset.patches, dataset.patches).sum() == 0 From 3039971e77d382038348094d5afdc790f4ad7df5 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 28 May 2024 13:24:56 +0200 Subject: [PATCH 06/13] Fix checking of optional numpy array --- src/careamics/dataset/in_memory_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 7c244a157..6c0b94696 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -161,7 +161,7 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, ...]: patch = self.patches[index] # if there is a target - if self.patch_targets: + if self.patch_targets is not None: # get target target = self.patch_targets[index] From 2967a803c5e9975b4185b9cae99e995b10e8d971 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 28 May 2024 13:32:25 +0200 Subject: [PATCH 07/13] More mypy fixes --- src/careamics/lightning_datamodule.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py index 695ded065..92dc7c86f 100644 --- a/src/careamics/lightning_datamodule.py +++ b/src/careamics/lightning_datamodule.py @@ -260,7 +260,7 @@ def __init__( self.extension_filter: str = extension_filter # Pytorch dataloader parameters - self.dataloader_params = ( + self.dataloader_params: Dict[str, Any] = ( data_config.dataloader_params if data_config.dataloader_params else {} ) @@ -325,6 +325,11 @@ def setup(self, *args: Any, **kwargs: Any) -> None: """ # if numpy array if self.data_type == SupportedData.ARRAY: + # mypy checks + assert isinstance(self.train_data, np.ndarray) + if self.train_data_target is not None: + assert isinstance(self.train_data_target, np.ndarray) + # train dataset self.train_dataset: DatasetType = InMemoryDataset( data_config=self.data_config, @@ -334,6 +339,11 @@ def setup(self, *args: Any, **kwargs: Any) -> None: # validation dataset if self.val_data is not None: + # mypy checks + assert isinstance(self.val_data, np.ndarray) + if self.val_data_target is not None: + assert isinstance(self.val_data_target, np.ndarray) + # create its own dataset self.val_dataset: DatasetType = InMemoryDataset( data_config=self.data_config, From cde561221f6e704fd0e2677f6ec0460332422ca8 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 28 May 2024 14:00:45 +0200 Subject: [PATCH 08/13] Rename NDFlip, fix some doc errors --- src/careamics/config/configuration_example.py | 2 +- src/careamics/config/configuration_factory.py | 4 +- src/careamics/config/data_model.py | 8 +-- .../config/support/supported_transforms.py | 4 +- .../config/transformations/__init__.py | 4 +- .../{nd_flip_model.py => xy_flip_model.py} | 10 +-- src/careamics/transforms/__init__.py | 4 +- src/careamics/transforms/compose.py | 4 +- src/careamics/transforms/n2v_manipulate.py | 61 +++++++++++++---- src/careamics/transforms/nd_flip.py | 67 ------------------- src/careamics/transforms/normalize.py | 57 +++++++++++++++- .../transforms/pixel_manipulation.py | 6 +- .../transforms/struct_mask_parameters.py | 3 +- src/careamics/transforms/transform.py | 17 ++++- .../transforms/xy_random_rotate90.py | 22 +++++- src/careamics/utils/base_enum.py | 26 +++++++ src/careamics/utils/path_utils.py | 1 + src/careamics/utils/ram.py | 1 + tests/config/test_configuration_factory.py | 4 +- tests/config/test_configuration_model.py | 2 +- tests/config/test_data_model.py | 24 +++---- tests/config/test_inference_model.py | 2 +- tests/transforms/test_compose.py | 14 ++-- tests/transforms/test_nd_flip.py | 6 +- 24 files changed, 215 insertions(+), 138 deletions(-) rename src/careamics/config/transformations/{nd_flip_model.py => xy_flip_model.py} (64%) delete mode 100644 src/careamics/transforms/nd_flip.py diff --git a/src/careamics/config/configuration_example.py b/src/careamics/config/configuration_example.py index d66c1331c..f601409e4 100644 --- a/src/careamics/config/configuration_example.py +++ b/src/careamics/config/configuration_example.py @@ -56,7 +56,7 @@ def full_configuration_example() -> Configuration: "name": SupportedTransform.NORMALIZE.value, }, { - "name": SupportedTransform.NDFLIP.value, + "name": SupportedTransform.XY_FLIP.value, }, { "name": SupportedTransform.XY_RANDOM_ROTATE90.value, diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index e50dbed39..d9c2f4f7e 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -111,7 +111,7 @@ def _create_supervised_configuration( "name": SupportedTransform.NORMALIZE.value, }, { - "name": SupportedTransform.NDFLIP.value, + "name": SupportedTransform.XY_FLIP.value, }, { "name": SupportedTransform.XY_RANDOM_ROTATE90.value, @@ -526,7 +526,7 @@ def create_n2v_configuration( "name": SupportedTransform.NORMALIZE.value, }, { - "name": SupportedTransform.NDFLIP.value, + "name": SupportedTransform.XY_FLIP.value, }, { "name": SupportedTransform.XY_RANDOM_ROTATE90.value, diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index 862bc3c92..28f365ade 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -17,14 +17,14 @@ from .support import SupportedTransform from .transformations.n2v_manipulate_model import N2VManipulateModel -from .transformations.nd_flip_model import NDFlipModel +from .transformations.xy_flip_model import XYFlipModel from .transformations.normalize_model import NormalizeModel from .transformations.xy_random_rotate90_model import XYRandomRotate90Model from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 TRANSFORMS_UNION = Annotated[ Union[ - NDFlipModel, + XYFlipModel, XYRandomRotate90Model, NormalizeModel, N2VManipulateModel, @@ -70,7 +70,7 @@ class DataConfig(BaseModel): ... "std": 47.2, ... }, ... { - ... "name": "NDFlip", + ... "name": "XYFlip", ... } ... ] ... ) @@ -97,7 +97,7 @@ class DataConfig(BaseModel): "name": SupportedTransform.NORMALIZE.value, }, { - "name": SupportedTransform.NDFLIP.value, + "name": SupportedTransform.XY_FLIP.value, }, { "name": SupportedTransform.XY_RANDOM_ROTATE90.value, diff --git a/src/careamics/config/support/supported_transforms.py b/src/careamics/config/support/supported_transforms.py index b262c169b..ed61a2fc9 100644 --- a/src/careamics/config/support/supported_transforms.py +++ b/src/careamics/config/support/supported_transforms.py @@ -9,14 +9,14 @@ class SupportedTransform(str, BaseEnum): - XYRandomRotate90: #TODO - Normalize # TODO add details, in particular about the parameters - ManipulateN2V # TODO add details, in particular about the parameters - - NDFlip + - XYFlip Note that while any Albumentations (see https://albumentations.ai/) transform can be used in CAREamics, no check are implemented to verify the compatibility of any other transforms than the ones officially supported. """ - NDFLIP = "NDFlip" + XY_FLIP = "XYFlip" XY_RANDOM_ROTATE90 = "XYRandomRotate90" NORMALIZE = "Normalize" N2V_MANIPULATE = "N2VManipulate" diff --git a/src/careamics/config/transformations/__init__.py b/src/careamics/config/transformations/__init__.py index d5aaa92e0..c314be107 100644 --- a/src/careamics/config/transformations/__init__.py +++ b/src/careamics/config/transformations/__init__.py @@ -2,13 +2,13 @@ __all__ = [ "N2VManipulateModel", - "NDFlipModel", + "XYFlipModel", "NormalizeModel", "XYRandomRotate90Model", ] from .n2v_manipulate_model import N2VManipulateModel -from .nd_flip_model import NDFlipModel +from .xy_flip_model import XYFlipModel from .normalize_model import NormalizeModel from .xy_random_rotate90_model import XYRandomRotate90Model diff --git a/src/careamics/config/transformations/nd_flip_model.py b/src/careamics/config/transformations/xy_flip_model.py similarity index 64% rename from src/careamics/config/transformations/nd_flip_model.py rename to src/careamics/config/transformations/xy_flip_model.py index 806f8b7e1..e8133dce6 100644 --- a/src/careamics/config/transformations/nd_flip_model.py +++ b/src/careamics/config/transformations/xy_flip_model.py @@ -1,4 +1,4 @@ -"""Pydantic model for the NDFlip transform.""" +"""Pydantic model for the XYFlip transform.""" from typing import Literal, Optional @@ -7,13 +7,13 @@ from .transform_model import TransformModel -class NDFlipModel(TransformModel): +class XYFlipModel(TransformModel): """ - Pydantic model used to represent NDFlip transformation. + Pydantic model used to represent XYFlip transformation. Attributes ---------- - name : Literal["NDFlip"] + name : Literal["XYFlip"] Name of the transformation. seed : Optional[int] Seed for the random number generator. @@ -23,5 +23,5 @@ class NDFlipModel(TransformModel): validate_assignment=True, ) - name: Literal["NDFlip"] = "NDFlip" + name: Literal["XYFlip"] = "XYFlip" seed: Optional[int] = None diff --git a/src/careamics/transforms/__init__.py b/src/careamics/transforms/__init__.py index 2a59f7d53..22aeb5677 100644 --- a/src/careamics/transforms/__init__.py +++ b/src/careamics/transforms/__init__.py @@ -3,7 +3,7 @@ __all__ = [ "get_all_transforms", "N2VManipulate", - "NDFlip", + "XYFlip", "XYRandomRotate90", "ImageRestorationTTA", "Denormalize", @@ -14,7 +14,7 @@ from .compose import Compose, get_all_transforms from .n2v_manipulate import N2VManipulate -from .nd_flip import NDFlip +from .xy_flip import XYFlip from .normalize import Denormalize, Normalize from .tta import ImageRestorationTTA from .xy_random_rotate90 import XYRandomRotate90 diff --git a/src/careamics/transforms/compose.py b/src/careamics/transforms/compose.py index 30a34ba2e..1a528c6f7 100644 --- a/src/careamics/transforms/compose.py +++ b/src/careamics/transforms/compose.py @@ -7,7 +7,7 @@ from careamics.config.data_model import TRANSFORMS_UNION from .n2v_manipulate import N2VManipulate -from .nd_flip import NDFlip +from .xy_flip import XYFlip from .normalize import Normalize from .transform import Transform from .xy_random_rotate90 import XYRandomRotate90 @@ -15,7 +15,7 @@ ALL_TRANSFORMS = { "Normalize": Normalize, "N2VManipulate": N2VManipulate, - "NDFlip": NDFlip, + "XYFlip": XYFlip, "XYRandomRotate90": XYRandomRotate90, } diff --git a/src/careamics/transforms/n2v_manipulate.py b/src/careamics/transforms/n2v_manipulate.py index 959e55d39..40ce0795c 100644 --- a/src/careamics/transforms/n2v_manipulate.py +++ b/src/careamics/transforms/n2v_manipulate.py @@ -1,3 +1,4 @@ +"""N2V manipulation transform.""" from typing import Any, Literal, Optional, Tuple import numpy as np @@ -17,10 +18,35 @@ class N2VManipulate(Transform): Parameters ---------- - mask_pixel_percentage : float - Approximate percentage of pixels to be masked. + roi_size : int, optional + Size of the replacement area, by default 11. + masked_pixel_percentage : float, optional + Percentage of pixels to mask, by default 0.2. + strategy : Literal[ "uniform", "median" ], optional + Replaccement strategy, uniform or median, by default uniform. + remove_center : bool, optional + Whether to remove central pixel from patch, by default True. + struct_mask_axis : Literal["horizontal", "vertical", "none"], optional + StructN2V mask axis, by default "none". + struct_mask_span : int, optional + StructN2V mask span, by default 5. + seed : Optional[int], optional + Random seed, by default None. + + Attributes + ---------- + masked_pixel_percentage : float + Percentage of pixels to mask. roi_size : int - Size of the ROI the new pixel value is sampled from, by default 11. + Size of the replacement area. + strategy : Literal[ "uniform", "median" ] + Replaccement strategy, uniform or median. + remove_center : bool + Whether to remove central pixel from patch. + struct_mask : Optional[StructMaskParameters] + StructN2V mask parameters. + rng : Generator + Random number generator. """ def __init__( @@ -40,24 +66,24 @@ def __init__( Parameters ---------- roi_size : int, optional - Size of the replacement area, by default 11 + Size of the replacement area, by default 11. masked_pixel_percentage : float, optional - Percentage of pixels to mask, by default 0.2 + Percentage of pixels to mask, by default 0.2. strategy : Literal[ "uniform", "median" ], optional - Replaccement strategy, uniform or median, by default uniform + Replaccement strategy, uniform or median, by default uniform. remove_center : bool, optional - Whether to remove central pixel from patch, by default True + Whether to remove central pixel from patch, by default True. struct_mask_axis : Literal["horizontal", "vertical", "none"], optional - StructN2V mask axis, by default "none" + StructN2V mask axis, by default "none". struct_mask_span : int, optional - StructN2V mask span, by default 5 + StructN2V mask span, by default 5. seed : Optional[int], optional - Random seed, by default None + Random seed, by default None. """ self.masked_pixel_percentage = masked_pixel_percentage self.roi_size = roi_size self.strategy = strategy - self.remove_center = remove_center + self.remove_center = remove_center # TODO is this ever used? if struct_mask_axis == SupportedStructAxis.NONE: self.struct_mask: Optional[StructMaskParameters] = None @@ -77,8 +103,17 @@ def __call__( Parameters ---------- - image : np.ndarray - Image or image patch, 2D or 3D, shape C(Z)YX. + patch : np.ndarray + Image patch, 2D or 3D, shape C(Z)YX. + *args : Any + Additional arguments, unused. + **kwargs : Any + Additional keyword arguments, unused. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + Masked patch, original patch, and mask. """ masked = np.zeros_like(patch) mask = np.zeros_like(patch) diff --git a/src/careamics/transforms/nd_flip.py b/src/careamics/transforms/nd_flip.py deleted file mode 100644 index 4f558ba3b..000000000 --- a/src/careamics/transforms/nd_flip.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Optional, Tuple - -import numpy as np - -from careamics.transforms.transform import Transform - - -class NDFlip(Transform): - """Flip ND arrays on a single axis. - - This transform ignores singleton axes and randomly flips one of the other - last two axes. - - This transform expects C(Z)YX dimensions. - """ - - def __init__(self, seed: Optional[int] = None): - """Constructor. - - Parameters - ---------- - seed : Optional[int], optional - Random seed, by default None - """ - # "flippable" axes - self.axis_indices = [-2, -1] - - # numpy random generator - self.rng = np.random.default_rng(seed=seed) - - def __call__( - self, patch: np.ndarray, target: Optional[np.ndarray] = None - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: - """Apply the transform to the source patch and the target (optional). - - Parameters - ---------- - patch : np.ndarray - Patch, 2D or 3D, shape C(Z)YX. - target : Optional[np.ndarray], optional - Target for the patch, by default None - - Returns - ------- - Tuple[np.ndarray, Optional[np.ndarray]] - Transformed patch and target. - """ - # choose an axis to flip - axis = self.rng.choice(self.axis_indices) - - patch_transformed = self._apply(patch, axis) - target_transformed = self._apply(target, axis) if target is not None else None - - return patch_transformed, target_transformed - - def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray: - """Apply the transform to the image. - - Parameters - ---------- - patch : np.ndarray - Image or image patch, 2D or 3D, shape C(Z)YX. - axis : int - Axis to flip. - """ - # TODO why ascontiguousarray? - return np.ascontiguousarray(np.flip(patch, axis=axis)) diff --git a/src/careamics/transforms/normalize.py b/src/careamics/transforms/normalize.py index 19d637c2d..76377fe02 100644 --- a/src/careamics/transforms/normalize.py +++ b/src/careamics/transforms/normalize.py @@ -1,3 +1,4 @@ +"""Normalization and denormalization transforms for image patches.""" from typing import Optional, Tuple import numpy as np @@ -15,6 +16,13 @@ class Normalize(Transform): Not that an epsilon value of 1e-6 is added to the standard deviation to avoid division by zero and that it returns a float32 image. + Parameters + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + Attributes ---------- mean : float @@ -28,6 +36,15 @@ def __init__( mean: float, std: float, ): + """Constructor. + + Parameters + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + """ self.mean = mean self.std = std self.eps = 1e-6 @@ -42,7 +59,7 @@ def __call__( patch : np.ndarray Patch, 2D or 3D, shape C(Z)YX. target : Optional[np.ndarray], optional - Target for the patch, by default None + Target for the patch, by default None. Returns ------- @@ -55,6 +72,19 @@ def __call__( return norm_patch, norm_target def _apply(self, patch: np.ndarray) -> np.ndarray: + """ + Apply the transform to the image. + + Parameters + ---------- + patch : np.ndarray + Image patch, 2D or 3D, shape C(Z)YX. + + Returns + ------- + np.ndarray + Normalizedimage patch. + """ return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32) @@ -69,6 +99,13 @@ class Denormalize: division by zero during the normalization step, which is taken into account during denormalization. + Parameters + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + Attributes ---------- mean : float @@ -82,6 +119,15 @@ def __init__( mean: float, std: float, ): + """Constructor. + + Parameters + ---------- + mean : float + Mean. + std : float + Standard deviation. + """ self.mean = mean self.std = std self.eps = 1e-6 @@ -96,7 +142,7 @@ def __call__( patch : np.ndarray Patch, 2D or 3D, shape C(Z)YX. target : Optional[np.ndarray], optional - Target for the patch, by default None + Target for the patch, by default None. Returns ------- @@ -115,6 +161,11 @@ def _apply(self, patch: np.ndarray) -> np.ndarray: Parameters ---------- patch : np.ndarray - Image or image patch, 2D or 3D, shape C(Z)YX. + Image patch, 2D or 3D, shape C(Z)YX. + + Returns + ------- + np.ndarray + Denormalized image patch. """ return patch * (self.std + self.eps) + self.mean diff --git a/src/careamics/transforms/pixel_manipulation.py b/src/careamics/transforms/pixel_manipulation.py index 280a974f4..9c8f0f76e 100644 --- a/src/careamics/transforms/pixel_manipulation.py +++ b/src/careamics/transforms/pixel_manipulation.py @@ -15,7 +15,7 @@ def _apply_struct_mask( patch: np.ndarray, coords: np.ndarray, struct_params: StructMaskParameters ) -> np.ndarray: - """Applies structN2V masks to patch. + """Apply structN2V masks to patch. Each point in `coords` corresponds to the center of a mask, masks are paremeterized by `struct_params` and pixels in the mask (with respect to `coords`) are replaced by @@ -248,7 +248,7 @@ def uniform_manipulate( Size of the subpatch the new pixel value is sampled from, by default 11. remove_center : bool Whether to remove the center pixel from the subpatch, by default False. - struct_params: Optional[StructMaskParameters] + struct_params : Optional[StructMaskParameters] Parameters for the structN2V mask (axis and span). Returns @@ -322,7 +322,7 @@ def median_manipulate( Approximate percentage of pixels to be masked. subpatch_size : int Size of the subpatch the new pixel value is sampled from, by default 11. - struct_params: Optional[StructMaskParameters] + struct_params : Optional[StructMaskParameters] Parameters for the structN2V mask (axis and span). Returns diff --git a/src/careamics/transforms/struct_mask_parameters.py b/src/careamics/transforms/struct_mask_parameters.py index 48e49ef77..0d3abd73c 100644 --- a/src/careamics/transforms/struct_mask_parameters.py +++ b/src/careamics/transforms/struct_mask_parameters.py @@ -1,3 +1,4 @@ +"""Class representing the parameters of structN2V masks.""" from dataclasses import dataclass from typing import Literal @@ -6,7 +7,7 @@ class StructMaskParameters: """Parameters of structN2V masks. - Parameters + Attributes ---------- axis : Literal[0, 1] Axis along which to apply the mask, horizontal (0) or vertical (1). diff --git a/src/careamics/transforms/transform.py b/src/careamics/transforms/transform.py index 9798355df..640d1b5da 100644 --- a/src/careamics/transforms/transform.py +++ b/src/careamics/transforms/transform.py @@ -6,6 +6,19 @@ class Transform: """A general parent class for transforms.""" - def __call__(self, *args: Any, **kwwargs: Any) -> Any: - """Apply the transform.""" + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Apply the transform. + + Parameters + ---------- + *args : Any + Arguments. + **kwargs : Any + Keyword arguments. + + Returns + ------- + Any + Transformed data. + """ pass diff --git a/src/careamics/transforms/xy_random_rotate90.py b/src/careamics/transforms/xy_random_rotate90.py index d970d24bf..25cbc7f98 100644 --- a/src/careamics/transforms/xy_random_rotate90.py +++ b/src/careamics/transforms/xy_random_rotate90.py @@ -1,3 +1,4 @@ +"""Patch transform applying XY random 90 degrees rotations.""" from typing import Optional, Tuple import numpy as np @@ -9,6 +10,16 @@ class XYRandomRotate90(Transform): """Applies random 90 degree rotations to the YX axis. This transform expects C(Z)YX dimensions. + + Attributes + ---------- + rng : np.random.Generator + Random number generator. + + Parameters + ---------- + seed : Optional[int] + Random seed, by default None. """ def __init__(self, seed: Optional[int] = None): @@ -16,8 +27,8 @@ def __init__(self, seed: Optional[int] = None): Parameters ---------- - seed : Optional[int], optional - Random seed, by default None + seed : Optional[int] + Random seed, by default None. """ # numpy random generator self.rng = np.random.default_rng(seed=seed) @@ -32,7 +43,7 @@ def __call__( patch : np.ndarray Patch, 2D or 3D, shape C(Z)YX. target : Optional[np.ndarray], optional - Target for the patch, by default None + Target for the patch, by default None. Returns ------- @@ -63,6 +74,11 @@ def _apply( Number of 90 degree rotations. axes : Tuple[int, int] Axes along which to rotate the patch. + + Returns + ------- + np.ndarray + Transformed patch. """ # TODO why ascontiguousarray? return np.ascontiguousarray(np.rot90(patch, k=n_rot, axes=axes)) diff --git a/src/careamics/utils/base_enum.py b/src/careamics/utils/base_enum.py index 8ff8bb6c4..db54d6ae2 100644 --- a/src/careamics/utils/base_enum.py +++ b/src/careamics/utils/base_enum.py @@ -1,9 +1,23 @@ +"""A base class for Enum that allows checking if a value is in the Enum.""" from enum import Enum, EnumMeta from typing import Any class _ContainerEnum(EnumMeta): + """Metaclass for Enum with __contains__ method.""" def __contains__(cls, item: Any) -> bool: + """Check if an item is in the Enum. + + Parameters + ---------- + item : Any + Item to check. + + Returns + ------- + bool + True if the item is in the Enum, False otherwise. + """ try: cls(item) except ValueError: @@ -12,6 +26,18 @@ def __contains__(cls, item: Any) -> bool: @classmethod def has_value(cls, value: Any) -> bool: + """Check if a value is in the Enum. + + Parameters + ---------- + value : Any + Value to check. + + Returns + ------- + bool + True if the value is in the Enum, False otherwise. + """ return value in cls._value2member_map_ diff --git a/src/careamics/utils/path_utils.py b/src/careamics/utils/path_utils.py index 61bb744a0..6ea25ec48 100644 --- a/src/careamics/utils/path_utils.py +++ b/src/careamics/utils/path_utils.py @@ -1,3 +1,4 @@ +"""Utility functions for paths.""" from pathlib import Path from typing import Union diff --git a/src/careamics/utils/ram.py b/src/careamics/utils/ram.py index 2a26c7811..1fc4a6556 100644 --- a/src/careamics/utils/ram.py +++ b/src/careamics/utils/ram.py @@ -1,3 +1,4 @@ +"""Utility function to get RAM size.""" import psutil diff --git a/tests/config/test_configuration_factory.py b/tests/config/test_configuration_factory.py index 51e9b9805..4ba0a7206 100644 --- a/tests/config/test_configuration_factory.py +++ b/tests/config/test_configuration_factory.py @@ -24,7 +24,7 @@ def test_n2n_configuration(): ) assert config.data_config.transforms[0].name == SupportedTransform.NORMALIZE.value - assert config.data_config.transforms[1].name == SupportedTransform.NDFLIP.value + assert config.data_config.transforms[1].name == SupportedTransform.XY_FLIP.value assert ( config.data_config.transforms[2].name == SupportedTransform.XY_RANDOM_ROTATE90.value @@ -152,7 +152,7 @@ def test_care_configuration(): ) assert config.data_config.transforms[0].name == SupportedTransform.NORMALIZE.value - assert config.data_config.transforms[1].name == SupportedTransform.NDFLIP.value + assert config.data_config.transforms[1].name == SupportedTransform.XY_FLIP.value assert ( config.data_config.transforms[2].name == SupportedTransform.XY_RANDOM_ROTATE90.value diff --git a/tests/config/test_configuration_model.py b/tests/config/test_configuration_model.py index 353184cf6..0b0853243 100644 --- a/tests/config/test_configuration_model.py +++ b/tests/config/test_configuration_model.py @@ -106,7 +106,7 @@ def test_n2v2_and_transforms(minimum_configuration: dict, algorithm, strategy): # missing ManipulateN2V minimum_configuration["data_config"]["transforms"] = [ - {"name": SupportedTransform.NDFLIP.value} + {"name": SupportedTransform.XY_FLIP.value} ] config = Configuration(**minimum_configuration) assert len(config.data_config.transforms) == 2 diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py index ceee24030..aa0920e35 100644 --- a/tests/config/test_data_model.py +++ b/tests/config/test_data_model.py @@ -8,7 +8,7 @@ ) from careamics.config.transformations import ( N2VManipulateModel, - NDFlipModel, + XYFlipModel, NormalizeModel, XYRandomRotate90Model, ) @@ -140,15 +140,15 @@ def test_set_3d(minimum_data: dict): "transforms", [ [ - {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_FLIP.value}, {"name": SupportedTransform.N2V_MANIPULATE.value}, ], [ - {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_FLIP.value}, ], [ {"name": SupportedTransform.NORMALIZE.value}, - {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_FLIP.value}, {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, {"name": SupportedTransform.N2V_MANIPULATE.value}, ], @@ -160,7 +160,7 @@ def test_passing_supported_transforms(minimum_data: dict, transforms): model = DataConfig(**minimum_data) supported = { - "NDFlip": NDFlipModel, + "XYFlip": XYFlipModel, "XYRandomRotate90": XYRandomRotate90Model, "Normalize": NormalizeModel, "N2VManipulate": N2VManipulateModel, @@ -176,14 +176,14 @@ def test_passing_supported_transforms(minimum_data: dict, transforms): [ [ {"name": SupportedTransform.N2V_MANIPULATE.value}, - {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_FLIP.value}, ], [ {"name": SupportedTransform.N2V_MANIPULATE.value}, ], [ {"name": SupportedTransform.NORMALIZE.value}, - {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_FLIP.value}, {"name": SupportedTransform.N2V_MANIPULATE.value}, {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, ], @@ -209,19 +209,19 @@ def test_multiple_n2v_manipulate(minimum_data: dict): def test_remove_n2v_manipulate(minimum_data: dict): """Test that N2V Manipulate can be removed.""" minimum_data["transforms"] = [ - {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_FLIP.value}, {"name": SupportedTransform.N2V_MANIPULATE.value}, ] model = DataConfig(**minimum_data) model.remove_n2v_manipulate() assert len(model.transforms) == 1 - assert model.transforms[-1].name == SupportedTransform.NDFLIP.value + assert model.transforms[-1].name == SupportedTransform.XY_FLIP.value def test_add_n2v_manipulate(minimum_data: dict): """Test that N2V Manipulate can be added.""" minimum_data["transforms"] = [ - {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_FLIP.value}, ] model = DataConfig(**minimum_data) model.add_n2v_manipulate() @@ -242,7 +242,7 @@ def test_correct_transform_parameters(minimum_data: dict): """ minimum_data["transforms"] = [ {"name": SupportedTransform.NORMALIZE.value}, - {"name": SupportedTransform.NDFLIP.value}, + {"name": SupportedTransform.XY_FLIP.value}, {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, {"name": SupportedTransform.N2V_MANIPULATE.value}, ] @@ -272,7 +272,7 @@ def test_passing_incorrect_element(minimum_data: dict): """Test that incorrect element in the list of transforms raises an error ( e.g. passing un object rather than a string).""" minimum_data["transforms"] = [ - {"name": get_all_transforms()[SupportedTransform.NDFLIP.value]()}, + {"name": get_all_transforms()[SupportedTransform.XY_FLIP.value]()}, ] with pytest.raises(ValueError): DataConfig(**minimum_data) diff --git a/tests/config/test_inference_model.py b/tests/config/test_inference_model.py index ef91620b1..732507077 100644 --- a/tests/config/test_inference_model.py +++ b/tests/config/test_inference_model.py @@ -154,7 +154,7 @@ def test_passing_incorrect_element(minimum_inference: dict): """Test that incorrect element in the list of transforms raises an error ( e.g. passing un object rather than a string).""" minimum_inference["transforms"] = [ - {"name": get_all_transforms()[SupportedTransform.NDFLIP.value]()}, + {"name": get_all_transforms()[SupportedTransform.XY_FLIP.value]()}, ] with pytest.raises(ValueError): InferenceConfig(**minimum_inference) diff --git a/tests/transforms/test_compose.py b/tests/transforms/test_compose.py index aaa8c3684..822e04937 100644 --- a/tests/transforms/test_compose.py +++ b/tests/transforms/test_compose.py @@ -2,11 +2,11 @@ from careamics.config.transformations import ( N2VManipulateModel, - NDFlipModel, + XYFlipModel, NormalizeModel, XYRandomRotate90Model, ) -from careamics.transforms import Compose, NDFlip, Normalize, XYRandomRotate90 +from careamics.transforms import Compose, XYFlip, Normalize, XYRandomRotate90 def test_empty_compose(ordered_array): @@ -32,9 +32,9 @@ def test_compose_with_target(ordered_array): target = array[:2, ...] # transform lists - transform_list = [NDFlip(seed=seed), XYRandomRotate90(seed=seed)] + transform_list = [XYFlip(seed=seed), XYRandomRotate90(seed=seed)] transform_list_pydantic = [ - NDFlipModel(name="NDFlip", seed=seed), + XYFlipModel(name="XYFlip", seed=seed), XYRandomRotate90Model(name="XYRandomRotate90", seed=seed), ] @@ -62,16 +62,16 @@ def test_compose_n2v(ordered_array): transform_list_pydantic = [ NormalizeModel(mean=mean, std=std), - NDFlipModel(seed=seed), + XYFlipModel(seed=seed), XYRandomRotate90Model(seed=seed), N2VManipulateModel(), ] # apply the transforms normalize = Normalize(mean=mean, std=std) - ndflip = NDFlip(seed=seed) + xyflip = XYFlip(seed=seed) xyrotate = XYRandomRotate90(seed=seed) - array_aug, _ = xyrotate(*ndflip(*normalize(array))) + array_aug, _ = xyrotate(*xyflip(*normalize(array))) # instantiate Compose compose = Compose(transform_list_pydantic) diff --git a/tests/transforms/test_nd_flip.py b/tests/transforms/test_nd_flip.py index eb1ed46a9..e04e0777a 100644 --- a/tests/transforms/test_nd_flip.py +++ b/tests/transforms/test_nd_flip.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from careamics.transforms import NDFlip +from careamics.transforms import XYFlip @pytest.mark.parametrize( @@ -21,7 +21,7 @@ def test_flip_nd(ordered_array, shape): array: np.ndarray = ordered_array(shape) # create augmentation - aug = NDFlip(seed=42) + aug = XYFlip(seed=42) r = np.random.default_rng(seed=42) # potential flips @@ -43,7 +43,7 @@ def test_flip_mask(ordered_array): array = array[:2, ...] # create augmentation - aug = NDFlip(seed=42) + aug = XYFlip(seed=42) r = np.random.default_rng(seed=42) # potential flips on Y and X axes From 04c871c327d6a34aee4fc54b2e3c6bd78a942f7d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 May 2024 12:01:13 +0000 Subject: [PATCH 09/13] style(pre-commit.ci): auto fixes [...] --- src/careamics/config/data_model.py | 2 +- src/careamics/config/transformations/__init__.py | 2 +- src/careamics/transforms/__init__.py | 2 +- src/careamics/transforms/compose.py | 2 +- src/careamics/transforms/n2v_manipulate.py | 3 ++- src/careamics/transforms/normalize.py | 1 + src/careamics/transforms/struct_mask_parameters.py | 1 + src/careamics/transforms/transform.py | 4 ++-- src/careamics/transforms/xy_random_rotate90.py | 1 + src/careamics/utils/base_enum.py | 2 ++ src/careamics/utils/path_utils.py | 1 + src/careamics/utils/ram.py | 1 + tests/config/test_data_model.py | 2 +- tests/transforms/test_compose.py | 4 ++-- 14 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index 28f365ade..202aee507 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -17,8 +17,8 @@ from .support import SupportedTransform from .transformations.n2v_manipulate_model import N2VManipulateModel -from .transformations.xy_flip_model import XYFlipModel from .transformations.normalize_model import NormalizeModel +from .transformations.xy_flip_model import XYFlipModel from .transformations.xy_random_rotate90_model import XYRandomRotate90Model from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 diff --git a/src/careamics/config/transformations/__init__.py b/src/careamics/config/transformations/__init__.py index c314be107..b083130fe 100644 --- a/src/careamics/config/transformations/__init__.py +++ b/src/careamics/config/transformations/__init__.py @@ -9,6 +9,6 @@ from .n2v_manipulate_model import N2VManipulateModel -from .xy_flip_model import XYFlipModel from .normalize_model import NormalizeModel +from .xy_flip_model import XYFlipModel from .xy_random_rotate90_model import XYRandomRotate90Model diff --git a/src/careamics/transforms/__init__.py b/src/careamics/transforms/__init__.py index 22aeb5677..76a53b130 100644 --- a/src/careamics/transforms/__init__.py +++ b/src/careamics/transforms/__init__.py @@ -14,7 +14,7 @@ from .compose import Compose, get_all_transforms from .n2v_manipulate import N2VManipulate -from .xy_flip import XYFlip from .normalize import Denormalize, Normalize from .tta import ImageRestorationTTA +from .xy_flip import XYFlip from .xy_random_rotate90 import XYRandomRotate90 diff --git a/src/careamics/transforms/compose.py b/src/careamics/transforms/compose.py index 1a528c6f7..7bb2e64d7 100644 --- a/src/careamics/transforms/compose.py +++ b/src/careamics/transforms/compose.py @@ -7,9 +7,9 @@ from careamics.config.data_model import TRANSFORMS_UNION from .n2v_manipulate import N2VManipulate -from .xy_flip import XYFlip from .normalize import Normalize from .transform import Transform +from .xy_flip import XYFlip from .xy_random_rotate90 import XYRandomRotate90 ALL_TRANSFORMS = { diff --git a/src/careamics/transforms/n2v_manipulate.py b/src/careamics/transforms/n2v_manipulate.py index 40ce0795c..0737f4396 100644 --- a/src/careamics/transforms/n2v_manipulate.py +++ b/src/careamics/transforms/n2v_manipulate.py @@ -1,4 +1,5 @@ """N2V manipulation transform.""" + from typing import Any, Literal, Optional, Tuple import numpy as np @@ -83,7 +84,7 @@ def __init__( self.masked_pixel_percentage = masked_pixel_percentage self.roi_size = roi_size self.strategy = strategy - self.remove_center = remove_center # TODO is this ever used? + self.remove_center = remove_center # TODO is this ever used? if struct_mask_axis == SupportedStructAxis.NONE: self.struct_mask: Optional[StructMaskParameters] = None diff --git a/src/careamics/transforms/normalize.py b/src/careamics/transforms/normalize.py index 76377fe02..1e24afd5b 100644 --- a/src/careamics/transforms/normalize.py +++ b/src/careamics/transforms/normalize.py @@ -1,4 +1,5 @@ """Normalization and denormalization transforms for image patches.""" + from typing import Optional, Tuple import numpy as np diff --git a/src/careamics/transforms/struct_mask_parameters.py b/src/careamics/transforms/struct_mask_parameters.py index 0d3abd73c..184ba5d0b 100644 --- a/src/careamics/transforms/struct_mask_parameters.py +++ b/src/careamics/transforms/struct_mask_parameters.py @@ -1,4 +1,5 @@ """Class representing the parameters of structN2V masks.""" + from dataclasses import dataclass from typing import Literal diff --git a/src/careamics/transforms/transform.py b/src/careamics/transforms/transform.py index 640d1b5da..0777d022c 100644 --- a/src/careamics/transforms/transform.py +++ b/src/careamics/transforms/transform.py @@ -8,14 +8,14 @@ class Transform: def __call__(self, *args: Any, **kwargs: Any) -> Any: """Apply the transform. - + Parameters ---------- *args : Any Arguments. **kwargs : Any Keyword arguments. - + Returns ------- Any diff --git a/src/careamics/transforms/xy_random_rotate90.py b/src/careamics/transforms/xy_random_rotate90.py index 25cbc7f98..21d168428 100644 --- a/src/careamics/transforms/xy_random_rotate90.py +++ b/src/careamics/transforms/xy_random_rotate90.py @@ -1,4 +1,5 @@ """Patch transform applying XY random 90 degrees rotations.""" + from typing import Optional, Tuple import numpy as np diff --git a/src/careamics/utils/base_enum.py b/src/careamics/utils/base_enum.py index db54d6ae2..1385478f7 100644 --- a/src/careamics/utils/base_enum.py +++ b/src/careamics/utils/base_enum.py @@ -1,10 +1,12 @@ """A base class for Enum that allows checking if a value is in the Enum.""" + from enum import Enum, EnumMeta from typing import Any class _ContainerEnum(EnumMeta): """Metaclass for Enum with __contains__ method.""" + def __contains__(cls, item: Any) -> bool: """Check if an item is in the Enum. diff --git a/src/careamics/utils/path_utils.py b/src/careamics/utils/path_utils.py index 6ea25ec48..3803c7281 100644 --- a/src/careamics/utils/path_utils.py +++ b/src/careamics/utils/path_utils.py @@ -1,4 +1,5 @@ """Utility functions for paths.""" + from pathlib import Path from typing import Union diff --git a/src/careamics/utils/ram.py b/src/careamics/utils/ram.py index 1fc4a6556..258ebc824 100644 --- a/src/careamics/utils/ram.py +++ b/src/careamics/utils/ram.py @@ -1,4 +1,5 @@ """Utility function to get RAM size.""" + import psutil diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py index aa0920e35..97ca23174 100644 --- a/tests/config/test_data_model.py +++ b/tests/config/test_data_model.py @@ -8,8 +8,8 @@ ) from careamics.config.transformations import ( N2VManipulateModel, - XYFlipModel, NormalizeModel, + XYFlipModel, XYRandomRotate90Model, ) from careamics.transforms import get_all_transforms diff --git a/tests/transforms/test_compose.py b/tests/transforms/test_compose.py index 822e04937..2903d9926 100644 --- a/tests/transforms/test_compose.py +++ b/tests/transforms/test_compose.py @@ -2,11 +2,11 @@ from careamics.config.transformations import ( N2VManipulateModel, - XYFlipModel, NormalizeModel, + XYFlipModel, XYRandomRotate90Model, ) -from careamics.transforms import Compose, XYFlip, Normalize, XYRandomRotate90 +from careamics.transforms import Compose, Normalize, XYFlip, XYRandomRotate90 def test_empty_compose(ordered_array): From d688dc7b0d70f1755b9b51ff2b2c61c4805534c6 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Tue, 28 May 2024 14:11:03 +0200 Subject: [PATCH 10/13] Missing file, more doc fix --- .../callbacks/hyperparameters_callback.py | 13 ++- .../callbacks/progress_bar_callback.py | 41 ++++++++- src/careamics/config/algorithm_model.py | 8 +- .../architectures/architecture_model.py | 2 + src/careamics/models/unet.py | 14 ++- src/careamics/prediction/stitch_prediction.py | 8 +- src/careamics/transforms/compose.py | 30 ++++++- src/careamics/transforms/xy_flip.py | 86 +++++++++++++++++++ 8 files changed, 182 insertions(+), 20 deletions(-) create mode 100644 src/careamics/transforms/xy_flip.py diff --git a/src/careamics/callbacks/hyperparameters_callback.py b/src/careamics/callbacks/hyperparameters_callback.py index d06090770..7de432e64 100644 --- a/src/careamics/callbacks/hyperparameters_callback.py +++ b/src/careamics/callbacks/hyperparameters_callback.py @@ -1,3 +1,5 @@ +"""Callback saving CAREamics configuration as hyperparameters in the model.""" + from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback @@ -11,13 +13,18 @@ class HyperParametersCallback(Callback): This allows saving the configuration as dictionnary in the checkpoints, and loading it subsequently in a CAREamist instance. + Parameters + ---------- + config : Configuration + CAREamics configuration to be saved as hyperparameter in the model. + Attributes ---------- config : Configuration CAREamics configuration to be saved as hyperparameter in the model. """ - def __init__(self, config: Configuration): + def __init__(self, config: Configuration) -> None: """ Constructor. @@ -28,14 +35,14 @@ def __init__(self, config: Configuration): """ self.config = config - def on_train_start(self, trainer: Trainer, pl_module: LightningModule): + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """ Update the hyperparameters of the model with the configuration on train start. Parameters ---------- trainer : Trainer - PyTorch Lightning trainer. + PyTorch Lightning trainer, unused. pl_module : LightningModule PyTorch Lightning module. """ diff --git a/src/careamics/callbacks/progress_bar_callback.py b/src/careamics/callbacks/progress_bar_callback.py index d7862091c..29c3cd0cf 100644 --- a/src/careamics/callbacks/progress_bar_callback.py +++ b/src/careamics/callbacks/progress_bar_callback.py @@ -1,3 +1,5 @@ +"""Progressbar callback.""" + import sys from typing import Dict, Union @@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar): """Progress bar for training and validation steps.""" def init_train_tqdm(self) -> tqdm: - """Override this to customize the tqdm bar for training.""" + """Override this to customize the tqdm bar for training. + + Returns + ------- + tqdm + A tqdm bar. + """ bar = tqdm( desc="Training", position=(2 * self.process_position), @@ -23,7 +31,13 @@ def init_train_tqdm(self) -> tqdm: return bar def init_validation_tqdm(self) -> tqdm: - """Override this to customize the tqdm bar for validation.""" + """Override this to customize the tqdm bar for validation. + + Returns + ------- + tqdm + A tqdm bar. + """ # The main progress bar doesn't exist in `trainer.validate()` has_main_bar = self.train_progress_bar is not None bar = tqdm( @@ -37,7 +51,13 @@ def init_validation_tqdm(self) -> tqdm: return bar def init_test_tqdm(self) -> tqdm: - """Override this to customize the tqdm bar for testing.""" + """Override this to customize the tqdm bar for testing. + + Returns + ------- + tqdm + A tqdm bar. + """ bar = tqdm( desc="Testing", position=(2 * self.process_position), @@ -52,6 +72,19 @@ def init_test_tqdm(self) -> tqdm: def get_metrics( self, trainer: Trainer, pl_module: LightningModule ) -> Dict[str, Union[int, str, float, Dict[str, float]]]: - """Override this to customize the metrics displayed in the progress bar.""" + """Override this to customize the metrics displayed in the progress bar. + + Parameters + ---------- + trainer : Trainer + The trainer object. + pl_module : LightningModule + The LightningModule object, unused. + + Returns + ------- + dict + A dictionary with the metrics to display in the progress bar. + """ pbar_metrics = trainer.progress_bar_metrics return {**pbar_metrics} diff --git a/src/careamics/config/algorithm_model.py b/src/careamics/config/algorithm_model.py index 77ea835a2..ddfe81e4a 100644 --- a/src/careamics/config/algorithm_model.py +++ b/src/careamics/config/algorithm_model.py @@ -1,3 +1,5 @@ +"""Algorithm configuration.""" + from __future__ import annotations from pprint import pformat @@ -17,9 +19,9 @@ class AlgorithmConfig(BaseModel): training algorithm: which algorithm, loss function, model architecture, optimizer, and learning rate scheduler to use. - Currently, we only support N2V and custom algorithms. The `n2v` algorithm is only - compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm allows - you to register your own architecture and select it using its name as + Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is + only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm + allows you to register your own architecture and select it using its name as `name` in the custom pydantic model. Attributes diff --git a/src/careamics/config/architectures/architecture_model.py b/src/careamics/config/architectures/architecture_model.py index 28113112c..6ea3f122f 100644 --- a/src/careamics/config/architectures/architecture_model.py +++ b/src/careamics/config/architectures/architecture_model.py @@ -1,3 +1,5 @@ +"""Base model for the various CAREamics architectures.""" + from typing import Any, Dict from pydantic import BaseModel diff --git a/src/careamics/models/unet.py b/src/careamics/models/unet.py index 11f29eb32..a8836798a 100644 --- a/src/careamics/models/unet.py +++ b/src/careamics/models/unet.py @@ -324,8 +324,14 @@ class UNet(nn.Module): Dropout probability, by default 0.0. pool_kernel : int, optional Kernel size of the pooling layers, by default 2. - last_activation : Optional[Callable], optional + final_activation : Optional[Callable], optional Activation function to use for the last layer, by default None. + n2v2 : bool, optional + Whether to use N2V2 architecture, by default False. + independent_channels : bool + Whether to train the channels independently, by default True. + **kwargs : Any + Additional keyword arguments, unused. """ def __init__( @@ -364,11 +370,15 @@ def __init__( Dropout probability, by default 0.0. pool_kernel : int, optional Kernel size of the pooling layers, by default 2. - last_activation : Optional[Callable], optional + final_activation : Optional[Callable], optional Activation function to use for the last layer, by default None. + n2v2 : bool, optional + Whether to use N2V2 architecture, by default False. independent_channels : bool Whether to train parallel independent networks for each channel, by default True. + **kwargs : Any + Additional keyword arguments, unused. """ super().__init__() diff --git a/src/careamics/prediction/stitch_prediction.py b/src/careamics/prediction/stitch_prediction.py index b849f3b55..1b4fe9690 100644 --- a/src/careamics/prediction/stitch_prediction.py +++ b/src/careamics/prediction/stitch_prediction.py @@ -1,8 +1,4 @@ -""" -Prediction convenience functions. - -These functions are used during prediction. -""" +"""Prediction utility functions.""" from typing import List @@ -21,7 +17,7 @@ def stitch_prediction( ---------- tiles : List[torch.Tensor] Cropped tiles and their respective stitching coordinates. - stitching_coords : List + stitching_data : List List of information and coordinates obtained from `dataset.tiled_patching.extract_tiles`. diff --git a/src/careamics/transforms/compose.py b/src/careamics/transforms/compose.py index 7bb2e64d7..9210c4c47 100644 --- a/src/careamics/transforms/compose.py +++ b/src/careamics/transforms/compose.py @@ -33,7 +33,19 @@ def get_all_transforms() -> Dict[str, type]: class Compose: - """A class chaining transforms together.""" + """A class chaining transforms together. + + Parameters + ---------- + transform_list : List[TRANSFORMS_UNION] + A list of dictionaries where each dictionary contains the name of a + transform and its parameters. + + Attributes + ---------- + _callable_transforms : Callable + A callable that applies the transforms to the input data. + """ def __init__(self, transform_list: List[TRANSFORMS_UNION]) -> None: """Instantiate a Compose object. @@ -69,6 +81,20 @@ def _chain_transforms(self, transforms: List[Transform]) -> Callable: def _chain( patch: np.ndarray, target: Optional[np.ndarray] ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Chain transforms on the input data. + + Parameters + ---------- + patch : np.ndarray + Input data. + target : Optional[np.ndarray] + Target data, by default None. + + Returns + ------- + Tuple[np.ndarray, Optional[np.ndarray]] + The output of the transformations. + """ params = (patch, target) for t in transforms: @@ -88,7 +114,7 @@ def __call__( patch : np.ndarray The input data. target : Optional[np.ndarray], optional - Target data, by default None + Target data, by default None. Returns ------- diff --git a/src/careamics/transforms/xy_flip.py b/src/careamics/transforms/xy_flip.py new file mode 100644 index 000000000..de26a8588 --- /dev/null +++ b/src/careamics/transforms/xy_flip.py @@ -0,0 +1,86 @@ +"""XY flip transform.""" + +from typing import Optional, Tuple + +import numpy as np + +from careamics.transforms.transform import Transform + + +class XYFlip(Transform): + """Flip image along X or Y axis. + + This transform ignores singleton axes and randomly flips one of the other + last two axes. + + This transform expects C(Z)YX dimensions. + + Attributes + ---------- + axis_indices : List[int] + Indices of the axes that can be flipped. + rng : np.random.Generator + Random number generator. + + Parameters + ---------- + seed : Optional[int], optional + Random seed, by default None. + """ + + def __init__(self, seed: Optional[int] = None) -> None: + """Constructor. + + Parameters + ---------- + seed : Optional[int], optional + Random seed, by default None. + """ + # "flippable" axes + self.axis_indices = [-2, -1] + + # numpy random generator + self.rng = np.random.default_rng(seed=seed) + + def __call__( + self, patch: np.ndarray, target: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """Apply the transform to the source patch and the target (optional). + + Parameters + ---------- + patch : np.ndarray + Patch, 2D or 3D, shape C(Z)YX. + target : Optional[np.ndarray], optional + Target for the patch, by default None. + + Returns + ------- + Tuple[np.ndarray, Optional[np.ndarray]] + Transformed patch and target. + """ + # choose an axis to flip + axis = self.rng.choice(self.axis_indices) + + patch_transformed = self._apply(patch, axis) + target_transformed = self._apply(target, axis) if target is not None else None + + return patch_transformed, target_transformed + + def _apply(self, patch: np.ndarray, axis: int) -> np.ndarray: + """Apply the transform to the image. + + Parameters + ---------- + patch : np.ndarray + Image patch, 2D or 3D, shape C(Z)YX. + axis : int + Axis to flip. + + Returns + ------- + np.ndarray + Flipped image patch. + """ + # TODO why ascontiguousarray? + return np.ascontiguousarray(np.flip(patch, axis=axis)) From a340500cbbcc72b167faf3c65d17280bd8973f29 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 29 May 2024 12:21:53 +0200 Subject: [PATCH 11/13] (chore): add docs --- src/careamics/models/activation.py | 2 + src/careamics/models/layers.py | 146 +++++++++++++++++++++----- src/careamics/models/model_factory.py | 2 +- src/careamics/models/unet.py | 23 ++-- 4 files changed, 140 insertions(+), 33 deletions(-) diff --git a/src/careamics/models/activation.py b/src/careamics/models/activation.py index c102fbc90..83d320d10 100644 --- a/src/careamics/models/activation.py +++ b/src/careamics/models/activation.py @@ -1,3 +1,5 @@ +"""Activations for CAREamics models.""" + from typing import Callable, Union import torch.nn as nn diff --git a/src/careamics/models/layers.py b/src/careamics/models/layers.py index 0fa846835..3cc621fe3 100644 --- a/src/careamics/models/layers.py +++ b/src/careamics/models/layers.py @@ -162,6 +162,18 @@ def _unpack_kernel_size( """Unpack kernel_size to a tuple of ints. Inspired by Kornia implementation. TODO: link + + Parameters + ---------- + kernel_size : Union[Tuple[int, ...], int] + Kernel size. + dim : int + Number of dimensions. + + Returns + ------- + Tuple[int, ...] + Kernel size tuple. """ if isinstance(kernel_size, int): kernel_dims = tuple([kernel_size for _ in range(dim)]) @@ -173,7 +185,20 @@ def _unpack_kernel_size( def _compute_zero_padding( kernel_size: Union[Tuple[int, ...], int], dim: int ) -> Tuple[int, ...]: - """Utility function that computes zero padding tuple.""" + """Utility function that computes zero padding tuple. + + Parameters + ---------- + kernel_size : Union[Tuple[int, ...], int] + Kernel size. + dim : int + Number of dimensions. + + Returns + ------- + Tuple[int, ...] + Zero padding tuple. + """ kernel_dims = _unpack_kernel_size(kernel_size, dim) return tuple([(kd - 1) // 2 for kd in kernel_dims]) @@ -191,14 +216,19 @@ def get_pascal_kernel_1d( Parameters ---------- - kernel_size: height and width of the kernel. - norm: if to normalize the kernel or not. Default: False. - device: tensor device - dtype: tensor dtype + kernel_size : int + Kernel size. + norm : bool + Normalize the kernel, by default False. + device : Optional[torch.device] + Device of the tensor, by default None. + dtype : Optional[torch.dtype] + Data type of the tensor, by default None. Returns ------- - kernel shaped as :math:`(kernel_size,)` + torch.Tensor + Pascal kernel. Examples -------- @@ -245,19 +275,28 @@ def _get_pascal_kernel_nd( ) -> torch.Tensor: """Generate pascal filter kernel by kernel size. + If kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size) + otherwise the kernel will be shaped as kernel_size + Inspired by Kornia implementation. Parameters ---------- - kernel_size: height and width of the kernel. - norm: if to normalize the kernel or not. Default: True. - device: tensor device - dtype: tensor dtype + kernel_size : Union[Tuple[int, int], int] + Kernel size for the pascal kernel. + norm : bool + Normalize the kernel, by default True. + dim : int + Number of dimensions, by default 2. + device : Optional[torch.device] + Device of the tensor, by default None. + dtype : Optional[torch.dtype] + Data type of the tensor, by default None. Returns ------- - if kernel_size is an integer the kernel will be shaped as (kernel_size, kernel_size) - otherwise the kernel will be shaped as kernel_size + torch.Tensor + Pascal kernel. Examples -------- @@ -303,6 +342,24 @@ def _max_blur_pool_by_kernel2d( """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxN` kernel. Inspired by Kornia implementation. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + kernel : torch.Tensor + Kernel tensor. + stride : int + Stride. + max_pool_size : int + Maximum pool size. + ceil_mode : bool + Ceil mode, by default False. Set to True to match output size of conv2d. + + Returns + ------- + torch.Tensor + Output tensor. """ # compute local maxima x = F.max_pool2d( @@ -323,6 +380,24 @@ def _max_blur_pool_by_kernel3d( """Compute max_blur_pool by a given :math:`CxC_(out, None)xNxNxN` kernel. Inspired by Kornia implementation. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + kernel : torch.Tensor + Kernel tensor. + stride : int + Stride. + max_pool_size : int + Maximum pool size. + ceil_mode : bool + Ceil mode, by default False. Set to True to match output size of conv2d. + + Returns + ------- + torch.Tensor + Output tensor. """ # compute local maxima x = F.max_pool3d( @@ -343,21 +418,16 @@ class MaxBlurPool(nn.Module): Parameters ---------- - dim: int - Toggles between 2D and 3D - kernel_size: Union[Tuple[int, int], int] + dim : int + Toggles between 2D and 3D. + kernel_size : Union[Tuple[int, int], int] Kernel size for max pooling. - stride: int + stride : int Stride for pooling. - max_pool_size: int + max_pool_size : int Max kernel size for max pooling. - ceil_mode: bool - Should be true to match output size of conv2d with same kernel size. - - Returns - ------- - torch.Tensor - The pooled and blurred tensor. + ceil_mode : bool + Ceil mode, by default False. Set to True to match output size of conv2d. """ def __init__( @@ -368,6 +438,21 @@ def __init__( max_pool_size: int = 2, ceil_mode: bool = False, ) -> None: + """Constructor. + + Parameters + ---------- + dim : int + Dimension of the convolution. + kernel_size : Union[Tuple[int, int], int] + Kernel size for max pooling. + stride : int, optional + Stride, by default 2. + max_pool_size : int, optional + Maximum pool size, by default 2. + ceil_mode : bool, optional + Ceil mode, by default False. Set to True to match output size of conv2d. + """ super().__init__() self.dim = dim self.kernel_size = kernel_size @@ -377,7 +462,18 @@ def __init__( self.kernel = _get_pascal_kernel_nd(kernel_size, norm=True, dim=self.dim) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass of the function.""" + """Forward pass of the function. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. + """ self.kernel = torch.as_tensor(self.kernel, device=x.device, dtype=x.dtype) if self.dim == 2: return _max_blur_pool_by_kernel2d( diff --git a/src/careamics/models/model_factory.py b/src/careamics/models/model_factory.py index 40a2f9013..0debed149 100644 --- a/src/careamics/models/model_factory.py +++ b/src/careamics/models/model_factory.py @@ -27,7 +27,7 @@ def model_factory( Parameters ---------- model_configuration : Union[UNetModel, VAEModel] - Model configuration + Model configuration. Returns ------- diff --git a/src/careamics/models/unet.py b/src/careamics/models/unet.py index a8836798a..13b5da0af 100644 --- a/src/careamics/models/unet.py +++ b/src/careamics/models/unet.py @@ -34,7 +34,9 @@ class UnetEncoder(nn.Module): Dropout probability, by default 0.0. pool_kernel : int, optional Kernel size for the max pooling layers, by default 2. - groups: int, optional + n2v2 : bool, optional + Whether to use N2V2 architecture, by default False. + groups : int, optional Number of blocked connections from input channels to output channels, by default 1. """ @@ -70,7 +72,9 @@ def __init__( Dropout probability, by default 0.0. pool_kernel : int, optional Kernel size for the max pooling layers, by default 2. - groups: int, optional + n2v2 : bool, optional + Whether to use N2V2 architecture, by default False. + groups : int, optional Number of blocked connections from input channels to output channels, by default 1. """ @@ -140,7 +144,9 @@ class UnetDecoder(nn.Module): Whether to use batch normalization, by default True. dropout : float, optional Dropout probability, by default 0.0. - groups: int, optional + n2v2 : bool, optional + Whether to use N2V2 architecture, by default False. + groups : int, optional Number of blocked connections from input channels to output channels, by default 1. """ @@ -170,7 +176,9 @@ def __init__( Whether to use batch normalization, by default True. dropout : float, optional Dropout probability, by default 0.0. - groups: int, optional + n2v2 : bool, optional + Whether to use N2V2 architecture, by default False. + groups : int, optional Number of blocked connections from input channels to output channels, by default 1. """ @@ -258,16 +266,17 @@ def _interleave(A: torch.Tensor, B: torch.Tensor, groups: int) -> torch.Tensor: Parameters ---------- - A: torch.Tensor + A : torch.Tensor First tensor. - B: torch.Tensor + B : torch.Tensor Second tensor. - groups: int + groups : int The number of groups. Returns ------- torch.Tensor + Interleaved tensor. Raises ------ From a04961f4a8ab7ba11c2645538f7f5e3f3ccee166 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 29 May 2024 12:24:17 +0200 Subject: [PATCH 12/13] (refac): remove unused noise models --- src/careamics/config/noise_models.py | 162 ------ src/careamics/losses/noise_model_factory.py | 40 -- src/careamics/losses/noise_models.py | 524 -------------------- 3 files changed, 726 deletions(-) delete mode 100644 src/careamics/config/noise_models.py delete mode 100644 src/careamics/losses/noise_model_factory.py delete mode 100644 src/careamics/losses/noise_models.py diff --git a/src/careamics/config/noise_models.py b/src/careamics/config/noise_models.py deleted file mode 100644 index 2bdae9388..000000000 --- a/src/careamics/config/noise_models.py +++ /dev/null @@ -1,162 +0,0 @@ -# from __future__ import annotations - -# from enum import Enum -# from typing import Dict, Union - -# from pydantic import BaseModel, ConfigDict, Field, field_validator - - -# class NoiseModelType(str, Enum): -# """ -# Available noise models. - -# Currently supported noise models: - -# - hist: Histogram noise model. -# - gmm: Gaussian mixture model noise model.F -# """ - -# NONE = "none" -# HIST = "hist" -# GMM = "gmm" - -# # TODO add validator decorator -# @classmethod -# def validate_noise_model_type( -# cls, noise_model: Union[str, NoiseModel], parameters: dict -# ) -> None: -# """_summary_. - -# Parameters -# ---------- -# noise_model : Union[str, NoiseModel] -# _description_ -# parameters : dict -# _description_ - -# Returns -# ------- -# BaseModel -# _description_ -# """ -# if noise_model == NoiseModelType.HIST.value: -# HistogramNoiseModel(**parameters) -# return HistogramNoiseModel().model_dump() if not parameters else parameters - -# elif noise_model == NoiseModelType.GMM.value: -# GaussianMixtureNoiseModel(**parameters) -# return ( -# GaussianMixtureNoiseModel().model_dump() -# if not parameters -# else parameters -# ) - - -# class NoiseModel(BaseModel): -# """_summary_. - -# Parameters -# ---------- -# BaseModel : _type_ -# _description_ - -# Returns -# ------- -# _type_ -# _description_ - -# Raises -# ------ -# ValueError -# _description_ -# """ - -# model_config = ConfigDict( -# use_enum_values=True, -# protected_namespaces=(), # allows to use model_* as a field name -# validate_assignment=True, -# ) - -# model_type: NoiseModelType -# parameters: Dict = Field(default_factory=dict, validate_default=True) - -# @field_validator("parameters") -# @classmethod -# def validate_parameters(cls, data, values) -> Dict: -# """_summary_. - -# Parameters -# ---------- -# parameters : Dict -# _description_ - -# Returns -# ------- -# Dict -# _description_ -# """ -# if values.data["model_type"] not in [NoiseModelType.GMM, NoiseModelType.HIST]: -# raise ValueError( -# f"Incorrect noise model {values.data['model_type']}." -# f"Please refer to the documentation" # TODO add link to documentation -# ) - -# parameters = NoiseModelType.validate_noise_model_type( -# values.data["model_type"], data -# ) -# return parameters - - -# class HistogramNoiseModel(BaseModel): -# """ -# Histogram noise model. - -# Attributes -# ---------- -# min_value : float -# Minimum value in the input. -# max_value : float -# Maximum value in the input. -# bins : int -# Number of bins of the histogram. -# """ - -# min_value: float = Field(default=350.0, ge=0.0, le=65535.0) -# max_value: float = Field(default=6500.0, ge=0.0, le=65535.0) -# bins: int = Field(default=256, ge=1) - - -# class GaussianMixtureNoiseModel(BaseModel): -# """ -# Gaussian mixture model noise model. - -# Attributes -# ---------- -# min_signal : float -# Minimum signal intensity expected in the image. -# max_signal : float -# Maximum signal intensity expected in the image. -# weight : array -# A [3*n_gaussian, n_coeff] sized array containing the values of the weights -# describing the noise model. -# Each gaussian contributes three parameters (mean, standard deviation and weight), -# hence the number of rows in `weight` are 3*n_gaussian. -# If `weight = None`, the weight array is initialized using the `min_signal` and -# `max_signal` parameters. -# n_gaussian: int -# Number of gaussians. -# n_coeff: int -# Number of coefficients to describe the functional relationship between gaussian -# parameters and the signal. -# 2 implies a linear relationship, 3 implies a quadratic relationship and so on. -# device: device -# GPU device. -# min_sigma: int -# """ - -# num_components: int = Field(default=3, ge=1) -# min_value: float = Field(default=350.0, ge=0.0, le=65535.0) -# max_value: float = Field(default=6500.0, ge=0.0, le=65535.0) -# n_gaussian: int = Field(default=3, ge=1) -# n_coeff: int = Field(default=2, ge=1) -# min_sigma: int = Field(default=50, ge=1) diff --git a/src/careamics/losses/noise_model_factory.py b/src/careamics/losses/noise_model_factory.py deleted file mode 100644 index 56173e4b3..000000000 --- a/src/careamics/losses/noise_model_factory.py +++ /dev/null @@ -1,40 +0,0 @@ -# from typing import Type, Union - -# from ..config.noise_models import NoiseModel, NoiseModelType -# from .noise_models import GaussianMixtureNoiseModel, HistogramNoiseModel - - -# def noise_model_factory( -# noise_config: NoiseModel, -# ) -> Type[Union[HistogramNoiseModel, GaussianMixtureNoiseModel, None]]: -# """Create loss model based on Configuration. - -# Parameters -# ---------- -# config : Configuration -# Configuration. - -# Returns -# ------- -# Noise model - -# Raises -# ------ -# NotImplementedError -# If the noise model is unknown. -# """ -# noise_model_type = noise_config.model_type if noise_config else None - -# if noise_model_type == NoiseModelType.HIST: -# return HistogramNoiseModel - -# elif noise_model_type == NoiseModelType.GMM: -# return GaussianMixtureNoiseModel - -# elif noise_model_type is None: -# return None - -# else: -# raise NotImplementedError( -# f"Noise model {noise_model_type} is not yet supported." -# ) diff --git a/src/careamics/losses/noise_models.py b/src/careamics/losses/noise_models.py deleted file mode 100644 index f43906a93..000000000 --- a/src/careamics/losses/noise_models.py +++ /dev/null @@ -1,524 +0,0 @@ -# from abc import ABC, abstractmethod - -# import numpy as np -# import torch - -# from ..utils.logging import get_logger - -# logger = get_logger(__name__) - - -# # TODO here "Model" clashes a bit with the naming convention of the Pydantic Models -# class NoiseModel(ABC): -# """Base class for noise models.""" - -# @abstractmethod -# def instantiate(self): -# """Instantiate the noise model. - -# Method that should produce ready to use noise model. -# """ -# pass - -# @abstractmethod -# def likelihood(self, observations, signals): -# """Function that returns the likelihood of observations given the signals.""" -# pass - - -# class HistogramNoiseModel(NoiseModel): -# """Creates a NoiseModel object. - -# Parameters -# ---------- -# histogram: numpy array -# A histogram as create by the 'createHistogram(...)' method. -# device: -# The device your NoiseModel lives on, e.g. your GPU. -# """ - -# def __init__(self, **kwargs): -# pass - -# def instantiate(self, bins, min_value, max_value, observation, signal): -# """Creates a nD histogram from 'observation' and 'signal'. - -# Parameters -# ---------- -# bins: int -# The number of bins in all dimensions. The total number of bins is -# 'bins' ** number_of_dimensions. -# min_value: float -# the lower bound of the lowest bin. -# max_value: float -# the highest bound of the highest bin. -# observation: np.array -# A stack of noisy images. The number has to be divisible by the number of -# images in signal. N subsequent images in observation belong to one image -# in the signal. -# signal: np.array -# A stack of clean images. - -# Returns -# ------- -# histogram: numpy array -# A 3D array: -# 'histogram[0,...]' holds the normalized nD counts. -# Each row sums to 1, describing p(x_i|s_i). -# 'histogram[1,...]' holds the lower boundaries of each bin in y. -# 'histogram[2,...]' holds the upper boundaries of each bin in y. -# The values for x can be obtained by transposing 'histogram[1,...]' -# and 'histogram[2,...]'. -# """ -# img_factor = int(observation.shape[0] / signal.shape[0]) -# histogram = np.zeros((3, bins, bins)) -# value_range = [min_value, max_value] - -# for i in range(observation.shape[0]): -# observation_i = observation[i].copy().ravel() - -# signal_i = (signal[i // img_factor].copy()).ravel() - -# histogram_i = np.histogramdd( -# (signal_i, observation_i), bins=bins, range=[value_range, value_range] -# ) -# # Adding a constant for numerical stability -# histogram[0] = histogram[0] + histogram_i[0] + 1e-30 - -# for i in range(bins): -# # Exclude empty rows from normalization -# if np.sum(histogram[0, i, :]) > 1e-20: -# # Normalize each non-empty row -# histogram[0, i, :] /= np.sum(histogram[0, i, :]) - -# for i in range(bins): -# # The lower boundaries of each bin in y are stored in dimension 1 -# histogram[1, :, i] = histogram_i[1][:-1] -# # The upper boundaries of each bin in y are stored in dimension 2 -# histogram[2, :, i] = histogram_i[1][1:] -# # The accordent numbers for x are just transposed. - -# return histogram - -# def likelihood(self, observed, signal): -# """Calculate the likelihood using a histogram based noise model. - -# For every pixel in a tensor, calculate (x_i|s_i). To ensure differentiability -# in the direction of s_i, we linearly interpolate in this direction. - -# Parameters -# ---------- -# observed: torch.Tensor -# tensor holding your observed intesities x_i. - -# signal: torch.Tensor -# tensor holding hypotheses for the clean signal at every pixel s_i^k. - -# Returns -# ------- -# Torch.tensor containing the observation likelihoods according to the -# noise model. -# """ -# observed_float = self.get_index_observed_float(observed) -# observed_long = observed_float.floor().long() -# signal_float = self.get_index_signal_float(signal) -# signal_long = signal_float.floor().long() -# fact = signal_float - signal_long.float() - -# # Finally we are looking ud the values and interpolate -# return self.fullHist[signal_long, observed_long] * (1.0 - fact) + self.fullHist[ -# torch.clamp((signal_long + 1).long(), 0, self.bins.long()), observed_long -# ] * (fact) - -# def get_index_observed_float(self, x: float): -# """_summary_. - -# Parameters -# ---------- -# x : _type_ -# _description_ - -# Returns -# ------- -# _type_ -# _description_ -# """ -# return torch.clamp( -# self.bins * (x - self.minv) / (self.maxv - self.minv), -# min=0.0, -# max=self.bins - 1 - 1e-3, -# ) - -# def get_index_signal_float(self, x): -# """_summary_. - -# Parameters -# ---------- -# x : _type_ -# _description_ - -# Returns -# ------- -# _type_ -# _description_ -# """ -# return torch.clamp( -# self.bins * (x - self.minv) / (self.maxv - self.minv), -# min=0.0, -# max=self.bins - 1 - 1e-3, -# ) - - -# # TODO refactor this into Pydantic model -# class GaussianMixtureNoiseModel(NoiseModel): -# """Describes a noise model parameterized as a mixture of gaussians. - -# If you would like to initialize a new object from scratch, then set `params` = None -# and specify the other parameters as keyword arguments. If you are instead loading -# a model, use only `params`. - -# Parameters -# ---------- -# **kwargs: keyworded, variable-length argument dictionary. -# Arguments include: -# min_signal : float -# Minimum signal intensity expected in the image. -# max_signal : float -# Maximum signal intensity expected in the image. -# weight : array -# A [3*n_gaussian, n_coeff] sized array containing the values of the weights -# describing the noise model. -# Each gaussian contributes three parameters (mean, standard deviation and weight), -# hence the number of rows in `weight` are 3*n_gaussian. -# If `weight = None`, the weight array is initialized using the `min_signal` and -# `max_signal` parameters. -# n_gaussian: int -# Number of gaussians. -# n_coeff: int -# Number of coefficients to describe the functional relationship between gaussian -# parameters and the signal. -# 2 implies a linear relationship, 3 implies a quadratic relationship and so on. -# device: device -# GPU device. -# min_sigma: int -# All values of sigma (`standard deviation`) below min_sigma are clamped to become -# equal to min_sigma. -# params: dictionary -# Use `params` if one wishes to load a model with trained weights. -# While initializing a new object of the class `GaussianMixtureNoiseModel` from -# scratch, set this to `None`. -# """ - -# def __init__(self, **kwargs): -# if kwargs.get("params") is None: -# weight = kwargs.get("weight") -# n_gaussian = kwargs.get("n_gaussian") -# n_coeff = kwargs.get("n_coeff") -# min_signal = kwargs.get("min_signal") -# max_signal = kwargs.get("max_signal") -# self.device = kwargs.get("device") -# self.path = kwargs.get("path") -# self.min_sigma = kwargs.get("min_sigma") -# if weight is None: -# weight = np.random.randn(n_gaussian * 3, n_coeff) -# weight[n_gaussian : 2 * n_gaussian, 1] = np.log(max_signal - min_signal) -# weight = ( -# torch.from_numpy(weight.astype(np.float32)).float().to(self.device) -# ) -# weight.requires_grad = True -# self.n_gaussian = weight.shape[0] // 3 -# self.n_coeff = weight.shape[1] -# self.weight = weight -# self.min_signal = torch.Tensor([min_signal]).to(self.device) -# self.max_signal = torch.Tensor([max_signal]).to(self.device) -# self.tol = torch.Tensor([1e-10]).to(self.device) -# else: -# params = kwargs.get("params") -# self.device = kwargs.get("device") - -# self.min_signal = torch.Tensor(params["min_signal"]).to(self.device) -# self.max_signal = torch.Tensor(params["max_signal"]).to(self.device) - -# self.weight = torch.Tensor(params["trained_weight"]).to(self.device) -# self.min_sigma = np.ndarray.item(params["min_sigma"]) -# self.n_gaussian = self.weight.shape[0] // 3 -# self.n_coeff = self.weight.shape[1] -# self.tol = torch.Tensor([1e-10]).to(self.device) -# self.min_signal = torch.Tensor([self.min_signal]).to(self.device) -# self.max_signal = torch.Tensor([self.max_signal]).to(self.device) - -# def fast_shuffle(self, series, num): -# """. - -# Parameters -# ---------- -# series : _type_ -# _description_ -# num : _type_ -# _description_ - -# Returns -# ------- -# _type_ -# _description_ -# """ -# length = series.shape[0] -# for _i in range(num): -# series = series[np.random.permutation(length), :] -# return series - -# def polynomial_regressor(self, weightParams, signals): -# """Combines weight_parameters and signals to perform regression. - -# Parameters -# ---------- -# weightParams : torch.cuda.FloatTensor -# Corresponds to specific rows of the `self.weight' - -# signals : torch.cuda.FloatTensor -# Signals - -# Returns -# ------- -# value : torch.cuda.FloatTensor -# Corresponds to either of mean, standard deviation or weight, evaluated at -# `signals` -# """ -# value = 0 -# for i in range(weightParams.shape[0]): -# value += weightParams[i] * ( -# ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i -# ) -# return value - -# def normal_density(self, x, m_=0.0, std_=None): -# """Evaluates the normal probability density. - -# Parameters -# ---------- -# x: torch.cuda.FloatTensor -# Observations -# m_: torch.cuda.FloatTensor -# Mean -# std_: torch.cuda.FloatTensor -# Standard-deviation - -# Returns -# ------- -# tmp: torch.cuda.FloatTensor -# Normal probability density of `x` given `m_` and `std_` - -# """ -# tmp = -((x - m_) ** 2) -# tmp = tmp / (2.0 * std_ * std_) -# tmp = torch.exp(tmp) -# tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_) -# return tmp - -# def likelihood(self, observations, signals): -# """Evaluates the likelihood of observations. - -# Given the signals and the corresponding gaussian parameters evaluates the -# likelihood of observations. - -# Parameters -# ---------- -# observations : torch.cuda.FloatTensor -# Noisy observations -# signals : torch.cuda.FloatTensor -# Underlying signals - -# Returns -# ------- -# value :p + self.tol -# Likelihood of observations given the signals and the GMM noise model - -# """ -# gaussianParameters = self.getGaussianParameters(signals) -# p = 0 -# for gaussian in range(self.n_gaussian): -# p += ( -# self.normalDens( -# observations, -# gaussianParameters[gaussian], -# gaussianParameters[self.n_gaussian + gaussian], -# ) -# * gaussianParameters[2 * self.n_gaussian + gaussian] -# ) -# return p + self.tol - -# def get_gaussian_parameters(self, signals): -# """Returns the noise model for given signals. - -# Parameters -# ---------- -# signals : torch.cuda.FloatTensor -# Underlying signals - -# Returns -# ------- -# noiseModel: list of torch.cuda.FloatTensor -# Contains a list of `mu`, `sigma` and `alpha` for the `signals` - -# """ -# noiseModel = [] -# mu = [] -# sigma = [] -# alpha = [] -# kernels = self.weight.shape[0] // 3 -# for num in range(kernels): -# mu.append(self.polynomialRegressor(self.weight[num, :], signals)) - -# sigmaTemp = self.polynomialRegressor( -# torch.exp(self.weight[kernels + num, :]), signals -# ) -# sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma) -# sigma.append(torch.sqrt(sigmaTemp)) -# alpha.append( -# torch.exp( -# self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) -# + self.tol -# ) -# ) - -# sum_alpha = 0 -# for al in range(kernels): -# sum_alpha = alpha[al] + sum_alpha -# for ker in range(kernels): -# alpha[ker] = alpha[ker] / sum_alpha - -# sum_means = 0 -# for ker in range(kernels): -# sum_means = alpha[ker] * mu[ker] + sum_means - -# for ker in range(kernels): -# mu[ker] = mu[ker] - sum_means + signals - -# for i in range(kernels): -# noiseModel.append(mu[i]) -# for j in range(kernels): -# noiseModel.append(sigma[j]) -# for k in range(kernels): -# noiseModel.append(alpha[k]) - -# return noiseModel - -# def get_signal_observation_pairs(self, signal, observation, lowerClip, upperClip): -# """Returns the Signal-Observation pixel intensities as a two-column array. - -# Parameters -# ---------- -# signal : numpy array -# Clean Signal Data -# observation: numpy array -# Noisy observation Data -# lowerClip: float -# Lower percentile bound for clipping. -# upperClip: float -# Upper percentile bound for clipping. - -# Returns -# ------- -# noiseModel: list of torch floats -# Contains a list of `mu`, `sigma` and `alpha` for the `signals` - -# """ -# lb = np.percentile(signal, lowerClip) -# ub = np.percentile(signal, upperClip) -# stepsize = observation[0].size -# n_observations = observation.shape[0] -# n_signals = signal.shape[0] -# sig_obs_pairs = np.zeros((n_observations * stepsize, 2)) - -# for i in range(n_observations): -# j = i // (n_observations // n_signals) -# sig_obs_pairs[stepsize * i : stepsize * (i + 1), 0] = signal[j].ravel() -# sig_obs_pairs[stepsize * i : stepsize * (i + 1), 1] = observation[i].ravel() -# sig_obs_pairs = sig_obs_pairs[ -# (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub) -# ] -# return self.fast_shuffle(sig_obs_pairs, 2) - -# def train( -# self, -# signal, -# observation, -# learning_rate=1e-1, -# batchSize=250000, -# n_epochs=2000, -# name="GMMNoiseModel.npz", -# lowerClip=0, -# upperClip=100, -# ): -# """Training to learn the noise model from signal - observation pairs. - -# Parameters -# ---------- -# signal: numpy array -# Clean Signal Data -# observation: numpy array -# Noisy Observation Data -# learning_rate: float -# Learning rate. Default = 1e-1. -# batchSize: int -# Nini-batch size. Default = 250000. -# n_epochs: int -# Number of epochs. Default = 2000. -# name: string -# Model name. Default is `GMMNoiseModel`. This model after being trained is -# saved at the location `path`. - -# lowerClip : int -# Lower percentile for clipping. Default is 0. -# upperClip : int -# Upper percentile for clipping. Default is 100. - - -# """ -# sig_obs_pairs = self.getSignalObservationPairs( -# signal, observation, lowerClip, upperClip -# ) -# counter = 0 -# optimizer = torch.optim.Adam([self.weight], lr=learning_rate) -# for t in range(n_epochs): -# jointLoss = 0 -# if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]: -# counter = 0 -# sig_obs_pairs = self.fast_shuffle(sig_obs_pairs, 1) - -# batch_vectors = sig_obs_pairs[ -# counter * batchSize : (counter + 1) * batchSize, : -# ] -# observations = batch_vectors[:, 1].astype(np.float32) -# signals = batch_vectors[:, 0].astype(np.float32) -# observations = ( -# torch.from_numpy(observations.astype(np.float32)) -# .float() -# .to(self.device) -# ) -# signals = torch.from_numpy(signals).float().to(self.device) -# p = self.likelihood(observations, signals) -# loss = torch.mean(-torch.log(p)) -# jointLoss = jointLoss + loss - -# if t % 100 == 0: -# print(t, jointLoss.item()) - -# if t % (int(n_epochs * 0.5)) == 0: -# trained_weight = self.weight.cpu().detach().numpy() -# min_signal = self.min_signal.cpu().detach().numpy() -# max_signal = self.max_signal.cpu().detach().numpy() -# np.savez( -# self.path + name, -# trained_weight=trained_weight, -# min_signal=min_signal, -# max_signal=max_signal, -# min_sigma=self.min_sigma, -# ) - -# optimizer.zero_grad() -# jointLoss.backward() -# optimizer.step() -# counter += 1 - -# logger.info(f"The trained parameters {name} is saved at location: " + self.path) From 10af1b4bf346ca045f6bbfcaeba50d473e2d0c70 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 29 May 2024 20:17:10 +0200 Subject: [PATCH 13/13] (chore): add more doc fixes --- .../architectures/architecture_model.py | 5 ++ .../config/architectures/custom_model.py | 9 ++- .../config/architectures/register_model.py | 4 +- .../config/architectures/unet_model.py | 2 + .../config/architectures/vae_model.py | 2 + src/careamics/config/callback_model.py | 18 +---- src/careamics/config/configuration_example.py | 4 +- src/careamics/config/optimizer_models.py | 8 +-- src/careamics/config/support/__init__.py | 2 - .../config/support/supported_activations.py | 2 + .../config/support/supported_algorithms.py | 4 +- .../config/support/supported_architectures.py | 2 + .../config/support/supported_data.py | 2 + .../supported_extraction_strategies.py | 25 ------- .../config/support/supported_loggers.py | 2 + .../config/support/supported_losses.py | 2 + .../config/support/supported_optimizers.py | 2 + .../support/supported_pixel_manipulations.py | 6 +- .../config/support/supported_struct_axis.py | 2 + .../config/support/supported_transforms.py | 17 +---- src/careamics/config/tile_information.py | 2 + .../dataset/dataset_utils/dataset_utils.py | 8 +-- .../dataset/dataset_utils/file_utils.py | 7 +- .../dataset/dataset_utils/read_tiff.py | 4 +- .../dataset/dataset_utils/read_utils.py | 2 + .../dataset/dataset_utils/read_zarr.py | 18 +++-- src/careamics/dataset/iterable_dataset.py | 2 + src/careamics/dataset/patching/patching.py | 71 +++++++++++++++---- .../dataset/patching/random_patching.py | 10 ++- .../dataset/patching/sequential_patching.py | 22 +++--- .../dataset/patching/tiled_patching.py | 2 + .../patching/validate_patch_dimension.py | 2 + src/careamics/dataset/zarr_dataset.py | 2 + src/careamics/lightning_module.py | 10 ++- src/careamics/lightning_prediction_loop.py | 13 ++-- src/careamics/losses/loss_factory.py | 2 +- src/careamics/losses/losses.py | 15 ++-- 37 files changed, 191 insertions(+), 121 deletions(-) delete mode 100644 src/careamics/config/support/supported_extraction_strategies.py diff --git a/src/careamics/config/architectures/architecture_model.py b/src/careamics/config/architectures/architecture_model.py index 6ea3f122f..7f836e6b5 100644 --- a/src/careamics/config/architectures/architecture_model.py +++ b/src/careamics/config/architectures/architecture_model.py @@ -18,6 +18,11 @@ def model_dump(self, **kwargs: Any) -> Dict[str, Any]: """ Dump the model as a dictionary, ignoring the architecture keyword. + Parameters + ---------- + **kwargs : Any + Additional keyword arguments from Pydantic BaseModel model_dump method. + Returns ------- dict[str, Any] diff --git a/src/careamics/config/architectures/custom_model.py b/src/careamics/config/architectures/custom_model.py index 3290344a1..462108bf3 100644 --- a/src/careamics/config/architectures/custom_model.py +++ b/src/careamics/config/architectures/custom_model.py @@ -1,3 +1,5 @@ +"""Custom architecture Pydantic model.""" + from __future__ import annotations from pprint import pformat @@ -84,6 +86,11 @@ def custom_model_is_known(cls, value: str) -> str: value : str Name of the custom model as registered using the `@register_model` decorator. + + Returns + ------- + str + The custom model name. """ # delegate error to get_custom_model model = get_custom_model(value) @@ -134,7 +141,7 @@ def model_dump(self, **kwargs: Any) -> Dict[str, Any]: Parameters ---------- - kwargs : Any + **kwargs : Any Additional keyword arguments from Pydantic BaseModel model_dump method. Returns diff --git a/src/careamics/config/architectures/register_model.py b/src/careamics/config/architectures/register_model.py index f35b0b88c..89d6896d8 100644 --- a/src/careamics/config/architectures/register_model.py +++ b/src/careamics/config/architectures/register_model.py @@ -1,3 +1,5 @@ +"""Custom model registration utilities.""" + from typing import Callable from torch.nn import Module @@ -53,7 +55,7 @@ def add_custom_model(model: Module) -> Module: Parameters ---------- model : Module - Module class to register + Module class to register. Returns ------- diff --git a/src/careamics/config/architectures/unet_model.py b/src/careamics/config/architectures/unet_model.py index 0af639e98..272becc72 100644 --- a/src/careamics/config/architectures/unet_model.py +++ b/src/careamics/config/architectures/unet_model.py @@ -1,3 +1,5 @@ +"""UNet Pydantic model.""" + from __future__ import annotations from typing import Literal diff --git a/src/careamics/config/architectures/vae_model.py b/src/careamics/config/architectures/vae_model.py index 03c7eb60b..0865c9e50 100644 --- a/src/careamics/config/architectures/vae_model.py +++ b/src/careamics/config/architectures/vae_model.py @@ -1,3 +1,5 @@ +"""VAE Pydantic model.""" + from typing import Literal from pydantic import ( diff --git a/src/careamics/config/callback_model.py b/src/careamics/config/callback_model.py index 666e82533..fcb1920ec 100644 --- a/src/careamics/config/callback_model.py +++ b/src/careamics/config/callback_model.py @@ -1,4 +1,4 @@ -"""Checkpoint saving configuration.""" +"""Callback Pydantic models.""" from __future__ import annotations @@ -13,13 +13,7 @@ class CheckpointModel(BaseModel): - """_summary_. - - Parameters - ---------- - BaseModel : _type_ - _description_ - """ + """Checkpoint saving callback Pydantic model.""" model_config = ConfigDict( validate_assignment=True, @@ -46,13 +40,7 @@ class CheckpointModel(BaseModel): class EarlyStoppingModel(BaseModel): - """_summary_. - - Parameters - ---------- - BaseModel : _type_ - _description_ - """ + """Early stopping callback Pydantic model.""" model_config = ConfigDict( validate_assignment=True, diff --git a/src/careamics/config/configuration_example.py b/src/careamics/config/configuration_example.py index f601409e4..37fae9c72 100644 --- a/src/careamics/config/configuration_example.py +++ b/src/careamics/config/configuration_example.py @@ -1,3 +1,5 @@ +"""Example of configurations.""" + from .algorithm_model import AlgorithmConfig from .architectures import UNetModel from .configuration_model import Configuration @@ -19,7 +21,7 @@ def full_configuration_example() -> Configuration: - """Returns a dictionnary representing a full configuration example. + """Return a dictionnary representing a full configuration example. Returns ------- diff --git a/src/careamics/config/optimizer_models.py b/src/careamics/config/optimizer_models.py index 43acbd6e9..1a709511c 100644 --- a/src/careamics/config/optimizer_models.py +++ b/src/careamics/config/optimizer_models.py @@ -1,3 +1,5 @@ +"""Optimizers and schedulers Pydantic models.""" + from __future__ import annotations from typing import Dict, Literal @@ -19,8 +21,7 @@ class OptimizerModel(BaseModel): - """ - Torch optimizer. + """Torch optimizer Pydantic model. Only parameters supported by the corresponding torch optimizer will be taken into account. For more details, check: @@ -115,8 +116,7 @@ def sgd_lr_parameter(self) -> Self: class LrSchedulerModel(BaseModel): - """ - Torch learning rate scheduler. + """Torch learning rate scheduler Pydantic model. Only parameters supported by the corresponding torch lr scheduler will be taken into account. For more details, check: diff --git a/src/careamics/config/support/__init__.py b/src/careamics/config/support/__init__.py index abb8284cc..db3ab620e 100644 --- a/src/careamics/config/support/__init__.py +++ b/src/careamics/config/support/__init__.py @@ -14,7 +14,6 @@ "SupportedPixelManipulation", "SupportedTransform", "SupportedData", - "SupportedExtractionStrategy", "SupportedStructAxis", "SupportedLogger", ] @@ -24,7 +23,6 @@ from .supported_algorithms import SupportedAlgorithm from .supported_architectures import SupportedArchitecture from .supported_data import SupportedData -from .supported_extraction_strategies import SupportedExtractionStrategy from .supported_loggers import SupportedLogger from .supported_losses import SupportedLoss from .supported_optimizers import SupportedOptimizer, SupportedScheduler diff --git a/src/careamics/config/support/supported_activations.py b/src/careamics/config/support/supported_activations.py index d7c84ae3c..56970a955 100644 --- a/src/careamics/config/support/supported_activations.py +++ b/src/careamics/config/support/supported_activations.py @@ -1,3 +1,5 @@ +"""Activations supported by CAREamics.""" + from careamics.utils import BaseEnum diff --git a/src/careamics/config/support/supported_algorithms.py b/src/careamics/config/support/supported_algorithms.py index a44e179b3..678d5b560 100644 --- a/src/careamics/config/support/supported_algorithms.py +++ b/src/careamics/config/support/supported_algorithms.py @@ -1,3 +1,5 @@ +"""Algorithms supported by CAREamics.""" + from __future__ import annotations from careamics.utils import BaseEnum @@ -10,9 +12,9 @@ class SupportedAlgorithm(str, BaseEnum): """ N2V = "n2v" - CUSTOM = "custom" CARE = "care" N2N = "n2n" + CUSTOM = "custom" # PN2V = "pn2v" # HDN = "hdn" # SEG = "segmentation" diff --git a/src/careamics/config/support/supported_architectures.py b/src/careamics/config/support/supported_architectures.py index 8246cf9cd..5a2c4eafe 100644 --- a/src/careamics/config/support/supported_architectures.py +++ b/src/careamics/config/support/supported_architectures.py @@ -1,3 +1,5 @@ +"""Architectures supported by CAREamics.""" + from careamics.utils import BaseEnum diff --git a/src/careamics/config/support/supported_data.py b/src/careamics/config/support/supported_data.py index cb9fd7004..73f32975a 100644 --- a/src/careamics/config/support/supported_data.py +++ b/src/careamics/config/support/supported_data.py @@ -1,3 +1,5 @@ +"""Data supported by CAREamics.""" + from __future__ import annotations from typing import Union diff --git a/src/careamics/config/support/supported_extraction_strategies.py b/src/careamics/config/support/supported_extraction_strategies.py deleted file mode 100644 index 759f0156f..000000000 --- a/src/careamics/config/support/supported_extraction_strategies.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Extraction strategy module. - -This module defines the various extraction strategies available in CAREamics. -""" - -from careamics.utils import BaseEnum - - -class SupportedExtractionStrategy(str, BaseEnum): - """ - Available extraction strategies. - - Currently supported: - - random: random extraction. - # TODO - - sequential: grid extraction, can miss edge values. - - tiled: tiled extraction, covers the whole image. - """ - - RANDOM = "random" - RANDOM_ZARR = "random_zarr" - SEQUENTIAL = "sequential" - TILED = "tiled" - NONE = "none" diff --git a/src/careamics/config/support/supported_loggers.py b/src/careamics/config/support/supported_loggers.py index b4d4842f1..e6169932c 100644 --- a/src/careamics/config/support/supported_loggers.py +++ b/src/careamics/config/support/supported_loggers.py @@ -1,3 +1,5 @@ +"""Logger supported by CAREamics.""" + from careamics.utils import BaseEnum diff --git a/src/careamics/config/support/supported_losses.py b/src/careamics/config/support/supported_losses.py index 4235034ec..d730439d6 100644 --- a/src/careamics/config/support/supported_losses.py +++ b/src/careamics/config/support/supported_losses.py @@ -1,3 +1,5 @@ +"""Losses supported by CAREamics.""" + from careamics.utils import BaseEnum diff --git a/src/careamics/config/support/supported_optimizers.py b/src/careamics/config/support/supported_optimizers.py index 40adb00ad..85e922cc8 100644 --- a/src/careamics/config/support/supported_optimizers.py +++ b/src/careamics/config/support/supported_optimizers.py @@ -1,3 +1,5 @@ +"""Optimizers and schedulers supported by CAREamics.""" + from careamics.utils import BaseEnum diff --git a/src/careamics/config/support/supported_pixel_manipulations.py b/src/careamics/config/support/supported_pixel_manipulations.py index 84db6d05a..65f8d48fb 100644 --- a/src/careamics/config/support/supported_pixel_manipulations.py +++ b/src/careamics/config/support/supported_pixel_manipulations.py @@ -1,15 +1,15 @@ +"""Pixel manipulation methods supported by CAREamics.""" + from careamics.utils import BaseEnum class SupportedPixelManipulation(str, BaseEnum): - """_summary_. + """Supported Noise2Void pixel manipulations. - Uniform: Replace masked pixel value by a (uniformly) randomly selected neighbor pixel value. - Median: Replace masked pixel value by the mean of the neighborhood. """ - # TODO docs - UNIFORM = "uniform" MEDIAN = "median" diff --git a/src/careamics/config/support/supported_struct_axis.py b/src/careamics/config/support/supported_struct_axis.py index 4d82307d4..9278ffee4 100644 --- a/src/careamics/config/support/supported_struct_axis.py +++ b/src/careamics/config/support/supported_struct_axis.py @@ -1,3 +1,5 @@ +"""StructN2V axes supported by CAREamics.""" + from careamics.utils import BaseEnum diff --git a/src/careamics/config/support/supported_transforms.py b/src/careamics/config/support/supported_transforms.py index ed61a2fc9..04f9ba7f5 100644 --- a/src/careamics/config/support/supported_transforms.py +++ b/src/careamics/config/support/supported_transforms.py @@ -1,23 +1,12 @@ +"""Transforms supported by CAREamics.""" + from careamics.utils import BaseEnum class SupportedTransform(str, BaseEnum): - """Transforms officially supported by CAREamics. - - - Flip: from Albumentations, randomly flip the input horizontally, vertically or - both, parameter `p` can be used to set the probability to apply the transform. - - XYRandomRotate90: #TODO - - Normalize # TODO add details, in particular about the parameters - - ManipulateN2V # TODO add details, in particular about the parameters - - XYFlip - - Note that while any Albumentations (see https://albumentations.ai/) transform can be - used in CAREamics, no check are implemented to verify the compatibility of any other - transforms than the ones officially supported. - """ + """Transforms officially supported by CAREamics.""" XY_FLIP = "XYFlip" XY_RANDOM_ROTATE90 = "XYRandomRotate90" NORMALIZE = "Normalize" N2V_MANIPULATE = "N2VManipulate" - # CUSTOM = "Custom" diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py index e018e0f16..3fc6a3468 100644 --- a/src/careamics/config/tile_information.py +++ b/src/careamics/config/tile_information.py @@ -1,3 +1,5 @@ +"""Pydantic model representing the metadata of a prediction tile.""" + from __future__ import annotations from typing import Optional, Tuple diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py index 4ce652752..ebaed0d46 100644 --- a/src/careamics/dataset/dataset_utils/dataset_utils.py +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -1,4 +1,4 @@ -"""Convenience methods for datasets.""" +"""Dataset utilities.""" from typing import List, Tuple @@ -17,12 +17,12 @@ def _get_shape_order( Parameters ---------- - shape_in : Tuple + shape_in : Tuple[int, ...] Input shape. - ref_axes : str - Reference axes. axes_in : str Input axes. + ref_axes : str + Reference axes. Returns ------- diff --git a/src/careamics/dataset/dataset_utils/file_utils.py b/src/careamics/dataset/dataset_utils/file_utils.py index 67b65f409..949b588e3 100644 --- a/src/careamics/dataset/dataset_utils/file_utils.py +++ b/src/careamics/dataset/dataset_utils/file_utils.py @@ -1,3 +1,5 @@ +"""File utilities.""" + from fnmatch import fnmatch from pathlib import Path from typing import List, Union @@ -11,8 +13,7 @@ def get_files_size(files: List[Path]) -> float: - """ - Get files size in MB. + """Get files size in MB. Parameters ---------- @@ -32,7 +33,7 @@ def list_files( data_type: Union[str, SupportedData], extension_filter: str = "", ) -> List[Path]: - """Creates a recursive list of files in `data_path`. + """Create a recursive list of files in `data_path`. If `data_path` is a file, its name is validated against the `data_type` using `fnmatch`, and the method returns `data_path` itself. diff --git a/src/careamics/dataset/dataset_utils/read_tiff.py b/src/careamics/dataset/dataset_utils/read_tiff.py index 7b4dd8e02..b0015600f 100644 --- a/src/careamics/dataset/dataset_utils/read_tiff.py +++ b/src/careamics/dataset/dataset_utils/read_tiff.py @@ -1,3 +1,5 @@ +"""Funtions to read tiff images.""" + import logging from fnmatch import fnmatch from pathlib import Path @@ -19,8 +21,6 @@ def read_tiff(file_path: Path, *args: list, **kwargs: dict) -> np.ndarray: ---------- file_path : Path Path to a file. - axes : str - Description of axes in format STCZYX. Returns ------- diff --git a/src/careamics/dataset/dataset_utils/read_utils.py b/src/careamics/dataset/dataset_utils/read_utils.py index 558626a8e..753732538 100644 --- a/src/careamics/dataset/dataset_utils/read_utils.py +++ b/src/careamics/dataset/dataset_utils/read_utils.py @@ -1,3 +1,5 @@ +"""Read function utilities.""" + from typing import Callable, Union from careamics.config.support import SupportedData diff --git a/src/careamics/dataset/dataset_utils/read_zarr.py b/src/careamics/dataset/dataset_utils/read_zarr.py index 5878a1cf0..d5153ce7a 100644 --- a/src/careamics/dataset/dataset_utils/read_zarr.py +++ b/src/careamics/dataset/dataset_utils/read_zarr.py @@ -1,3 +1,5 @@ +"""Function to read zarr images.""" + from typing import Union from zarr import Group, core, hierarchy, storage @@ -6,26 +8,28 @@ def read_zarr( zarr_source: Group, axes: str ) -> Union[core.Array, storage.DirectoryStore, hierarchy.Group]: - """Reads a file and returns a pointer. + """Read a file and returns a pointer. Parameters ---------- - file_path : Path - pathlib.Path object containing a path to a file + zarr_source : Group + Zarr storage. + axes : str + Axes of the data. Returns ------- np.ndarray - Pointer to zarr storage + Pointer to zarr storage. Raises ------ ValueError, OSError - if a file is not a valid tiff or damaged + if a file is not a valid tiff or damaged. ValueError - if data dimensions are not 2, 3 or 4 + if data dimensions are not 2, 3 or 4. ValueError - if axes parameter from config is not consistent with data dimensions + if axes parameter from config is not consistent with data dimensions. """ if isinstance(zarr_source, hierarchy.Group): array = zarr_source[0] diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index b5e48dc01..0e6835f5f 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -1,3 +1,5 @@ +"""Iterable dataset used to load data file by file.""" + from __future__ import annotations import copy diff --git a/src/careamics/dataset/patching/patching.py b/src/careamics/dataset/patching/patching.py index 49e3d3f75..d445c0ec3 100644 --- a/src/careamics/dataset/patching/patching.py +++ b/src/careamics/dataset/patching/patching.py @@ -1,8 +1,4 @@ -""" -Tiling submodule. - -These functions are used to tile images into patches or tiles. -""" +"""Patching functions.""" from pathlib import Path from typing import Callable, List, Tuple, Union @@ -21,12 +17,25 @@ def prepare_patches_supervised( train_files: List[Path], target_files: List[Path], axes: str, - patch_size: Union[List[int], Tuple[int]], + patch_size: Union[List[int], Tuple[int, ...]], read_source_func: Callable, ) -> Tuple[np.ndarray, np.ndarray, float, float]: """ Iterate over data source and create an array of patches and corresponding targets. + Parameters + ---------- + train_files : List[Path] + List of paths to training data. + target_files : List[Path] + List of paths to target data. + axes : str + Axes of the data. + patch_size : Union[List[int], Tuple[int]] + Size of the patches. + read_source_func : Callable + Function to read the data. + Returns ------- np.ndarray @@ -95,13 +104,25 @@ def prepare_patches_unsupervised( patch_size: Union[List[int], Tuple[int]], read_source_func: Callable, ) -> Tuple[np.ndarray, None, float, float]: - """ - Iterate over data source and create an array of patches. + """Iterate over data source and create an array of patches. + + This method returns the mean and standard deviation of the image. + + Parameters + ---------- + train_files : List[Path] + List of paths to training data. + axes : str + Axes of the data. + patch_size : Union[List[int], Tuple[int]] + Size of the patches. + read_source_func : Callable + Function to read the data. Returns ------- - np.ndarray - Array of patches. + Tuple[np.ndarray, None, float, float] + Source and target patches, mean and standard deviation. """ means, stds, num_samples = 0, 0, 0 all_patches = [] @@ -150,10 +171,21 @@ def prepare_patches_supervised_array( Patches returned are of shape SC(Z)YX, where S is now the patches dimension. + Parameters + ---------- + data : np.ndarray + Input data array. + axes : str + Axes of the data. + data_target : np.ndarray + Target data array. + patch_size : Union[List[int], Tuple[int]] + Size of the patches. + Returns ------- - np.ndarray - Array of patches. + Tuple[np.ndarray, np.ndarray, float, float] + Source and target patches, mean and standard deviation. """ # compute statistics mean = data.mean() @@ -195,10 +227,19 @@ def prepare_patches_unsupervised_array( Patches returned are of shape SC(Z)YX, where S is now the patches dimension. + Parameters + ---------- + data : np.ndarray + Input data array. + axes : str + Axes of the data. + patch_size : Union[List[int], Tuple[int]] + Size of the patches. + Returns ------- - np.ndarray - Array of patches. + Tuple[np.ndarray, None, float, float] + Source patches, mean and standard deviation. """ # calculate mean and std mean = data.mean() @@ -210,4 +251,4 @@ def prepare_patches_unsupervised_array( # generate patches, return a generator patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size) - return patches, _, mean, std # TODO inelegant, replace by dataclass? + return patches, _, mean, std # TODO inelegant, replace by dataclass? diff --git a/src/careamics/dataset/patching/random_patching.py b/src/careamics/dataset/patching/random_patching.py index c06c5bbd6..e71c54a6f 100644 --- a/src/careamics/dataset/patching/random_patching.py +++ b/src/careamics/dataset/patching/random_patching.py @@ -1,3 +1,5 @@ +"""Random patching utilities.""" + from typing import Generator, List, Optional, Tuple, Union import numpy as np @@ -30,6 +32,8 @@ def extract_patches_random( Input image array. patch_size : Tuple[int] Patch sizes in each dimension. + target : Optional[np.ndarray], optional + Target array, by default None. Yields ------ @@ -120,10 +124,12 @@ def extract_patches_random_from_chunks( ---------- arr : np.ndarray Input image array. - patch_size : Tuple[int] + patch_size : Union[List[int], Tuple[int, ...]] Patch sizes in each dimension. - chunk_size : Tuple[int] + chunk_size : Union[List[int], Tuple[int, ...]] Chunk sizes to load from the. + chunk_limit : Optional[int], optional + Number of chunks to load, by default None. Yields ------ diff --git a/src/careamics/dataset/patching/sequential_patching.py b/src/careamics/dataset/patching/sequential_patching.py index 6b09cd08d..1d1df044b 100644 --- a/src/careamics/dataset/patching/sequential_patching.py +++ b/src/careamics/dataset/patching/sequential_patching.py @@ -1,3 +1,5 @@ +"""Sequential patching functions.""" + from typing import List, Optional, Tuple, Union import numpy as np @@ -14,14 +16,14 @@ def _compute_number_of_patches( Parameters ---------- - arr : Tuple[int, ...] + arr_shape : Tuple[int, ...] Shape of the input array. - patch_sizes : Tuple[int] + patch_sizes : Union[List[int], Tuple[int, ...] Shape of the patches. Returns ------- - Tuple[int] + Tuple[int, ...] Number of patches in each dimension. """ if len(arr_shape) != len(patch_sizes): @@ -55,14 +57,14 @@ def _compute_overlap( Parameters ---------- - arr : Tuple[int, ...] + arr_shape : Tuple[int, ...] Input array shape. - patch_sizes : Tuple[int] + patch_sizes : Union[List[int], Tuple[int, ...]] Size of the patches. Returns ------- - Tuple[int] + Tuple[int, ...] Overlap between patches in each dimension. """ n_patches = _compute_number_of_patches(arr_shape, patch_sizes) @@ -123,6 +125,8 @@ def _compute_patch_views( Steps between views. output_shape : Tuple[int] Shape of the output array. + target : Optional[np.ndarray], optional + Target array, by default None. Returns ------- @@ -161,11 +165,13 @@ def extract_patches_sequential( Input image array. patch_size : Tuple[int] Patch sizes in each dimension. + target : Optional[np.ndarray], optional + Target array, by default None. Returns ------- - Generator[Tuple[np.ndarray, ...], None, None] - Generator of patches. + Tuple[np.ndarray, Optional[np.ndarray]] + Patches. """ is_3d_patch = len(patch_size) == 3 diff --git a/src/careamics/dataset/patching/tiled_patching.py b/src/careamics/dataset/patching/tiled_patching.py index ddd97c1b0..890c7f616 100644 --- a/src/careamics/dataset/patching/tiled_patching.py +++ b/src/careamics/dataset/patching/tiled_patching.py @@ -1,3 +1,5 @@ +"""Tiled patching utilities.""" + import itertools from typing import Generator, List, Tuple, Union diff --git a/src/careamics/dataset/patching/validate_patch_dimension.py b/src/careamics/dataset/patching/validate_patch_dimension.py index ffd013b01..8174493a4 100644 --- a/src/careamics/dataset/patching/validate_patch_dimension.py +++ b/src/careamics/dataset/patching/validate_patch_dimension.py @@ -1,3 +1,5 @@ +"""Patch validation functions.""" + from typing import List, Tuple, Union import numpy as np diff --git a/src/careamics/dataset/zarr_dataset.py b/src/careamics/dataset/zarr_dataset.py index ee54fdd26..7ec334aca 100644 --- a/src/careamics/dataset/zarr_dataset.py +++ b/src/careamics/dataset/zarr_dataset.py @@ -1,3 +1,5 @@ +"""Zarr dataset.""" + # from itertools import islice # from typing import Callable, Dict, List, Optional, Tuple, Union diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py index 18a07bf16..e16ea39e4 100644 --- a/src/careamics/lightning_module.py +++ b/src/careamics/lightning_module.py @@ -1,3 +1,5 @@ +"""CAREamics Lightning module.""" + from typing import Any, Optional, Union import pytorch_lightning as L @@ -24,6 +26,11 @@ class CAREamicsModule(L.LightningModule): This class encapsulates the a PyTorch model along with the training, validation, and testing logic. It is configured using an `AlgorithmModel` Pydantic class. + Parameters + ---------- + algorithm_config : Union[AlgorithmModel, dict] + Algorithm configuration. + Attributes ---------- model : nn.Module @@ -39,8 +46,7 @@ class CAREamicsModule(L.LightningModule): """ def __init__(self, algorithm_config: Union[AlgorithmConfig, dict]) -> None: - """ - CAREamics Lightning module. + """Lightning module for CAREamics. This class encapsulates the a PyTorch model along with the training, validation, and testing logic. It is configured using an `AlgorithmModel` Pydantic class. diff --git a/src/careamics/lightning_prediction_loop.py b/src/careamics/lightning_prediction_loop.py index c7e00fd2e..46d41ca99 100644 --- a/src/careamics/lightning_prediction_loop.py +++ b/src/careamics/lightning_prediction_loop.py @@ -1,3 +1,5 @@ +"""Lithning prediction loop allowing tiling.""" + from typing import Optional import pytorch_lightning as L @@ -18,14 +20,14 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop): """ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: - """ - Calls `on_predict_epoch_end` hook. + """Call `on_predict_epoch_end` hook. Adapted from the parent method. Returns ------- - the results for all dataloaders + Optional[_PREDICT_OUTPUT] + Prediction output. """ trainer = self.trainer call._call_callback_hooks(trainer, "on_predict_epoch_end") @@ -45,15 +47,14 @@ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: @_no_grad_context def run(self) -> Optional[_PREDICT_OUTPUT]: - """ - Runs the prediction loop. + """Run the prediction loop. Adapted from the parent method in order to stitch the predictions. Returns ------- Optional[_PREDICT_OUTPUT] - Prediction output + Prediction output. """ self.setup_data() if self.skip: diff --git a/src/careamics/losses/loss_factory.py b/src/careamics/losses/loss_factory.py index 80fe66e7e..7376cc41c 100644 --- a/src/careamics/losses/loss_factory.py +++ b/src/careamics/losses/loss_factory.py @@ -17,7 +17,7 @@ def loss_factory(loss: Union[SupportedLoss, str]) -> Callable: Parameters ---------- - loss: SupportedLoss + loss : Union[SupportedLoss, str] Requested loss. Returns diff --git a/src/careamics/losses/losses.py b/src/careamics/losses/losses.py index c6c3234ee..e29675f0c 100644 --- a/src/careamics/losses/losses.py +++ b/src/careamics/losses/losses.py @@ -8,17 +8,24 @@ from torch.nn import L1Loss, MSELoss -def mse_loss(samples: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: +def mse_loss(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Mean squared error loss. + Parameters + ---------- + source : torch.Tensor + Source patches. + target : torch.Tensor + Target patches. + Returns ------- torch.Tensor Loss value. """ loss = MSELoss() - return loss(samples, labels) + return loss(source, target) def n2v_loss( @@ -31,9 +38,9 @@ def n2v_loss( Parameters ---------- - samples : torch.Tensor + manipulated_patches : torch.Tensor Patches with manipulated pixels. - labels : torch.Tensor + original_patches : torch.Tensor Noisy patches. masks : torch.Tensor Array containing masked pixel locations.