diff --git a/.github/dependabot.yml b/.github/dependabot.yml index dbe5ef0e0..e189a531c 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -8,3 +8,4 @@ updates: time: "08:00" ignore: - dependency-name: "numpy" + - dependency-name: "zarr" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 96d9f1b83..c2ed145dd 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,33 +1,80 @@ -### Description +## Description -Please provide a brief description of the changes in this PR. Include any relevant context or background information. + -- **What**: Clearly and concisely describe what changes you have made. -- **Why**: Explain the reasoning behind these changes. What problem are you solving? Why is this change necessary? -- **How**: Describe how you implemented these changes. Provide an overview of the approach and any important implementation details. +> [!NOTE] +> **tldr**: -### Changes Made -- **Added**: List new features or files added. -- **Modified**: Describe existing features or files modified. -- **Removed**: Detail features or files that were removed. +### Background - why do we need this PR? -### Related Issues + -Link to any related issues or discussions. Use keywords like "Fixes", "Resolves", or "Closes" to link to issues automatically. +### Overview - what changed? -- Fixes # -- Resolves # -- Closes # + + +### Implementation - how did you implement the changes? + + + + +## Changes Made + + + +### New features or files + + +- -### Breaking changes +### Modified features or files + + +- + +### Removed features or files + + +- + +## How has this been tested? + + + + +## Related Issues + + + +- Resolves # -Describe any breaking change. +## Breaking changes + -### Additional Notes and Examples +## Additional Notes and Examples -Include any additional notes or context that reviewers should be aware of, including snippets of code illustrating your new feature. + --- diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db5d716ee..b3a381b3f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: validate-pyproject - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.8.1 + rev: v0.8.6 hooks: - id: ruff exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^scripts/.*" @@ -26,14 +26,15 @@ repos: - id: black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.13.0 + rev: v1.14.1 hooks: - id: mypy files: "^src/" exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/config/likelihood_model.py|^src/careamics/losses/loss_factory.py|^src/careamics/losses/lvae/losses.py" args: ["--config-file", "mypy.ini"] additional_dependencies: - - numpy + - numpy<2.0.0 + - pydantic - types-PyYAML - types-setuptools diff --git a/mypy.ini b/mypy.ini index f87d73eb1..cffb58834 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,6 @@ [mypy] ignore_missing_imports = True +plugins = pydantic.mypy [mypy-careamics.lvae_training.*] follow_imports = skip diff --git a/pyproject.toml b/pyproject.toml index 38f2a3961..cef46b433 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,18 +43,20 @@ classifiers = [ ] dependencies = [ 'numpy<2.0.0', - 'torch>=2.0,<=2.5.1', + 'torch>=2.0,<=2.6.0', + 'torchvision<=0.21.0', 'torchvision<=0.20.1', 'bioimageio.core==0.7', - 'tifffile<=2024.12.12', - 'psutil<=6.1', + 'tifffile<=2025.1.10', + 'psutil<=6.1.1', 'pydantic>=2.5,<2.11', - 'pytorch_lightning>=2.2,<=2.4', + 'pytorch_lightning>=2.2,<=2.5.0.post0', 'pyyaml<=6.0.2,!=6.0.0', 'typer>=0.12.3,<=0.15.1', - 'scikit-image<=0.25.0', + 'scikit-image<=0.25.1', 'zarr<3.0.0', - 'pillow<=11.0.0', + 'pillow<=11.1.0', + 'matplotlib<=3.10.0' ] [project.optional-dependencies] diff --git a/src/careamics/__init__.py b/src/careamics/__init__.py index 975020433..0253d94b3 100644 --- a/src/careamics/__init__.py +++ b/src/careamics/__init__.py @@ -7,7 +7,22 @@ except PackageNotFoundError: __version__ = "uninstalled" -__all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"] +__all__ = [ + "CAREamist", + "Configuration", + "algorithm_factory", + "configuration_factory", + "data_factory", + "load_configuration", + "save_configuration", +] from .careamist import CAREamist -from .config import Configuration, load_configuration, save_configuration +from .config import ( + Configuration, + algorithm_factory, + configuration_factory, + data_factory, + load_configuration, + save_configuration, +) diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 315f3a60a..b0aa83512 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -13,7 +13,7 @@ ) from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger -from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration +from careamics.config import Configuration, UNetBasedAlgorithm, load_configuration from careamics.config.support import ( SupportedAlgorithm, SupportedArchitecture, @@ -137,7 +137,7 @@ def __init__( self.cfg = source # instantiate model - if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig): + if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm): self.model = FCNModule( algorithm_config=self.cfg.algorithm_config, ) @@ -157,7 +157,8 @@ def __init__( self.cfg = load_configuration(source) # instantiate model - if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig): + # TODO call model factory here + if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm): self.model = FCNModule( algorithm_config=self.cfg.algorithm_config, ) # type: ignore diff --git a/src/careamics/config/__init__.py b/src/careamics/config/__init__.py index c204377e6..4bcc4d683 100644 --- a/src/careamics/config/__init__.py +++ b/src/careamics/config/__init__.py @@ -1,41 +1,63 @@ -"""Configuration module.""" +"""CAREamics Pydantic configuration models. + +To maintain clarity at the module level, we follow the following naming conventions: +`*_model` is specific for sub-configurations (e.g. architecture, data, algorithm), +while `*_configuration` is reserved for the main configuration models, including the +`Configuration` base class and its algorithm-specific child classes. +""" __all__ = [ + "CAREAlgorithm", + "CAREConfiguration", "CheckpointModel", "Configuration", - "CustomModel", "DataConfig", - "FCNAlgorithmConfig", "GaussianMixtureNMConfig", + "GeneralDataConfig", "InferenceConfig", "LVAELossConfig", "MultiChannelNMConfig", + "N2NAlgorithm", + "N2NConfiguration", + "N2VAlgorithm", + "N2VConfiguration", + "N2VDataConfig", "TrainingConfig", - "VAEAlgorithmConfig", - "clear_custom_models", + "UNetBasedAlgorithm", + "VAEBasedAlgorithm", + "algorithm_factory", + "configuration_factory", "create_care_configuration", "create_n2n_configuration", "create_n2v_configuration", + "data_factory", "load_configuration", - "register_model", "save_configuration", ] -from .architectures import CustomModel, clear_custom_models, register_model + +from .algorithms import ( + CAREAlgorithm, + N2NAlgorithm, + N2VAlgorithm, + UNetBasedAlgorithm, + VAEBasedAlgorithm, +) from .callback_model import CheckpointModel -from .configuration_factory import ( +from .care_configuration import CAREConfiguration +from .configuration import Configuration +from .configuration_factories import ( + algorithm_factory, + configuration_factory, create_care_configuration, create_n2n_configuration, create_n2v_configuration, + data_factory, ) -from .configuration_model import ( - Configuration, - load_configuration, - save_configuration, -) -from .data_model import DataConfig -from .fcn_algorithm_model import FCNAlgorithmConfig +from .configuration_io import load_configuration, save_configuration +from .data import DataConfig, GeneralDataConfig, N2VDataConfig from .inference_model import InferenceConfig from .loss_model import LVAELossConfig +from .n2n_configuration import N2NConfiguration +from .n2v_configuration import N2VConfiguration from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig from .training_model import TrainingConfig -from .vae_algorithm_model import VAEAlgorithmConfig diff --git a/src/careamics/config/algorithms/__init__.py b/src/careamics/config/algorithms/__init__.py new file mode 100644 index 000000000..3ab358258 --- /dev/null +++ b/src/careamics/config/algorithms/__init__.py @@ -0,0 +1,15 @@ +"""Algorithm configurations.""" + +__all__ = [ + "CAREAlgorithm", + "N2NAlgorithm", + "N2VAlgorithm", + "UNetBasedAlgorithm", + "VAEBasedAlgorithm", +] + +from .care_algorithm_model import CAREAlgorithm +from .n2n_algorithm_model import N2NAlgorithm +from .n2v_algorithm_model import N2VAlgorithm +from .unet_algorithm_model import UNetBasedAlgorithm +from .vae_algorithm_model import VAEBasedAlgorithm diff --git a/src/careamics/config/algorithms/care_algorithm_model.py b/src/careamics/config/algorithms/care_algorithm_model.py new file mode 100644 index 000000000..57435112a --- /dev/null +++ b/src/careamics/config/algorithms/care_algorithm_model.py @@ -0,0 +1,38 @@ +"""CARE algorithm configuration.""" + +from typing import Annotated, Literal + +from pydantic import AfterValidator + +from careamics.config.architectures import UNetModel +from careamics.config.validators import ( + model_without_final_activation, + model_without_n2v2, +) + +from .unet_algorithm_model import UNetBasedAlgorithm + + +class CAREAlgorithm(UNetBasedAlgorithm): + """CARE algorithm configuration. + + Attributes + ---------- + algorithm : "care" + CARE Algorithm name. + loss : {"mae", "mse"} + CARE-compatible loss function. + """ + + algorithm: Literal["care"] = "care" + """CARE Algorithm name.""" + + loss: Literal["mae", "mse"] = "mae" + """CARE-compatible loss function.""" + + model: Annotated[ + UNetModel, + AfterValidator(model_without_n2v2), + AfterValidator(model_without_final_activation), + ] + """UNet without a final activation function and without the `n2v2` modifications.""" diff --git a/src/careamics/config/algorithms/n2n_algorithm_model.py b/src/careamics/config/algorithms/n2n_algorithm_model.py new file mode 100644 index 000000000..08bdbaa72 --- /dev/null +++ b/src/careamics/config/algorithms/n2n_algorithm_model.py @@ -0,0 +1,30 @@ +"""N2N Algorithm configuration.""" + +from typing import Annotated, Literal + +from pydantic import AfterValidator + +from careamics.config.architectures import UNetModel +from careamics.config.validators import ( + model_without_final_activation, + model_without_n2v2, +) + +from .unet_algorithm_model import UNetBasedAlgorithm + + +class N2NAlgorithm(UNetBasedAlgorithm): + """Noise2Noise Algorithm configuration.""" + + algorithm: Literal["n2n"] = "n2n" + """N2N Algorithm name.""" + + loss: Literal["mae", "mse"] = "mae" + """N2N-compatible loss function.""" + + model: Annotated[ + UNetModel, + AfterValidator(model_without_n2v2), + AfterValidator(model_without_final_activation), + ] + """UNet without a final activation function and without the `n2v2` modifications.""" diff --git a/src/careamics/config/algorithms/n2v_algorithm_model.py b/src/careamics/config/algorithms/n2v_algorithm_model.py new file mode 100644 index 000000000..235b7b2a8 --- /dev/null +++ b/src/careamics/config/algorithms/n2v_algorithm_model.py @@ -0,0 +1,29 @@ +""""N2V Algorithm configuration.""" + +from typing import Annotated, Literal + +from pydantic import AfterValidator + +from careamics.config.architectures import UNetModel +from careamics.config.validators import ( + model_matching_in_out_channels, + model_without_final_activation, +) + +from .unet_algorithm_model import UNetBasedAlgorithm + + +class N2VAlgorithm(UNetBasedAlgorithm): + """N2V Algorithm configuration.""" + + algorithm: Literal["n2v"] = "n2v" + """N2V Algorithm name.""" + + loss: Literal["n2v"] = "n2v" + """N2V loss function.""" + + model: Annotated[ + UNetModel, + AfterValidator(model_matching_in_out_channels), + AfterValidator(model_without_final_activation), + ] diff --git a/src/careamics/config/algorithms/unet_algorithm_model.py b/src/careamics/config/algorithms/unet_algorithm_model.py new file mode 100644 index 000000000..d1ca22557 --- /dev/null +++ b/src/careamics/config/algorithms/unet_algorithm_model.py @@ -0,0 +1,88 @@ +"""UNet-based algorithm Pydantic model.""" + +from pprint import pformat +from typing import Literal + +from pydantic import BaseModel, ConfigDict + +from careamics.config.architectures import UNetModel +from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel + + +class UNetBasedAlgorithm(BaseModel): + """General UNet-based algorithm configuration. + + This Pydantic model validates the parameters governing the components of the + training algorithm: which algorithm, loss function, model architecture, optimizer, + and learning rate scheduler to use. + + Currently, we only support N2V, CARE, and N2N algorithms. In order to train these + algorithms, use the corresponding configuration child classes (e.g. + `N2VAlgorithm`) to ensure coherent parameters (e.g. specific losses). + + + Attributes + ---------- + algorithm : {"n2v", "care", "n2n"} + Algorithm to use. + loss : {"n2v", "mae", "mse"} + Loss function to use. + model : UNetModel + Model architecture to use. + optimizer : OptimizerModel, optional + Optimizer to use. + lr_scheduler : LrSchedulerModel, optional + Learning rate scheduler to use. + + Raises + ------ + ValueError + Algorithm parameter type validation errors. + ValueError + If the algorithm, loss and model are not compatible. + """ + + # Pydantic class configuration + model_config = ConfigDict( + protected_namespaces=(), # allows to use model_* as a field name + validate_assignment=True, + extra="allow", + ) + + # Mandatory fields + algorithm: Literal["n2v", "care", "n2n"] + """Algorithm name, as defined in SupportedAlgorithm.""" + + loss: Literal["n2v", "mae", "mse"] + """Loss function to use, as defined in SupportedLoss.""" + + model: UNetModel + """UNet model configuration.""" + + # Optional fields + optimizer: OptimizerModel = OptimizerModel() + """Optimizer to use, defined in SupportedOptimizer.""" + + lr_scheduler: LrSchedulerModel = LrSchedulerModel() + """Learning rate scheduler to use, defined in SupportedLrScheduler.""" + + def __str__(self) -> str: + """Pretty string representing the configuration. + + Returns + ------- + str + Pretty string. + """ + return pformat(self.model_dump()) + + @classmethod + def get_compatible_algorithms(cls) -> list[str]: + """Get the list of compatible algorithms. + + Returns + ------- + list of str + List of compatible algorithms. + """ + return ["n2v", "care", "n2n"] diff --git a/src/careamics/config/vae_algorithm_model.py b/src/careamics/config/algorithms/vae_algorithm_model.py similarity index 87% rename from src/careamics/config/vae_algorithm_model.py rename to src/careamics/config/algorithms/vae_algorithm_model.py index fffcdb801..a77ac361d 100644 --- a/src/careamics/config/vae_algorithm_model.py +++ b/src/careamics/config/algorithms/vae_algorithm_model.py @@ -1,24 +1,26 @@ -"""Algorithm configuration.""" +"""VAE-based algorithm Pydantic model.""" from __future__ import annotations from pprint import pformat -from typing import Literal, Optional, Union +from typing import Literal, Optional -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, model_validator from typing_extensions import Self +from careamics.config.architectures import LVAEModel +from careamics.config.likelihood_model import ( + GaussianLikelihoodConfig, + NMLikelihoodConfig, +) +from careamics.config.loss_model import LVAELossConfig +from careamics.config.nm_model import MultiChannelNMConfig +from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel from careamics.config.support import SupportedAlgorithm, SupportedLoss -from .architectures import CustomModel, LVAEModel -from .likelihood_model import GaussianLikelihoodConfig, NMLikelihoodConfig -from .loss_model import LVAELossConfig -from .nm_model import MultiChannelNMConfig -from .optimizer_models import LrSchedulerModel, OptimizerModel - -class VAEAlgorithmConfig(BaseModel): - """Algorithm configuration. +class VAEBasedAlgorithm(BaseModel): + """VAE-based algorithm configuration. # TODO @@ -42,7 +44,7 @@ class VAEAlgorithmConfig(BaseModel): # NOTE: these are all configs (pydantic models) loss: LVAELossConfig - model: Union[LVAEModel, CustomModel] = Field(discriminator="architecture") + model: LVAEModel noise_model: Optional[MultiChannelNMConfig] = None noise_model_likelihood: Optional[NMLikelihoodConfig] = None gaussian_likelihood: Optional[GaussianLikelihoodConfig] = None diff --git a/src/careamics/config/architectures/__init__.py b/src/careamics/config/architectures/__init__.py index 148f4e29e..4ad4dfc67 100644 --- a/src/careamics/config/architectures/__init__.py +++ b/src/careamics/config/architectures/__init__.py @@ -1,17 +1,7 @@ """Deep-learning model configurations.""" -__all__ = [ - "ArchitectureModel", - "CustomModel", - "LVAEModel", - "UNetModel", - "clear_custom_models", - "get_custom_model", - "register_model", -] +__all__ = ["ArchitectureModel", "LVAEModel", "UNetModel"] from .architecture_model import ArchitectureModel -from .custom_model import CustomModel from .lvae_model import LVAEModel -from .register_model import clear_custom_models, get_custom_model, register_model from .unet_model import UNetModel diff --git a/src/careamics/config/architectures/custom_model.py b/src/careamics/config/architectures/custom_model.py deleted file mode 100644 index fbda9ef8d..000000000 --- a/src/careamics/config/architectures/custom_model.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Custom architecture Pydantic model.""" - -from __future__ import annotations - -import inspect -from pprint import pformat -from typing import Any, Literal - -from pydantic import ConfigDict, field_validator, model_validator -from torch.nn import Module -from typing_extensions import Self - -from .architecture_model import ArchitectureModel -from .register_model import get_custom_model - - -class CustomModel(ArchitectureModel): - """Custom model configuration. - - This Pydantic model allows storing parameters for a custom model. In order for the - model to be valid, the specific model needs to be registered using the - `register_model` decorator, and its name correctly passed to this model - configuration (see Examples). - - Attributes - ---------- - architecture : Literal["custom"] - Discriminator for the custom model, must be set to "custom". - name : str - Name of the custom model. - parameters : CustomParametersModel - All parameters, required for the initialization of the torch module have to be - passed here. - - Raises - ------ - ValueError - If the custom model `name` is unknown. - ValueError - If the custom model is not a torch Module subclass. - ValueError - If the custom model parameters are not valid. - - Examples - -------- - >>> from torch import nn, ones - >>> from careamics.config import CustomModel, register_model - >>> # Register a custom model - >>> @register_model(name="my_linear") - ... class LinearModel(nn.Module): - ... def __init__(self, in_features, out_features, *args, **kwargs): - ... super().__init__() - ... self.in_features = in_features - ... self.out_features = out_features - ... self.weight = nn.Parameter(ones(in_features, out_features)) - ... self.bias = nn.Parameter(ones(out_features)) - ... def forward(self, input): - ... return (input @ self.weight) + self.bias - ... - >>> # Create a configuration - >>> config_dict = { - ... "architecture": "custom", - ... "name": "my_linear", - ... "in_features": 10, - ... "out_features": 5, - ... } - >>> config = CustomModel(**config_dict) - """ - - # pydantic model config - model_config = ConfigDict( - extra="allow", - ) - - # discriminator used for choosing the pydantic model in Model - architecture: Literal["custom"] - """Name of the architecture.""" - - name: str - """Name of the custom model.""" - - @field_validator("name") - @classmethod - def custom_model_is_known(cls, value: str) -> str: - """Check whether the custom model is known. - - Parameters - ---------- - value : str - Name of the custom model as registered using the `@register_model` - decorator. - - Returns - ------- - str - The custom model name. - """ - # delegate error to get_custom_model - model = get_custom_model(value) - - # check if it is a torch Module subclass - if not issubclass(model, Module): - raise ValueError( - f'Retrieved class {model} with name "{value}" is not a ' - f"torch.nn.Module subclass." - ) - - return value - - @model_validator(mode="after") - def check_parameters(self: Self) -> Self: - """Validate model by instantiating the model with the parameters. - - Returns - ------- - Self - The validated model. - """ - # instantiate model - try: - get_custom_model(self.name)(**self.model_dump()) - except Exception as e: - raise ValueError( - f"while passing parameters to the model {e}. Verify that all " - f"mandatory parameters are provided, and that either the {e} accepts " - f"*args and **kwargs in its __init__() method, or that no additional" - f"parameter is provided. Trace: " - f"filename: {inspect.trace()[-1].filename}, function: " - f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}" - ) from None - - return self - - def __str__(self) -> str: - """Pretty string representing the configuration. - - Returns - ------- - str - Pretty string. - """ - return pformat(self.model_dump()) - - def model_dump(self, **kwargs: Any) -> dict[str, Any]: - """Dump the model configuration. - - Parameters - ---------- - **kwargs : Any - Additional keyword arguments from Pydantic BaseModel model_dump method. - - Returns - ------- - dict[str, Any] - Model configuration. - """ - model_dict = super().model_dump() - - # remove the name key - model_dict.pop("name") - - return model_dict diff --git a/src/careamics/config/architectures/lvae_model.py b/src/careamics/config/architectures/lvae_model.py index 8881845f0..0075d23b6 100644 --- a/src/careamics/config/architectures/lvae_model.py +++ b/src/careamics/config/architectures/lvae_model.py @@ -15,12 +15,17 @@ class LVAEModel(ArchitectureModel): model_config = ConfigDict(validate_assignment=True, validate_default=True) architecture: Literal["LVAE"] - input_shape: list[int] = Field(default=(64, 64), validate_default=True) + """Name of the architecture.""" + + input_shape: list[int] = Field(default=[64, 64], validate_default=True) """Shape of the input patch (C, Z, Y, X) or (C, Y, X) if the data is 2D.""" + encoder_conv_strides: list = Field(default=[2, 2], validate_default=True) + # TODO make this per hierarchy step ? decoder_conv_strides: list = Field(default=[2, 2], validate_default=True) """Dimensions (2D or 3D) of the convolutional layers.""" + multiscale_count: int = Field(default=1) # TODO there should be a check for multiscale_count in dataset !! diff --git a/src/careamics/config/architectures/register_model.py b/src/careamics/config/architectures/register_model.py deleted file mode 100644 index 89d6896d8..000000000 --- a/src/careamics/config/architectures/register_model.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Custom model registration utilities.""" - -from typing import Callable - -from torch.nn import Module - -CUSTOM_MODELS = {} # dictionary of custom models {"name": __class__} - - -def register_model(name: str) -> Callable: - """Decorator used to register a torch.nn.Module class with a given `name`. - - Parameters - ---------- - name : str - Name of the model. - - Returns - ------- - Callable - Function allowing to instantiate the wrapped Module class. - - Raises - ------ - ValueError - If a model is already registered with that name. - - Examples - -------- - ```python - @register_model(name="linear") - class LinearModel(nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - - self.weight = nn.Parameter(ones(in_features, out_features)) - self.bias = nn.Parameter(ones(out_features)) - - def forward(self, input): - return (input @ self.weight) + self.bias - ``` - """ - if name is None or name == "": - raise ValueError("Model name cannot be empty.") - - if name in CUSTOM_MODELS: - raise ValueError( - f"Model {name} already exists. Choose a different name or run " - f"`clear_custom_models()` to empty the registry." - ) - - def add_custom_model(model: Module) -> Module: - """Add a custom model to the registry and return it. - - Parameters - ---------- - model : Module - Module class to register. - - Returns - ------- - Module - The registered model. - """ - # add model to the registry - CUSTOM_MODELS[name] = model - - return model - - return add_custom_model - - -def get_custom_model(name: str) -> Module: - """Get the custom model corresponding to `name` from the registry. - - Parameters - ---------- - name : str - Name of the model to retrieve. - - Returns - ------- - Module - The requested model. - - Raises - ------ - ValueError - If the model is not registered. - """ - if name not in CUSTOM_MODELS: - raise ValueError( - f"Model {name} is unknown. Have you registered it using " - f'@register_model("{name}") as decorator?' - ) - - return CUSTOM_MODELS[name] - - -def clear_custom_models() -> None: - """Clear the custom models registry.""" - # clear dictionary - CUSTOM_MODELS.clear() diff --git a/src/careamics/config/architectures/unet_model.py b/src/careamics/config/architectures/unet_model.py index 6ad1b334d..d1a5d129d 100644 --- a/src/careamics/config/architectures/unet_model.py +++ b/src/careamics/config/architectures/unet_model.py @@ -48,6 +48,7 @@ class UNetModel(ArchitectureModel): num_channels_init: int = Field(default=32, ge=8, le=1024, validate_default=True) """Number of convolutional filters in the first layer of the UNet.""" + # TODO we are not using this, so why make it a choice? final_activation: Literal[ "None", "Sigmoid", "Softmax", "Tanh", "ReLU", "LeakyReLU" ] = Field(default="None", validate_default=True) diff --git a/src/careamics/config/care_configuration.py b/src/careamics/config/care_configuration.py new file mode 100644 index 000000000..251b619f0 --- /dev/null +++ b/src/careamics/config/care_configuration.py @@ -0,0 +1,100 @@ +"""CARE Pydantic configuration.""" + +from bioimageio.spec.generic.v0_3 import CiteEntry + +from careamics.config.algorithms.care_algorithm_model import CAREAlgorithm +from careamics.config.configuration import Configuration +from careamics.config.data import DataConfig + +CARE = "CARE" + +CARE_DESCRIPTION = ( + "Content-aware image restoration (CARE) is a deep-learning-based " + "algorithm that uses a U-Net architecture to restore images. CARE " + "is a supervised algorithm that requires pairs of noisy and " + "clean images to train the network. The algorithm learns to " + "predict the clean image from the noisy image. CARE is " + "particularly useful for denoising images acquired in low-light " + "conditions, such as fluorescence microscopy images." +) +CARE_REF = CiteEntry( + text='Weigert, Martin, et al. "Content-aware image restoration: pushing the ' + 'limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.', + doi="10.1038/s41592-018-0216-7", +) + + +class CAREConfiguration(Configuration): + """CARE configuration.""" + + algorithm_config: CAREAlgorithm + """Algorithm configuration.""" + + data_config: DataConfig + """Data configuration.""" + + def get_algorithm_friendly_name(self) -> str: + """ + Get the algorithm friendly name. + + Returns + ------- + str + Friendly name of the algorithm. + """ + return CARE + + def get_algorithm_keywords(self) -> list[str]: + """ + Get algorithm keywords. + + Returns + ------- + list[str] + List of keywords. + """ + return [ + "restoration", + "UNet", + "3D" if "Z" in self.data_config.axes else "2D", + "CAREamics", + "pytorch", + CARE, + ] + + def get_algorithm_references(self) -> str: + """ + Get the algorithm references. + + This is used to generate the README of the BioImage Model Zoo export. + + Returns + ------- + str + Algorithm references. + """ + return CARE_REF.text + " doi: " + CARE_REF.doi + + def get_algorithm_citations(self) -> list[CiteEntry]: + """ + Return a list of citation entries of the current algorithm. + + This is used to generate the model description for the BioImage Model Zoo. + + Returns + ------- + List[CiteEntry] + List of citation entries. + """ + return [CARE_REF] + + def get_algorithm_description(self) -> str: + """ + Get the algorithm description. + + Returns + ------- + str + Algorithm description. + """ + return CARE_DESCRIPTION diff --git a/src/careamics/config/configuration.py b/src/careamics/config/configuration.py new file mode 100644 index 000000000..b4641d559 --- /dev/null +++ b/src/careamics/config/configuration.py @@ -0,0 +1,354 @@ +"""Pydantic CAREamics configuration.""" + +from __future__ import annotations + +import re +from pprint import pformat +from typing import Any, Literal, Union + +from bioimageio.spec.generic.v0_3 import CiteEntry +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from typing_extensions import Self + +from careamics.config.algorithms import UNetBasedAlgorithm, VAEBasedAlgorithm +from careamics.config.data import GeneralDataConfig +from careamics.config.training_model import TrainingConfig + + +class Configuration(BaseModel): + """ + CAREamics configuration. + + The configuration defines all parameters used to build and train a CAREamics model. + These parameters are validated to ensure that they are compatible with each other. + + It contains three sub-configurations: + + - AlgorithmModel: configuration for the algorithm training, which includes the + architecture, loss function, optimizer, and other hyperparameters. + - DataModel: configuration for the dataloader, which includes the type of data, + transformations, mean/std and other parameters. + - TrainingModel: configuration for the training, which includes the number of + epochs or the callbacks. + + Attributes + ---------- + experiment_name : str + Name of the experiment, used when saving logs and checkpoints. + algorithm : AlgorithmModel + Algorithm configuration. + data : DataModel + Data configuration. + training : TrainingModel + Training configuration. + + Methods + ------- + set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None + Switch configuration between 2D and 3D. + model_dump( + exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict + ) -> Dict + Export configuration to a dictionary. + + Raises + ------ + ValueError + Configuration parameter type validation errors. + ValueError + If the experiment name contains invalid characters or is empty. + ValueError + If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm + with "Z" in data axes. + ValueError + Algorithm, data or training validation errors. + + Notes + ----- + We provide convenience methods to create standards configurations, for instance: + >>> from careamics.config import create_n2v_configuration + >>> config = create_n2v_configuration( + ... experiment_name="n2v_experiment", + ... data_type="array", + ... axes="YX", + ... patch_size=[64, 64], + ... batch_size=32, + ... num_epochs=100 + ... ) + + The configuration can be exported to a dictionary using the model_dump method: + >>> config_dict = config.model_dump() + + Configurations can also be exported or imported from yaml files: + >>> from careamics.config import save_configuration, load_configuration + >>> path_to_config = save_configuration(config, my_path / "config.yml") + >>> other_config = load_configuration(path_to_config) + + Examples + -------- + Minimum example: + >>> from careamics import configuration_factory + >>> config_dict = { + ... "experiment_name": "N2V_experiment", + ... "algorithm_config": { + ... "algorithm": "n2v", + ... "loss": "n2v", + ... "model": { + ... "architecture": "UNet", + ... }, + ... }, + ... "training_config": { + ... "num_epochs": 200, + ... }, + ... "data_config": { + ... "data_type": "tiff", + ... "patch_size": [64, 64], + ... "axes": "SYX", + ... }, + ... } + >>> config = configuration_factory(config_dict) + """ + + model_config = ConfigDict( + validate_assignment=True, + arbitrary_types_allowed=True, + ) + + # version + version: Literal["0.1.0"] = "0.1.0" + """CAREamics configuration version.""" + + # required parameters + experiment_name: str + """Name of the experiment, used to name logs and checkpoints.""" + + # Sub-configurations + algorithm_config: Union[UNetBasedAlgorithm, VAEBasedAlgorithm] = Field( + discriminator="algorithm" + ) + """Algorithm configuration, holding all parameters required to configure the + model.""" + + data_config: GeneralDataConfig + """Data configuration, holding all parameters required to configure the training + data loader.""" + + training_config: TrainingConfig + """Training configuration, holding all parameters required to configure the + training process.""" + + @field_validator("experiment_name") + @classmethod + def no_symbol(cls, name: str) -> str: + """ + Validate experiment name. + + A valid experiment name is a non-empty string with only contains letters, + numbers, underscores, dashes and spaces. + + Parameters + ---------- + name : str + Name to validate. + + Returns + ------- + str + Validated name. + + Raises + ------ + ValueError + If the name is empty or contains invalid characters. + """ + if len(name) == 0 or name.isspace(): + raise ValueError("Experiment name is empty.") + + # Validate using a regex that it contains only letters, numbers, underscores, + # dashes and spaces + if not re.match(r"^[a-zA-Z0-9_\- ]*$", name): + raise ValueError( + f"Experiment name contains invalid characters (got {name}). " + f"Only letters, numbers, underscores, dashes and spaces are allowed." + ) + + return name + + @model_validator(mode="after") + def validate_3D(self: Self) -> Self: + """ + Change algorithm dimensions to match data.axes. + + Returns + ------- + Self + Validated configuration. + """ + if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D(): + # change algorithm to 3D + self.algorithm_config.model.set_3D(True) + elif "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D(): + # change algorithm to 2D + self.algorithm_config.model.set_3D(False) + + return self + + def __str__(self) -> str: + """ + Pretty string reprensenting the configuration. + + Returns + ------- + str + Pretty string. + """ + return pformat(self.model_dump()) + + def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None: + """ + Set 3D flag and axes. + + Parameters + ---------- + is_3D : bool + Whether the algorithm is 3D or not. + axes : str + Axes of the data. + patch_size : list[int] + Patch size. + """ + # set the flag and axes (this will not trigger validation at the config level) + self.algorithm_config.model.set_3D(is_3D) + self.data_config.set_3D(axes, patch_size) + + # cheap hack: trigger validation + self.algorithm_config = self.algorithm_config + + def get_algorithm_friendly_name(self) -> str: + """ + Get the algorithm name. + + Returns + ------- + str + Algorithm name. + """ + raise ValueError("Unknown algorithm.") + + def get_algorithm_description(self) -> str: + """ + Return a description of the algorithm. + + This method is used to generate the README of the BioImage Model Zoo export. + + Returns + ------- + str + Description of the algorithm. + """ + raise ValueError("No algorithm description available.") + + def get_algorithm_citations(self) -> list[CiteEntry]: + """ + Return a list of citation entries of the current algorithm. + + This is used to generate the model description for the BioImage Model Zoo. + + Returns + ------- + List[CiteEntry] + List of citation entries. + """ + raise ValueError("No algorithm citations available.") + + def get_algorithm_references(self) -> str: + """ + Get the algorithm references. + + This is used to generate the README of the BioImage Model Zoo export. + + Returns + ------- + str + Algorithm references. + """ + raise ValueError("No algorithm references available.") + + def get_algorithm_keywords(self) -> list[str]: + """ + Get algorithm keywords. + + Returns + ------- + list[str] + List of keywords. + """ + return ["CAREamics"] + + def model_dump( + self, + *, + mode: Literal["json", "python"] | str = "python", + include: Any | None = None, + exclude: Any | None = None, + context: Any | None = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = True, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + serialize_as_any: bool = False, + ) -> dict: + """ + Override model_dump method in order to set default values. + + As opposed to the parent model_dump method, this method sets exclude none by + default. + + Parameters + ---------- + mode : Literal['json', 'python'] | str, default='python' + The serialization format. + include : Any | None, default=None + Attributes to include. + exclude : Any | None, default=None + Attributes to exclude. + context : Any | None, default=None + Additional context to pass to the serialization functions. + by_alias : bool, default=False + Whether to use attribute aliases. + exclude_unset : bool, default=False + Whether to exclude fields that are not set. + exclude_defaults : bool, default=False + Whether to exclude fields that have default values. + exclude_none : bool, default=true + Whether to exclude fields that have None values. + round_trip : bool, default=False + Whether to dump and load the data to ensure that the output is a valid + representation. + warnings : bool | Literal['none', 'warn', 'error'], default=True + Whether to emit warnings. + serialize_as_any : bool, default=False + Whether to serialize all types as Any. + + Returns + ------- + dict + Dictionary containing the model parameters. + """ + dictionary = super().model_dump( + mode=mode, + include=include, + exclude=exclude, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + serialize_as_any=serialize_as_any, + ) + + return dictionary diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factories.py similarity index 87% rename from src/careamics/config/configuration_factory.py rename to src/careamics/config/configuration_factories.py index 59f1955b0..45759e7bd 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factories.py @@ -1,27 +1,120 @@ """Convenience functions to create configurations for training and inference.""" -from typing import Any, Literal, Optional, Union - -from .architectures import UNetModel -from .configuration_model import Configuration -from .data_model import DataConfig -from .fcn_algorithm_model import FCNAlgorithmConfig -from .support import ( +from typing import Annotated, Any, Literal, Optional, Union + +from pydantic import Discriminator, Tag, TypeAdapter + +from careamics.config.algorithms import CAREAlgorithm, N2NAlgorithm, N2VAlgorithm +from careamics.config.architectures import UNetModel +from careamics.config.care_configuration import CAREConfiguration +from careamics.config.configuration import Configuration +from careamics.config.data import DataConfig, N2VDataConfig +from careamics.config.n2n_configuration import N2NConfiguration +from careamics.config.n2v_configuration import N2VConfiguration +from careamics.config.support import ( + SupportedAlgorithm, SupportedArchitecture, SupportedPixelManipulation, SupportedTransform, ) -from .training_model import TrainingConfig -from .transformations import ( +from careamics.config.training_model import TrainingConfig +from careamics.config.transformations import ( + N2V_TRANSFORMS_UNION, + SPATIAL_TRANSFORMS_UNION, N2VManipulateModel, XYFlipModel, XYRandomRotate90Model, ) -def _list_augmentations( - augmentations: Optional[list[Union[XYFlipModel, XYRandomRotate90Model]]], -) -> list[Union[XYFlipModel, XYRandomRotate90Model]]: +def _algorithm_config_discriminator(value: Union[dict, Configuration]) -> str: + """Discriminate algorithm-specific configurations based on the algorithm. + + Parameters + ---------- + value : Any + Value to discriminate. + + Returns + ------- + str + Discriminator value. + """ + if isinstance(value, dict): + return value["algorithm_config"]["algorithm"] + return value.algorithm_config.algorithm + + +def configuration_factory( + configuration: dict[str, Any] +) -> Union[N2VConfiguration, N2NConfiguration, CAREConfiguration]: + """ + Create a configuration for training CAREamics. + + Parameters + ---------- + configuration : dict + Configuration dictionary. + + Returns + ------- + N2VConfiguration or N2NConfiguration or CAREConfiguration + Configuration for training CAREamics. + """ + adapter: TypeAdapter = TypeAdapter( + Annotated[ + Union[ + Annotated[N2VConfiguration, Tag(SupportedAlgorithm.N2V.value)], + Annotated[N2NConfiguration, Tag(SupportedAlgorithm.N2N.value)], + Annotated[CAREConfiguration, Tag(SupportedAlgorithm.CARE.value)], + ], + Discriminator(_algorithm_config_discriminator), + ] + ) + return adapter.validate_python(configuration) + + +def algorithm_factory( + algorithm: dict[str, Any] +) -> Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]: + """ + Create an algorithm model for training CAREamics. + + Parameters + ---------- + algorithm : dict + Algorithm dictionary. + + Returns + ------- + N2VAlgorithm or N2NAlgorithm or CAREAlgorithm + Algorithm model for training CAREamics. + """ + adapter: TypeAdapter = TypeAdapter(Union[N2VAlgorithm, N2NAlgorithm, CAREAlgorithm]) + return adapter.validate_python(algorithm) + + +def data_factory(data: dict[str, Any]) -> Union[DataConfig, N2VDataConfig]: + """ + Create a data model for training CAREamics. + + Parameters + ---------- + data : dict + Data dictionary. + + Returns + ------- + DataConfig or N2VDataConfig + Data model for training CAREamics. + """ + adapter: TypeAdapter = TypeAdapter(Union[DataConfig, N2VDataConfig]) + return adapter.validate_python(data) + + +def _list_spatial_augmentations( + augmentations: Optional[list[SPATIAL_TRANSFORMS_UNION]], +) -> list[SPATIAL_TRANSFORMS_UNION]: """ List the augmentations to apply. @@ -44,7 +137,7 @@ def _list_augmentations( If there are duplicate transforms. """ if augmentations is None: - transform_list: list[Union[XYFlipModel, XYRandomRotate90Model]] = [ + transform_list: list[SPATIAL_TRANSFORMS_UNION] = [ XYFlipModel(), XYRandomRotate90Model(), ] @@ -123,7 +216,7 @@ def _create_configuration( patch_size: list[int], batch_size: int, num_epochs: int, - augmentations: list[Union[XYFlipModel, XYRandomRotate90Model]], + augmentations: Union[list[N2V_TRANSFORMS_UNION], list[SPATIAL_TRANSFORMS_UNION]], independent_channels: bool, loss: Literal["n2v", "mae", "mse"], n_channels_in: int, @@ -188,21 +281,21 @@ def _create_configuration( ) # algorithm model - algorithm_config = FCNAlgorithmConfig( - algorithm=algorithm, - loss=loss, - model=unet_model, - ) + algorithm_config = { + "algorithm": algorithm, + "loss": loss, + "model": unet_model, + } # data model - data = DataConfig( - data_type=data_type, - axes=axes, - patch_size=patch_size, - batch_size=batch_size, - transforms=augmentations, - dataloader_params=dataloader_params, - ) + data = { + "data_type": data_type, + "axes": axes, + "patch_size": patch_size, + "batch_size": batch_size, + "transforms": augmentations, + "dataloader_params": dataloader_params, + } # training model training = TrainingConfig( @@ -212,14 +305,14 @@ def _create_configuration( ) # create configuration - configuration = Configuration( - experiment_name=experiment_name, - algorithm_config=algorithm_config, - data_config=data, - training_config=training, - ) + configuration = { + "experiment_name": experiment_name, + "algorithm_config": algorithm_config, + "data_config": data, + "training_config": training, + } - return configuration + return configuration_factory(configuration) # TODO reconsider naming once we officially support LVAE approaches @@ -306,7 +399,7 @@ def _create_supervised_configuration( n_channels_out = n_channels_in # augmentations - transform_list = _list_augmentations(augmentations) + spatial_transform_list = _list_spatial_augmentations(augmentations) return _create_configuration( algorithm=algorithm, @@ -316,7 +409,7 @@ def _create_supervised_configuration( patch_size=patch_size, batch_size=batch_size, num_epochs=num_epochs, - augmentations=transform_list, + augmentations=spatial_transform_list, independent_channels=independent_channels, loss=loss, n_channels_in=n_channels_in, @@ -853,7 +946,7 @@ def create_n2v_configuration( n_channels = 1 # augmentations - transform_list = _list_augmentations(augmentations) + spatial_transforms = _list_spatial_augmentations(augmentations) # create the N2VManipulate transform using the supplied parameters n2v_transform = N2VManipulateModel( @@ -868,7 +961,7 @@ def create_n2v_configuration( struct_mask_axis=struct_n2v_axis, struct_mask_span=struct_n2v_span, ) - transform_list.append(n2v_transform) + transform_list: list[N2V_TRANSFORMS_UNION] = spatial_transforms + [n2v_transform] return _create_configuration( algorithm="n2v", diff --git a/src/careamics/config/configuration_io.py b/src/careamics/config/configuration_io.py new file mode 100644 index 000000000..62a3adabe --- /dev/null +++ b/src/careamics/config/configuration_io.py @@ -0,0 +1,85 @@ +"""I/O functions for Configuration objects.""" + +from pathlib import Path +from typing import Union + +import yaml + +from careamics.config import Configuration, configuration_factory + + +def load_configuration(path: Union[str, Path]) -> Configuration: + """ + Load configuration from a yaml file. + + Parameters + ---------- + path : str or Path + Path to the configuration. + + Returns + ------- + Configuration + Configuration. + + Raises + ------ + FileNotFoundError + If the configuration file does not exist. + """ + # load dictionary from yaml + if not Path(path).exists(): + raise FileNotFoundError( + f"Configuration file {path} does not exist in " f" {Path.cwd()!s}" + ) + + dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader) + + return configuration_factory(dictionary) + + +def save_configuration(config: Configuration, path: Union[str, Path]) -> Path: + """ + Save configuration to path. + + Parameters + ---------- + config : Configuration + Configuration to save. + path : str or Path + Path to a existing folder in which to save the configuration, or to a valid + configuration file path (uses a .yml or .yaml extension). + + Returns + ------- + Path + Path object representing the configuration. + + Raises + ------ + ValueError + If the path does not point to an existing directory or .yml file. + """ + # make sure path is a Path object + config_path = Path(path) + + # check if path is pointing to an existing directory or .yml file + if config_path.exists(): + if config_path.is_dir(): + config_path = Path(config_path, "config.yml") + elif config_path.suffix != ".yml" and config_path.suffix != ".yaml": + raise ValueError( + f"Path must be a directory or .yml or .yaml file (got {config_path})." + ) + else: + if config_path.suffix != ".yml" and config_path.suffix != ".yaml": + raise ValueError( + f"Path must be a directory or .yml or .yaml file (got {config_path})." + ) + + # save configuration as dictionary to yaml + with open(config_path, "w") as f: + # dump configuration + yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False) + + return config_path diff --git a/src/careamics/config/configuration_model.py b/src/careamics/config/configuration_model.py deleted file mode 100644 index c666138ed..000000000 --- a/src/careamics/config/configuration_model.py +++ /dev/null @@ -1,603 +0,0 @@ -"""Pydantic CAREamics configuration.""" - -from __future__ import annotations - -import re -from pathlib import Path -from pprint import pformat -from typing import Literal, Union - -import yaml -from bioimageio.spec.generic.v0_3 import CiteEntry -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from typing_extensions import Self - -from .data_model import DataConfig -from .fcn_algorithm_model import FCNAlgorithmConfig -from .references import ( - CARE, - CUSTOM, - N2N, - N2V, - N2V2, - STRUCT_N2V, - STRUCT_N2V2, - CAREDescription, - CARERef, - N2NDescription, - N2NRef, - N2V2Description, - N2V2Ref, - N2VDescription, - N2VRef, - StructN2V2Description, - StructN2VDescription, - StructN2VRef, -) -from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform -from .training_model import TrainingConfig -from .transformations.n2v_manipulate_model import ( - N2VManipulateModel, -) -from .vae_algorithm_model import VAEAlgorithmConfig - - -class Configuration(BaseModel): - """ - CAREamics configuration. - - The configuration defines all parameters used to build and train a CAREamics model. - These parameters are validated to ensure that they are compatible with each other. - - It contains three sub-configurations: - - - AlgorithmModel: configuration for the algorithm training, which includes the - architecture, loss function, optimizer, and other hyperparameters. - - DataModel: configuration for the dataloader, which includes the type of data, - transformations, mean/std and other parameters. - - TrainingModel: configuration for the training, which includes the number of - epochs or the callbacks. - - Attributes - ---------- - experiment_name : str - Name of the experiment, used when saving logs and checkpoints. - algorithm : AlgorithmModel - Algorithm configuration. - data : DataModel - Data configuration. - training : TrainingModel - Training configuration. - - Methods - ------- - set_3D(is_3D: bool, axes: str, patch_size: List[int]) -> None - Switch configuration between 2D and 3D. - set_N2V2(use_n2v2: bool) -> None - Switch N2V algorithm between N2V and N2V2. - set_structN2V( - mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int) -> None - Set StructN2V parameters. - model_dump( - exclude_defaults: bool = False, exclude_none: bool = True, **kwargs: Dict - ) -> Dict - Export configuration to a dictionary. - - Raises - ------ - ValueError - Configuration parameter type validation errors. - ValueError - If the experiment name contains invalid characters or is empty. - ValueError - If the algorithm is 3D but there is not "Z" in the data axes, or 2D algorithm - with "Z" in data axes. - ValueError - Algorithm, data or training validation errors. - - Notes - ----- - We provide convenience methods to create standards configurations, for instance - for N2V, in the `careamics.config.configuration_factory` module. - >>> from careamics.config.configuration_factory import create_n2v_configuration - >>> config = create_n2v_configuration( - ... experiment_name="n2v_experiment", - ... data_type="array", - ... axes="YX", - ... patch_size=[64, 64], - ... batch_size=32, - ... num_epochs=100 - ... ) - - The configuration can be exported to a dictionary using the model_dump method: - >>> config_dict = config.model_dump() - - Configurations can also be exported or imported from yaml files: - >>> from careamics.config import save_configuration, load_configuration - >>> path_to_config = save_configuration(config, my_path / "config.yml") - >>> other_config = load_configuration(path_to_config) - - Examples - -------- - Minimum example: - >>> from careamics.config import Configuration - >>> config_dict = { - ... "experiment_name": "N2V_experiment", - ... "algorithm_config": { - ... "algorithm": "n2v", - ... "loss": "n2v", - ... "model": { - ... "architecture": "UNet", - ... }, - ... }, - ... "training_config": { - ... "num_epochs": 200, - ... }, - ... "data_config": { - ... "data_type": "tiff", - ... "patch_size": [64, 64], - ... "axes": "SYX", - ... }, - ... } - >>> config = Configuration(**config_dict) - """ - - model_config = ConfigDict( - validate_assignment=True, - set_arbitrary_types_allowed=True, - ) - - # version - version: Literal["0.1.0"] = "0.1.0" - """CAREamics configuration version.""" - - # required parameters - experiment_name: str - """Name of the experiment, used to name logs and checkpoints.""" - - # Sub-configurations - algorithm_config: Union[FCNAlgorithmConfig, VAEAlgorithmConfig] = Field( - discriminator="algorithm" - ) - """Algorithm configuration, holding all parameters required to configure the - model.""" - - data_config: DataConfig - """Data configuration, holding all parameters required to configure the training - data loader.""" - - training_config: TrainingConfig - """Training configuration, holding all parameters required to configure the - training process.""" - - @field_validator("experiment_name") - @classmethod - def no_symbol(cls, name: str) -> str: - """ - Validate experiment name. - - A valid experiment name is a non-empty string with only contains letters, - numbers, underscores, dashes and spaces. - - Parameters - ---------- - name : str - Name to validate. - - Returns - ------- - str - Validated name. - - Raises - ------ - ValueError - If the name is empty or contains invalid characters. - """ - if len(name) == 0 or name.isspace(): - raise ValueError("Experiment name is empty.") - - # Validate using a regex that it contains only letters, numbers, underscores, - # dashes and spaces - if not re.match(r"^[a-zA-Z0-9_\- ]*$", name): - raise ValueError( - f"Experiment name contains invalid characters (got {name}). " - f"Only letters, numbers, underscores, dashes and spaces are allowed." - ) - - return name - - @model_validator(mode="after") - def validate_3D(self: Self) -> Self: - """ - Change algorithm dimensions to match data.axes. - - Only for non-custom algorithms. - - Returns - ------- - Self - Validated configuration. - """ - if self.algorithm_config.algorithm != SupportedAlgorithm.CUSTOM: - if "Z" in self.data_config.axes and not self.algorithm_config.model.is_3D(): - # change algorithm to 3D - self.algorithm_config.model.set_3D(True) - elif ( - "Z" not in self.data_config.axes and self.algorithm_config.model.is_3D() - ): - # change algorithm to 2D - self.algorithm_config.model.set_3D(False) - - return self - - @model_validator(mode="after") - def validate_algorithm_and_data(self: Self) -> Self: - """ - Validate algorithm and data compatibility. - - In particular, the validation does the following: - - - If N2V is used, it enforces the presence of N2V_Maniuplate in the transforms - - If N2V2 is used, it enforces the correct manipulation strategy - - Returns - ------- - Self - Validated configuration. - """ - if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: - # missing N2V_MANIPULATE - if not self.data_config.has_n2v_manipulate(): - self.data_config.transforms.append( - N2VManipulateModel( - name=SupportedTransform.N2V_MANIPULATE.value, - ) - ) - - median = SupportedPixelManipulation.MEDIAN.value - uniform = SupportedPixelManipulation.UNIFORM.value - strategy = median if self.algorithm_config.model.n2v2 else uniform - self.data_config.set_N2V2_strategy(strategy) - else: - # remove N2V manipulate if present - if self.data_config.has_n2v_manipulate(): - self.data_config.remove_n2v_manipulate() - - return self - - def __str__(self) -> str: - """ - Pretty string reprensenting the configuration. - - Returns - ------- - str - Pretty string. - """ - return pformat(self.model_dump()) - - def set_3D(self, is_3D: bool, axes: str, patch_size: list[int]) -> None: - """ - Set 3D flag and axes. - - Parameters - ---------- - is_3D : bool - Whether the algorithm is 3D or not. - axes : str - Axes of the data. - patch_size : list[int] - Patch size. - """ - # set the flag and axes (this will not trigger validation at the config level) - self.algorithm_config.model.set_3D(is_3D) - self.data_config.set_3D(axes, patch_size) - - # cheap hack: trigger validation - self.algorithm_config = self.algorithm_config - - def set_N2V2(self, use_n2v2: bool) -> None: - """ - Switch N2V algorithm between N2V and N2V2. - - Parameters - ---------- - use_n2v2 : bool - Whether to use N2V2 or not. - - Raises - ------ - ValueError - If the algorithm is not N2V. - """ - if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: - self.algorithm_config.model.n2v2 = use_n2v2 - strategy = ( - SupportedPixelManipulation.MEDIAN.value - if use_n2v2 - else SupportedPixelManipulation.UNIFORM.value - ) - self.data_config.set_N2V2_strategy(strategy) - else: - raise ValueError("N2V2 can only be set for N2V algorithm.") - - def set_structN2V( - self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int - ) -> None: - """ - Set StructN2V parameters. - - Parameters - ---------- - mask_axis : Literal["horizontal", "vertical", "none"] - Axis of the structural mask. - mask_span : int - Span of the structural mask. - """ - self.data_config.set_structN2V_mask(mask_axis, mask_span) - - def get_algorithm_flavour(self) -> str: - """ - Get the algorithm name. - - Returns - ------- - str - Algorithm name. - """ - if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: - use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" - - # return the n2v flavour - if use_n2v2 and use_structN2V: - return STRUCT_N2V2 - elif use_n2v2: - return N2V2 - elif use_structN2V: - return STRUCT_N2V - else: - return N2V - elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N: - return N2N - elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE: - return CARE - else: - return CUSTOM - - def get_algorithm_description(self) -> str: - """ - Return a description of the algorithm. - - This method is used to generate the README of the BioImage Model Zoo export. - - Returns - ------- - str - Description of the algorithm. - """ - algorithm_flavour = self.get_algorithm_flavour() - - if algorithm_flavour == CUSTOM: - return f"Custom algorithm, named {self.algorithm_config.model.name}" - else: # currently only N2V flavours - if algorithm_flavour == N2V: - return N2VDescription().description - elif algorithm_flavour == N2V2: - return N2V2Description().description - elif algorithm_flavour == STRUCT_N2V: - return StructN2VDescription().description - elif algorithm_flavour == STRUCT_N2V2: - return StructN2V2Description().description - elif algorithm_flavour == N2N: - return N2NDescription().description - elif algorithm_flavour == CARE: - return CAREDescription().description - - return "" - - def get_algorithm_citations(self) -> list[CiteEntry]: - """ - Return a list of citation entries of the current algorithm. - - This is used to generate the model description for the BioImage Model Zoo. - - Returns - ------- - List[CiteEntry] - List of citation entries. - """ - if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: - use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" - - # return the (struct)N2V(2) references - if use_n2v2 and use_structN2V: - return [N2VRef, N2V2Ref, StructN2VRef] - elif use_n2v2: - return [N2VRef, N2V2Ref] - elif use_structN2V: - return [N2VRef, StructN2VRef] - else: - return [N2VRef] - elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N: - return [N2NRef] - elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE: - return [CARERef] - - raise ValueError("Citation not available for custom algorithm.") - - def get_algorithm_references(self) -> str: - """ - Get the algorithm references. - - This is used to generate the README of the BioImage Model Zoo export. - - Returns - ------- - str - Algorithm references. - """ - if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: - use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" - - references = [ - N2VRef.text + " doi: " + N2VRef.doi, - N2V2Ref.text + " doi: " + N2V2Ref.doi, - StructN2VRef.text + " doi: " + StructN2VRef.doi, - ] - - # return the (struct)N2V(2) references - if use_n2v2 and use_structN2V: - return "".join(references) - elif use_n2v2: - references.pop(-1) - return "".join(references) - elif use_structN2V: - references.pop(-2) - return "".join(references) - else: - return references[0] - - return "" - - def get_algorithm_keywords(self) -> list[str]: - """ - Get algorithm keywords. - - Returns - ------- - list[str] - List of keywords. - """ - if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: - use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" - - keywords = [ - "denoising", - "restoration", - "UNet", - "3D" if "Z" in self.data_config.axes else "2D", - "CAREamics", - "pytorch", - N2V, - ] - - if use_n2v2: - keywords.append(N2V2) - if use_structN2V: - keywords.append(STRUCT_N2V) - else: - keywords = ["CAREamics"] - - return keywords - - def model_dump( - self, - exclude_defaults: bool = False, - exclude_none: bool = True, - **kwargs: dict, - ) -> dict: - """ - Override model_dump method in order to set default values. - - Parameters - ---------- - exclude_defaults : bool, optional - Whether to exclude fields with default values or not, by default - True. - exclude_none : bool, optional - Whether to exclude fields with None values or not, by default True. - **kwargs : dict - Keyword arguments. - - Returns - ------- - dict - Dictionary containing the model parameters. - """ - dictionary = super().model_dump( - exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs - ) - - return dictionary - - -def load_configuration(path: Union[str, Path]) -> Configuration: - """ - Load configuration from a yaml file. - - Parameters - ---------- - path : str or Path - Path to the configuration. - - Returns - ------- - Configuration - Configuration. - - Raises - ------ - FileNotFoundError - If the configuration file does not exist. - """ - # load dictionary from yaml - if not Path(path).exists(): - raise FileNotFoundError( - f"Configuration file {path} does not exist in " f" {Path.cwd()!s}" - ) - - dictionary = yaml.load(Path(path).open("r"), Loader=yaml.SafeLoader) - - return Configuration(**dictionary) - - -def save_configuration(config: Configuration, path: Union[str, Path]) -> Path: - """ - Save configuration to path. - - Parameters - ---------- - config : Configuration - Configuration to save. - path : str or Path - Path to a existing folder in which to save the configuration, or to a valid - configuration file path (uses a .yml or .yaml extension). - - Returns - ------- - Path - Path object representing the configuration. - - Raises - ------ - ValueError - If the path does not point to an existing directory or .yml file. - """ - # make sure path is a Path object - config_path = Path(path) - - # check if path is pointing to an existing directory or .yml file - if config_path.exists(): - if config_path.is_dir(): - config_path = Path(config_path, "config.yml") - elif config_path.suffix != ".yml" and config_path.suffix != ".yaml": - raise ValueError( - f"Path must be a directory or .yml or .yaml file (got {config_path})." - ) - else: - if config_path.suffix != ".yml" and config_path.suffix != ".yaml": - raise ValueError( - f"Path must be a directory or .yml or .yaml file (got {config_path})." - ) - - # save configuration as dictionary to yaml - with open(config_path, "w") as f: - # dump configuration - yaml.dump(config.model_dump(), f, default_flow_style=False, sort_keys=False) - - return config_path diff --git a/src/careamics/config/data/__init__.py b/src/careamics/config/data/__init__.py new file mode 100644 index 000000000..631c3ecfa --- /dev/null +++ b/src/careamics/config/data/__init__.py @@ -0,0 +1,10 @@ +"""Data Pydantic configuration models.""" + +__all__ = [ + "DataConfig", + "GeneralDataConfig", + "N2VDataConfig", +] + +from .data_model import DataConfig, GeneralDataConfig +from .n2v_data_model import N2VDataConfig diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data/data_model.py similarity index 65% rename from src/careamics/config/data_model.py rename to src/careamics/config/data/data_model.py index 25a134f03..b93263ee5 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data/data_model.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Sequence from pprint import pformat from typing import Annotated, Any, Literal, Optional, Union @@ -17,9 +18,8 @@ ) from typing_extensions import Self -from .support import SupportedTransform -from .transformations import TRANSFORMS_UNION, N2VManipulateModel -from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 +from ..transformations import N2V_TRANSFORMS_UNION, XYFlipModel, XYRandomRotate90Model +from ..validators import check_axes_validity, patch_size_ge_than_8_power_of_2 def np_float_to_scientific_str(x: float) -> str: @@ -45,47 +45,8 @@ def np_float_to_scientific_str(x: float) -> str: """Annotated float type, used to serialize floats to strings.""" -class DataConfig(BaseModel): - """ - Data configuration. - - If std is specified, mean must be specified as well. Note that setting the std first - and then the mean (if they were both `None` before) will raise a validation error. - Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected - to be lists of floats, one for each channel. For supervised tasks, the mean and std - of the target could be different from the input data. - - All supported transforms are defined in the SupportedTransform enum. - - Examples - -------- - Minimum example: - - >>> data = DataConfig( - ... data_type="array", # defined in SupportedData - ... patch_size=[128, 128], - ... batch_size=4, - ... axes="YX" - ... ) - - To change the image_means and image_stds of the data: - >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5]) - - One can pass also a list of transformations, by keyword, using the - SupportedTransform value: - >>> from careamics.config.support import SupportedTransform - >>> data = DataConfig( - ... data_type="tiff", - ... patch_size=[128, 128], - ... batch_size=4, - ... axes="YX", - ... transforms=[ - ... { - ... "name": "XYFlip", - ... } - ... ] - ... ) - """ +class GeneralDataConfig(BaseModel): + """General data configuration.""" # Pydantic class configuration model_config = ConfigDict( @@ -126,22 +87,18 @@ class DataConfig(BaseModel): """Standard deviations of the target data across channels, used for normalization.""" - transforms: list[TRANSFORMS_UNION] = Field( + # defining as Sequence allows assigning subclasses of TransformModel without mypy + # complaining, this is important for instance to differentiate N2VDataConfig and + # DataConfig + transforms: Sequence[N2V_TRANSFORMS_UNION] = Field( default=[ - { - "name": SupportedTransform.XY_FLIP.value, - }, - { - "name": SupportedTransform.XY_RANDOM_ROTATE90.value, - }, - { - "name": SupportedTransform.N2V_MANIPULATE.value, - }, + XYFlipModel(), + XYRandomRotate90Model(), ], validate_default=True, ) """List of transformations to apply to the data, available transforms are defined - in SupportedTransform. The default values are set for Noise2Void.""" + in SupportedTransform.""" dataloader_params: Optional[dict] = None """Dictionary of PyTorch dataloader parameters.""" @@ -210,48 +167,6 @@ def axes_valid(cls, axes: str) -> str: return axes - @field_validator("transforms") - @classmethod - def validate_prediction_transforms( - cls, transforms: list[TRANSFORMS_UNION] - ) -> list[TRANSFORMS_UNION]: - """ - Validate N2VManipulate transform position in the transform list. - - Parameters - ---------- - transforms : list[Transformations_Union] - Transforms. - - Returns - ------- - list of transforms - Validated transforms. - - Raises - ------ - ValueError - If multiple instances of N2VManipulate are found. - """ - transform_list = [t.name for t in transforms] - - if SupportedTransform.N2V_MANIPULATE in transform_list: - # multiple N2V_MANIPULATE - if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1: - raise ValueError( - f"Multiple instances of " - f"{SupportedTransform.N2V_MANIPULATE} transforms " - f"are not allowed." - ) - - # N2V_MANIPULATE not the last transform - elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE: - index = transform_list.index(SupportedTransform.N2V_MANIPULATE.value) - transform = transforms.pop(index) - transforms.append(transform) - - return transforms - @model_validator(mode="after") def std_only_with_mean(self: Self) -> Self: """ @@ -350,32 +265,6 @@ def _update(self, **kwargs: Any) -> None: self.__dict__.update(kwargs) self.__class__.model_validate(self.__dict__) - def has_n2v_manipulate(self) -> bool: - """ - Check if the transforms contain N2VManipulate. - - Returns - ------- - bool - True if the transforms contain N2VManipulate, False otherwise. - """ - return any( - transform.name == SupportedTransform.N2V_MANIPULATE.value - for transform in self.transforms - ) - - def add_n2v_manipulate(self) -> None: - """Add N2VManipulate to the transforms.""" - if not self.has_n2v_manipulate(): - self.transforms.append( - N2VManipulateModel(name=SupportedTransform.N2V_MANIPULATE.value) - ) - - def remove_n2v_manipulate(self) -> None: - """Remove N2VManipulate from the transforms.""" - if self.has_n2v_manipulate(): - self.transforms.pop(-1) - def set_means_and_stds( self, image_means: Union[NDArray, tuple, list, None], @@ -430,84 +319,55 @@ def set_3D(self, axes: str, patch_size: list[int]) -> None: """ self._update(axes=axes, patch_size=patch_size) - def set_N2V2(self, use_n2v2: bool) -> None: - """ - Set N2V2. - - Parameters - ---------- - use_n2v2 : bool - Whether to use N2V2. - - Raises - ------ - ValueError - If the N2V pixel manipulate transform is not found in the transforms. - """ - if use_n2v2: - self.set_N2V2_strategy("median") - else: - self.set_N2V2_strategy("uniform") - - def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None: - """ - Set N2V2 strategy. - - Parameters - ---------- - strategy : Literal["uniform", "median"] - Strategy to use for N2V2. - Raises - ------ - ValueError - If the N2V pixel manipulate transform is not found in the transforms. - """ - found_n2v = False - - for transform in self.transforms: - if transform.name == SupportedTransform.N2V_MANIPULATE.value: - transform.strategy = strategy - found_n2v = True +class DataConfig(GeneralDataConfig): + """ + Data configuration. - if not found_n2v: - transforms = [t.name for t in self.transforms] - raise ValueError( - f"N2V_Manipulate transform not found in the transforms list " - f"({transforms})." - ) + If std is specified, mean must be specified as well. Note that setting the std first + and then the mean (if they were both `None` before) will raise a validation error. + Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected + to be lists of floats, one for each channel. For supervised tasks, the mean and std + of the target could be different from the input data. - def set_structN2V_mask( - self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int - ) -> None: - """ - Set structN2V mask parameters. + All supported transforms are defined in the SupportedTransform enum. - Setting `mask_axis` to `none` will disable structN2V. + Examples + -------- + Minimum example: - Parameters - ---------- - mask_axis : Literal["horizontal", "vertical", "none"] - Axis along which to apply the mask. `none` will disable structN2V. - mask_span : int - Total span of the mask in pixels. + >>> data = DataConfig( + ... data_type="array", # defined in SupportedData + ... patch_size=[128, 128], + ... batch_size=4, + ... axes="YX" + ... ) - Raises - ------ - ValueError - If the N2V pixel manipulate transform is not found in the transforms. - """ - found_n2v = False + To change the image_means and image_stds of the data: + >>> data.set_means_and_stds(image_means=[214.3], image_stds=[84.5]) - for transform in self.transforms: - if transform.name == SupportedTransform.N2V_MANIPULATE.value: - transform.struct_mask_axis = mask_axis - transform.struct_mask_span = mask_span - found_n2v = True + One can pass also a list of transformations, by keyword, using the + SupportedTransform value: + >>> from careamics.config.support import SupportedTransform + >>> data = DataConfig( + ... data_type="tiff", + ... patch_size=[128, 128], + ... batch_size=4, + ... axes="YX", + ... transforms=[ + ... { + ... "name": "XYFlip", + ... } + ... ] + ... ) + """ - if not found_n2v: - transforms = [t.name for t in self.transforms] - raise ValueError( - f"N2V pixel manipulate transform not found in the transforms " - f"({transforms})." - ) + transforms: Sequence[Union[XYFlipModel, XYRandomRotate90Model]] = Field( + default=[ + XYFlipModel(), + XYRandomRotate90Model(), + ], + validate_default=True, + ) + """List of transformations to apply to the data, available transforms are defined + in SupportedTransform. This excludes N2V specific transformations.""" diff --git a/src/careamics/config/data/n2v_data_model.py b/src/careamics/config/data/n2v_data_model.py new file mode 100644 index 000000000..24522becd --- /dev/null +++ b/src/careamics/config/data/n2v_data_model.py @@ -0,0 +1,193 @@ +"""Noise2Void specific data configuration model.""" + +from collections.abc import Sequence +from typing import Literal + +from pydantic import Field, field_validator + +from careamics.config.data.data_model import GeneralDataConfig +from careamics.config.support import SupportedTransform +from careamics.config.transformations import ( + N2V_TRANSFORMS_UNION, + N2VManipulateModel, + XYFlipModel, + XYRandomRotate90Model, +) + + +class N2VDataConfig(GeneralDataConfig): + """N2V specific data configuration model.""" + + transforms: Sequence[N2V_TRANSFORMS_UNION] = Field( + default=[XYFlipModel(), XYRandomRotate90Model(), N2VManipulateModel()], + validate_default=True, + ) + """N2V compatible transforms. N2VManpulate should be the last transform.""" + + @field_validator("transforms") + @classmethod + def validate_transforms( + cls, transforms: list[N2V_TRANSFORMS_UNION] + ) -> list[N2V_TRANSFORMS_UNION]: + """ + Validate N2VManipulate transform position in the transform list. + + Parameters + ---------- + transforms : list of transforms compatible with N2V + Transforms. + + Returns + ------- + list of transforms + Validated transforms. + + Raises + ------ + ValueError + If multiple instances of N2VManipulate are found or if it is not the last + transform. + """ + transform_list = [t.name for t in transforms] + + if SupportedTransform.N2V_MANIPULATE in transform_list: + # multiple N2V_MANIPULATE + if transform_list.count(SupportedTransform.N2V_MANIPULATE.value) > 1: + raise ValueError( + f"Multiple instances of " + f"{SupportedTransform.N2V_MANIPULATE} transforms " + f"are not allowed." + ) + + # N2V_MANIPULATE not the last transform + elif transform_list[-1] != SupportedTransform.N2V_MANIPULATE: + raise ValueError( + f"{SupportedTransform.N2V_MANIPULATE} transform " + f"should be the last transform." + ) + + else: + raise ValueError( + f"{SupportedTransform.N2V_MANIPULATE} transform " + f"is required for N2V training." + ) + + return transforms + + def set_n2v2(self, use_n2v2: bool) -> None: + """ + Set the N2V transform to the N2V2 version. + + Parameters + ---------- + use_n2v2 : bool + Whether to use N2V2. + + Raises + ------ + ValueError + If the N2V pixel manipulate transform is not found in the transforms. + """ + if use_n2v2: + self.set_masking_strategy("median") + else: + self.set_masking_strategy("uniform") + + def set_masking_strategy(self, strategy: Literal["uniform", "median"]) -> None: + """ + Set masking strategy. + + Parameters + ---------- + strategy : "uniform" or "median" + Strategy to use for N2V2. + + Raises + ------ + ValueError + If the N2V pixel manipulate transform is not found in the transforms. + """ + found_n2v = False + + for transform in self.transforms: + if transform.name == SupportedTransform.N2V_MANIPULATE.value: + transform.strategy = strategy + found_n2v = True + + if not found_n2v: + transforms = [t.name for t in self.transforms] + raise ValueError( + f"N2V_Manipulate transform not found in the transforms list " + f"({transforms})." + ) + + def get_masking_strategy(self) -> Literal["uniform", "median"]: + """ + Get N2V2 strategy. + + Returns + ------- + "uniform" or "median" + Strategy used for N2V2. + """ + for transform in self.transforms: + if transform.name == SupportedTransform.N2V_MANIPULATE.value: + return transform.strategy + + raise ValueError( + f"{SupportedTransform.N2V_MANIPULATE} transform " + f"is required for N2V training." + ) + + def set_structN2V_mask( + self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int + ) -> None: + """ + Set structN2V mask parameters. + + Setting `mask_axis` to `none` will disable structN2V. + + Parameters + ---------- + mask_axis : Literal["horizontal", "vertical", "none"] + Axis along which to apply the mask. `none` will disable structN2V. + mask_span : int + Total span of the mask in pixels. + + Raises + ------ + ValueError + If the N2V pixel manipulate transform is not found in the transforms. + """ + found_n2v = False + + for transform in self.transforms: + if transform.name == SupportedTransform.N2V_MANIPULATE.value: + transform.struct_mask_axis = mask_axis + transform.struct_mask_span = mask_span + found_n2v = True + + if not found_n2v: + transforms = [t.name for t in self.transforms] + raise ValueError( + f"N2V pixel manipulate transform not found in the transforms " + f"({transforms})." + ) + + def is_using_struct_n2v(self) -> bool: + """ + Check if structN2V is enabled. + + Returns + ------- + bool + Whether structN2V is enabled or not. + """ + for transform in self.transforms: + if transform.name == SupportedTransform.N2V_MANIPULATE.value: + return transform.struct_mask_axis != "none" + + raise ValueError( + f"N2V pixel manipulate transform not found in the transforms " + f"({self.transforms})." + ) diff --git a/src/careamics/config/fcn_algorithm_model.py b/src/careamics/config/fcn_algorithm_model.py deleted file mode 100644 index 4755878cc..000000000 --- a/src/careamics/config/fcn_algorithm_model.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Module containing `FCNAlgorithmConfig` class.""" - -from pprint import pformat -from typing import Literal, Union - -from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Self - -from careamics.config.architectures import CustomModel, UNetModel -from careamics.config.optimizer_models import LrSchedulerModel, OptimizerModel - - -class FCNAlgorithmConfig(BaseModel): - """Algorithm configuration. - - This Pydantic model validates the parameters governing the components of the - training algorithm: which algorithm, loss function, model architecture, optimizer, - and learning rate scheduler to use. - - Currently, we only support N2V, CARE, N2N and custom models. The `n2v` algorithm is - only compatible with `n2v` loss and `UNet` architecture. The `custom` algorithm - allows you to register your own architecture and select it using its name as - `name` in the custom pydantic model. - - Attributes - ---------- - algorithm : {"n2v", "care", "n2n", "custom"} - Algorithm to use. - loss : {"n2v", "mae", "mse"} - Loss function to use. - model : UNetModel or CustomModel - Model architecture to use. - optimizer : OptimizerModel, optional - Optimizer to use. - lr_scheduler : LrSchedulerModel, optional - Learning rate scheduler to use. - - Raises - ------ - ValueError - Algorithm parameter type validation errors. - ValueError - If the algorithm, loss and model are not compatible. - - Examples - -------- - Minimum example: - >>> from careamics.config import FCNAlgorithmConfig - >>> config_dict = { - ... "algorithm": "n2v", - ... "loss": "n2v", - ... "model": { - ... "architecture": "UNet", - ... } - ... } - >>> config = FCNAlgorithmConfig(**config_dict) - """ - - # Pydantic class configuration - model_config = ConfigDict( - protected_namespaces=(), # allows to use model_* as a field name - validate_assignment=True, - extra="allow", - ) - - # Mandatory fields - algorithm: Literal["n2v", "care", "n2n", "custom"] - """Name of the algorithm, as defined in SupportedAlgorithm. Use `custom` for custom - model architecture.""" - - loss: Literal["n2v", "mae", "mse"] - """Loss function to use, as defined in SupportedLoss.""" - - model: Union[UNetModel, CustomModel] = Field(discriminator="architecture") - """Model architecture to use, along with its parameters. Compatible architectures - are defined in SupportedArchitecture, and their Pydantic models in - `careamics.config.architectures`.""" - # TODO supported architectures are now all the architectures but does not warn users - # of the compatibility with the algorithm - - # Optional fields - optimizer: OptimizerModel = OptimizerModel() - """Optimizer to use, defined in SupportedOptimizer.""" - - lr_scheduler: LrSchedulerModel = LrSchedulerModel() - """Learning rate scheduler to use, defined in SupportedLrScheduler.""" - - @model_validator(mode="after") - def algorithm_cross_validation(self: Self) -> Self: - """Validate the algorithm model based on `algorithm`. - - N2V: - - loss must be n2v - - model must be a `UNetModel` - - Returns - ------- - Self - The validated model. - """ - # N2V - if self.algorithm == "n2v": - # n2v is only compatible with the n2v loss - if self.loss != "n2v": - raise ValueError( - f"Algorithm {self.algorithm} only supports loss `n2v`." - ) - - # n2v is only compatible with the UNet model - if not isinstance(self.model, UNetModel): - raise ValueError( - f"Model for algorithm {self.algorithm} must be a `UNetModel`." - ) - - # n2v requires the number of input and output channels to be the same - if self.model.in_channels != self.model.num_classes: - raise ValueError( - "N2V requires the same number of input and output channels. Make " - "sure that `in_channels` and `num_classes` are the same." - ) - - if self.algorithm == "care" or self.algorithm == "n2n": - if self.loss == "n2v": - raise ValueError("Supervised algorithms do not support loss `n2v`.") - - if (self.algorithm == "custom") != (self.model.architecture == "custom"): - raise ValueError( - "Algorithm and model architecture must be both `custom` or not." - ) - - return self - - def __str__(self) -> str: - """Pretty string representing the configuration. - - Returns - ------- - str - Pretty string. - """ - return pformat(self.model_dump()) - - @classmethod - def get_compatible_algorithms(cls) -> list[str]: - """Get the list of compatible algorithms. - - Returns - ------- - list of str - List of compatible algorithms. - """ - return ["n2v", "care", "n2n"] diff --git a/src/careamics/config/n2n_configuration.py b/src/careamics/config/n2n_configuration.py new file mode 100644 index 000000000..fb31f5ea3 --- /dev/null +++ b/src/careamics/config/n2n_configuration.py @@ -0,0 +1,101 @@ +"""N2N configuration.""" + +from bioimageio.spec.generic.v0_3 import CiteEntry + +from careamics.config.algorithms import N2NAlgorithm +from careamics.config.configuration import Configuration +from careamics.config.data import DataConfig + +N2N = "Noise2Noise" + +N2N_DESCRIPTION = ( + "Noise2Noise is a deep-learning-based algorithm that uses a U-Net " + "architecture to restore images. Noise2Noise is a self-supervised " + "algorithm that requires only noisy images to train the network. " + "The algorithm learns to predict the clean image from the noisy " + "image. Noise2Noise is particularly useful when clean images are " + "not available for training." +) + +N2N_REF = CiteEntry( + text="Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., " + 'Aittala, M. and Aila, T., 2018. "Noise2Noise: Learning image restoration ' + 'without clean data". arXiv preprint arXiv:1803.04189.', + doi="10.48550/arXiv.1803.04189", +) + + +class N2NConfiguration(Configuration): + """Noise2Noise configuration.""" + + algorithm_config: N2NAlgorithm + """Algorithm configuration.""" + + data_config: DataConfig + """Data configuration.""" + + def get_algorithm_friendly_name(self) -> str: + """ + Get the algorithm friendly name. + + Returns + ------- + str + Friendly name of the algorithm. + """ + return N2N + + def get_algorithm_keywords(self) -> list[str]: + """ + Get algorithm keywords. + + Returns + ------- + list[str] + List of keywords. + """ + return [ + "restoration", + "UNet", + "3D" if "Z" in self.data_config.axes else "2D", + "CAREamics", + "pytorch", + N2N, + ] + + def get_algorithm_references(self) -> str: + """ + Get the algorithm references. + + This is used to generate the README of the BioImage Model Zoo export. + + Returns + ------- + str + Algorithm references. + """ + return N2N_REF.text + " doi: " + N2N_REF.doi + + def get_algorithm_citations(self) -> list[CiteEntry]: + """ + Return a list of citation entries of the current algorithm. + + This is used to generate the model description for the BioImage Model Zoo. + + Returns + ------- + List[CiteEntry] + List of citation entries. + """ + return [N2N_REF] + + def get_algorithm_description(self) -> str: + """ + Get the algorithm description. + + Returns + ------- + str + Algorithm description. + """ + return N2N_DESCRIPTION diff --git a/src/careamics/config/n2v_configuration.py b/src/careamics/config/n2v_configuration.py new file mode 100644 index 000000000..9a6586115 --- /dev/null +++ b/src/careamics/config/n2v_configuration.py @@ -0,0 +1,266 @@ +"""N2V configuration.""" + +from bioimageio.spec.generic.v0_3 import CiteEntry +from pydantic import model_validator +from typing_extensions import Self + +from careamics.config.algorithms import N2VAlgorithm +from careamics.config.configuration import Configuration +from careamics.config.data.n2v_data_model import N2VDataConfig +from careamics.config.support import SupportedPixelManipulation + +N2V = "Noise2Void" +N2V2 = "N2V2" +STRUCT_N2V = "StructN2V" +STRUCT_N2V2 = "StructN2V2" + +N2V_REF = CiteEntry( + text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning ' + 'denoising from single noisy images". In Proceedings of the IEEE/CVF ' + "conference on computer vision and pattern recognition (pp. 2129-2137).", + doi="10.1109/cvpr.2019.00223", +) + +N2V2_REF = CiteEntry( + text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., " + '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified ' + 'sampling strategies and a tweaked network architecture". In European ' + "Conference on Computer Vision (pp. 503-518).", + doi="10.1007/978-3-031-25069-9_33", +) + +STRUCTN2V_REF = CiteEntry( + text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020." + '"Removing structured noise with self-supervised blind-spot ' + 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' + "Imaging (ISBI) (pp. 159-163).", + doi="10.1109/isbi45749.2020.9098336", +) + +N2V_DESCRIPTION = ( + "Noise2Void is a UNet-based self-supervised algorithm that " + "uses blind-spot training to denoise images. In short, in every " + "patches during training, random pixels are selected and their " + "value replaced by a neighboring pixel value. The network is then " + "trained to predict the original pixel value. The algorithm " + "relies on the continuity of the signal (neighboring pixels have " + "similar values) and the pixel-wise independence of the noise " + "(the noise in a pixel is not correlated with the noise in " + "neighboring pixels)." +) + +N2V2_DESCRIPTION = ( + "N2V2 is a variant of Noise2Void. " + + N2V_DESCRIPTION + + "\nN2V2 introduces blur-pool layers and removed skip " + "connections in the UNet architecture to remove checkboard " + "artefacts, a common artefacts ocurring in Noise2Void." +) + +STR_N2V_DESCRIPTION = ( + "StructN2V is a variant of Noise2Void. " + + N2V_DESCRIPTION + + "\nStructN2V uses a linear mask (horizontal or vertical) to replace " + "the pixel values of neighbors of the masked pixels by a random " + "value. Such masking allows removing 1D structured noise from the " + "the images, the main failure case of the original N2V." +) + +STR_N2V2_DESCRIPTION = ( + "StructN2V2 is a a variant of Noise2Void that uses both " + "structN2V and N2V2. " + + N2V_DESCRIPTION + + "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace " + "the pixel values of neighbors of the masked pixels by a random " + "value. Such masking allows removing 1D structured noise from the " + "the images, the main failure case of the original N2V." + "\nN2V2 introduces blur-pool layers and removed skip connections in " + "the UNet architecture to remove checkboard artefacts, a common " + "artefacts ocurring in Noise2Void." +) + + +class N2VConfiguration(Configuration): + """N2V configuration.""" + + algorithm_config: N2VAlgorithm + + data_config: N2VDataConfig + + @model_validator(mode="after") + def validate_n2v2(self) -> Self: + """Validate that the N2V2 strategy and models are set correctly. + + Returns + ------- + Self + The validateed configuration. + + + Raises + ------ + ValueError + If N2V2 is used with the wrong pixel manipulation strategy. + """ + if self.algorithm_config.model.n2v2: + if ( + self.data_config.get_masking_strategy() + != SupportedPixelManipulation.MEDIAN.value + ): + raise ValueError( + f"N2V2 can only be used with the " + f"{SupportedPixelManipulation.MEDIAN} pixel manipulation strategy" + f". Change the N2VManipulate transform strategy." + ) + else: + if ( + self.data_config.get_masking_strategy() + != SupportedPixelManipulation.UNIFORM.value + ): + raise ValueError( + f"N2V can only be used with the " + f"{SupportedPixelManipulation.UNIFORM} pixel manipulation strategy" + f". Change the N2VManipulate transform strategy." + ) + return self + + def set_n2v2(self, use_n2v2: bool) -> None: + """ + Set the configuration to use N2V2 or the vanilla Noise2Void. + + Parameters + ---------- + use_n2v2 : bool + Whether to use N2V2. + """ + self.data_config.set_n2v2(use_n2v2) + self.algorithm_config.model.n2v2 = use_n2v2 + + def get_algorithm_friendly_name(self) -> str: + """ + Get the friendly name of the algorithm. + + Returns + ------- + str + Friendly name. + """ + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.is_using_struct_n2v() + + if use_n2v2 and use_structN2V: + return STRUCT_N2V2 + elif use_n2v2: + return N2V2 + elif use_structN2V: + return STRUCT_N2V + else: + return N2V + + def get_algorithm_keywords(self) -> list[str]: + """ + Get algorithm keywords. + + Returns + ------- + list[str] + List of keywords. + """ + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.is_using_struct_n2v() + + keywords = [ + "denoising", + "restoration", + "UNet", + "3D" if "Z" in self.data_config.axes else "2D", + "CAREamics", + "pytorch", + N2V, + ] + + if use_n2v2: + keywords.append(N2V2) + if use_structN2V: + keywords.append(STRUCT_N2V) + + return keywords + + def get_algorithm_references(self) -> str: + """ + Get the algorithm references. + + This is used to generate the README of the BioImage Model Zoo export. + + Returns + ------- + str + Algorithm references. + """ + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.is_using_struct_n2v() + + references = [ + N2V_REF.text + " doi: " + N2V_REF.doi, + N2V2_REF.text + " doi: " + N2V2_REF.doi, + STRUCTN2V_REF.text + " doi: " + STRUCTN2V_REF.doi, + ] + + # return the (struct)N2V(2) references + if use_n2v2 and use_structN2V: + return "\n".join(references) + elif use_n2v2: + references.pop(-1) + return "\n".join(references) + elif use_structN2V: + references.pop(-2) + return "\n".join(references) + else: + return references[0] + + def get_algorithm_citations(self) -> list[CiteEntry]: + """ + Return a list of citation entries of the current algorithm. + + This is used to generate the model description for the BioImage Model Zoo. + + Returns + ------- + List[CiteEntry] + List of citation entries. + """ + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.is_using_struct_n2v() + + references = [N2V_REF] + + if use_n2v2: + references.append(N2V2_REF) + + if use_structN2V: + references.append(STRUCTN2V_REF) + + return references + + def get_algorithm_description(self) -> str: + """ + Return a description of the algorithm. + + This method is used to generate the README of the BioImage Model Zoo export. + + Returns + ------- + str + Description of the algorithm. + """ + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.is_using_struct_n2v() + + if use_n2v2 and use_structN2V: + return STR_N2V2_DESCRIPTION + elif use_n2v2: + return N2V2_DESCRIPTION + elif use_structN2V: + return STR_N2V_DESCRIPTION + else: + return N2V_DESCRIPTION diff --git a/src/careamics/config/references/__init__.py b/src/careamics/config/references/__init__.py deleted file mode 100644 index 4fc141138..000000000 --- a/src/careamics/config/references/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Module containing references to the algorithm used in CAREamics.""" - -__all__ = [ - "CARE", - "CUSTOM", - "N2N", - "N2V", - "N2V2", - "STRUCT_N2V", - "STRUCT_N2V2", - "CAREDescription", - "CARERef", - "N2NDescription", - "N2NRef", - "N2V2Description", - "N2V2Ref", - "N2VDescription", - "N2VRef", - "StructN2V2Description", - "StructN2VDescription", - "StructN2VRef", -] - -from .algorithm_descriptions import ( - CARE, - CUSTOM, - N2N, - N2V, - N2V2, - STRUCT_N2V, - STRUCT_N2V2, - CAREDescription, - N2NDescription, - N2V2Description, - N2VDescription, - StructN2V2Description, - StructN2VDescription, -) -from .references import ( - CARERef, - N2NRef, - N2V2Ref, - N2VRef, - StructN2VRef, -) diff --git a/src/careamics/config/references/algorithm_descriptions.py b/src/careamics/config/references/algorithm_descriptions.py deleted file mode 100644 index e60c9164d..000000000 --- a/src/careamics/config/references/algorithm_descriptions.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Descriptions of the algorithms used in CAREmics.""" - -from pydantic import BaseModel - -CUSTOM = "Custom" -N2V = "Noise2Void" -N2V2 = "N2V2" -STRUCT_N2V = "StructN2V" -STRUCT_N2V2 = "StructN2V2" -N2N = "Noise2Noise" -CARE = "CARE" - - -N2V_DESCRIPTION = ( - "Noise2Void is a UNet-based self-supervised algorithm that " - "uses blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels)." -) - - -class AlgorithmDescription(BaseModel): - """Description of an algorithm. - - Attributes - ---------- - description : str - Description of the algorithm. - """ - - description: str - - -class N2VDescription(AlgorithmDescription): - """Description of Noise2Void. - - Attributes - ---------- - description : str - Description of Noise2Void. - """ - - description: str = N2V_DESCRIPTION - - -class N2V2Description(AlgorithmDescription): - """Description of N2V2. - - Attributes - ---------- - description : str - Description of N2V2. - """ - - description: str = ( - "N2V2 is a variant of Noise2Void. " - + N2V_DESCRIPTION - + "\nN2V2 introduces blur-pool layers and removed skip " - "connections in the UNet architecture to remove checkboard " - "artefacts, a common artefacts ocurring in Noise2Void." - ) - - -class StructN2VDescription(AlgorithmDescription): - """Description of StructN2V. - - Attributes - ---------- - description : str - Description of StructN2V. - """ - - description: str = ( - "StructN2V is a variant of Noise2Void. " - + N2V_DESCRIPTION - + "\nStructN2V uses a linear mask (horizontal or vertical) to replace " - "the pixel values of neighbors of the masked pixels by a random " - "value. Such masking allows removing 1D structured noise from the " - "the images, the main failure case of the original N2V." - ) - - -class StructN2V2Description(AlgorithmDescription): - """Description of StructN2V2. - - Attributes - ---------- - description : str - Description of StructN2V2. - """ - - description: str = ( - "StructN2V2 is a a variant of Noise2Void that uses both " - "structN2V and N2V2. " - + N2V_DESCRIPTION - + "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace " - "the pixel values of neighbors of the masked pixels by a random " - "value. Such masking allows removing 1D structured noise from the " - "the images, the main failure case of the original N2V." - "\nN2V2 introduces blur-pool layers and removed skip connections in " - "the UNet architecture to remove checkboard artefacts, a common " - "artefacts ocurring in Noise2Void." - ) - - -class N2NDescription(AlgorithmDescription): - """Description of Noise2Noise. - - Attributes - ---------- - description : str - Description of Noise2Noise. - """ - - description: str = "Noise2Noise" # TODO - - -class CAREDescription(AlgorithmDescription): - """Description of CARE. - - Attributes - ---------- - description : str - Description of CARE. - """ - - description: str = "CARE" # TODO diff --git a/src/careamics/config/references/references.py b/src/careamics/config/references/references.py deleted file mode 100644 index 1d13cfad4..000000000 --- a/src/careamics/config/references/references.py +++ /dev/null @@ -1,39 +0,0 @@ -"""References for the CAREamics algorithms.""" - -from bioimageio.spec.generic.v0_3 import CiteEntry - -N2VRef = CiteEntry( - text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning ' - 'denoising from single noisy images". In Proceedings of the IEEE/CVF ' - "conference on computer vision and pattern recognition (pp. 2129-2137).", - doi="10.1109/cvpr.2019.00223", -) - -N2V2Ref = CiteEntry( - text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., " - '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified ' - 'sampling strategies and a tweaked network architecture". In European ' - "Conference on Computer Vision (pp. 503-518).", - doi="10.1007/978-3-031-25069-9_33", -) - -StructN2VRef = CiteEntry( - text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020." - '"Removing structured noise with self-supervised blind-spot ' - 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' - "Imaging (ISBI) (pp. 159-163).", - doi="10.1109/isbi45749.2020.9098336", -) - -N2NRef = CiteEntry( - text="Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., " - 'Aittala, M. and Aila, T., 2018. "Noise2Noise: Learning image restoration ' - 'without clean data". arXiv preprint arXiv:1803.04189.', - doi="10.48550/arXiv.1803.04189", -) - -CARERef = CiteEntry( - text='Weigert, Martin, et al. "Content-aware image restoration: pushing the ' - 'limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.', - doi="10.1038/s41592-018-0216-7", -) diff --git a/src/careamics/config/support/supported_algorithms.py b/src/careamics/config/support/supported_algorithms.py index db6ce326e..15b30274b 100644 --- a/src/careamics/config/support/supported_algorithms.py +++ b/src/careamics/config/support/supported_algorithms.py @@ -6,7 +6,11 @@ class SupportedAlgorithm(str, BaseEnum): - """Algorithms available in CAREamics.""" + """Algorithms available in CAREamics. + + These definitions are the same as the keyword `name` of the algorithm + configurations. + """ N2V = "n2v" """Noise2Void algorithm, a self-supervised approach based on blind denoising.""" @@ -25,9 +29,6 @@ class SupportedAlgorithm(str, BaseEnum): DENOISPLIT = "denoisplit" """An image splitting and denoising approach based on ladder VAE architectures.""" - CUSTOM = "custom" - """Custom algorithm, used for cases where a custom architecture is provided.""" - # PN2V = "pn2v" # HDN = "hdn" # SEG = "segmentation" diff --git a/src/careamics/config/support/supported_architectures.py b/src/careamics/config/support/supported_architectures.py index 52312dab4..83ec535e6 100644 --- a/src/careamics/config/support/supported_architectures.py +++ b/src/careamics/config/support/supported_architectures.py @@ -11,7 +11,3 @@ class SupportedArchitecture(str, BaseEnum): LVAE = "LVAE" """Ladder Variational Autoencoder used for muSplit and denoiSplit.""" - - CUSTOM = "custom" - """Keyword used for custom architectures provided by users and only compatible - with `FCNAlgorithmConfig` configuration.""" diff --git a/src/careamics/config/transformations/__init__.py b/src/careamics/config/transformations/__init__.py index 2b6f7cc1e..14b5abedb 100644 --- a/src/careamics/config/transformations/__init__.py +++ b/src/careamics/config/transformations/__init__.py @@ -1,7 +1,9 @@ """CAREamics transformation Pydantic models.""" __all__ = [ - "TRANSFORMS_UNION", + "N2V_TRANSFORMS_UNION", + "NORM_AND_SPATIAL_UNION", + "SPATIAL_TRANSFORMS_UNION", "N2VManipulateModel", "NormalizeModel", "TransformModel", @@ -13,6 +15,10 @@ from .n2v_manipulate_model import N2VManipulateModel from .normalize_model import NormalizeModel from .transform_model import TransformModel -from .transform_union import TRANSFORMS_UNION +from .transform_unions import ( + N2V_TRANSFORMS_UNION, + NORM_AND_SPATIAL_UNION, + SPATIAL_TRANSFORMS_UNION, +) from .xy_flip_model import XYFlipModel from .xy_random_rotate90_model import XYRandomRotate90Model diff --git a/src/careamics/config/transformations/transform_union.py b/src/careamics/config/transformations/transform_union.py deleted file mode 100644 index 810d3b75e..000000000 --- a/src/careamics/config/transformations/transform_union.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Type used to represent all transformations users can create.""" - -from typing import Annotated, Union - -from pydantic import Discriminator - -from .n2v_manipulate_model import N2VManipulateModel -from .xy_flip_model import XYFlipModel -from .xy_random_rotate90_model import XYRandomRotate90Model - -TRANSFORMS_UNION = Annotated[ - Union[ - XYFlipModel, - XYRandomRotate90Model, - N2VManipulateModel, - ], - Discriminator("name"), # used to tell the different transform models apart -] -"""Available transforms in CAREamics.""" diff --git a/src/careamics/config/transformations/transform_unions.py b/src/careamics/config/transformations/transform_unions.py new file mode 100644 index 000000000..2d9198de8 --- /dev/null +++ b/src/careamics/config/transformations/transform_unions.py @@ -0,0 +1,42 @@ +"""Type used to represent all transformations users can create.""" + +from typing import Annotated, Union + +from pydantic import Discriminator + +from .n2v_manipulate_model import N2VManipulateModel +from .normalize_model import NormalizeModel +from .xy_flip_model import XYFlipModel +from .xy_random_rotate90_model import XYRandomRotate90Model + +NORM_AND_SPATIAL_UNION = Annotated[ + Union[ + NormalizeModel, + XYFlipModel, + XYRandomRotate90Model, + N2VManipulateModel, + ], + Discriminator("name"), # used to tell the different transform models apart +] +"""All transforms including normalization.""" + + +SPATIAL_TRANSFORMS_UNION = Annotated[ + Union[ + XYFlipModel, + XYRandomRotate90Model, + ], + Discriminator("name"), # used to tell the different transform models apart +] +"""Available spatial transforms in CAREamics.""" + + +N2V_TRANSFORMS_UNION = Annotated[ + Union[ + XYFlipModel, + XYRandomRotate90Model, + N2VManipulateModel, + ], + Discriminator("name"), # used to tell the different transform models apart +] +"""Available N2V-compatible transforms in CAREamics.""" diff --git a/src/careamics/config/validators/__init__.py b/src/careamics/config/validators/__init__.py index 53ddbf8db..06a0f103f 100644 --- a/src/careamics/config/validators/__init__.py +++ b/src/careamics/config/validators/__init__.py @@ -1,5 +1,16 @@ """Validator utilities.""" -__all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"] +__all__ = [ + "check_axes_validity", + "model_matching_in_out_channels", + "model_without_final_activation", + "model_without_n2v2", + "patch_size_ge_than_8_power_of_2", +] +from .model_validators import ( + model_matching_in_out_channels, + model_without_final_activation, + model_without_n2v2, +) from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2 diff --git a/src/careamics/config/validators/model_validators.py b/src/careamics/config/validators/model_validators.py new file mode 100644 index 000000000..66f627855 --- /dev/null +++ b/src/careamics/config/validators/model_validators.py @@ -0,0 +1,84 @@ +"""Architecture model validators.""" + +from careamics.config.architectures import UNetModel + + +def model_without_n2v2(model: UNetModel) -> UNetModel: + """Validate that the Unet model does not have the n2v2 attribute. + + Parameters + ---------- + model : UNetModel + Model to validate. + + Returns + ------- + UNetModel + The validated model. + + Raises + ------ + ValueError + If the model has the `n2v2` attribute set to `True`. + """ + if model.n2v2: + raise ValueError( + "The algorithm does not support the `n2v2` attribute in the model. " + "Set it to `False`." + ) + + return model + + +def model_without_final_activation(model: UNetModel) -> UNetModel: + """Validate that the UNet model does not have the final_activation. + + Parameters + ---------- + model : UNetModel + Model to validate. + + Returns + ------- + UNetModel + The validated model. + + Raises + ------ + ValueError + If the model has the final_activation attribute set. + """ + if model.final_activation != "None": + raise ValueError( + "The algorithm does not support a `final_activation` in the model. " + 'Set it to `"None"`.' + ) + + return model + + +def model_matching_in_out_channels(model: UNetModel) -> UNetModel: + """Validate that the UNet model has the same number of channel inputs and outputs. + + Parameters + ---------- + model : UNetModel + Model to validate. + + Returns + ------- + UNetModel + Validated model. + + Raises + ------ + ValueError + If the model has different number of input and output channels. + """ + if model.num_classes != model.in_channels: + raise ValueError( + "The algorithm requires the same number of input and output channels. " + "Make sure that `in_channels` and `num_classes` are equal." + ) + + return model diff --git a/src/careamics/dataset/dataset_utils/iterate_over_files.py b/src/careamics/dataset/dataset_utils/iterate_over_files.py index a6497555a..c067cd5f0 100644 --- a/src/careamics/dataset/dataset_utils/iterate_over_files.py +++ b/src/careamics/dataset/dataset_utils/iterate_over_files.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from torch.utils.data import get_worker_info -from careamics.config import DataConfig, InferenceConfig +from careamics.config import GeneralDataConfig, InferenceConfig from careamics.file_io.read import read_tiff from careamics.utils.logging import get_logger @@ -19,7 +19,7 @@ def iterate_over_files( - data_config: Union[DataConfig, InferenceConfig], + data_config: Union[GeneralDataConfig, InferenceConfig], data_files: list[Path], target_files: Optional[list[Path]] = None, read_source_func: Callable = read_tiff, diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 0918513a6..4703cf8bd 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -9,13 +9,9 @@ import numpy as np from torch.utils.data import Dataset -from careamics.file_io.read import read_tiff -from careamics.transforms import Compose - -from ..config import DataConfig -from ..config.transformations import NormalizeModel -from ..utils.logging import get_logger -from .patching.patching import ( +from careamics.config import GeneralDataConfig, N2VDataConfig +from careamics.config.transformations import NormalizeModel +from careamics.dataset.patching.patching import ( PatchedOutput, Stats, prepare_patches_supervised, @@ -23,6 +19,9 @@ prepare_patches_unsupervised, prepare_patches_unsupervised_array, ) +from careamics.file_io.read import read_tiff +from careamics.transforms import Compose +from careamics.utils.logging import get_logger logger = get_logger(__name__) @@ -47,7 +46,7 @@ class InMemoryDataset(Dataset): def __init__( self, - data_config: DataConfig, + data_config: GeneralDataConfig, inputs: Union[np.ndarray, list[Path]], input_target: Optional[Union[np.ndarray, list[Path]]] = None, read_source_func: Callable = read_tiff, @@ -58,7 +57,7 @@ def __init__( Parameters ---------- - data_config : DataConfig + data_config : GeneralDataConfig Data configuration. inputs : numpy.ndarray or list[pathlib.Path] Input data. @@ -124,7 +123,7 @@ def __init__( target_stds=self.target_stats.stds, ) ] - + self.data_config.transforms, + + list(self.data_config.transforms), ) def _prepare_patches(self, supervised: bool) -> PatchedOutput: @@ -219,12 +218,12 @@ def __getitem__(self, index: int) -> tuple[np.ndarray, ...]: return self.patch_transform(patch=patch, target=target) - elif self.data_config.has_n2v_manipulate(): # TODO not compatible with HDN + elif isinstance(self.data_config, N2VDataConfig): return self.patch_transform(patch=patch) else: raise ValueError( "Something went wrong! No target provided (not supervised training) " - "and no N2V manipulation (no N2V training)." + "while the algorithm is not Noise2Void." ) def get_data_statistics(self) -> tuple[list[float], list[float]]: diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 680355e36..13b949eeb 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -10,7 +10,7 @@ import numpy as np from torch.utils.data import IterableDataset -from careamics.config import DataConfig +from careamics.config import GeneralDataConfig from careamics.config.transformations import NormalizeModel from careamics.file_io.read import read_tiff from careamics.transforms import Compose @@ -49,7 +49,7 @@ class PathIterableDataset(IterableDataset): def __init__( self, - data_config: DataConfig, + data_config: GeneralDataConfig, src_files: list[Path], target_files: Optional[list[Path]] = None, read_source_func: Callable = read_tiff, @@ -58,7 +58,7 @@ def __init__( Parameters ---------- - data_config : DataConfig + data_config : GeneralDataConfig Data configuration. src_files : list[Path] List of data files. @@ -115,7 +115,7 @@ def __init__( target_stds=self.target_stats.stds, ) ] - + data_config.transforms + + list(data_config.transforms) ) def _calculate_mean_and_std(self) -> tuple[Stats, Stats]: diff --git a/src/careamics/lightning/lightning_module.py b/src/careamics/lightning/lightning_module.py index d9b65ff8d..81492831e 100644 --- a/src/careamics/lightning/lightning_module.py +++ b/src/careamics/lightning/lightning_module.py @@ -6,7 +6,7 @@ import pytorch_lightning as L from torch import Tensor, nn -from careamics.config import FCNAlgorithmConfig, VAEAlgorithmConfig +from careamics.config import UNetBasedAlgorithm, VAEBasedAlgorithm from careamics.config.support import ( SupportedAlgorithm, SupportedArchitecture, @@ -34,6 +34,7 @@ NoiseModel = Union[GaussianMixtureNoiseModel, MultiChannelNoiseModel] +# TODO rename to UNetModule class FCNModule(L.LightningModule): """ CAREamics Lightning module. @@ -60,7 +61,7 @@ class FCNModule(L.LightningModule): Learning rate scheduler name. """ - def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None: + def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None: """Lightning module for CAREamics. This class encapsulates the a PyTorch model along with the training, validation, @@ -74,7 +75,9 @@ def __init__(self, algorithm_config: Union[FCNAlgorithmConfig, dict]) -> None: super().__init__() # if loading from a checkpoint, AlgorithmModel needs to be instantiated if isinstance(algorithm_config, dict): - algorithm_config = FCNAlgorithmConfig(**algorithm_config) + algorithm_config = UNetBasedAlgorithm( + **algorithm_config + ) # TODO this needs to be updated using the algorithm-specific class # create model and loss function self.model: nn.Module = model_factory(algorithm_config.model) @@ -266,7 +269,7 @@ class VAEModule(L.LightningModule): Learning rate scheduler name. """ - def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None: + def __init__(self, algorithm_config: Union[VAEBasedAlgorithm, dict]) -> None: """Lightning module for CAREamics. This class encapsulates the a PyTorch model along with the training, validation, @@ -280,7 +283,7 @@ def __init__(self, algorithm_config: Union[VAEAlgorithmConfig, dict]) -> None: super().__init__() # if loading from a checkpoint, AlgorithmModel needs to be instantiated self.algorithm_config = ( - VAEAlgorithmConfig(**algorithm_config) + VAEBasedAlgorithm(**algorithm_config) if isinstance(algorithm_config, dict) else algorithm_config ) @@ -656,9 +659,10 @@ def create_careamics_module( algorithm_configuration["model"] = model_configuration # call the parent init using an AlgorithmModel instance + # TODO broken by new configutations! algorithm_str = algorithm_configuration["algorithm"] - if algorithm_str in FCNAlgorithmConfig.get_compatible_algorithms(): - return FCNModule(FCNAlgorithmConfig(**algorithm_configuration)) + if algorithm_str in UNetBasedAlgorithm.get_compatible_algorithms(): + return FCNModule(UNetBasedAlgorithm(**algorithm_configuration)) else: raise NotImplementedError( f"Model {algorithm_str} is not implemented or unknown." diff --git a/src/careamics/lightning/train_data_module.py b/src/careamics/lightning/train_data_module.py index 1663ef890..6df263fa0 100644 --- a/src/careamics/lightning/train_data_module.py +++ b/src/careamics/lightning/train_data_module.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from torch.utils.data import DataLoader, IterableDataset -from careamics.config import DataConfig +from careamics.config.data import DataConfig, GeneralDataConfig, N2VDataConfig from careamics.config.support import SupportedData from careamics.config.transformations import TransformModel from careamics.dataset.dataset_utils import ( @@ -119,7 +119,7 @@ class TrainDataModule(L.LightningDataModule): def __init__( self, - data_config: DataConfig, + data_config: GeneralDataConfig, train_data: Union[Path, str, NDArray], val_data: Optional[Union[Path, str, NDArray]] = None, train_data_target: Optional[Union[Path, str, NDArray]] = None, @@ -219,7 +219,7 @@ def __init__( ) # configuration - self.data_config: DataConfig = data_config + self.data_config: GeneralDataConfig = data_config self.data_type: str = data_config.data_type self.batch_size: int = data_config.batch_size self.use_in_memory: bool = use_in_memory @@ -502,12 +502,23 @@ def create_train_datamodule( """Create a TrainDataModule. This function is used to explicitly pass the parameters usually contained in a - `data_model` configuration to a TrainDataModule. + `GenericDataConfig` to a TrainDataModule. Since the lightning datamodule has no access to the model, make sure that the parameters passed to the datamodule are consistent with the model's requirements and are coherent. + By default, the train DataModule will be set for Noise2Void if no target data is + provided. That means that it will add a `N2VManipulateModel` transformation to the + list of augmentations. The default augmentations are XY flip, XY rotation, and N2V + pixel manipulation. If you pass a training target data, the default behaviour is to + train a supervised model. It will use the default XY flip and rotation + augmentations. + + To use a different set of transformations, you can pass a list of transforms to + `transforms`. Note that if you intend to use Noise2Void, you should add + `N2VManipulateModel` as the last transform in the list of transformations. + The data module can be used with Path, str or numpy arrays. In the case of numpy arrays, it loads and computes all the patches in memory. For Path and str inputs, it calculates the total file size and estimate whether it can fit in @@ -518,11 +529,6 @@ def create_train_datamodule( To use array data, set `data_type` to `array` and pass a numpy array to `train_data`. - In particular, N2V requires a specific transformation (N2V manipulates), which is - not compatible with supervised training. The default transformations applied to the - training patches are defined in `careamics.config.data_model`. To use different - transformations, pass a list of transforms. See examples for more details. - By default, CAREamics only supports types defined in `careamics.config.support.SupportedData`. To read custom data types, you can set `data_type` to `custom` and provide a function that returns a numpy array from a @@ -627,12 +633,12 @@ def create_train_datamodule( transforms: >>> import numpy as np >>> from careamics.lightning import create_train_datamodule + >>> from careamics.config.transformations import XYFlipModel, N2VManipulateModel >>> from careamics.config.support import SupportedTransform >>> my_array = np.arange(256).reshape(16, 16) >>> my_transforms = [ - ... { - ... "name": SupportedTransform.XY_FLIP.value, - ... } + ... XYFlipModel(flip_y=False), + ... N2VManipulateModel() ... ] >>> data_module = create_train_datamodule( ... train_data=my_array, @@ -659,21 +665,15 @@ def create_train_datamodule( if transforms is not None: data_dict["transforms"] = transforms - # validate configuration - data_config = DataConfig(**data_dict) + # TODO not compatible with HDN, consider adding an argument for n2v/hdn + if train_target_data is None: + data_config: GeneralDataConfig = N2VDataConfig(**data_dict) + assert isinstance(data_config, N2VDataConfig) - # N2V specific checks, N2V, structN2V, and transforms - if data_config.has_n2v_manipulate(): - # there is not target, n2v2 and structN2V can be changed - if train_target_data is None: - data_config.set_N2V2(use_n2v2) - data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span) - else: - raise ValueError( - "Cannot have both supervised training (target data) and " - "N2V manipulation in the transforms. Pass a list of transforms " - "that is compatible with your supervised training." - ) + data_config.set_n2v2(use_n2v2) + data_config.set_structN2V_mask(struct_n2v_axis, struct_n2v_span) + else: + data_config = DataConfig(**data_dict) # sanity check on the dataloader parameters if "batch_size" in dataloader_params: diff --git a/src/careamics/model_io/bioimage/_readme_factory.py b/src/careamics/model_io/bioimage/_readme_factory.py index 7db30c6fe..63d57927a 100644 --- a/src/careamics/model_io/bioimage/_readme_factory.py +++ b/src/careamics/model_io/bioimage/_readme_factory.py @@ -55,7 +55,7 @@ def readme_factory( readme.touch() # algorithm pretty name - algorithm_flavour = config.get_algorithm_flavour() + algorithm_flavour = config.get_algorithm_friendly_name() algorithm_pretty_name = algorithm_flavour + " - CAREamics" description = [f"# {algorithm_pretty_name}\n\n"] diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index 82f5082fe..b40894d6f 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -28,14 +28,14 @@ WeightsDescr, ) -from careamics.config import Configuration, DataConfig +from careamics.config import Configuration, GeneralDataConfig from ._readme_factory import readme_factory def _create_axes( array: np.ndarray, - data_config: DataConfig, + data_config: GeneralDataConfig, channel_names: Optional[list[str]] = None, is_input: bool = True, ) -> list[AxisBase]: @@ -102,7 +102,7 @@ def _create_axes( def _create_inputs_ouputs( input_array: np.ndarray, output_array: np.ndarray, - data_config: DataConfig, + data_config: GeneralDataConfig, input_path: Union[Path, str], output_path: Union[Path, str], channel_names: Optional[list[str]] = None, diff --git a/src/careamics/model_io/bmz_io.py b/src/careamics/model_io/bmz_io.py index 5e17297ab..bc92e6741 100644 --- a/src/careamics/model_io/bmz_io.py +++ b/src/careamics/model_io/bmz_io.py @@ -127,20 +127,9 @@ def export_to_bmz( Channel names, by default None. model_version : str, default="0.1.0" Model version. - - Raises - ------ - ValueError - If the model is a Custom model. """ path_to_archive = Path(path_to_archive) - # method is not compatible with Custom models - if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM: - raise ValueError( - "Exporting Custom models to BioImage Model Zoo format is not supported." - ) - if path_to_archive.suffix != ".zip": raise ValueError( f"Path to archive must point to a zip file, got {path_to_archive}." diff --git a/src/careamics/model_io/model_io_utils.py b/src/careamics/model_io/model_io_utils.py index 565a0ca2b..bd140b4f7 100644 --- a/src/careamics/model_io/model_io_utils.py +++ b/src/careamics/model_io/model_io_utils.py @@ -5,7 +5,7 @@ import torch -from careamics.config import Configuration +from careamics.config import Configuration, configuration_factory from careamics.lightning.lightning_module import FCNModule, VAEModule from careamics.model_io.bmz_io import load_from_bmz from careamics.utils import check_path_exists @@ -92,4 +92,4 @@ def _load_checkpoint( f"{cfg_dict['algorithm_config']['model']['architecture']}" ) - return model, Configuration(**cfg_dict) + return model, configuration_factory(cfg_dict) diff --git a/src/careamics/models/lvae/likelihoods.py b/src/careamics/models/lvae/likelihoods.py index 51c5fbef2..a38b0dbb3 100644 --- a/src/careamics/models/lvae/likelihoods.py +++ b/src/careamics/models/lvae/likelihoods.py @@ -324,6 +324,8 @@ def _set_params_to_same_device_as( if self.data_mean.device != correct_device_tensor.device: self.data_mean = self.data_mean.to(correct_device_tensor.device) self.data_std = self.data_std.to(correct_device_tensor.device) + if correct_device_tensor.device != self.noiseModel.device: + self.noiseModel.to_device(correct_device_tensor.device) def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]: return x, None diff --git a/src/careamics/models/lvae/lvae.py b/src/careamics/models/lvae/lvae.py index 6a608f2c8..7f496ae38 100644 --- a/src/careamics/models/lvae/lvae.py +++ b/src/careamics/models/lvae/lvae.py @@ -12,8 +12,6 @@ import torch import torch.nn as nn -from careamics.config.architectures import register_model - from ..activation import get_activation from .layers import ( BottomUpDeterministicResBlock, @@ -25,7 +23,6 @@ from .utils import Interpolate, ModelType, crop_img_tensor -@register_model("LVAE") class LadderVAE(nn.Module): """ Constructor. diff --git a/src/careamics/models/lvae/noise_models.py b/src/careamics/models/lvae/noise_models.py index 6bd1044c9..e89814de3 100644 --- a/src/careamics/models/lvae/noise_models.py +++ b/src/careamics/models/lvae/noise_models.py @@ -3,6 +3,7 @@ import os from typing import TYPE_CHECKING, Optional +from numpy.typing import NDArray import numpy as np import torch import torch.nn as nn @@ -13,63 +14,59 @@ # TODO this module shouldn't be in lvae folder -def create_histogram(bins, min_val, max_val, observation, signal): +def create_histogram( + bins: int, min_val: float, max_val: float, observation: NDArray, signal: NDArray +) -> NDArray: """ Creates a 2D histogram from 'observation' and 'signal'. Parameters ---------- - bins: int - The number of bins in x and y. The total number of 2D bins is 'bins'**2. - min_val: float - the lower bound of the lowest bin in x and y. - max_val: float - the highest bound of the highest bin in x and y. - observation: numpy array - A 3D numpy array that is interpretted as a stack of 2D images. - The number of images has to be divisible by the number of images in 'signal'. - It is assumed that n subsequent images in observation belong to one image image in 'signal'. - signal: numpy array - A 3D numpy array that is interpretted as a stack of 2D images. + bins : int + Number of bins in x and y. + min_val : float + Lower bound of the lowest bin in x and y. + max_val : float + Upper bound of the highest bin in x and y. + observation : np.ndarray + 3D numpy array (stack of 2D images). + Observation.shape[0] must be divisible by signal.shape[0]. + Assumes that n subsequent images in observation belong to one image in 'signal'. + signal : np.ndarray + 3D numpy array (stack of 2D images). Returns ------- - histogram: numpy array + histogram : np.ndarray A 3D array: - 'histogram[0,...]' holds the normalized 2D counts. - Each row sums to 1, describing p(x_i|s_i). - 'histogram[1,...]' holds the lower boundaries of each bin in y. - 'histogram[2,...]' holds the upper boundaries of each bin in y. - The values for x can be obtained by transposing 'histogram[1,...]' and 'histogram[2,...]'. + - histogram[0]: Normalized 2D counts. + - histogram[1]: Lower boundaries of bins along y. + - histogram[2]: Upper boundaries of bins along y. + The values for x can be obtained by transposing 'histogram[1]' and 'histogram[2]'. """ - # TODO refactor this function - img_factor = int(observation.shape[0] / signal.shape[0]) histogram = np.zeros((3, bins, bins)) - ra = [min_val, max_val] - - for i in range(observation.shape[0]): - observation_ = observation[i].copy().ravel() - - signal_ = (signal[i // img_factor].copy()).ravel() - a = np.histogram2d(signal_, observation_, bins=bins, range=[ra, ra]) - histogram[0] = histogram[0] + a[0] + 1e-30 # This is for numerical stability - - for i in range(bins): - if ( - np.sum(histogram[0, i, :]) > 1e-20 - ): # We exclude empty rows from normalization - histogram[0, i, :] /= np.sum( - histogram[0, i, :] - ) # we normalize each non-empty row - - for i in range(bins): - histogram[1, :, i] = a[1][ - :-1 - ] # The lower boundaries of each bin in y are stored in dimension 1 - histogram[2, :, i] = a[1][ - 1: - ] # The upper boundaries of each bin in y are stored in dimension 2 - # The accordent numbers for x are just transopsed. + + value_range = [min_val, max_val] + + # Compute mapping factor between observation and signal samples + obs_to_signal_shape_factor = int(observation.shape[0] / signal.shape[0]) + + # Flatten arrays and align signal values + signal_indices = np.arange(observation.shape[0]) // obs_to_signal_shape_factor + signal_values = signal[signal_indices].ravel() + observation_values = observation.ravel() + + count_histogram, signal_edges, _ = np.histogram2d( + signal_values, observation_values, bins=bins, range=[value_range, value_range] + ) + + # Normalize rows to obtain probabilities + row_sums = count_histogram.sum(axis=1, keepdims=True) + count_histogram /= np.clip(row_sums, a_min=1e-20, a_max=None) + + histogram[0] = count_histogram + histogram[1] = signal_edges[:-1][..., np.newaxis] + histogram[2] = signal_edges[1:][..., np.newaxis] return histogram @@ -111,8 +108,11 @@ def noise_model_factory( # TODO train a new model. Config should always be provided? if nm.model_type == "GaussianMixtureNoiseModel": # TODO one model for each channel all make this choise inside the model? - trained_nm = train_gm_noise_model(nm) - noise_models.append(trained_nm) + # trained_nm = train_gm_noise_model(nm) + # noise_models.append(trained_nm) + raise NotImplementedError( + "GaussianMixtureNoiseModel model training is not implemented." + ) else: raise NotImplementedError( f"Model {nm.model_type} is not implemented" @@ -163,6 +163,8 @@ def __init__(self, nmodels: list[GaussianMixtureNoiseModel]): List of noise models, one for each output channel. """ super().__init__() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + for i, nmodel in enumerate(nmodels): # TODO refactor this !!! if nmodel is not None: self.add_module( @@ -176,6 +178,13 @@ def __init__(self, nmodels: list[GaussianMixtureNoiseModel]): print(f"[{self.__class__.__name__}] Nmodels count:{self._nm_cnt}") + def to_device(self, device: torch.device): + self.device = device + self.to(device) + for ch_idx in range(self._nm_cnt): + nmodel = getattr(self, f"nmodel_{ch_idx}") + nmodel.to_device(device) + def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor: """Compute the likelihood of observations given signals for each channel. @@ -212,28 +221,6 @@ def likelihood(self, obs: torch.Tensor, signal: torch.Tensor) -> torch.Tensor: return torch.cat(ll_list, dim=1) -# TODO: is this needed? -def fastShuffle(series, num): - """_summary_. - - Parameters - ---------- - series : _type_ - _description_ - num : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - """ - length = series.shape[0] - for _ in range(num): - series = series[np.random.permutation(length), :] - return series - - class GaussianMixtureNoiseModel(nn.Module): """Define a noise model parameterized as a mixture of gaussians. @@ -276,166 +263,176 @@ class GaussianMixtureNoiseModel(nn.Module): """ # TODO training a NM relies on getting a clean data(N2V e.g,) - def __init__(self, config: GaussianMixtureNMConfig): + def __init__(self, config: GaussianMixtureNMConfig) -> None: super().__init__() + self.device = torch.device("cpu") - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if config.path is None: - self.mode = "train" - # TODO this is (probably) to train a nm. We leave it for later refactoring - weight = config.weight - n_gaussian = config.n_gaussian - n_coeff = config.n_coeff - min_signal = torch.Tensor([config.min_signal]) - max_signal = torch.Tensor([config.max_signal]) - # TODO min_sigma cant be None ? - self.min_sigma = config.min_sigma - if weight is None: - weight = torch.randn(n_gaussian * 3, n_coeff) - weight[n_gaussian : 2 * n_gaussian, 1] = ( - torch.log(max_signal - min_signal).float().to(self.device) - ) - weight.requires_grad = True - - self.n_gaussian = weight.shape[0] // 3 - self.n_coeff = weight.shape[1] - self.weight = weight - self.min_signal = torch.Tensor([min_signal]).to(self.device) - self.max_signal = torch.Tensor([max_signal]).to(self.device) - self.tol = torch.tensor([1e-10]).to(self.device) - # TODO refactor to train on CPU! - else: + if config.path is not None: params = np.load(config.path) - self.mode = "inference" # TODO better name? + else: + params = config.model_dump(exclude_none=True) + + min_sigma = torch.tensor(params["min_sigma"]) + min_signal = torch.tensor(params["min_signal"]) + max_signal = torch.tensor(params["max_signal"]) + self.register_buffer("min_signal", min_signal) + self.register_buffer("max_signal", max_signal) + self.register_buffer("min_sigma", min_sigma) + self.register_buffer("tolerance", torch.tensor([1e-10])) + + if "trained_weight" in params: + weight = torch.tensor(params["trained_weight"]) + elif "weight" in params and params["weight"] is not None: + weight = torch.tensor(params["weight"]) + else: + weight = self._initialize_weights( + params["n_gaussian"], params["n_coeff"], max_signal, min_signal + ) - self.min_signal = torch.Tensor(params["min_signal"]) - self.max_signal = torch.Tensor(params["max_signal"]) + self.n_gaussian = weight.shape[0] // 3 + self.n_coeff = weight.shape[1] - self.weight = torch.Tensor(params["trained_weight"]) - self.min_sigma = params["min_sigma"].item() - self.n_gaussian = self.weight.shape[0] // 3 # TODO why // 3 ? - self.n_coeff = self.weight.shape[1] - self.tol = torch.Tensor([1e-10]) - self.min_signal = torch.Tensor([self.min_signal]) - self.max_signal = torch.Tensor([self.max_signal]) + self.register_parameter("weight", nn.Parameter(weight)) + self._set_model_mode(mode="prediction") print(f"[{self.__class__.__name__}] min_sigma: {self.min_sigma}") - def polynomialRegressor(self, weightParams, signals): - """Combines `weightParams` and signal `signals` to regress for the gaussian parameter values. + def _initialize_weights( + self, + n_gaussian: int, + n_coeff: int, + max_signal: torch.Tensor, + min_signal: torch.Tensor, + ) -> torch.Tensor: + """Create random weight initialization.""" + weight = torch.randn(n_gaussian * 3, n_coeff) + weight[n_gaussian : 2 * n_gaussian, 1] = torch.log( + max_signal - min_signal + ).float() + return weight + + def to_device(self, device: torch.device): + self.device = device + self.to(device) + + def _set_model_mode(self, mode: str) -> None: + """Move parameters to the device and set weights' requires_grad depending on the mode""" + if mode == "train": + self.weight.requires_grad = True + else: + self.weight.requires_grad = False + + def polynomial_regressor( + self, weight_params: torch.Tensor, signals: torch.Tensor + ) -> torch.Tensor: + """Combines `weight_params` and signal `signals` to regress for the gaussian parameter values. Parameters ---------- - weightParams : torch.cuda.FloatTensor + weight_params : Tensor Corresponds to specific rows of the `self.weight` - signals : torch.cuda.FloatTensor + signals : Tensor Signals Returns ------- - value : torch.cuda.FloatTensor + value : Tensor Corresponds to either of mean, standard deviation or weight, evaluated at `signals` """ - value = 0 - for i in range(weightParams.shape[0]): - value += weightParams[i] * ( + value = torch.zeros_like(signals) + for i in range(weight_params.shape[0]): + value += weight_params[i] * ( ((signals - self.min_signal) / (self.max_signal - self.min_signal)) ** i ) return value - def normalDens(self, x, m_=0.0, std_=None): - """Evaluates the normal probability density at `x` given the mean `m` and standard deviation `std`. + def normal_density( + self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor + ) -> torch.Tensor: + """ + Evaluates the normal probability density at `x` given the mean `mean` and standard deviation `std`. Parameters ---------- - x: torch.cuda.FloatTensor + x: Tensor Observations - m_: torch.cuda.FloatTensor + mean: Tensor Mean - std_: torch.cuda.FloatTensor + std: Tensor Standard-deviation Returns ------- - tmp: torch.cuda.FloatTensor - Normal probability density of `x` given `m_` and `std_` - + tmp: Tensor + Normal probability density of `x` given `mean` and `std` """ - tmp = -((x - m_) ** 2) - tmp = tmp / (2.0 * std_ * std_) + tmp = -((x - mean) ** 2) + tmp = tmp / (2.0 * std * std) tmp = torch.exp(tmp) - tmp = tmp / torch.sqrt((2.0 * np.pi) * std_ * std_) - # print(tmp.min().item(), tmp.mean().item(), tmp.max().item(), tmp.shape) + tmp = tmp / torch.sqrt((2.0 * np.pi) * std * std) return tmp - def likelihood(self, observations, signals): - """Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters. + def likelihood( + self, observations: torch.Tensor, signals: torch.Tensor + ) -> torch.Tensor: + """ + Evaluates the likelihood of observations given the signals and the corresponding gaussian parameters. Parameters ---------- - observations : torch.cuda.FloatTensor + observations : Tensor Noisy observations - signals : torch.cuda.FloatTensor + signals : Tensor Underlying signals Returns ------- - value :p + self.tol + value: torch.Tensor: Likelihood of observations given the signals and the GMM noise model - """ - if self.mode != "train": - signals = signals.cpu() - observations = observations.cpu() - self.weight = self.weight.to(signals.device) - self.min_signal = self.min_signal.to(signals.device) - self.max_signal = self.max_signal.to(signals.device) - self.tol = self.tol.to(signals.device) - - gaussianParameters = self.getGaussianParameters(signals) + gaussian_parameters: list[torch.Tensor] = self.get_gaussian_parameters(signals) p = 0 for gaussian in range(self.n_gaussian): p += ( - self.normalDens( + self.normal_density( observations, - gaussianParameters[gaussian], - gaussianParameters[self.n_gaussian + gaussian], + gaussian_parameters[gaussian], + gaussian_parameters[self.n_gaussian + gaussian], ) - * gaussianParameters[2 * self.n_gaussian + gaussian] + * gaussian_parameters[2 * self.n_gaussian + gaussian] ) - return p + self.tol + return p + self.tolerance - def getGaussianParameters(self, signals): - """Returns the noise model for given signals + def get_gaussian_parameters(self, signals: torch.Tensor) -> list[torch.Tensor]: + """ + Returns the noise model for given signals Parameters ---------- - signals : torch.cuda.FloatTensor + signals : Tensor Underlying signals Returns ------- - noiseModel: list of torch.cuda.FloatTensor + noise_model: list of Tensor Contains a list of `mu`, `sigma` and `alpha` for the `signals` """ - noiseModel = [] + noise_model = [] mu = [] sigma = [] alpha = [] kernels = self.weight.shape[0] // 3 for num in range(kernels): - mu.append(self.polynomialRegressor(self.weight[num, :], signals)) - # expval = torch.exp(torch.clamp(self.weight[kernels + num, :], max=MAX_VAR_W)) + mu.append(self.polynomial_regressor(self.weight[num, :], signals)) expval = torch.exp(self.weight[kernels + num, :]) - # self.maxval = max(self.maxval, expval.max().item()) - sigmaTemp = self.polynomialRegressor(expval, signals) - sigmaTemp = torch.clamp(sigmaTemp, min=self.min_sigma) - sigma.append(torch.sqrt(sigmaTemp)) + sigma_temp = self.polynomial_regressor(expval, signals) + sigma_temp = torch.clamp(sigma_temp, min=self.min_sigma) + sigma.append(torch.sqrt(sigma_temp)) expval = torch.exp( - self.polynomialRegressor(self.weight[2 * kernels + num, :], signals) - + self.tol + self.polynomial_regressor(self.weight[2 * kernels + num, :], signals) + + self.tolerance ) alpha.append(expval) @@ -459,15 +456,30 @@ def getGaussianParameters(self, signals): mu[ker] = mu[ker] - sum_means + signals for i in range(kernels): - noiseModel.append(mu[i]) + noise_model.append(mu[i]) for j in range(kernels): - noiseModel.append(sigma[j]) + noise_model.append(sigma[j]) for k in range(kernels): - noiseModel.append(alpha[k]) + noise_model.append(alpha[k]) + + return noise_model - return noiseModel + @staticmethod + def _fast_shuffle(series: torch.Tensor, num: int) -> torch.Tensor: + """Shuffle the inputs randomly num times""" + length = series.shape[0] + for _ in range(num): + idx = torch.randperm(length) + series = series[idx, :] + return series - def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip): + def get_signal_observation_pairs( + self, + signal: NDArray, + observation: NDArray, + lower_clip: float, + upper_clip: float, + ) -> torch.Tensor: """Returns the Signal-Observation pixel intensities as a two-column array Parameters @@ -476,19 +488,18 @@ def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip): Clean Signal Data observation: numpy array Noisy observation Data - lowerClip: float + lower_clip: float Lower percentile bound for clipping. - upperClip: float + upper_clip: float Upper percentile bound for clipping. Returns ------- - noiseModel: list of torch floats + noise_model: list of torch floats Contains a list of `mu`, `sigma` and `alpha` for the `signals` - """ - lb = np.percentile(signal, lowerClip) - ub = np.percentile(signal, upperClip) + lb = np.percentile(signal, lower_clip) + ub = np.percentile(signal, upper_clip) stepsize = observation[0].size n_observations = observation.shape[0] n_signals = signal.shape[0] @@ -501,19 +512,20 @@ def getSignalObservationPairs(self, signal, observation, lowerClip, upperClip): sig_obs_pairs = sig_obs_pairs[ (sig_obs_pairs[:, 0] > lb) & (sig_obs_pairs[:, 0] < ub) ] - return fastShuffle(sig_obs_pairs, 2) + sig_obs_pairs = sig_obs_pairs.astype(np.float32) + sig_obs_pairs = torch.from_numpy(sig_obs_pairs) + return self._fast_shuffle(sig_obs_pairs, 2) def fit( self, - signal, - observation, - learning_rate=1e-1, - batchSize=250000, - n_epochs=2000, - name="GMMNoiseModel.npz", - lowerClip=0, - upperClip=100, - ): + signal: NDArray, + observation: NDArray, + learning_rate: float = 1e-1, + batch_size: int = 250000, + n_epochs: int = 2000, + lower_clip: float = 0.0, + upper_clip: float = 100.0, + ) -> list[float]: """Training to learn the noise model from signal - observation pairs. Parameters @@ -524,49 +536,42 @@ def fit( Noisy Observation Data learning_rate: float Learning rate. Default = 1e-1. - batchSize: int + batch_size: int Nini-batch size. Default = 250000. n_epochs: int Number of epochs. Default = 2000. - name: string - - Model name. Default is `GMMNoiseModel`. This model after being trained is saved at the location `path`. - - lowerClip : int + lower_clip : int Lower percentile for clipping. Default is 0. - upperClip : int + upper_clip : int Upper percentile for clipping. Default is 100. - - """ - sig_obs_pairs = self.getSignalObservationPairs( - signal, observation, lowerClip, upperClip - ) - counter = 0 + self._set_model_mode(mode="train") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.to_device(device) optimizer = torch.optim.Adam([self.weight], lr=learning_rate) - loss_arr = [] + sig_obs_pairs = self.get_signal_observation_pairs( + signal, observation, lower_clip, upper_clip + ) + + train_losses = [] + counter = 0 for t in range(n_epochs): - if (counter + 1) * batchSize >= sig_obs_pairs.shape[0]: + if (counter + 1) * batch_size >= sig_obs_pairs.shape[0]: counter = 0 - sig_obs_pairs = fastShuffle(sig_obs_pairs, 1) + sig_obs_pairs = self._fast_shuffle(sig_obs_pairs, 1) batch_vectors = sig_obs_pairs[ - counter * batchSize : (counter + 1) * batchSize, : + counter * batch_size : (counter + 1) * batch_size, : ] - observations = batch_vectors[:, 1].astype(np.float32) - signals = batch_vectors[:, 0].astype(np.float32) - observations = ( - torch.from_numpy(observations.astype(np.float32)) - .float() - .to(self.device) - ) - signals = torch.from_numpy(signals).float().to(self.device) + observations = batch_vectors[:, 1].to(self.device) + signals = batch_vectors[:, 0].to(self.device) p = self.likelihood(observations, signals) - jointLoss = torch.mean(-torch.log(p)) - loss_arr.append(jointLoss.item()) + joint_loss = torch.mean(-torch.log(p)) + train_losses.append(joint_loss.item()) + if self.weight.isnan().any() or self.weight.isinf().any(): print( "NaN or Inf detected in the weights. Aborting training at epoch: ", @@ -575,19 +580,77 @@ def fit( break if t % 100 == 0: - print(t, np.mean(loss_arr)) + last_losses = train_losses[-100:] + print(t, np.mean(last_losses)) optimizer.zero_grad() - jointLoss.backward() + joint_loss.backward() optimizer.step() counter += 1 - self.trained_weight = self.weight.cpu().detach().numpy() - self.min_signal = self.min_signal.cpu().detach().numpy() - self.max_signal = self.max_signal.cpu().detach().numpy() + self._set_model_mode(mode="prediction") + self.to_device(torch.device("cpu")) print("===================\n") + return train_losses + + def sample_observation_from_signal(self, signal: NDArray) -> NDArray: + """ + Sample an instance of observation based on an input signal using a + learned Gaussian Mixture Model. For each pixel in the input signal, + samples a corresponding noisy pixel. + + Parameters + ---------- + signal: numpy array + Clean 2D signal data. + + Returns + ------- + observation: numpy array + An instance of noisy observation data based on the input signal. + """ + assert len(signal.shape) == 2, "Only 2D inputs are supported." + + signal_tensor = torch.from_numpy(signal).to(torch.float32) + height, width = signal_tensor.shape + + with torch.no_grad(): + # Get gaussian parameters for each pixel + gaussian_params = self.get_gaussian_parameters(signal_tensor) + means = np.array(gaussian_params[: self.n_gaussian]) + stds = np.array(gaussian_params[self.n_gaussian : self.n_gaussian * 2]) + alphas = np.array(gaussian_params[self.n_gaussian * 2 :]) + + if self.n_gaussian == 1: + # Single gaussian case + observation = np.random.normal( + loc=means[0], scale=stds[0], size=(height, width) + ) + else: + # Multiple gaussians: sample component for each pixel + uniform = np.random.rand(1, height, width) + # Compute cumulative probabilities for component selection + cumulative_alphas = np.cumsum( + alphas, axis=0 + ) # Shape: (n_gaussian, height, width) + selected_component = np.argmax( + uniform < cumulative_alphas, axis=0, keepdims=True + ) + + # For every pixel, choose the corresponding gaussian + # and get the learned mu and sigma + selected_mus = np.take_along_axis(means, selected_component, axis=0) + selected_stds = np.take_along_axis(stds, selected_component, axis=0) + selected_mus = selected_mus.squeeze(0) + selected_stds = selected_stds.squeeze(0) + + # Sample from the normal distribution with learned mu and sigma + observation = np.random.normal( + selected_mus, selected_stds, size=(height, width) + ) + return observation - def save(self, path: str, name: str): + def save(self, path: str, name: str) -> None: """Save the trained parameters on the noise model. Parameters @@ -600,9 +663,9 @@ def save(self, path: str, name: str): os.makedirs(path, exist_ok=True) np.savez( os.path.join(path, name), - trained_weight=self.trained_weight, - min_signal=self.min_signal, - max_signal=self.max_signal, + trained_weight=self.weight.numpy(), + min_signal=self.min_signal.numpy(), + max_signal=self.max_signal.numpy(), min_sigma=self.min_sigma, ) print("The trained parameters (" + name + ") is saved at location: " + path) diff --git a/src/careamics/models/model_factory.py b/src/careamics/models/model_factory.py index c235a6bc5..533ca9412 100644 --- a/src/careamics/models/model_factory.py +++ b/src/careamics/models/model_factory.py @@ -1,8 +1,4 @@ -""" -Model factory. - -Model creation factory functions. -""" +"""Model creation factory functions.""" from __future__ import annotations @@ -10,10 +6,6 @@ import torch -from careamics.config.architectures import ( - CustomModel, - get_custom_model, -) from careamics.config.support import SupportedArchitecture from careamics.models.lvae import LadderVAE as LVAE from careamics.models.unet import UNet @@ -21,7 +13,6 @@ if TYPE_CHECKING: from careamics.config.architectures import ( - CustomModel, LVAEModel, UNetModel, ) @@ -31,7 +22,7 @@ def model_factory( - model_configuration: Union[UNetModel, LVAEModel, CustomModel], + model_configuration: Union[UNetModel, LVAEModel], ) -> torch.nn.Module: """ Deep learning model factory. @@ -57,10 +48,6 @@ def model_factory( return UNet(**model_configuration.model_dump()) elif model_configuration.architecture == SupportedArchitecture.LVAE: return LVAE(**model_configuration.model_dump()) - elif model_configuration.architecture == SupportedArchitecture.CUSTOM: - assert isinstance(model_configuration, CustomModel) - model = get_custom_model(model_configuration.name) - return model(**model_configuration.model_dump()) else: raise NotImplementedError( f"Model {model_configuration.architecture} is not implemented or unknown." diff --git a/src/careamics/transforms/compose.py b/src/careamics/transforms/compose.py index 6a386c930..4e1fef675 100644 --- a/src/careamics/transforms/compose.py +++ b/src/careamics/transforms/compose.py @@ -4,7 +4,7 @@ from numpy.typing import NDArray -from careamics.config.transformations import TransformModel +from careamics.config.transformations import NORM_AND_SPATIAL_UNION from .n2v_manipulate import N2VManipulate from .normalize import Normalize @@ -47,12 +47,12 @@ class Compose: A callable that applies the transforms to the input data. """ - def __init__(self, transform_list: list[TransformModel]) -> None: + def __init__(self, transform_list: list[NORM_AND_SPATIAL_UNION]) -> None: """Instantiate a Compose object. Parameters ---------- - transform_list : list[TransformModel] + transform_list : list[NORM_AND_SPATIAL_UNION] A list of dictionaries where each dictionary contains the name of a transform and its parameters. """ diff --git a/src/careamics/utils/plotting.py b/src/careamics/utils/plotting.py new file mode 100644 index 000000000..5ec78dc5d --- /dev/null +++ b/src/careamics/utils/plotting.py @@ -0,0 +1,78 @@ +"""Plotting utilities.""" + +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from numpy.typing import NDArray + +from careamics.models.lvae.noise_models import GaussianMixtureNoiseModel + + +def plot_noise_model_probability_distribution( + noise_model: GaussianMixtureNoiseModel, + signalBinIndex: int, + histogram: NDArray, + channel: Optional[str] = None, + number_of_bins: int = 100, +) -> None: + """Plot probability distribution P(x|s) for a certain ground truth signal. + + Predictions from both Histogram and GMM-based + Noise models are displayed for comparison. + + Parameters + ---------- + noise_model : GaussianMixtureNoiseModel + Trained GaussianMixtureNoiseModel. + signalBinIndex : int + Index of signal bin. Values go from 0 to number of bins (`n_bin`). + histogram : NDArray + Histogram based noise model. + channel : Optional[str], optional + Channel name used for plotting. Default is None. + number_of_bins : int, optional + Number of bins in the resulting histogram. Default is 100. + """ + min_signal = noise_model.min_signal.item() + max_signal = noise_model.max_signal.item() + bin_size = (max_signal - min_signal) / number_of_bins + + query_signal_normalized = signalBinIndex / number_of_bins + query_signal = query_signal_normalized * (max_signal - min_signal) + min_signal + query_signal += bin_size / 2 + query_signal = torch.tensor(query_signal) + + query_observations = torch.arange(min_signal, max_signal, bin_size) + query_observations += bin_size / 2 + + likelihoods = noise_model.likelihood( + observations=query_observations, signals=query_signal + ).numpy() + + plt.figure(figsize=(12, 5)) + if channel: + plt.suptitle(f"Noise model for channel {channel}") + else: + plt.suptitle("Noise model") + + plt.subplot(1, 2, 1) + plt.xlabel("Observation Bin") + plt.ylabel("Signal Bin") + plt.imshow(histogram**0.25, cmap="gray") + plt.axhline(y=signalBinIndex + 0.5, linewidth=5, color="blue", alpha=0.5) + + plt.subplot(1, 2, 2) + plt.plot( + query_observations, + likelihoods, + label="GMM : " + " signal = " + str(np.round(query_signal, 2)), + marker=".", + color="red", + linewidth=2, + ) + plt.xlabel("Observations (x) for signal s = " + str(query_signal)) + plt.ylabel("Probability Density") + plt.title("Probability Distribution P(x|s) at signal =" + str(query_signal)) + plt.legend() diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py index 01f4794e0..bbf9c913e 100644 --- a/tests/cli/test_main.py +++ b/tests/cli/test_main.py @@ -6,17 +6,17 @@ from careamics import CAREamist from careamics.cli.main import app -from careamics.config import Configuration, save_configuration +from careamics.config import configuration_factory, save_configuration from careamics.config.support import SupportedData runner = CliRunner() -def test_train(tmp_path: Path, minimum_configuration: dict): +def test_train(tmp_path: Path, minimum_n2v_configuration: dict): # create & save config config_path = tmp_path / "config.yaml" - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.data_config.data_type = SupportedData.TIFF.value save_configuration(config, config_path) @@ -43,11 +43,11 @@ def test_train(tmp_path: Path, minimum_configuration: dict): assert result.exit_code == 0 -def test_predict_single_file(tmp_path: Path, minimum_configuration: dict): +def test_predict_single_file(tmp_path: Path, minimum_n2v_configuration: dict): # create & save config config_path = tmp_path / "config.yaml" - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.data_config.data_type = SupportedData.TIFF.value save_configuration(config, config_path) @@ -70,11 +70,11 @@ def test_predict_single_file(tmp_path: Path, minimum_configuration: dict): assert result.exit_code == 0 -def test_predict_directory(tmp_path: Path, minimum_configuration: dict): +def test_predict_directory(tmp_path: Path, minimum_n2v_configuration: dict): # create & save config config_path = tmp_path / "config.yaml" - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.data_config.data_type = SupportedData.TIFF.value save_configuration(config, config_path) diff --git a/tests/config/algorithms/test_care_algorithm_model.py b/tests/config/algorithms/test_care_algorithm_model.py new file mode 100644 index 000000000..aab357c3b --- /dev/null +++ b/tests/config/algorithms/test_care_algorithm_model.py @@ -0,0 +1,32 @@ +import pytest + +from careamics.config.algorithms import CAREAlgorithm + + +def test_instantiation(): + """Test the instantiation of the CAREAlgorithm class.""" + model = { + "architecture": "UNet", + } + CAREAlgorithm(model=model) + + +def test_no_n2v2(): + """Check that an error is raised if the model is set for n2v2.""" + model = { + "architecture": "UNet", + "n2v2": True, + } + + with pytest.raises(ValueError): + CAREAlgorithm(model=model) + + +def test_no_final_activation(minimum_algorithm_supervised: dict): + """Check that an error is raised if the model has a final activation.""" + minimum_algorithm_supervised["model"] = { + "architecture": "UNet", + "final_activation": "ReLU", + } + with pytest.raises(ValueError): + CAREAlgorithm(**minimum_algorithm_supervised) diff --git a/tests/config/algorithms/test_n2n_algorithm_model.py b/tests/config/algorithms/test_n2n_algorithm_model.py new file mode 100644 index 000000000..ce439f4c5 --- /dev/null +++ b/tests/config/algorithms/test_n2n_algorithm_model.py @@ -0,0 +1,32 @@ +import pytest + +from careamics.config.algorithms import N2NAlgorithm + + +def test_instantiation(): + """Test the instantiation of the N2NAlgorithm class.""" + model = { + "architecture": "UNet", + } + N2NAlgorithm(model=model) + + +def test_no_n2v2(): + """Check that an error is raised if the model is set for n2v2.""" + model = { + "architecture": "UNet", + "n2v2": True, + } + + with pytest.raises(ValueError): + N2NAlgorithm(model=model) + + +def test_no_final_activation(minimum_algorithm_supervised: dict): + """Check that an error is raised if the model has a final activation.""" + minimum_algorithm_supervised["model"] = { + "architecture": "UNet", + "final_activation": "ReLU", + } + with pytest.raises(ValueError): + N2NAlgorithm(**minimum_algorithm_supervised) diff --git a/tests/config/algorithms/test_n2v_algorithm_model.py b/tests/config/algorithms/test_n2v_algorithm_model.py new file mode 100644 index 000000000..ad971896c --- /dev/null +++ b/tests/config/algorithms/test_n2v_algorithm_model.py @@ -0,0 +1,43 @@ +import pytest + +from careamics.config.algorithms import N2VAlgorithm + + +def test_n_channels_n2v(): + """Check that an error is raised if n2v has different number of channels in + input and output.""" + model = { + "architecture": "UNet", + "in_channels": 1, + "num_classes": 2, + "n2v2": False, + } + loss = "n2v" + + with pytest.raises(ValueError): + N2VAlgorithm(algorithm="n2v", loss=loss, model=model) + + +def test_channels(minimum_algorithm_n2v: dict): + """Check that error is thrown if the number of channels are different.""" + minimum_algorithm_n2v["model"] = { + "architecture": "UNet", + "in_channels": 2, + "num_classes": 2, + "n2v2": False, + } + N2VAlgorithm(**minimum_algorithm_n2v) + + minimum_algorithm_n2v["model"]["num_classes"] = 3 + with pytest.raises(ValueError): + N2VAlgorithm(**minimum_algorithm_n2v) + + +def test_no_final_activation(minimum_algorithm_n2v: dict): + """Check that an error is raised if the model has a final activation.""" + minimum_algorithm_n2v["model"] = { + "architecture": "UNet", + "final_activation": "ReLU", + } + with pytest.raises(ValueError): + N2VAlgorithm(**minimum_algorithm_n2v) diff --git a/tests/config/architectures/test_custom_model.py b/tests/config/architectures/test_custom_model.py deleted file mode 100644 index 93aa7c1ac..000000000 --- a/tests/config/architectures/test_custom_model.py +++ /dev/null @@ -1,82 +0,0 @@ -import pytest - -from careamics.config.architectures import CustomModel, get_custom_model, register_model -from careamics.config.support import SupportedArchitecture - - -@register_model(name="not_a_model") -class NotAModel: - def __init__(self, id): - self.id = id - - def forward(self, input): - return input - - -def test_any_custom_parameters(): - """Test that the custom model can have any fields. - - Note that those fields are validated by instantiating the - model. - """ - CustomModel( - architecture=SupportedArchitecture.CUSTOM.value, - name="linear", - in_features=10, - out_features=5, - ) - - -def test_linear_model(custom_model_name): - """Test that the model can be retrieved and instantiated.""" - model = get_custom_model(custom_model_name) - model(in_features=10, out_features=5) - - -def test_not_a_model(): - """Test that the model can be retrieved and instantiated.""" - model = get_custom_model("not_a_model") - model(3) - - -def test_custom_model(custom_model_parameters): - """Test that the custom model can be instantiated.""" - - # create Pydantic model - pydantic_model = CustomModel(**custom_model_parameters) - - # instantiate model - model_class = get_custom_model(pydantic_model.name) - model = model_class(**pydantic_model.model_dump()) - - assert model.in_features == 10 - assert model.out_features == 5 - - -def test_custom_model_wrong_class(): - """Test that the Pydantic custom model raises an error if the model is not a - torch.nn.Module subclass.""" - # prepare model dictionary - model_dict = { - "architecture": SupportedArchitecture.CUSTOM.value, - "name": "not_a_model", - "parameters": {"id": 3}, - } - - # create Pydantic model - with pytest.raises(ValueError): - CustomModel(**model_dict) - - -def test_wrong_parameters(custom_model_name): - """Test that the custom model raises an error if the parameters are not valid.""" - # prepare model dictionary - model_dict = { - "architecture": SupportedArchitecture.CUSTOM.value, - "name": custom_model_name, - "parameters": {"in_features": 10}, - } - - # create Pydantic model - with pytest.raises(ValueError): - CustomModel(**model_dict) diff --git a/tests/config/architectures/test_lvae_model.py b/tests/config/architectures/test_lvae_model.py index 6ae87b431..0c7ef6e1a 100644 --- a/tests/config/architectures/test_lvae_model.py +++ b/tests/config/architectures/test_lvae_model.py @@ -113,25 +113,3 @@ def test_parameters_wrong_values_by_assigment(): model.encoder_n_filters = model_params["encoder_n_filters"] with pytest.raises(ValueError): model.encoder_n_filters = 2 - - -def test_model_dump(): - """Test that default values are excluded from model dump.""" - model_params = { - "architecture": "LVAE", # default value - "z_dims": (128, 128, 128, 128), # default value - "nonlinearity": "ReLU", # non-default value - "decoder_n_filters": 32, # non-default value - } - model = LVAEModel(**model_params) - - # dump model - model_dict = model.model_dump(exclude_defaults=True) - - # check that default values are excluded except the architecture - assert "architecture" not in model_dict - assert len(model_dict) == 3 # TODO not sure it's hardcoded? - - # check that we get all the optional values with the exclude_defaults flag - model_dict = model.model_dump(exclude_defaults=False) - assert len(model_dict) == len(dict(model)) - 1 diff --git a/tests/config/architectures/test_register_model.py b/tests/config/architectures/test_register_model.py deleted file mode 100644 index 00022c325..000000000 --- a/tests/config/architectures/test_register_model.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest - -from careamics.config.architectures import ( - clear_custom_models, - get_custom_model, - register_model, -) - - -# register a model -@register_model(name="mymodel") -class MyModel: - model_name: str - model_id: int - - -def test_register_model(): - """Test the register_model decorator.""" - - # get custom model - model = get_custom_model("mymodel") - - # check if it is a subclass of MyModel - assert issubclass(model, MyModel) - - -def test_wrong_model(): - """Test that an error is raised if an unknown model is requested.""" - get_custom_model("mymodel") - - with pytest.raises(ValueError): - get_custom_model("unknown_model") - - -@pytest.mark.skip("This tests prevents other tests with custom models to pass.") -def test_clear_custom_models(): - """Test that the custom models are cleared.""" - # retrieve model - get_custom_model("mymodel") - - # clear custom models - clear_custom_models() - - # request the model again - with pytest.raises(ValueError): - get_custom_model("mymodel") diff --git a/tests/config/data/test_data_model.py b/tests/config/data/test_data_model.py new file mode 100644 index 000000000..785317494 --- /dev/null +++ b/tests/config/data/test_data_model.py @@ -0,0 +1,174 @@ +import numpy as np +import pytest +import yaml + +from careamics.config.data.data_model import DataConfig +from careamics.config.support import ( + SupportedTransform, +) +from careamics.config.transformations import NormalizeModel +from careamics.transforms import get_all_transforms + + +@pytest.mark.parametrize("ext", ["nd2", "jpg", "png ", "zarr", "npy"]) +def test_wrong_extensions(minimum_data: dict, ext: str): + """Test that supported model raises ValueError for unsupported extensions.""" + minimum_data["data_type"] = ext + + # instantiate DataModel model + with pytest.raises(ValueError): + DataConfig(**minimum_data) + + +@pytest.mark.parametrize("mean, std", [(0, 124.5), (12.6, 0.1)]) +def test_mean_std_non_negative(minimum_data: dict, mean, std): + """Test that non negative mean and std are accepted.""" + minimum_data["image_means"] = [mean] + minimum_data["image_stds"] = [std] + minimum_data["target_means"] = [mean] + minimum_data["target_stds"] = [std] + + data_model = DataConfig(**minimum_data) + assert data_model.image_means == [mean] + assert data_model.image_stds == [std] + assert data_model.target_means == [mean] + assert data_model.target_stds == [std] + + +def test_mean_std_both_specified_or_none(minimum_data: dict): + """Test an error is raised if std is specified but mean is None.""" + # No error if both are None + DataConfig(**minimum_data) + + # Error if only mean is defined + minimum_data["image_means"] = [10.4] + with pytest.raises(ValueError): + DataConfig(**minimum_data) + + # Error if only std is defined + minimum_data.pop("image_means") + minimum_data["image_stds"] = [10.4] + with pytest.raises(ValueError): + DataConfig(**minimum_data) + + # No error if both are specified + minimum_data["image_means"] = [10.4] + minimum_data["image_stds"] = [10.4] + DataConfig(**minimum_data) + + # Error if target mean is defined but target std is None + minimum_data["target_stds"] = [10.4, 11] + with pytest.raises(ValueError): + DataConfig(**minimum_data) + + +def test_set_mean_and_std(minimum_data: dict): + """Test that mean and std can be set after initialization.""" + # they can be set both, when they None + mean = [4.07] + std = [14.07] + data = DataConfig(**minimum_data) + data.set_means_and_stds(mean, std) + assert data.image_means == mean + assert data.image_stds == std + + # Set also target mean and std + data.set_means_and_stds(mean, std, mean, std) + assert data.target_means == mean + assert data.target_stds == std + + +def test_normalize_not_accepted(minimum_data: dict): + """Test that normalize is not accepted, because it is mandatory and applied else + where.""" + minimum_data["image_means"] = [10.4] + minimum_data["image_stds"] = [3.2] + minimum_data["transforms"] = [ + NormalizeModel(image_means=[0.485], image_stds=[0.229]) + ] + + with pytest.raises(ValueError): + DataConfig(**minimum_data) + + +def test_patch_size(minimum_data: dict): + """Test that non-zero even patch size are accepted.""" + # 2D + data_model = DataConfig(**minimum_data) + + # 3D + minimum_data["patch_size"] = [16, 8, 8] + minimum_data["axes"] = "ZYX" + + data_model = DataConfig(**minimum_data) + assert data_model.patch_size == minimum_data["patch_size"] + + +@pytest.mark.parametrize( + "patch_size", [[12], [0, 12, 12], [12, 12, 13], [16, 10, 16], [12, 12, 12, 12]] +) +def test_wrong_patch_size(minimum_data: dict, patch_size): + """Test that wrong patch sizes are not accepted (zero or odd, dims 1 or > 3).""" + minimum_data["axes"] = "ZYX" if len(patch_size) == 3 else "YX" + minimum_data["patch_size"] = patch_size + + with pytest.raises(ValueError): + DataConfig(**minimum_data) + + +def test_set_3d(minimum_data: dict): + """Test that 3D can be set.""" + data = DataConfig(**minimum_data) + assert "Z" not in data.axes + assert len(data.patch_size) == 2 + + # error if changing Z manually + with pytest.raises(ValueError): + data.axes = "ZYX" + + # or patch size + data = DataConfig(**minimum_data) + with pytest.raises(ValueError): + data.patch_size = [64, 64, 64] + + # set 3D + data = DataConfig(**minimum_data) + data.set_3D("ZYX", [64, 64, 64]) + assert "Z" in data.axes + assert len(data.patch_size) == 3 + + +def test_passing_empty_transforms(minimum_data: dict): + """Test that empty list of transforms can be passed.""" + minimum_data["transforms"] = [] + DataConfig(**minimum_data) + + +def test_passing_incorrect_element(minimum_data: dict): + """Test that incorrect element in the list of transforms raises an error ( + e.g. passing un object rather than a string).""" + minimum_data["transforms"] = [ + {"name": get_all_transforms()[SupportedTransform.XY_FLIP.value]()}, + ] + with pytest.raises(ValueError): + DataConfig(**minimum_data) + + +def test_export_to_yaml_float32_stats(tmp_path, minimum_data: dict): + """Test exporting and loading the pydantic model when the statistics are + np.float32.""" + data = DataConfig(**minimum_data) + + # set np.float32 stats values + data.set_means_and_stds([np.float32(1234.5678)], [np.float32(21.73)]) + + # export to yaml + config_path = tmp_path / "data_config.yaml" + with open(config_path, "w") as f: + # dump configuration + yaml.dump(data.model_dump(), f, default_flow_style=False, sort_keys=False) + + # load model + dictionary = yaml.load(config_path.open("r"), Loader=yaml.SafeLoader) + read_data = DataConfig(**dictionary) + assert read_data.model_dump() == data.model_dump() diff --git a/tests/config/data/test_n2v_data_model.py b/tests/config/data/test_n2v_data_model.py new file mode 100644 index 000000000..2cebde5df --- /dev/null +++ b/tests/config/data/test_n2v_data_model.py @@ -0,0 +1,128 @@ +import pytest + +from careamics.config.data import N2VDataConfig +from careamics.config.support import ( + SupportedPixelManipulation, + SupportedStructAxis, + SupportedTransform, +) + + +def test_error_no_manipulate(minimum_data: dict): + """Test that an error is raised if no N2VManipulate transform is passed.""" + minimum_data["transforms"] = [ + {"name": SupportedTransform.XY_FLIP.value}, + {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, + ] + with pytest.raises(ValueError): + N2VDataConfig(**minimum_data) + + +@pytest.mark.parametrize( + "transforms", + [ + [ + {"name": SupportedTransform.N2V_MANIPULATE.value}, + {"name": SupportedTransform.XY_FLIP.value}, + ], + [ + {"name": SupportedTransform.XY_FLIP.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, + ], + ], +) +def test_n2vmanipulate_not_last_transform(minimum_data: dict, transforms): + """Test that N2V Manipulate not in the last position raises an error.""" + minimum_data["transforms"] = transforms + with pytest.raises(ValueError): + N2VDataConfig(**minimum_data) + + +def test_multiple_n2v_manipulate(minimum_data: dict): + """Test that passing multiple n2v manipulate raises an error.""" + minimum_data["transforms"] = [ + {"name": SupportedTransform.N2V_MANIPULATE.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ] + with pytest.raises(ValueError): + N2VDataConfig(**minimum_data) + + +def test_correct_transform_parameters(minimum_data: dict): + """Test that the transforms have the correct parameters. + + This is important to know that the transforms are not all instantiated as + a generic transform. + """ + minimum_data["transforms"] = [ + {"name": SupportedTransform.XY_FLIP.value}, + {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, + {"name": SupportedTransform.N2V_MANIPULATE.value}, + ] + model = N2VDataConfig(**minimum_data) + + # N2VManipulate + params = model.transforms[-1].model_dump() + assert "roi_size" in params + assert "masked_pixel_percentage" in params + assert "strategy" in params + assert "struct_mask_axis" in params + assert "struct_mask_span" in params + + +def test_set_n2v_strategy(minimum_data: dict): + """Test that the N2V strategy can be set.""" + uniform = SupportedPixelManipulation.UNIFORM.value + median = SupportedPixelManipulation.MEDIAN.value + + data = N2VDataConfig(**minimum_data) + assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + assert data.transforms[-1].strategy == uniform + + data.set_masking_strategy(median) + assert data.transforms[-1].strategy == median + + data.set_masking_strategy(uniform) + assert data.transforms[-1].strategy == uniform + + +def test_set_n2v_strategy_wrong_value(minimum_data: dict): + """Test that passing a wrong strategy raises an error.""" + data = N2VDataConfig(**minimum_data) + with pytest.raises(ValueError): + data.set_masking_strategy("wrong_value") + + +def test_set_struct_mask(minimum_data: dict): + """Test that the struct mask can be set.""" + none = SupportedStructAxis.NONE.value + vertical = SupportedStructAxis.VERTICAL.value + horizontal = SupportedStructAxis.HORIZONTAL.value + + data = N2VDataConfig(**minimum_data) + assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + assert data.transforms[-1].struct_mask_axis == none + assert data.transforms[-1].struct_mask_span == 5 + + data.set_structN2V_mask(vertical, 3) + assert data.transforms[-1].struct_mask_axis == vertical + assert data.transforms[-1].struct_mask_span == 3 + + data.set_structN2V_mask(horizontal, 7) + assert data.transforms[-1].struct_mask_axis == horizontal + assert data.transforms[-1].struct_mask_span == 7 + + data.set_structN2V_mask(none, 11) + assert data.transforms[-1].struct_mask_axis == none + assert data.transforms[-1].struct_mask_span == 11 + + +def test_set_struct_mask_wrong_value(minimum_data: dict): + """Test that passing a wrong struct mask axis raises an error.""" + data = N2VDataConfig(**minimum_data) + with pytest.raises(ValueError): + data.set_structN2V_mask("wrong_value", 3) + + with pytest.raises(ValueError): + data.set_structN2V_mask(SupportedStructAxis.VERTICAL.value, 1) diff --git a/tests/config/test_configuration.py b/tests/config/test_configuration.py new file mode 100644 index 000000000..139ea216c --- /dev/null +++ b/tests/config/test_configuration.py @@ -0,0 +1,53 @@ +import pytest + +from careamics.config import Configuration + + +@pytest.mark.parametrize("name", ["Sn4K3", "C4_M e-L"]) +def test_valid_names(minimum_supervised_configuration: dict, name: str): + """Test valid names (letters, numbers, spaces, dashes and underscores).""" + minimum_supervised_configuration["experiment_name"] = name + myconf = Configuration(**minimum_supervised_configuration) + assert myconf.experiment_name == name + + +@pytest.mark.parametrize("name", ["", " ", "#", "/", "^", "%", ",", ".", "a=b"]) +def test_invalid_names(minimum_supervised_configuration: dict, name: str): + """Test that invalid names raise an error.""" + minimum_supervised_configuration["experiment_name"] = name + with pytest.raises(ValueError): + Configuration(**minimum_supervised_configuration) + + +def test_3D_algorithm_and_data_compatibility(minimum_supervised_configuration: dict): + """Test that errors are raised if algorithm `is_3D` and data axes are + incompatible. + """ + # 3D but no Z in axes + minimum_supervised_configuration["algorithm_config"]["model"]["conv_dims"] = 3 + config = Configuration(**minimum_supervised_configuration) + assert config.algorithm_config.model.conv_dims == 2 + + # 2D but Z in axes + minimum_supervised_configuration["algorithm_config"]["model"]["conv_dims"] = 2 + minimum_supervised_configuration["data_config"]["axes"] = "ZYX" + minimum_supervised_configuration["data_config"]["patch_size"] = [64, 64, 64] + config = Configuration(**minimum_supervised_configuration) + assert config.algorithm_config.model.conv_dims == 3 + + +def test_set_3D(minimum_supervised_configuration: dict): + """Test the set 3D method.""" + conf = Configuration(**minimum_supervised_configuration) + + # set to 3D + conf.set_3D(True, "ZYX", [64, 64, 64]) + assert conf.data_config.axes == "ZYX" + assert conf.data_config.patch_size == [64, 64, 64] + assert conf.algorithm_config.model.conv_dims == 3 + + # set to 2D + conf.set_3D(False, "SYX", [64, 64]) + assert conf.data_config.axes == "SYX" + assert conf.data_config.patch_size == [64, 64] + assert conf.algorithm_config.model.conv_dims == 2 diff --git a/tests/config/test_configuration_factory.py b/tests/config/test_configuration_factories.py similarity index 79% rename from tests/config/test_configuration_factory.py rename to tests/config/test_configuration_factories.py index e45ad85c7..3d7b637a5 100644 --- a/tests/config/test_configuration_factory.py +++ b/tests/config/test_configuration_factories.py @@ -1,17 +1,30 @@ import pytest from careamics.config import ( + CAREAlgorithm, + CAREConfiguration, + DataConfig, + N2NAlgorithm, + N2NConfiguration, + N2VAlgorithm, + N2VConfiguration, + algorithm_factory, create_care_configuration, create_n2n_configuration, create_n2v_configuration, ) -from careamics.config.configuration_factory import ( +from careamics.config.configuration_factories import ( + _algorithm_config_discriminator, _create_configuration, _create_supervised_configuration, _create_unet_configuration, - _list_augmentations, + _list_spatial_augmentations, + configuration_factory, + data_factory, ) +from careamics.config.data import N2VDataConfig from careamics.config.support import ( + SupportedAlgorithm, SupportedPixelManipulation, SupportedStructAxis, SupportedTransform, @@ -23,9 +36,82 @@ ) +def test_algorithm_discriminator_n2v(minimum_n2v_configuration): + """Test that the N2V configuration is discriminated correctly.""" + tag = _algorithm_config_discriminator(minimum_n2v_configuration) + assert tag == SupportedAlgorithm.N2V.value + + +@pytest.mark.parametrize( + "algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value] +) +def test_algorithm_discriminator_supervised( + minimum_supervised_configuration, algorithm +): + """Test that the supervised configuration is discriminated correctly.""" + minimum_supervised_configuration["algorithm_config"]["algorithm"] = algorithm + tag = _algorithm_config_discriminator(minimum_supervised_configuration) + assert tag == algorithm + + +def test_careamics_config_n2v(minimum_n2v_configuration): + """Test that the N2V configuration is created correctly.""" + configuration = configuration_factory(minimum_n2v_configuration) + assert isinstance(configuration, N2VConfiguration) + + +@pytest.mark.parametrize( + "algorithm", [SupportedAlgorithm.N2N.value, SupportedAlgorithm.CARE.value] +) +def test_careamics_config_supervised(minimum_supervised_configuration, algorithm): + """Test that the supervised configuration is created correctly.""" + min_config = minimum_supervised_configuration + min_config["algorithm_config"]["algorithm"] = algorithm + + config = configuration_factory(min_config) + + exp_class = N2NConfiguration if algorithm == "n2n" else CAREConfiguration + assert isinstance(config, exp_class) + + +def test_data_factory_n2v(minimum_data): + """Test that having N2VManipule yields a N2VDataConfig.""" + minimum_data["transforms"] = [ + { + "name": SupportedTransform.N2V_MANIPULATE.value, + } + ] + data = data_factory(minimum_data) + assert isinstance(data, N2VDataConfig) + + +def test_data_factory_supervised(minimum_data): + """Test that the normal configuration yields a DataConfig.""" + data = data_factory(minimum_data) + assert isinstance(data, DataConfig) + + +def test_algorithm_factory_n2v(minimum_algorithm_n2v): + """Test that the N2V configuration is created correctly.""" + algorithm = algorithm_factory(minimum_algorithm_n2v) + assert isinstance(algorithm, N2VAlgorithm) + + +@pytest.mark.parametrize("algorithm", ["n2n", "care"]) +def test_algorithm_factory_supervised(minimum_algorithm_supervised, algorithm): + """Test that the supervised configuration is created correctly.""" + min_config = minimum_algorithm_supervised + min_config["algorithm"] = algorithm + + algorithm_config = algorithm_factory(min_config) + + exp_class = N2NAlgorithm if algorithm == "n2n" else CAREAlgorithm + assert isinstance(algorithm_config, exp_class) + + def test_list_aug_default(): """Test that the default augmentations are present.""" - list_aug = _list_augmentations(augmentations=None) + list_aug = _list_spatial_augmentations(augmentations=None) assert len(list_aug) == 2 assert list_aug[0].name == SupportedTransform.XY_FLIP.value @@ -34,14 +120,14 @@ def test_list_aug_default(): def test_list_aug_no_aug(): """Test that disabling augmentation results in empty transform list.""" - list_aug = _list_augmentations(augmentations=[]) + list_aug = _list_spatial_augmentations(augmentations=[]) assert len(list_aug) == 0 def test_list_aug_error_duplicate_transforms(): """Test that an error is raised when there are duplicate transforms.""" with pytest.raises(ValueError): - _list_augmentations( + _list_spatial_augmentations( augmentations=[XYFlipModel(), XYRandomRotate90Model(), XYFlipModel()], ) @@ -49,7 +135,7 @@ def test_list_aug_error_duplicate_transforms(): def test_list_aug_error_wrong_transform(): """Test that an error is raised when the wrong transform is passed.""" with pytest.raises(ValueError): - _list_augmentations( + _list_spatial_augmentations( augmentations=[XYFlipModel(), N2VManipulateModel()], ) diff --git a/tests/config/test_configuration_io.py b/tests/config/test_configuration_io.py new file mode 100644 index 000000000..8a63ff513 --- /dev/null +++ b/tests/config/test_configuration_io.py @@ -0,0 +1,43 @@ +from pathlib import Path + +import pytest + +from careamics.config import ( + configuration_factory, + load_configuration, + save_configuration, +) + + +def test_config_to_yaml(tmp_path: Path, minimum_supervised_configuration: dict): + """Test that we can export a config to yaml and load it back""" + + # test that we can instantiate a config + myconf = configuration_factory(minimum_supervised_configuration) + + # export to yaml + yaml_path = save_configuration(myconf, tmp_path) + assert yaml_path.exists() + + # load from yaml + my_other_conf = load_configuration(yaml_path) + assert my_other_conf == myconf + + +def test_config_to_yaml_wrong_path( + tmp_path: Path, minimum_supervised_configuration: dict +): + """Test that an error is raised when the path is not a directory and not a .yml""" + + # test that we can instantiate a config + myconf = configuration_factory(minimum_supervised_configuration) + + # export to yaml + yaml_path = tmp_path / "tmp.txt" + with pytest.raises(ValueError): + save_configuration(myconf, yaml_path) + + # existing file + yaml_path.touch() + with pytest.raises(ValueError): + save_configuration(myconf, yaml_path) diff --git a/tests/config/test_configuration_model.py b/tests/config/test_configuration_model.py deleted file mode 100644 index 0b0853243..000000000 --- a/tests/config/test_configuration_model.py +++ /dev/null @@ -1,191 +0,0 @@ -from pathlib import Path - -import pytest - -from careamics.config import ( - Configuration, - load_configuration, - save_configuration, -) -from careamics.config.support import ( - SupportedAlgorithm, - SupportedPixelManipulation, - SupportedTransform, -) - - -@pytest.mark.parametrize("name", ["Sn4K3", "C4_M e-L"]) -def test_valid_names(minimum_configuration: dict, name: str): - """Test valid names (letters, numbers, spaces, dashes and underscores).""" - minimum_configuration["experiment_name"] = name - myconf = Configuration(**minimum_configuration) - assert myconf.experiment_name == name - - -@pytest.mark.parametrize("name", ["", " ", "#", "/", "^", "%", ",", ".", "a=b"]) -def test_invalid_names(minimum_configuration: dict, name: str): - """Test that invalid names raise an error.""" - minimum_configuration["experiment_name"] = name - with pytest.raises(ValueError): - Configuration(**minimum_configuration) - - -def test_3D_algorithm_and_data_compatibility(minimum_configuration: dict): - """Test that errors are raised if algorithm `is_3D` and data axes are - incompatible. - """ - # 3D but no Z in axes - minimum_configuration["algorithm_config"]["model"]["conv_dims"] = 3 - config = Configuration(**minimum_configuration) - assert config.algorithm_config.model.conv_dims == 2 - - # 2D but Z in axes - minimum_configuration["algorithm_config"]["model"]["conv_dims"] = 2 - minimum_configuration["data_config"]["axes"] = "ZYX" - minimum_configuration["data_config"]["patch_size"] = [64, 64, 64] - config = Configuration(**minimum_configuration) - assert config.algorithm_config.model.conv_dims == 3 - - -def test_set_3D(minimum_configuration: dict): - """Test the set 3D method.""" - conf = Configuration(**minimum_configuration) - - # set to 3D - conf.set_3D(True, "ZYX", [64, 64, 64]) - assert conf.data_config.axes == "ZYX" - assert conf.data_config.patch_size == [64, 64, 64] - assert conf.algorithm_config.model.conv_dims == 3 - - # set to 2D - conf.set_3D(False, "SYX", [64, 64]) - assert conf.data_config.axes == "SYX" - assert conf.data_config.patch_size == [64, 64] - assert conf.algorithm_config.model.conv_dims == 2 - - -def test_algorithm_and_data_default_transforms(minimum_configuration: dict): - """Test that the default data transforms are compatible with n2v.""" - minimum_configuration["algorithm_config"] = { - "algorithm": "n2v", - "loss": "n2v", - "model": { - "architecture": "UNet", - }, - } - Configuration(**minimum_configuration) - - -@pytest.mark.parametrize( - "algorithm, strategy", - [ - ("n2v", SupportedPixelManipulation.UNIFORM.value), - ("n2v", SupportedPixelManipulation.MEDIAN.value), - ("n2v2", SupportedPixelManipulation.UNIFORM.value), - ("n2v2", SupportedPixelManipulation.MEDIAN.value), - ], -) -def test_n2v2_and_transforms(minimum_configuration: dict, algorithm, strategy): - """Test that the manipulation strategy is corrected if the data transforms are - incompatible with n2v2.""" - use_n2v2 = algorithm == "n2v2" - minimum_configuration["algorithm_config"] = { - "algorithm": "n2v", - "loss": "n2v", - "model": { - "architecture": "UNet", - "n2v2": use_n2v2, - }, - } - - expected_strategy = ( - SupportedPixelManipulation.MEDIAN.value - if use_n2v2 - else SupportedPixelManipulation.UNIFORM.value - ) - - # missing ManipulateN2V - minimum_configuration["data_config"]["transforms"] = [ - {"name": SupportedTransform.XY_FLIP.value} - ] - config = Configuration(**minimum_configuration) - assert len(config.data_config.transforms) == 2 - assert ( - config.data_config.transforms[-1].name - == SupportedTransform.N2V_MANIPULATE.value - ) - assert config.data_config.transforms[-1].strategy == expected_strategy - - # passing ManipulateN2V with the wrong strategy - minimum_configuration["data_config"]["transforms"] = [ - { - "name": SupportedTransform.N2V_MANIPULATE.value, - "strategy": strategy, - } - ] - config = Configuration(**minimum_configuration) - assert config.data_config.transforms[-1].strategy == expected_strategy - - -def test_setting_n2v2(minimum_configuration: dict): - # make sure we use n2v - minimum_configuration["algorithm_config"][ - "algorithm" - ] = SupportedAlgorithm.N2V.value - - # test config - config = Configuration(**minimum_configuration) - assert config.algorithm_config.algorithm == SupportedAlgorithm.N2V.value - assert not config.algorithm_config.model.n2v2 - assert ( - config.data_config.transforms[-1].strategy - == SupportedPixelManipulation.UNIFORM.value - ) - - # set N2V2 - config.set_N2V2(True) - assert config.algorithm_config.model.n2v2 - assert ( - config.data_config.transforms[-1].strategy - == SupportedPixelManipulation.MEDIAN.value - ) - - # set back to N2V - config.set_N2V2(False) - assert not config.algorithm_config.model.n2v2 - assert ( - config.data_config.transforms[-1].strategy - == SupportedPixelManipulation.UNIFORM.value - ) - - -def test_config_to_yaml(tmp_path: Path, minimum_configuration: dict): - """Test that we can export a config to yaml and load it back""" - - # test that we can instantiate a config - myconf = Configuration(**minimum_configuration) - - # export to yaml - yaml_path = save_configuration(myconf, tmp_path) - assert yaml_path.exists() - - # load from yaml - my_other_conf = load_configuration(yaml_path) - assert my_other_conf == myconf - - -def test_config_to_yaml_wrong_path(tmp_path: Path, minimum_configuration: dict): - """Test that an error is raised when the path is not a directory and not a .yml""" - - # test that we can instantiate a config - myconf = Configuration(**minimum_configuration) - - # export to yaml - yaml_path = tmp_path / "tmp.txt" - with pytest.raises(ValueError): - save_configuration(myconf, yaml_path) - - # existing file - yaml_path.touch() - with pytest.raises(ValueError): - save_configuration(myconf, yaml_path) diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py deleted file mode 100644 index f1e490222..000000000 --- a/tests/config/test_data_model.py +++ /dev/null @@ -1,355 +0,0 @@ -import numpy as np -import pytest -import yaml - -from careamics.config.data_model import DataConfig -from careamics.config.support import ( - SupportedPixelManipulation, - SupportedStructAxis, - SupportedTransform, -) -from careamics.config.transformations import ( - N2VManipulateModel, - NormalizeModel, - XYFlipModel, - XYRandomRotate90Model, -) -from careamics.transforms import get_all_transforms - - -@pytest.mark.parametrize("ext", ["nd2", "jpg", "png ", "zarr", "npy"]) -def test_wrong_extensions(minimum_data: dict, ext: str): - """Test that supported model raises ValueError for unsupported extensions.""" - minimum_data["data_type"] = ext - - # instantiate DataModel model - with pytest.raises(ValueError): - DataConfig(**minimum_data) - - -@pytest.mark.parametrize("mean, std", [(0, 124.5), (12.6, 0.1)]) -def test_mean_std_non_negative(minimum_data: dict, mean, std): - """Test that non negative mean and std are accepted.""" - minimum_data["image_means"] = [mean] - minimum_data["image_stds"] = [std] - minimum_data["target_means"] = [mean] - minimum_data["target_stds"] = [std] - - data_model = DataConfig(**minimum_data) - assert data_model.image_means == [mean] - assert data_model.image_stds == [std] - assert data_model.target_means == [mean] - assert data_model.target_stds == [std] - - -def test_mean_std_both_specified_or_none(minimum_data: dict): - """Test an error is raised if std is specified but mean is None.""" - # No error if both are None - DataConfig(**minimum_data) - - # Error if only mean is defined - minimum_data["image_means"] = [10.4] - with pytest.raises(ValueError): - DataConfig(**minimum_data) - - # Error if only std is defined - minimum_data.pop("image_means") - minimum_data["image_stds"] = [10.4] - with pytest.raises(ValueError): - DataConfig(**minimum_data) - - # No error if both are specified - minimum_data["image_means"] = [10.4] - minimum_data["image_stds"] = [10.4] - DataConfig(**minimum_data) - - # Error if target mean is defined but target std is None - minimum_data["target_stds"] = [10.4, 11] - with pytest.raises(ValueError): - DataConfig(**minimum_data) - - -def test_set_mean_and_std(minimum_data: dict): - """Test that mean and std can be set after initialization.""" - # they can be set both, when they None - mean = [4.07] - std = [14.07] - data = DataConfig(**minimum_data) - data.set_means_and_stds(mean, std) - assert data.image_means == mean - assert data.image_stds == std - - # Set also target mean and std - data.set_means_and_stds(mean, std, mean, std) - assert data.target_means == mean - assert data.target_stds == std - - -def test_normalize_not_accepted(minimum_data: dict): - """Test that normalize is not accepted, because it is mandatory and applied else - where.""" - minimum_data["image_means"] = [10.4] - minimum_data["image_stds"] = [3.2] - minimum_data["transforms"] = [ - NormalizeModel(image_means=[0.485], image_stds=[0.229]) - ] - - with pytest.raises(ValueError): - DataConfig(**minimum_data) - - -def test_patch_size(minimum_data: dict): - """Test that non-zero even patch size are accepted.""" - # 2D - data_model = DataConfig(**minimum_data) - - # 3D - minimum_data["patch_size"] = [16, 8, 8] - minimum_data["axes"] = "ZYX" - - data_model = DataConfig(**minimum_data) - assert data_model.patch_size == minimum_data["patch_size"] - - -@pytest.mark.parametrize( - "patch_size", [[12], [0, 12, 12], [12, 12, 13], [16, 10, 16], [12, 12, 12, 12]] -) -def test_wrong_patch_size(minimum_data: dict, patch_size): - """Test that wrong patch sizes are not accepted (zero or odd, dims 1 or > 3).""" - minimum_data["axes"] = "ZYX" if len(patch_size) == 3 else "YX" - minimum_data["patch_size"] = patch_size - - with pytest.raises(ValueError): - DataConfig(**minimum_data) - - -def test_set_3d(minimum_data: dict): - """Test that 3D can be set.""" - data = DataConfig(**minimum_data) - assert "Z" not in data.axes - assert len(data.patch_size) == 2 - - # error if changing Z manually - with pytest.raises(ValueError): - data.axes = "ZYX" - - # or patch size - data = DataConfig(**minimum_data) - with pytest.raises(ValueError): - data.patch_size = [64, 64, 64] - - # set 3D - data = DataConfig(**minimum_data) - data.set_3D("ZYX", [64, 64, 64]) - assert "Z" in data.axes - assert len(data.patch_size) == 3 - - -@pytest.mark.parametrize( - "transforms", - [ - [ - {"name": SupportedTransform.XY_FLIP.value}, - {"name": SupportedTransform.N2V_MANIPULATE.value}, - ], - [ - {"name": SupportedTransform.XY_FLIP.value}, - ], - [ - {"name": SupportedTransform.XY_FLIP.value}, - {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, - {"name": SupportedTransform.N2V_MANIPULATE.value}, - ], - ], -) -def test_passing_supported_transforms(minimum_data: dict, transforms): - """Test that list of supported transforms can be passed.""" - minimum_data["transforms"] = transforms - model = DataConfig(**minimum_data) - - supported = { - "XYFlip": XYFlipModel, - "XYRandomRotate90": XYRandomRotate90Model, - "N2VManipulate": N2VManipulateModel, - } - - for ind, t in enumerate(transforms): - assert t["name"] == model.transforms[ind].name - assert isinstance(model.transforms[ind], supported[t["name"]]) - - -@pytest.mark.parametrize( - "transforms", - [ - [ - {"name": SupportedTransform.N2V_MANIPULATE.value}, - {"name": SupportedTransform.XY_FLIP.value}, - ], - [ - {"name": SupportedTransform.N2V_MANIPULATE.value}, - ], - [ - {"name": SupportedTransform.XY_FLIP.value}, - {"name": SupportedTransform.N2V_MANIPULATE.value}, - {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, - ], - ], -) -def test_n2vmanipulate_last_transform(minimum_data: dict, transforms): - """Test that N2V Manipulate is moved to the last position if it is not.""" - minimum_data["transforms"] = transforms - model = DataConfig(**minimum_data) - assert model.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - - -def test_multiple_n2v_manipulate(minimum_data: dict): - """Test that passing multiple n2v manipulate raises an error.""" - minimum_data["transforms"] = [ - {"name": SupportedTransform.N2V_MANIPULATE.value}, - {"name": SupportedTransform.N2V_MANIPULATE.value}, - ] - with pytest.raises(ValueError): - DataConfig(**minimum_data) - - -def test_remove_n2v_manipulate(minimum_data: dict): - """Test that N2V Manipulate can be removed.""" - minimum_data["transforms"] = [ - {"name": SupportedTransform.XY_FLIP.value}, - {"name": SupportedTransform.N2V_MANIPULATE.value}, - ] - model = DataConfig(**minimum_data) - model.remove_n2v_manipulate() - assert len(model.transforms) == 1 - assert model.transforms[-1].name == SupportedTransform.XY_FLIP.value - - -def test_add_n2v_manipulate(minimum_data: dict): - """Test that N2V Manipulate can be added.""" - minimum_data["transforms"] = [ - {"name": SupportedTransform.XY_FLIP.value}, - ] - model = DataConfig(**minimum_data) - model.add_n2v_manipulate() - assert len(model.transforms) == 2 - assert model.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - - # test that adding twice doesn't change anything - model.add_n2v_manipulate() - assert len(model.transforms) == 2 - assert model.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - - -def test_correct_transform_parameters(minimum_data: dict): - """Test that the transforms have the correct parameters. - - This is important to know that the transforms are not all instantiated as - a generic transform. - """ - minimum_data["transforms"] = [ - {"name": SupportedTransform.XY_FLIP.value}, - {"name": SupportedTransform.XY_RANDOM_ROTATE90.value}, - {"name": SupportedTransform.N2V_MANIPULATE.value}, - ] - model = DataConfig(**minimum_data) - - # N2VManipulate - params = model.transforms[-1].model_dump() - assert "roi_size" in params - assert "masked_pixel_percentage" in params - assert "strategy" in params - assert "struct_mask_axis" in params - assert "struct_mask_span" in params - - -def test_passing_empty_transforms(minimum_data: dict): - """Test that empty list of transforms can be passed.""" - minimum_data["transforms"] = [] - DataConfig(**minimum_data) - - -def test_passing_incorrect_element(minimum_data: dict): - """Test that incorrect element in the list of transforms raises an error ( - e.g. passing un object rather than a string).""" - minimum_data["transforms"] = [ - {"name": get_all_transforms()[SupportedTransform.XY_FLIP.value]()}, - ] - with pytest.raises(ValueError): - DataConfig(**minimum_data) - - -def test_set_n2v_strategy(minimum_data: dict): - """Test that the N2V strategy can be set.""" - uniform = SupportedPixelManipulation.UNIFORM.value - median = SupportedPixelManipulation.MEDIAN.value - - data = DataConfig(**minimum_data) - assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - assert data.transforms[-1].strategy == uniform - - data.set_N2V2_strategy(median) - assert data.transforms[-1].strategy == median - - data.set_N2V2_strategy(uniform) - assert data.transforms[-1].strategy == uniform - - -def test_set_n2v_strategy_wrong_value(minimum_data: dict): - """Test that passing a wrong strategy raises an error.""" - data = DataConfig(**minimum_data) - with pytest.raises(ValueError): - data.set_N2V2_strategy("wrong_value") - - -def test_set_struct_mask(minimum_data: dict): - """Test that the struct mask can be set.""" - none = SupportedStructAxis.NONE.value - vertical = SupportedStructAxis.VERTICAL.value - horizontal = SupportedStructAxis.HORIZONTAL.value - - data = DataConfig(**minimum_data) - assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - assert data.transforms[-1].struct_mask_axis == none - assert data.transforms[-1].struct_mask_span == 5 - - data.set_structN2V_mask(vertical, 3) - assert data.transforms[-1].struct_mask_axis == vertical - assert data.transforms[-1].struct_mask_span == 3 - - data.set_structN2V_mask(horizontal, 7) - assert data.transforms[-1].struct_mask_axis == horizontal - assert data.transforms[-1].struct_mask_span == 7 - - data.set_structN2V_mask(none, 11) - assert data.transforms[-1].struct_mask_axis == none - assert data.transforms[-1].struct_mask_span == 11 - - -def test_set_struct_mask_wrong_value(minimum_data: dict): - """Test that passing a wrong struct mask axis raises an error.""" - data = DataConfig(**minimum_data) - with pytest.raises(ValueError): - data.set_structN2V_mask("wrong_value", 3) - - with pytest.raises(ValueError): - data.set_structN2V_mask(SupportedStructAxis.VERTICAL.value, 1) - - -def test_export_to_yaml_float32_stats(tmp_path, minimum_data: dict): - """Test exporting and loading the pydantic model when the statistics are - np.float32.""" - data = DataConfig(**minimum_data) - - # set np.float32 stats values - data.set_means_and_stds([np.float32(1234.5678)], [np.float32(21.73)]) - - # export to yaml - config_path = tmp_path / "data_config.yaml" - with open(config_path, "w") as f: - # dump configuration - yaml.dump(data.model_dump(), f, default_flow_style=False, sort_keys=False) - - # load model - dictionary = yaml.load(config_path.open("r"), Loader=yaml.SafeLoader) - read_data = DataConfig(**dictionary) - assert read_data.model_dump() == data.model_dump() diff --git a/tests/config/test_fcn_algorithm_model.py b/tests/config/test_fcn_algorithm_model.py deleted file mode 100644 index cc3efdda2..000000000 --- a/tests/config/test_fcn_algorithm_model.py +++ /dev/null @@ -1,117 +0,0 @@ -import pytest - -from careamics.config.fcn_algorithm_model import FCNAlgorithmConfig -from careamics.config.support import ( - SupportedAlgorithm, - SupportedArchitecture, - SupportedLoss, -) - - -def test_all_algorithms_are_supported(): - """Test that all algorithms defined in the Literal are supported.""" - # list of supported algorithms - algorithms = list(SupportedAlgorithm) - - # Algorithm json schema to extract the literal value - schema = FCNAlgorithmConfig.model_json_schema() - - # check that all algorithms are supported - for algo in schema["properties"]["algorithm"]["enum"]: - assert algo in algorithms - - -# TODO: this should not support musplit and denoisplit losses -def test_all_losses_are_supported(): - """Test that all losses defined in the Literal are supported.""" - # list of supported losses - losses = list(SupportedLoss) - - # Algorithm json schema - schema = FCNAlgorithmConfig.model_json_schema() - - # check that all losses are supported - for loss in schema["properties"]["loss"]["enum"]: - assert loss in losses - - -def test_model_discriminator(minimum_algorithm_n2v): - """Test that discriminator permits correct assignment.""" - for model_name in SupportedArchitecture: - # TODO change once VAE are implemented - if model_name.value == "UNet": - minimum_algorithm_n2v["model"]["architecture"] = model_name.value - - algo = FCNAlgorithmConfig(**minimum_algorithm_n2v) - assert algo.model.architecture == model_name.value - - -@pytest.mark.parametrize( - "algorithm, loss, model", - [ - ("n2v", "n2v", {"architecture": "UNet", "n2v2": False}), - ("n2n", "mae", {"architecture": "UNet", "n2v2": False}), - ("care", "mae", {"architecture": "UNet", "n2v2": False}), - ], -) -def test_algorithm_constraints(algorithm: str, loss: str, model: dict): - """Test that constraints are passed for each algorithm.""" - FCNAlgorithmConfig(algorithm=algorithm, loss=loss, model=model) - - -def test_n_channels_n2v(): - """Check that an error is raised if n2v has different number of channels in - input and output.""" - model = { - "architecture": "UNet", - "in_channels": 1, - "num_classes": 2, - "n2v2": False, - } - loss = "n2v" - - with pytest.raises(ValueError): - FCNAlgorithmConfig(algorithm="n2v", loss=loss, model=model) - - -@pytest.mark.parametrize( - "algorithm, n_in, n_out", - [ - ("n2v", 2, 2), - ("n2n", 3, 3), - ("care", 1, 2), - ], -) -def test_comaptiblity_of_number_of_channels(algorithm, n_in, n_out): - """Check that no error is thrown when instantiating the algorithm with a valid - number of in and out channels.""" - model = { - "architecture": "UNet", - "in_channels": n_in, - "num_classes": n_out, - "n2v2": False, - } - loss = "n2v" if algorithm == "n2v" else "mae" - - FCNAlgorithmConfig(algorithm=algorithm, loss=loss, model=model) - - -def test_custom_model(custom_model_parameters): - """Test that a custom model can be instantiated.""" - # create algorithm configuration - FCNAlgorithmConfig( - algorithm=SupportedAlgorithm.CUSTOM.value, - loss="mse", - model=custom_model_parameters, - ) - - -def test_custom_model_wrong_algorithm(custom_model_parameters): - """Test that a custom model fails if the algorithm is not custom.""" - # create algorithm configuration - with pytest.raises(ValueError): - FCNAlgorithmConfig( - algorithm=SupportedAlgorithm.CARE.value, - loss="mse", - model=custom_model_parameters, - ) diff --git a/tests/config/test_n2v_configuration.py b/tests/config/test_n2v_configuration.py new file mode 100644 index 000000000..6d45baff8 --- /dev/null +++ b/tests/config/test_n2v_configuration.py @@ -0,0 +1,101 @@ +import pytest + +from careamics.config import ( + N2VConfiguration, +) +from careamics.config.support import ( + SupportedAlgorithm, + SupportedPixelManipulation, +) + + +@pytest.mark.parametrize( + "algorithm, strategy", + [ + ("n2v", SupportedPixelManipulation.UNIFORM.value), + ("n2v2", SupportedPixelManipulation.MEDIAN.value), + ], +) +def test_correct_n2v2_and_transforms( + minimum_n2v_configuration: dict, algorithm, strategy +): + """Test that N2V and N2V2 are correctly instantiated.""" + minimum_n2v_configuration["algorithm_config"] = { + "algorithm": "n2v", + "loss": "n2v", + "model": { + "architecture": "UNet", + "n2v2": algorithm == "n2v2", + }, + } + minimum_n2v_configuration["data_config"]["transforms"] = [ + { + "name": "N2VManipulate", + "strategy": strategy, + } + ] + + N2VConfiguration(**minimum_n2v_configuration) + + +@pytest.mark.parametrize( + "algorithm, strategy", + [ + ("n2v", SupportedPixelManipulation.MEDIAN.value), + ("n2v2", SupportedPixelManipulation.UNIFORM.value), + ], +) +def test_wrong_n2v2_and_transforms( + minimum_n2v_configuration: dict, algorithm, strategy +): + """Test that N2V and N2V2 throw an error if the strategy and the N2V2 UNet + parameters disagree.""" + minimum_n2v_configuration["algorithm_config"] = { + "algorithm": "n2v", + "loss": "n2v", + "model": { + "architecture": "UNet", + "n2v2": algorithm == "n2v2", + }, + } + minimum_n2v_configuration["data_config"]["transforms"] = [ + { + "name": "N2VManipulate", + "strategy": strategy, + } + ] + + with pytest.raises(ValueError): + N2VConfiguration(**minimum_n2v_configuration) + + +def test_setting_n2v2(minimum_n2v_configuration: dict): + # make sure we use n2v + minimum_n2v_configuration["algorithm_config"][ + "algorithm" + ] = SupportedAlgorithm.N2V.value + + # test config + config = N2VConfiguration(**minimum_n2v_configuration) + assert config.algorithm_config.algorithm == SupportedAlgorithm.N2V.value + assert not config.algorithm_config.model.n2v2 + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.UNIFORM.value + ) + + # set N2V2 + config.set_n2v2(True) + assert config.algorithm_config.model.n2v2 + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.MEDIAN.value + ) + + # set back to N2V + config.set_n2v2(False) + assert not config.algorithm_config.model.n2v2 + assert ( + config.data_config.transforms[-1].strategy + == SupportedPixelManipulation.UNIFORM.value + ) diff --git a/tests/config/test_unet_algorithm_model.py b/tests/config/test_unet_algorithm_model.py new file mode 100644 index 000000000..e8ee9a0c4 --- /dev/null +++ b/tests/config/test_unet_algorithm_model.py @@ -0,0 +1,32 @@ +from careamics.config import UNetBasedAlgorithm +from careamics.config.support import ( + SupportedAlgorithm, + SupportedLoss, +) + + +def test_all_algorithms_are_supported(): + """Test that all algorithms defined in the Literal are supported.""" + # list of supported algorithms + algorithms = list(SupportedAlgorithm) + + # Algorithm json schema to extract the literal value + schema = UNetBasedAlgorithm.model_json_schema() + + # check that all algorithms are supported + for algo in schema["properties"]["algorithm"]["enum"]: + assert algo in algorithms + + +# TODO: this should not support musplit and denoisplit losses +def test_all_losses_are_supported(): + """Test that all losses defined in the Literal are supported.""" + # list of supported losses + losses = list(SupportedLoss) + + # Algorithm json schema + schema = UNetBasedAlgorithm.model_json_schema() + + # check that all losses are supported + for loss in schema["properties"]["loss"]["enum"]: + assert loss in losses diff --git a/tests/config/test_vae_algorithm_model.py b/tests/config/test_vae_algorithm_model.py index 554b42dfc..976ec82b9 100644 --- a/tests/config/test_vae_algorithm_model.py +++ b/tests/config/test_vae_algorithm_model.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from careamics.config import VAEAlgorithmConfig +from careamics.config import VAEBasedAlgorithm from careamics.config.architectures import LVAEModel from careamics.config.nm_model import ( GaussianMixtureNMConfig, @@ -25,7 +25,7 @@ def test_all_losses_are_supported(): losses = list(SupportedLoss) # Algorithm json schema - schema = VAEAlgorithmConfig.model_json_schema() + schema = VAEBasedAlgorithm.model_json_schema() # check that all losses are supported for loss in schema["properties"]["loss"]["enum"]: @@ -35,7 +35,7 @@ def test_all_losses_are_supported(): @pytest.mark.skip("Needs to be updated!") def test_noise_model_usplit(minimum_algorithm_musplit): """Test that the noise model is correctly provided.""" - config = VAEAlgorithmConfig(**minimum_algorithm_musplit) + config = VAEBasedAlgorithm(**minimum_algorithm_musplit) assert config.noise_model is None @@ -53,8 +53,7 @@ def test_noise_model_denoisplit(tmp_path: Path, create_dummy_noise_model): path=tmp_path / "dummy_noise_model.npz", # all other params are default ) - config = VAEAlgorithmConfig( - algorithm_type="vae", + config = VAEBasedAlgorithm( algorithm="denoisplit", loss="denoisplit", model=LVAEModel(architecture="LVAE"), @@ -67,4 +66,4 @@ def test_no_noise_model_error_denoisplit(minimum_algorithm_denoisplit): """Test that the noise model is correctly provided.""" minimum_algorithm_denoisplit["noise_model"] = None with pytest.raises(ValueError): - VAEAlgorithmConfig(**minimum_algorithm_denoisplit) + VAEBasedAlgorithm(**minimum_algorithm_denoisplit) diff --git a/tests/config/validators/test_model_validators.py b/tests/config/validators/test_model_validators.py new file mode 100644 index 000000000..5589d7ba6 --- /dev/null +++ b/tests/config/validators/test_model_validators.py @@ -0,0 +1,52 @@ +import pytest + +from careamics.config.architectures import UNetModel +from careamics.config.validators import ( + model_matching_in_out_channels, + model_without_final_activation, + model_without_n2v2, +) + + +def test_model_without_n2v2(): + """Test the validation of the model without the `n2v2` attribute.""" + model = UNetModel(architecture="UNet", final_activation="None", n2v2=False) + assert model_without_n2v2(model) == model + + model = UNetModel(architecture="UNet", final_activation="None", n2v2=True) + with pytest.raises(ValueError): + model_without_n2v2(model) + + +def test_model_without_final_activation(): + """Test the validation of the model without the `final_activation` attribute.""" + model = UNetModel( + architecture="UNet", + final_activation="None", + ) + assert model_without_final_activation(model) == model + + model = UNetModel( + architecture="UNet", + final_activation="ReLU", + ) + with pytest.raises(ValueError): + model_without_final_activation(model) + + +def test_model_matching_in_out_channels(): + """Test the validation of the model with matching in and out channels.""" + model = UNetModel( + architecture="UNet", + in_channels=1, + num_classes=1, + ) + assert model_matching_in_out_channels(model) == model + + model = UNetModel( + architecture="UNet", + in_channels=1, + num_classes=2, + ) + with pytest.raises(ValueError): + model_matching_in_out_channels(model) diff --git a/tests/conftest.py b/tests/conftest.py index ab5189fc1..8a640d89b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,54 +3,12 @@ import numpy as np import pytest -from torch import nn, ones -from careamics import CAREamist, Configuration -from careamics.config import register_model -from careamics.config.support import SupportedArchitecture, SupportedData +from careamics import CAREamist +from careamics.config import configuration_factory +from careamics.config.support import SupportedData from careamics.model_io import export_to_bmz -###################################### -## fixture for custom model testing ## -## -## - config/algorithm_model -## - config/architectures/custom_model -## - models/model_factory -## -###################################### - - -@register_model(name="linear") -class LinearModel(nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter(ones(in_features, out_features)) - self.bias = nn.Parameter(ones(out_features)) - - def forward(self, input): - return (input @ self.weight) + self.bias - - -@pytest.fixture -def custom_model_name() -> str: - return "linear" - - -@pytest.fixture -def custom_model_parameters(custom_model_name) -> dict: - return { - "architecture": SupportedArchitecture.CUSTOM.value, - "name": custom_model_name, - "in_features": 10, - "out_features": 5, - } - - -###################################### - @pytest.fixture def gaussian_likelihood_params(): @@ -186,6 +144,26 @@ def minimum_data() -> dict: return data +@pytest.fixture +def minimum_data_n2v() -> dict: + """Create a minimum N2V data dictionary. + + Returns + ------- + dict + A minimum data example. + """ + # create dictionary + data = { + "data_type": SupportedData.ARRAY.value, + "patch_size": [8, 8], + "axes": "YX", + "transforms": [{"name": "N2VManipulate"}], + } + + return data + + @pytest.fixture def minimum_inference() -> dict: """Create a minimum inference dictionary. @@ -224,8 +202,8 @@ def minimum_training() -> dict: @pytest.fixture -def minimum_configuration( - minimum_algorithm_n2v: dict, minimum_data: dict, minimum_training: dict +def minimum_n2v_configuration( + minimum_algorithm_n2v: dict, minimum_data_n2v: dict, minimum_training: dict ) -> dict: """Create a minimum configuration dictionary. @@ -235,8 +213,8 @@ def minimum_configuration( Temporary path for testing. minimum_algorithm : dict Minimum algorithm configuration. - minimum_data : dict - Minimum data configuration. + minimum_data_n2v : dict + Minimum N2V data configuration. minimum_training : dict Minimum training configuration. @@ -250,14 +228,14 @@ def minimum_configuration( "experiment_name": "LevitatingFrog", "algorithm_config": minimum_algorithm_n2v, "training_config": minimum_training, - "data_config": minimum_data, + "data_config": minimum_data_n2v, } return configuration @pytest.fixture -def supervised_configuration( +def minimum_supervised_configuration( minimum_algorithm_supervised: dict, minimum_data: dict, minimum_training: dict ) -> dict: configuration = { @@ -327,13 +305,13 @@ def overlaps() -> tuple[int, int]: @pytest.fixture -def pre_trained(tmp_path, minimum_configuration): +def pre_trained(tmp_path, minimum_n2v_configuration): """Fixture to create a pre-trained CAREamics model.""" # training data train_array = np.arange(32 * 32).reshape((32, 32)).astype(np.float32) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 diff --git a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py index df6a71dce..7f41e3759 100644 --- a/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py +++ b/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py @@ -12,7 +12,7 @@ from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader -from careamics.config import Configuration +from careamics.config import configuration_factory from careamics.config.support import SupportedData from careamics.dataset import IterablePredDataset from careamics.lightning import ( @@ -51,7 +51,7 @@ def prediction_writer_callback( # TODO: smoke test with tiff (& example custom save func?) -def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): +def test_smoke_n2v_tiled_tiff(tmp_path, minimum_n2v_configuration): rng = np.random.default_rng(42) # training data @@ -65,7 +65,7 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): train_file = train_dir / file_name tifffile.imwrite(train_file, train_array) - cfg = Configuration(**minimum_configuration) + cfg = configuration_factory(minimum_n2v_configuration) # create lightning module model = create_careamics_module( @@ -129,7 +129,7 @@ def test_smoke_n2v_tiled_tiff(tmp_path, minimum_configuration): np.testing.assert_array_equal(save_data, predicted_images[0][0], verbose=True) -def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): +def test_smoke_n2v_untiled_tiff(tmp_path, minimum_n2v_configuration): rng = np.random.default_rng(42) # training data @@ -143,7 +143,7 @@ def test_smoke_n2v_untiled_tiff(tmp_path, minimum_configuration): train_file = train_dir / file_name tifffile.imwrite(train_file, train_array) - cfg = Configuration(**minimum_configuration) + cfg = configuration_factory(minimum_n2v_configuration) # create lightning module model = create_careamics_module( diff --git a/tests/lightning/test_LVAE_lightning_module.py b/tests/lightning/test_LVAE_lightning_module.py index 3ccd6bc62..19409a7da 100644 --- a/tests/lightning/test_LVAE_lightning_module.py +++ b/tests/lightning/test_LVAE_lightning_module.py @@ -9,7 +9,7 @@ from pytorch_lightning import Trainer from torch.utils.data import DataLoader, Dataset -from careamics.config import VAEAlgorithmConfig +from careamics.config import VAEBasedAlgorithm from careamics.config.architectures import LVAEModel from careamics.config.likelihood_model import ( GaussianLikelihoodConfig, @@ -87,7 +87,7 @@ def create_split_lightning_model( noise_model_config = None nm_lik_config = None - vae_config = VAEAlgorithmConfig( + vae_config = VAEBasedAlgorithm( algorithm=algorithm, loss=loss_config, model=lvae_config, @@ -167,7 +167,7 @@ def test_musplit_lightining_init( ) with exp_error: - vae_config = VAEAlgorithmConfig( + vae_config = VAEBasedAlgorithm( algorithm="musplit", loss=LVAELossConfig(loss_type=loss_type), model=lvae_config, @@ -243,7 +243,7 @@ def test_denoisplit_lightining_init( nm_lik_config = NMLikelihoodConfig() with exp_error: - vae_config = VAEAlgorithmConfig( + vae_config = VAEBasedAlgorithm( algorithm="denoisplit", loss=LVAELossConfig(loss_type=loss_type), model=lvae_config, diff --git a/tests/lightning/test_lightning_api.py b/tests/lightning/test_lightning_api.py index db32e3eb5..6f5e31938 100644 --- a/tests/lightning/test_lightning_api.py +++ b/tests/lightning/test_lightning_api.py @@ -6,7 +6,7 @@ ModelCheckpoint, ) -from careamics import Configuration +from careamics.config import configuration_factory from careamics.lightning import ( create_careamics_module, create_predict_datamodule, @@ -15,7 +15,7 @@ from careamics.prediction_utils import convert_outputs -def test_smoke_n2v_2d_array(tmp_path, minimum_configuration): +def test_smoke_n2v_2d_array(tmp_path, minimum_n2v_configuration): """Test a full run of N2V training with the lightning API.""" rng = np.random.default_rng(42) @@ -23,7 +23,7 @@ def test_smoke_n2v_2d_array(tmp_path, minimum_configuration): train_array = rng.integers(0, 255, (32, 32)).astype(np.float32) val_array = rng.integers(0, 255, (32, 32)).astype(np.float32) - cfg = Configuration(**minimum_configuration) + cfg = configuration_factory(minimum_n2v_configuration) # create lightning module model = create_careamics_module( @@ -73,14 +73,14 @@ def test_smoke_n2v_2d_array(tmp_path, minimum_configuration): assert predicted[0].squeeze().shape == val_array.shape -def test_smoke_n2v_2d_tiling(tmp_path, minimum_configuration): +def test_smoke_n2v_2d_tiling(tmp_path, minimum_n2v_configuration): """Test a full run of N2V training with the lightning API and tiled prediction.""" # training data rng = np.random.default_rng(42) train_array = rng.integers(0, 255, (32, 32)).astype(np.float32) val_array = rng.integers(0, 255, (32, 32)).astype(np.float32) - cfg = Configuration(**minimum_configuration) + cfg = configuration_factory(minimum_n2v_configuration) # create lightning module model = create_careamics_module( diff --git a/tests/lightning/test_lightning_module.py b/tests/lightning/test_lightning_module.py index fc8094875..fb97506cf 100644 --- a/tests/lightning/test_lightning_module.py +++ b/tests/lightning/test_lightning_module.py @@ -1,7 +1,7 @@ import pytest import torch -from careamics.config import FCNAlgorithmConfig +from careamics.config import UNetBasedAlgorithm from careamics.lightning.lightning_module import ( FCNModule, create_careamics_module, @@ -13,7 +13,7 @@ def test_careamics_module(minimum_algorithm_n2v): """Test that the minimum algorithm allows instantiating a the Lightning API intermediate layer.""" - algo_config = FCNAlgorithmConfig(**minimum_algorithm_n2v) + algo_config = UNetBasedAlgorithm(**minimum_algorithm_n2v) # extract model parameters model_parameters = algo_config.model.model_dump(exclude_none=True) @@ -33,7 +33,7 @@ def test_careamics_module(minimum_algorithm_n2v): def test_careamics_fcn(minimum_algorithm_n2v): """Test that the minimum algorithm allows instantiating a CAREamicsKiln.""" - algo_config = FCNAlgorithmConfig(**minimum_algorithm_n2v) + algo_config = UNetBasedAlgorithm(**minimum_algorithm_n2v) # instantiate CAREamicsKiln FCNModule(algo_config) @@ -59,7 +59,7 @@ def test_fcn_module_unet_2D_depth_2_shape(shape): }, "loss": "mae", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -94,7 +94,7 @@ def test_fcn_module_unet_2D_depth_3_shape(shape): }, "loss": "mae", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -127,7 +127,7 @@ def test_fcn_module_unet_depth_2_3D(shape): }, "loss": "mae", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -160,7 +160,7 @@ def test_fcn_module_unet_depth_3_3D(shape): }, "loss": "mae", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -194,7 +194,7 @@ def test_fcn_module_unet_depth_3_3D_n2v2(shape): }, "loss": "n2v", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -219,7 +219,7 @@ def test_fcn_module_unet_depth_2_channels_2D(n_channels): }, "loss": "mae", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -248,7 +248,7 @@ def test_fcn_module_unet_depth_3_channels_2D(n_channels, independent_channels): }, "loss": "mae", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -273,7 +273,7 @@ def test_fcn_module_unet_depth_2_channels_3D(n_channels): }, "loss": "mae", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -298,7 +298,7 @@ def test_fcn_module_unet_depth_3_channels_3D(n_channels): }, "loss": "mae", } - algo_config = FCNAlgorithmConfig(**algo_dict) + algo_config = UNetBasedAlgorithm(**algo_dict) # instantiate CAREamicsKiln model = FCNModule(algo_config) @@ -311,15 +311,16 @@ def test_fcn_module_unet_depth_3_channels_3D(n_channels): @pytest.mark.parametrize("tiled", [False, True]) -def test_prediction_callback_during_training(minimum_configuration, tiled): +def test_prediction_callback_during_training(minimum_n2v_configuration, tiled): import numpy as np from pytorch_lightning import Callback, Trainer - from careamics import CAREamist, Configuration + from careamics import CAREamist + from careamics.config import configuration_factory from careamics.lightning import PredictDataModule, create_predict_datamodule from careamics.prediction_utils import convert_outputs - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) class CustomPredictAfterValidationCallback(Callback): def __init__(self, pred_datamodule: PredictDataModule): diff --git a/tests/lightning/test_lightning_module_onnx_exportability.py b/tests/lightning/test_lightning_module_onnx_exportability.py index b68679d95..f4e8c0d7f 100644 --- a/tests/lightning/test_lightning_module_onnx_exportability.py +++ b/tests/lightning/test_lightning_module_onnx_exportability.py @@ -2,10 +2,12 @@ import torch from onnx import checker -from careamics.config import FCNAlgorithmConfig +from careamics.config import UNetBasedAlgorithm from careamics.lightning.lightning_module import FCNModule +# TODO: move a module for special integration tests +# TODO revisit for specific algorithm configuration @pytest.mark.parametrize( "algorithm, architecture, conv_dim, n2v2, loss, shape", [ @@ -32,7 +34,7 @@ def test_onnx_export(tmp_path, algorithm, architecture, conv_dim, n2v2, loss, sh }, "loss": loss, } - algo_config = FCNAlgorithmConfig(**algo_config) + algo_config = UNetBasedAlgorithm(**algo_config) # instantiate CAREamicsKiln model = FCNModule(algo_config) diff --git a/tests/lightning/test_train_data_module.py b/tests/lightning/test_train_data_module.py index b1ba457ed..35fee190f 100644 --- a/tests/lightning/test_train_data_module.py +++ b/tests/lightning/test_train_data_module.py @@ -7,7 +7,9 @@ SupportedData, SupportedPixelManipulation, SupportedStructAxis, + SupportedTransform, ) +from careamics.config.transformations import N2VManipulateModel, XYFlipModel from careamics.dataset import InMemoryDataset, PathIterableDataset from careamics.lightning import TrainDataModule, create_train_datamodule @@ -40,7 +42,7 @@ def test_wrapper_unknown_type(simple_array): create_train_datamodule( train_data=simple_array, data_type="wrong_type", - patch_size=(10, 10), + patch_size=(8, 8), axes="YX", batch_size=2, ) @@ -63,6 +65,21 @@ def test_wrapper_train_array(simple_array): assert len(list(data_module.train_dataloader())) > 0 +def test_wrapper_supervised(simple_array): + """Test that a supervised data config is created.""" + data_module = create_train_datamodule( + train_data=simple_array, + data_type="array", + patch_size=(8, 8), + axes="YX", + batch_size=2, + train_target_data=simple_array, + val_minimum_patches=2, + ) + for transform in data_module.data_config.transforms: + assert transform.name != SupportedTransform.N2V_MANIPULATE.value + + def test_wrapper_supervised_n2v_throws_error(simple_array): """Test that an error is raised if target data is passed but the transformations (default ones) contain N2V manipulate.""" @@ -70,14 +87,49 @@ def test_wrapper_supervised_n2v_throws_error(simple_array): create_train_datamodule( train_data=simple_array, data_type="array", - patch_size=(10, 10), + patch_size=(8, 8), axes="YX", batch_size=2, train_target_data=simple_array, val_minimum_patches=2, + transforms=[XYFlipModel(), N2VManipulateModel()], + ) + + +def test_wrapper_n2v_wthout_pm_error(simple_array): + """Test that an error is raised if target data is passed but the transformations + (default ones) contain N2V manipulate.""" + with pytest.raises(ValueError): + create_train_datamodule( + train_data=simple_array, + data_type="array", + patch_size=(8, 8), + axes="YX", + batch_size=2, + val_minimum_patches=2, + transforms=[ + XYFlipModel(), + ], ) +def test_wrapper_default_n2v(): + """Test that instantiating a TrainDataModule with N2V works.""" + data_module = create_train_datamodule( + train_data=np.zeros((10, 10)), + data_type="array", + patch_size=(8, 8), + axes="YX", + batch_size=2, + ) + + # N2VManipulate as last transform + assert ( + data_module.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + + @pytest.mark.parametrize( "use_n2v2, strategy", [ @@ -90,7 +142,7 @@ def test_wrapper_n2v2(simple_array, use_n2v2, strategy): data_module = create_train_datamodule( train_data=simple_array, data_type="array", - patch_size=(16, 16), + patch_size=(8, 8), axes="YX", batch_size=2, use_n2v2=use_n2v2, @@ -106,7 +158,7 @@ def test_wrapper_structn2v(simple_array): data_module = create_train_datamodule( train_data=simple_array, data_type="array", - patch_size=(16, 16), + patch_size=(8, 8), axes="YX", batch_size=2, struct_n2v_axis=struct_axis, diff --git a/tests/models/lvae/test_lvae_architecture.py b/tests/models/lvae/test_lvae_architecture.py index df2b2b88f..dc1ed8029 100644 --- a/tests/models/lvae/test_lvae_architecture.py +++ b/tests/models/lvae/test_lvae_architecture.py @@ -5,7 +5,7 @@ import torch from torch import nn -from careamics.config import VAEAlgorithmConfig +from careamics.config import VAEBasedAlgorithm from careamics.config.architectures import LVAEModel from careamics.config.likelihood_model import GaussianLikelihoodConfig from careamics.config.loss_model import LVAELossConfig @@ -35,7 +35,7 @@ def create_LVAE_model( analytical_kl=analytical_kl, ) - config = VAEAlgorithmConfig( + config = VAEBasedAlgorithm( algorithm_type="vae", algorithm="musplit", loss=LVAELossConfig(loss_type="musplit"), diff --git a/tests/models/lvae/test_noise_model.py b/tests/models/lvae/test_noise_model.py index d3d5820ad..87afb59aa 100644 --- a/tests/models/lvae/test_noise_model.py +++ b/tests/models/lvae/test_noise_model.py @@ -3,8 +3,10 @@ import numpy as np import pytest import torch +from scipy.stats import wasserstein_distance from careamics.config import GaussianMixtureNMConfig, MultiChannelNMConfig +from careamics.models.lvae.likelihoods import NoiseModelLikelihood from careamics.models.lvae.noise_models import ( GaussianMixtureNoiseModel, MultiChannelNoiseModel, @@ -141,18 +143,93 @@ def test_multi_channel_noise_model_likelihood( assert likelihood.shape == inp_shape -@pytest.mark.skip(reason="Need to refac noise model to be able to train on CPU") -def test_gm_noise_model_training(tmp_path): - x = np.random.rand(3) - y = np.random.rand(3) +@pytest.mark.parametrize( + "image_size, max_value, noise_scale", + [ + ([5, 128, 128], 255, 0.1), + ([5, 128, 128], 255, 0.5), + ], +) +def test_gm_noise_model_training(image_size, max_value, noise_scale): + gen = np.random.default_rng(42) + signal_normalized = gen.uniform(0, 1, image_size) + noise = gen.normal(0, noise_scale, image_size) + observation_normalized = signal_normalized + noise + signal = signal_normalized * max_value + observation = observation_normalized * max_value nm_config = GaussianMixtureNMConfig( - model_type="GaussianMixtureNoiseModel", signal=x, observation=y + model_type="GaussianMixtureNoiseModel", + n_gaussian=1, + min_signal=signal.min(), + max_signal=signal.max(), ) + noise_model = GaussianMixtureNoiseModel(nm_config) + training_losses = noise_model.fit( + signal=signal, observation=observation, n_epochs=500 + ) + initial_loss = training_losses[0] + last_loss = training_losses[-1] + # Check if model is training + assert initial_loss > last_loss + + # check if estimated mean and std of a noisy sample are close to real ones + signal_tensor = torch.from_numpy(signal).to(torch.float32) + mus, sigmas, _ = noise_model.get_gaussian_parameters(signal_tensor) + + # learned mean should be close to the mean of the signal + learned_mu = mus.mean() / max_value + real_mu = signal_normalized.mean() + assert np.allclose(learned_mu, real_mu, atol=1e-2) + + # learned sigma should be close to the noise sigma + learned_sigma = sigmas.mean() / max_value + noise_image = observation_normalized - signal_normalized + real_sigma = noise_image.std() + assert np.allclose(learned_sigma, real_sigma, atol=1e-2) + + +@pytest.mark.parametrize("image_size, max_value", [([256, 256], 255)]) +def test_noise_model_sampling(image_size, max_value): + gen = np.random.default_rng(42) + + signal = gen.uniform(0, 1, image_size) + observation = signal + gen.normal(0, 0.1, signal.shape) + signal = signal * max_value + observation = observation * max_value + nm_config = GaussianMixtureNMConfig( + model_type="GaussianMixtureNoiseModel", + n_gaussian=1, + min_sigma=100, + min_signal=signal.min(), + max_signal=signal.max(), + ) noise_model = GaussianMixtureNoiseModel(nm_config) + noise_model.fit(signal=signal, observation=observation, n_epochs=200) + sampled_noise_data = noise_model.sample_observation_from_signal(signal) + assert sampled_noise_data.shape == signal.shape + + real_noise = observation - signal + predicted_noise = sampled_noise_data - signal + real_noise = real_noise / max_value + predicted_noise = predicted_noise / max_value + noise_distribution_difference = wasserstein_distance( + real_noise.ravel(), predicted_noise.ravel() + ) + assert noise_distribution_difference < 0.1 + - # Test training - output = noise_model.train(x, y, n_epochs=2) - assert output is not None - # TODO do something with output ? +def test_noise_model_in_likelihood_call(): + test_input = torch.rand(256, 256) + test_target = torch.rand(256, 256) + + nm_config = GaussianMixtureNMConfig( + model_type="GaussianMixtureNoiseModel", n_gaussian=1 + ) + noise_model = GaussianMixtureNoiseModel(nm_config) + likelihood = NoiseModelLikelihood( + data_mean=test_input.mean(), data_std=test_input.std(), noise_model=noise_model + ) + log_likelihood, _ = likelihood(test_input, test_target) + assert log_likelihood is not None diff --git a/tests/models/test_model_factory.py b/tests/models/test_model_factory.py index 18864a569..c2650039c 100644 --- a/tests/models/test_model_factory.py +++ b/tests/models/test_model_factory.py @@ -1,15 +1,9 @@ -from torch import nn, ones - -from careamics.config.architectures import ( - CustomModel, - UNetModel, - register_model, -) -from careamics.config.support import SupportedArchitecture +from careamics.config.architectures import UNetModel from careamics.models import model_factory from careamics.models.unet import UNet +# TODO improve tests def test_model_registry_unet(): """Test that""" model_config = { @@ -19,34 +13,3 @@ def test_model_registry_unet(): # instantiate model model = model_factory(UNetModel(**model_config)) assert isinstance(model, UNet) - - -def test_model_registry_custom(): - """Test that a custom model can be retrieved and instantiated.""" - - # create and register a custom model - @register_model(name="linear_model") - class LinearModel(nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter(ones(in_features, out_features)) - self.bias = nn.Parameter(ones(out_features)) - - def forward(self, input): - return (input @ self.weight) + self.bias - - model_config = { - "architecture": SupportedArchitecture.CUSTOM.value, - "name": "linear_model", - "in_features": 10, - "out_features": 5, - } - - # instantiate model - model = model_factory(CustomModel(**model_config)) - assert isinstance(model, LinearModel) - assert model.in_features == 10 - assert model.out_features == 5 diff --git a/tests/test_careamist.py b/tests/test_careamist.py index 05113dde4..93a107088 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -8,7 +8,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint -from careamics import CAREamist, Configuration, save_configuration +from careamics import CAREamist +from careamics.config import configuration_factory, save_configuration from careamics.config.support import SupportedAlgorithm, SupportedData from careamics.dataset.dataset_utils import reshape_array from careamics.lightning.callbacks import HyperParametersCallback, ProgressBarCallback @@ -27,21 +28,25 @@ def test_no_parameters(): CAREamist() -def test_minimum_configuration_via_object(tmp_path: Path, minimum_configuration: dict): +def test_minimum_configuration_via_object( + tmp_path: Path, minimum_n2v_configuration: dict +): """Test that CAREamics can be instantiated with a minimum configuration object.""" # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) # instantiate CAREamist CAREamist(source=config, work_dir=tmp_path) -def test_minimum_configuration_via_path(tmp_path: Path, minimum_configuration: dict): +def test_minimum_configuration_via_path( + tmp_path: Path, minimum_n2v_configuration: dict +): """Test that CAREamics can be instantiated with a path to a minimum configuration. """ # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) path_to_config = save_configuration(config, tmp_path) # instantiate CAREamist @@ -49,11 +54,11 @@ def test_minimum_configuration_via_path(tmp_path: Path, minimum_configuration: d def test_train_error_target_unsupervised_algorithm( - tmp_path: Path, minimum_configuration: dict + tmp_path: Path, minimum_n2v_configuration: dict ): """Test that an error is raised when a target is provided for N2V.""" # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.algorithm_config.algorithm = SupportedAlgorithm.N2V.value # train error with Paths @@ -82,14 +87,13 @@ def test_train_error_target_unsupervised_algorithm( ) -@pytest.mark.skip(reason="bmz") -def test_train_single_array_no_val(tmp_path: Path, minimum_configuration: dict): +def test_train_single_array_no_val(tmp_path: Path, minimum_n2v_configuration: dict): """Test that CAREamics can be trained with arrays.""" # training data train_array = random_array((32, 32)) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -117,15 +121,14 @@ def test_train_single_array_no_val(tmp_path: Path, minimum_configuration: dict): assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") -def test_train_array(tmp_path: Path, minimum_configuration: dict): +def test_train_array(tmp_path: Path, minimum_n2v_configuration: dict): """Test that CAREamics can be trained on arrays.""" # training data train_array = random_array((32, 32)) val_array = random_array((32, 32)) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -153,10 +156,9 @@ def test_train_array(tmp_path: Path, minimum_configuration: dict): assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") @pytest.mark.parametrize("independent_channels", [False, True]) def test_train_array_channel( - tmp_path: Path, minimum_configuration: dict, independent_channels: bool + tmp_path: Path, minimum_n2v_configuration: dict, independent_channels: bool ): """Test that CAREamics can be trained on arrays with channels.""" # training data @@ -164,7 +166,7 @@ def test_train_array_channel( val_array = random_array((32, 32, 3)) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YXC" config.algorithm_config.model.in_channels = 3 @@ -196,17 +198,16 @@ def test_train_array_channel( assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") -def test_train_array_3d(tmp_path: Path, minimum_configuration: dict): +def test_train_array_3d(tmp_path: Path, minimum_n2v_configuration: dict): """Test that CAREamics can be trained on 3D arrays.""" # training data train_array = random_array((8, 32, 32)) val_array = random_array((8, 32, 32)) # create configuration - minimum_configuration["data_config"]["axes"] = "ZYX" - minimum_configuration["data_config"]["patch_size"] = (8, 16, 16) - config = Configuration(**minimum_configuration) + minimum_n2v_configuration["data_config"]["axes"] = "ZYX" + minimum_n2v_configuration["data_config"]["patch_size"] = (8, 16, 16) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.batch_size = 2 config.data_config.data_type = SupportedData.ARRAY.value @@ -232,8 +233,9 @@ def test_train_array_3d(tmp_path: Path, minimum_configuration: dict): assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") -def test_train_tiff_files_in_memory_no_val(tmp_path: Path, minimum_configuration: dict): +def test_train_tiff_files_in_memory_no_val( + tmp_path: Path, minimum_n2v_configuration: dict +): """Test that CAREamics can be trained with tiff files in memory.""" # training data train_array = random_array((32, 32)) @@ -243,7 +245,7 @@ def test_train_tiff_files_in_memory_no_val(tmp_path: Path, minimum_configuration tifffile.imwrite(train_file, train_array) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -271,8 +273,7 @@ def test_train_tiff_files_in_memory_no_val(tmp_path: Path, minimum_configuration assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") -def test_train_tiff_files_in_memory(tmp_path: Path, minimum_configuration: dict): +def test_train_tiff_files_in_memory(tmp_path: Path, minimum_n2v_configuration: dict): """Test that CAREamics can be trained with tiff files in memory.""" # training data train_array = random_array((32, 32)) @@ -286,7 +287,7 @@ def test_train_tiff_files_in_memory(tmp_path: Path, minimum_configuration: dict) tifffile.imwrite(val_file, val_array) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -314,7 +315,7 @@ def test_train_tiff_files_in_memory(tmp_path: Path, minimum_configuration: dict) assert (tmp_path / "model.zip").exists() -def test_train_tiff_files(tmp_path: Path, minimum_configuration: dict): +def test_train_tiff_files(tmp_path: Path, minimum_n2v_configuration: dict): """Test that CAREamics can be trained with tiff files by deactivating the in memory dataset. """ @@ -330,7 +331,7 @@ def test_train_tiff_files(tmp_path: Path, minimum_configuration: dict): tifffile.imwrite(val_file, val_array) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -358,8 +359,7 @@ def test_train_tiff_files(tmp_path: Path, minimum_configuration: dict): assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") -def test_train_array_supervised(tmp_path: Path, supervised_configuration: dict): +def test_train_array_supervised(tmp_path: Path, minimum_supervised_configuration: dict): """Test that CAREamics can be trained with arrays.""" # training data train_array = random_array((32, 32)) @@ -368,7 +368,7 @@ def test_train_array_supervised(tmp_path: Path, supervised_configuration: dict): val_target = random_array((32, 32)) # create configuration - config = Configuration(**supervised_configuration) + config = configuration_factory(minimum_supervised_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -401,9 +401,8 @@ def test_train_array_supervised(tmp_path: Path, supervised_configuration: dict): assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") def test_train_tiff_files_in_memory_supervised( - tmp_path: Path, supervised_configuration: dict + tmp_path: Path, minimum_supervised_configuration: dict ): """Test that CAREamics can be trained with tiff files in memory.""" # training data @@ -430,7 +429,7 @@ def test_train_tiff_files_in_memory_supervised( tifffile.imwrite(val_target_file, val_target) # create configuration - config = Configuration(**supervised_configuration) + config = configuration_factory(minimum_supervised_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -463,7 +462,9 @@ def test_train_tiff_files_in_memory_supervised( assert (tmp_path / "model.zip").exists() -def test_train_tiff_files_supervised(tmp_path: Path, supervised_configuration: dict): +def test_train_tiff_files_supervised( + tmp_path: Path, minimum_supervised_configuration: dict +): """Test that CAREamics can be trained with tiff files by deactivating the in memory dataset. """ @@ -491,7 +492,7 @@ def test_train_tiff_files_supervised(tmp_path: Path, supervised_configuration: d tifffile.imwrite(val_target_file, val_target) # create configuration - config = Configuration(**supervised_configuration) + config = configuration_factory(minimum_supervised_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -525,18 +526,17 @@ def test_train_tiff_files_supervised(tmp_path: Path, supervised_configuration: d assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") @pytest.mark.parametrize("samples", [1, 2, 4]) @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_on_array_tiled( - tmp_path: Path, minimum_configuration: dict, batch_size, samples + tmp_path: Path, minimum_n2v_configuration: dict, batch_size, samples ): """Test that CAREamics can predict on arrays.""" # training data train_array = random_array((samples, 32, 32)) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "SYX" config.data_config.batch_size = 2 @@ -571,18 +571,17 @@ def test_predict_on_array_tiled( assert (tmp_path / "model.zip").exists() -@pytest.mark.skip(reason="bmz") @pytest.mark.parametrize("samples", [1, 2, 4]) @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_arrays_no_tiling( - tmp_path: Path, minimum_configuration: dict, batch_size, samples + tmp_path: Path, minimum_n2v_configuration: dict, batch_size, samples ): """Test that CAREamics can predict on arrays without tiling.""" # training data train_array = random_array((samples, 32, 32)) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "SYX" config.data_config.batch_size = 2 @@ -621,7 +620,7 @@ def test_predict_arrays_no_tiling( "0.001 different." ) ) -def test_batched_prediction(tmp_path: Path, minimum_configuration: dict): +def test_batched_prediction(tmp_path: Path, minimum_n2v_configuration: dict): "Compare outputs when a batch size of 1 or 2 is used" tile_size = (16, 16) @@ -630,7 +629,7 @@ def test_batched_prediction(tmp_path: Path, minimum_configuration: dict): train_array = random_array(shape) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -657,7 +656,7 @@ def test_batched_prediction(tmp_path: Path, minimum_configuration: dict): @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_tiled_channel( tmp_path: Path, - minimum_configuration: dict, + minimum_n2v_configuration: dict, independent_channels: bool, batch_size: int, ): @@ -667,7 +666,7 @@ def test_predict_tiled_channel( val_array = random_array((3, 32, 32)) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "CYX" config.algorithm_config.model.in_channels = 3 @@ -694,12 +693,11 @@ def test_predict_tiled_channel( ) -@pytest.mark.skip(reason="bmz") @pytest.mark.parametrize("tiled", [True, False]) @pytest.mark.parametrize("n_samples", [1, 2]) @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_path( - tmp_path: Path, minimum_configuration: dict, batch_size, n_samples, tiled + tmp_path: Path, minimum_n2v_configuration: dict, batch_size, n_samples, tiled ): """Test that CAREamics can predict with tiff files.""" # training data @@ -711,7 +709,7 @@ def test_predict_path( tifffile.imwrite(train_file, train_array) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -846,7 +844,7 @@ def test_export_bmz_pretrained_with_array(tmp_path: Path, pre_trained: Path): assert (tmp_path / "model2.zip").exists() -def test_predict_to_disk_path_tiff(tmp_path, minimum_configuration): +def test_predict_to_disk_path_tiff(tmp_path, minimum_n2v_configuration): """Test predict_to_disk function with path source and tiff write type.""" # prepare dummy data @@ -861,7 +859,7 @@ def test_predict_to_disk_path_tiff(tmp_path, minimum_configuration): tifffile.imwrite(train_file, train_array) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -879,7 +877,7 @@ def test_predict_to_disk_path_tiff(tmp_path, minimum_configuration): assert (tmp_path / "predictions" / f"image_{i}.tiff").is_file() -def test_predict_to_disk_datamodule_tiff(tmp_path, minimum_configuration): +def test_predict_to_disk_datamodule_tiff(tmp_path, minimum_n2v_configuration): """Test predict_to_disk function with datamodule source and tiff write type.""" # prepare dummy data @@ -894,7 +892,7 @@ def test_predict_to_disk_datamodule_tiff(tmp_path, minimum_configuration): tifffile.imwrite(train_file, train_array) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -920,7 +918,7 @@ def test_predict_to_disk_datamodule_tiff(tmp_path, minimum_configuration): assert (tmp_path / "predictions" / f"image_{i}.tiff").is_file() -def test_predict_to_disk_custom(tmp_path, minimum_configuration): +def test_predict_to_disk_custom(tmp_path, minimum_n2v_configuration): """Test predict_to_disk function with custom write type.""" def write_numpy(file_path: Path, img: NDArray, *args, **kwargs) -> None: @@ -938,7 +936,7 @@ def write_numpy(file_path: Path, img: NDArray, *args, **kwargs) -> None: tifffile.imwrite(train_file, train_array) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -961,7 +959,7 @@ def write_numpy(file_path: Path, img: NDArray, *args, **kwargs) -> None: assert (tmp_path / "predictions" / f"image_{i}.npy").is_file() -def test_predict_to_disk_custom_raises(tmp_path, minimum_configuration): +def test_predict_to_disk_custom_raises(tmp_path, minimum_n2v_configuration): """ Test predict_to_disk custom write type raises ValueError. @@ -983,7 +981,7 @@ def write_numpy(file_path: Path, img: NDArray, *args, **kwargs) -> None: tifffile.imwrite(train_file, train_array) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -1012,7 +1010,7 @@ def write_numpy(file_path: Path, img: NDArray, *args, **kwargs) -> None: ) -def test_add_custom_callback(tmp_path, minimum_configuration): +def test_add_custom_callback(tmp_path, minimum_n2v_configuration): """Test that custom callback can be added to the CAREamist.""" # define a custom callback @@ -1037,7 +1035,7 @@ def on_train_end(self, trainer, pl_module): train_array = random_array((32, 32)) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -1057,10 +1055,10 @@ def on_train_end(self, trainer, pl_module): assert my_callback.has_ended -def test_error_passing_careamics_callback(tmp_path, minimum_configuration): +def test_error_passing_careamics_callback(tmp_path, minimum_n2v_configuration): """Test that an error is thrown if we pass known callbacks to CAREamist.""" # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -1095,13 +1093,13 @@ def test_error_passing_careamics_callback(tmp_path, minimum_configuration): CAREamist(source=config, work_dir=tmp_path, callbacks=[hyper_params]) -def test_stop_training(tmp_path: Path, minimum_configuration: dict): +def test_stop_training(tmp_path: Path, minimum_n2v_configuration: dict): """Test that CAREamics can stop the training""" # training data train_array = random_array((32, 32)) # create configuration - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 1_000 config.data_config.axes = "YX" config.data_config.batch_size = 2 @@ -1125,9 +1123,9 @@ def _train(): assert careamist.trainer.should_stop -def test_read_logger(tmp_path, minimum_configuration): +def test_read_logger(tmp_path, minimum_n2v_configuration): - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 10 array = np.arange(32 * 32).reshape((32, 32)) diff --git a/tests/test_conftest.py b/tests/test_conftest.py index 17e68e5ba..6cb54b9f7 100644 --- a/tests/test_conftest.py +++ b/tests/test_conftest.py @@ -1,13 +1,15 @@ -from careamics import Configuration -from careamics.config.data_model import DataConfig -from careamics.config.fcn_algorithm_model import FCNAlgorithmConfig +"""Tests for the pytest fixtures.""" + +from careamics.config import Configuration, N2VConfiguration +from careamics.config.algorithms import UNetBasedAlgorithm +from careamics.config.data import DataConfig, N2VDataConfig from careamics.config.inference_model import InferenceConfig from careamics.config.training_model import TrainingConfig def test_minimum_algorithm(minimum_algorithm_n2v): # create algorithm configuration - FCNAlgorithmConfig(**minimum_algorithm_n2v) + UNetBasedAlgorithm(**minimum_algorithm_n2v) def test_minimum_data(minimum_data): @@ -25,6 +27,16 @@ def test_minimum_training(minimum_training): TrainingConfig(**minimum_training) -def test_minimum_configuration(minimum_configuration): +def test_minimum_data_n2v(minimum_data_n2v): + # create data configuration + N2VDataConfig(**minimum_data_n2v) + + +def test_minimum_n2v_configuration(minimum_n2v_configuration): + # create configuration + N2VConfiguration(**minimum_n2v_configuration) + + +def test_minimum_configuration(minimum_supervised_configuration): # create configuration - Configuration(**minimum_configuration) + Configuration(**minimum_supervised_configuration) diff --git a/tests/utils/test_lightning_utils.py b/tests/utils/test_lightning_utils.py index 2e9ea4d4e..4b1ff44a0 100644 --- a/tests/utils/test_lightning_utils.py +++ b/tests/utils/test_lightning_utils.py @@ -1,13 +1,14 @@ import numpy as np -from careamics import CAREamist, Configuration +from careamics import CAREamist +from careamics.config import configuration_factory from careamics.utils import cwd from careamics.utils.lightning_utils import read_csv_logger -def test_read_logger(tmp_path, minimum_configuration): +def test_read_logger(tmp_path, minimum_n2v_configuration): - config = Configuration(**minimum_configuration) + config = configuration_factory(minimum_n2v_configuration) config.training_config.num_epochs = 10 array = np.arange(32 * 32).reshape((32, 32))