From 879fe0771da6bc52baf311138e3eef10bf47d966 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Wed, 24 Apr 2024 20:55:23 +0200 Subject: [PATCH 01/14] Add model parent class --- src/careamics/careamist.py | 32 ++----------- src/careamics/config/algorithm_model.py | 6 +-- .../config/architectures/__init__.py | 2 + .../architectures/architecture_model.py | 29 +++++++++++ .../config/architectures/custom_model.py | 48 +++++++++++-------- .../config/architectures/register_model.py | 3 ++ .../config/architectures/unet_model.py | 6 ++- .../config/architectures/vae_model.py | 5 +- src/careamics/models/model_factory.py | 2 +- .../architectures/test_architecture_model.py | 11 +++++ .../config/architectures/test_custom_model.py | 24 ++++++---- tests/config/architectures/test_unet_model.py | 6 +-- tests/models/test_model_factory.py | 3 +- tests/test_careamist.py | 33 +++++++++++++ tests/test_lightning_module.py | 3 +- 15 files changed, 140 insertions(+), 73 deletions(-) create mode 100644 src/careamics/config/architectures/architecture_model.py create mode 100644 tests/config/architectures/test_architecture_model.py diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 2cba131b8..b02e1ac12 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -645,33 +645,9 @@ def predict( f"Invalid input. Expected a CAREamicsWood instance, paths or " f"np.ndarray (got {type(source)})." ) - + def export_model( - self, path: Union[Path, str], type: Literal["bmz", "script"] = "bmz" + self, + path: Union[Path, str] ) -> None: - """ - Export the model to the BioImage Model Zoo or torchscript format. - - Parameters - ---------- - path : Union[Path, str] - Path to save the model. - type : Literal["bmz", "script"], optional - Export format, by default "bmz". - - Raises - ------ - NotImplementedError - If the export format is not implemented yet. - """ - path = Path(path) - if type == "bmz": - raise NotImplementedError( - "Exporting a model to BioImage Model Zoo is not implemented yet." - ) - elif type == "script": - self.model.to_torchscript(path) - else: - raise ValueError( - f"Invalid export format. Expected 'bmz' or 'script', got {type}." - ) + pass \ No newline at end of file diff --git a/src/careamics/config/algorithm_model.py b/src/careamics/config/algorithm_model.py index 01d120513..b77670095 100644 --- a/src/careamics/config/algorithm_model.py +++ b/src/careamics/config/algorithm_model.py @@ -75,10 +75,8 @@ class AlgorithmModel(BaseModel): ... "model": { ... "architecture": "Custom", ... "name": "linear_model", - ... "parameters": { - ... "in_features": 10, - ... "out_features": 5, - ... }, + ... "in_features": 10, + ... "out_features": 5, ... } ... } >>> config = AlgorithmModel(**config_dict) diff --git a/src/careamics/config/architectures/__init__.py b/src/careamics/config/architectures/__init__.py index 11cab73eb..c65d97f09 100644 --- a/src/careamics/config/architectures/__init__.py +++ b/src/careamics/config/architectures/__init__.py @@ -1,6 +1,7 @@ """Deep-learning model configurations.""" __all__ = [ + "ArchitectureModel", "CustomModel", "UNetModel", "VAEModel", @@ -9,6 +10,7 @@ "register_model", ] +from .architecture_model import ArchitectureModel from .custom_model import CustomModel from .register_model import clear_custom_models, get_custom_model, register_model from .unet_model import UNetModel diff --git a/src/careamics/config/architectures/architecture_model.py b/src/careamics/config/architectures/architecture_model.py new file mode 100644 index 000000000..28113112c --- /dev/null +++ b/src/careamics/config/architectures/architecture_model.py @@ -0,0 +1,29 @@ +from typing import Any, Dict + +from pydantic import BaseModel + + +class ArchitectureModel(BaseModel): + """ + Base Pydantic model for all model architectures. + + The `model_dump` method allows removing the `architecture` key from the model. + """ + + architecture: str + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + """ + Dump the model as a dictionary, ignoring the architecture keyword. + + Returns + ------- + dict[str, Any] + Model as a dictionnary. + """ + model_dict = super().model_dump(**kwargs) + + # remove the architecture key + model_dict.pop("architecture") + + return model_dict diff --git a/src/careamics/config/architectures/custom_model.py b/src/careamics/config/architectures/custom_model.py index 647175e3c..557032a2d 100644 --- a/src/careamics/config/architectures/custom_model.py +++ b/src/careamics/config/architectures/custom_model.py @@ -3,19 +3,14 @@ from pprint import pformat from typing import Literal -from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from torch.nn import Module +from .architecture_model import ArchitectureModel from .register_model import get_custom_model -class CustomParametersModel(BaseModel): - """A Pydantic model that allows any parameter.""" - - model_config = ConfigDict(extra="allow") - - -class CustomModel(BaseModel): +class CustomModel(ArchitectureModel): """Custom model configuration. This Pydantic model allows storing parameters for a custom model. In order for the @@ -61,16 +56,16 @@ class CustomModel(BaseModel): >>> config_dict = { ... "architecture": "Custom", ... "name": "linear", - ... "parameters": { - ... "in_features": 10, - ... "out_features": 5, - ... }, + ... "in_features": 10, + ... "out_features": 5, ... } >>> config = CustomModel(**config_dict) """ # pydantic model config - model_config = ConfigDict(validate_assignment=True) + model_config = ConfigDict( + extra="allow", + ) # discriminator used for choosing the pydantic model in Model architecture: Literal["Custom"] @@ -78,9 +73,6 @@ class CustomModel(BaseModel): # name of the custom model name: str - # parameters - parameters: CustomParametersModel - @field_validator("name") @classmethod def custom_model_is_known(cls, value: str) -> str: @@ -115,12 +107,13 @@ def check_parameters(self: CustomModel) -> CustomModel: """ # instantiate model try: - get_custom_model(self.name)(**self.parameters.model_dump()) + get_custom_model(self.name)(**self.model_dump()) except Exception as e: raise ValueError( - f"error while passing parameters to the model: {e}. Verify that all " - f"mandatory parameters are provided, and that either the model accepts " - f"*args and **kwargs, or that no additional parameter is provided." + f"error while passing parameters to the model {e}. Verify that all " + f"mandatory parameters are provided, and that either the {e} accepts " + f"*args and **kwargs in its __init__() method, or that no additional" + f"parameter is provided." ) from None return self @@ -134,3 +127,18 @@ def __str__(self) -> str: Pretty string. """ return pformat(self.model_dump()) + + def model_dump(self) -> dict: + """Dump the model configuration. + + Returns + ------- + dict + Model configuration. + """ + model_dict = super().model_dump() + + # remove the name key + model_dict.pop("name") + + return model_dict diff --git a/src/careamics/config/architectures/register_model.py b/src/careamics/config/architectures/register_model.py index 6e62cbce1..f35b0b88c 100644 --- a/src/careamics/config/architectures/register_model.py +++ b/src/careamics/config/architectures/register_model.py @@ -38,6 +38,9 @@ def forward(self, input): return (input @ self.weight) + self.bias ``` """ + if name is None or name == "": + raise ValueError("Model name cannot be empty.") + if name in CUSTOM_MODELS: raise ValueError( f"Model {name} already exists. Choose a different name or run " diff --git a/src/careamics/config/architectures/unet_model.py b/src/careamics/config/architectures/unet_model.py index 13178731d..9a032e2b1 100644 --- a/src/careamics/config/architectures/unet_model.py +++ b/src/careamics/config/architectures/unet_model.py @@ -2,12 +2,14 @@ from typing import Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, field_validator + +from .architecture_model import ArchitectureModel # TODO tests activation <-> pydantic model, test the literals! # TODO annotations for the json schema? -class UNetModel(BaseModel): +class UNetModel(ArchitectureModel): """ Pydantic model for a N2V(2)-compatible UNet. diff --git a/src/careamics/config/architectures/vae_model.py b/src/careamics/config/architectures/vae_model.py index 608bf4f06..03c7eb60b 100644 --- a/src/careamics/config/architectures/vae_model.py +++ b/src/careamics/config/architectures/vae_model.py @@ -1,12 +1,13 @@ from typing import Literal from pydantic import ( - BaseModel, ConfigDict, ) +from .architecture_model import ArchitectureModel -class VAEModel(BaseModel): + +class VAEModel(ArchitectureModel): """VAE model placeholder.""" model_config = ConfigDict( diff --git a/src/careamics/models/model_factory.py b/src/careamics/models/model_factory.py index 3ca69a0b3..5b93622b3 100644 --- a/src/careamics/models/model_factory.py +++ b/src/careamics/models/model_factory.py @@ -44,7 +44,7 @@ def model_factory( assert isinstance(model_configuration, CustomModel) model = get_custom_model(model_configuration.name) - return model(**model_configuration.parameters.model_dump()) + return model(**model_configuration.model_dump()) else: raise NotImplementedError( f"Model {model_configuration.architecture} is not implemented or unknown." diff --git a/tests/config/architectures/test_architecture_model.py b/tests/config/architectures/test_architecture_model.py new file mode 100644 index 000000000..8db4fe660 --- /dev/null +++ b/tests/config/architectures/test_architecture_model.py @@ -0,0 +1,11 @@ +from careamics.config.architectures import ArchitectureModel + + +def test_model_dump(): + """Test that architecture keyword is removed from the model dump.""" + model_params = {"architecture": "LeCorbusier"} + model = ArchitectureModel(**model_params) + + # dump model + model_dict = model.model_dump() + assert model_dict == {} \ No newline at end of file diff --git a/tests/config/architectures/test_custom_model.py b/tests/config/architectures/test_custom_model.py index 6b8983937..a6ebd04db 100644 --- a/tests/config/architectures/test_custom_model.py +++ b/tests/config/architectures/test_custom_model.py @@ -2,7 +2,6 @@ from torch import nn, ones from careamics.config.architectures import CustomModel, get_custom_model, register_model -from careamics.config.architectures.custom_model import CustomParametersModel from careamics.config.support import SupportedArchitecture @@ -29,14 +28,18 @@ def forward(self, input): return input -def test_empty_parameters(): - """Test that the custom model parameters does not require any fields.""" - CustomParametersModel() - - def test_any_custom_parameters(): - """Test that the custom model parameters can have any fields.""" - CustomParametersModel(id=3, some_param={"a": 1, "b": 2}, t="test") + """Test that the custom model can have any fields. + + Note that those fields are validated by instantiating the + model. + """ + CustomModel( + architecture="Custom", + name="linear", + in_features=10, + out_features=5 + ) def test_linear_model(): @@ -57,7 +60,8 @@ def test_custom_model(): model_dict = { "architecture": SupportedArchitecture.CUSTOM.value, "name": "linear", - "parameters": {"in_features": 10, "out_features": 5}, + "in_features": 10, + "out_features": 5, } # create Pydantic model @@ -65,7 +69,7 @@ def test_custom_model(): # instantiate model model_class = get_custom_model(pydantic_model.name) - model = model_class(**pydantic_model.parameters.model_dump()) + model = model_class(**pydantic_model.model_dump()) assert isinstance(model, LinearModel) assert model.in_features == 10 diff --git a/tests/config/architectures/test_unet_model.py b/tests/config/architectures/test_unet_model.py index 299223ee3..0f41a91b0 100644 --- a/tests/config/architectures/test_unet_model.py +++ b/tests/config/architectures/test_unet_model.py @@ -113,9 +113,9 @@ def test_model_dump(): model_dict = model.model_dump(exclude_defaults=True) # check that default values are excluded except the architecture - assert "architecture" in model_dict - assert len(model_dict) == 3 + assert "architecture" not in model_dict + assert len(model_dict) == 2 # check that we get all the optional values with the exclude_defaults flag model_dict = model.model_dump(exclude_defaults=False) - assert len(model_dict) == len(dict(model)) + assert len(model_dict) == len(dict(model)) - 1 diff --git a/tests/models/test_model_factory.py b/tests/models/test_model_factory.py index 749524398..08f7075bb 100644 --- a/tests/models/test_model_factory.py +++ b/tests/models/test_model_factory.py @@ -42,7 +42,8 @@ def forward(self, input): model_config = { "architecture": SupportedArchitecture.CUSTOM.value, "name": "linear_model", - "parameters": {"in_features": 10, "out_features": 5}, + "in_features": 10, + "out_features": 5, } # instantiate model diff --git a/tests/test_careamist.py b/tests/test_careamist.py index 07a768102..aded0ce2b 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -447,3 +447,36 @@ def test_predict_pretrained(tmp_path, pre_trained): # check that it predicted assert predicted is not None assert predicted.squeeze().shape == train_array.shape + + +# TODO move to test_export_bmz +def test_export_bmz(tmp_path, pre_trained): + # training data + train_array = np.ones((32, 32), dtype=np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict + predicted = careamist.predict(train_array, tile_overlap=(4, 4)) + + # save images + train_path = tmp_path / "train.npy" + np.save(train_path, train_array[np.newaxis, np.newaxis, ...]) + + predicted_path = tmp_path / "predicted.npy" + np.save(tmp_path / "predicted.npy", predicted[np.newaxis, ...]) + + from careamics.model_io.model_io_utils import export_bmz + + # export to BioImage Model Zoo + export_bmz( + model=careamist.model, + config=careamist.cfg, + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + inputs=train_path, + outputs=predicted_path, + ) diff --git a/tests/test_lightning_module.py b/tests/test_lightning_module.py index 14c5d8228..038ad8438 100644 --- a/tests/test_lightning_module.py +++ b/tests/test_lightning_module.py @@ -3,13 +3,12 @@ def test_careamics_module(minimum_algorithm_n2v): - """Test that the minimum algorithm allows isntantiating a the Lightning API + """Test that the minimum algorithm allows instantiating a the Lightning API intermediate layer.""" algo_config = AlgorithmModel(**minimum_algorithm_n2v) # extract model parameters model_parameters = algo_config.model.model_dump(exclude_none=True) - model_parameters.pop("architecture") # instantiate CAREamicsModule CAREamicsModule( From e6129074a419e96b902125d031abb63db4e2cb1b Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:58:21 +0200 Subject: [PATCH 02/14] Configuration export in checkpoint via callback --- pyproject.toml | 3 +- src/careamics/callbacks/__init__.py | 7 +- .../callbacks/hyperparameters_callback.py | 42 +++++ ...ogress_bar.py => progress_bar_callback.py} | 9 +- src/careamics/careamist.py | 43 ++--- src/careamics/config/__init__.py | 1 + src/careamics/config/configuration_model.py | 49 ++++-- .../dataset/dataset_utils/dataset_utils.py | 5 +- src/careamics/lightning_module.py | 31 +++- src/careamics/model_io/bioimage/__init__.py | 4 + src/careamics/model_io/bioimage/io.py | 0 .../model_io/bioimage/model_description.py | 109 +++++++++---- .../model_io/bioimage/readme_factory.py | 9 +- src/careamics/model_io/model_io_utils.py | 153 ++++++++++++++++-- .../architectures/test_architecture_model.py | 2 +- .../config/architectures/test_custom_model.py | 13 +- tests/conftest.py | 29 +++- tests/test_careamist.py | 62 +------ 18 files changed, 397 insertions(+), 174 deletions(-) create mode 100644 src/careamics/callbacks/hyperparameters_callback.py rename src/careamics/callbacks/{progress_bar.py => progress_bar_callback.py} (93%) delete mode 100644 src/careamics/model_io/bioimage/io.py diff --git a/pyproject.toml b/pyproject.toml index 70f38a57b..e252a82e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,8 +37,9 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + 'pytorch==2.*', 'albumentations', - 'bioimageio.core', + 'bioimageio.core>=0.6.*', 'tifffile', 'psutil', 'pydantic>=2.5', diff --git a/src/careamics/callbacks/__init__.py b/src/careamics/callbacks/__init__.py index a7c04f454..74a82442d 100644 --- a/src/careamics/callbacks/__init__.py +++ b/src/careamics/callbacks/__init__.py @@ -1,5 +1,6 @@ -"""Callback module.""" +"""Callbacks module.""" -__all__ = ["ProgressBarCallback"] +__all__ = ["HyperParametersCallback", "ProgressBarCallback"] -from .progress_bar import ProgressBarCallback +from .hyperparameters_callback import HyperParametersCallback +from .progress_bar_callback import ProgressBarCallback diff --git a/src/careamics/callbacks/hyperparameters_callback.py b/src/careamics/callbacks/hyperparameters_callback.py new file mode 100644 index 000000000..d06090770 --- /dev/null +++ b/src/careamics/callbacks/hyperparameters_callback.py @@ -0,0 +1,42 @@ +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import Callback + +from careamics.config import Configuration + + +class HyperParametersCallback(Callback): + """ + Callback allowing saving CAREamics configuration as hyperparameters in the model. + + This allows saving the configuration as dictionnary in the checkpoints, and + loading it subsequently in a CAREamist instance. + + Attributes + ---------- + config : Configuration + CAREamics configuration to be saved as hyperparameter in the model. + """ + + def __init__(self, config: Configuration): + """ + Constructor. + + Parameters + ---------- + config : Configuration + CAREamics configuration to be saved as hyperparameter in the model. + """ + self.config = config + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule): + """ + Update the hyperparameters of the model with the configuration on train start. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer. + pl_module : LightningModule + PyTorch Lightning module. + """ + pl_module.hparams.update(self.config.model_dump()) diff --git a/src/careamics/callbacks/progress_bar.py b/src/careamics/callbacks/progress_bar_callback.py similarity index 93% rename from src/careamics/callbacks/progress_bar.py rename to src/careamics/callbacks/progress_bar_callback.py index 1ac6861aa..d7862091c 100644 --- a/src/careamics/callbacks/progress_bar.py +++ b/src/careamics/callbacks/progress_bar_callback.py @@ -12,7 +12,7 @@ class ProgressBarCallback(TQDMProgressBar): def init_train_tqdm(self) -> tqdm: """Override this to customize the tqdm bar for training.""" bar = tqdm( - desc='Training', + desc="Training", position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -27,12 +27,12 @@ def init_validation_tqdm(self) -> tqdm: # The main progress bar doesn't exist in `trainer.validate()` has_main_bar = self.train_progress_bar is not None bar = tqdm( - desc='Validating', + desc="Validating", position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, - file=sys.stdout + file=sys.stdout, ) return bar @@ -45,7 +45,7 @@ def init_test_tqdm(self) -> tqdm: leave=True, dynamic_ncols=False, ncols=100, - file=sys.stdout + file=sys.stdout, ) return bar @@ -55,4 +55,3 @@ def get_metrics( """Override this to customize the metrics displayed in the progress bar.""" pbar_metrics = trainer.progress_bar_metrics return {**pbar_metrics} - diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index b02e1ac12..144108d8a 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload import numpy as np -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ( Callback, EarlyStopping, @@ -20,22 +20,23 @@ ) from careamics.config.inference_model import TRANSFORMS_UNION from careamics.config.support import SupportedAlgorithm, SupportedLogger -from careamics.model_io import load_pretrained from careamics.lightning_datamodule import CAREamicsClay, CAREamicsWood from careamics.lightning_module import CAREamicsKiln from careamics.lightning_prediction import CAREamicsPredictionLoop +from careamics.model_io import load_pretrained from careamics.utils import check_path_exists, get_logger +from .callbacks import HyperParametersCallback + logger = get_logger(__name__) LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]] # TODO napari callbacks -# TODO save as modelzoo, lightning and pytorch_dict # TODO: how to do AMP? How to continue training? -class CAREamist(LightningModule): +class CAREamist: """ Main CAREamics class, allowing training and prediction using various algorithms. @@ -118,7 +119,6 @@ def __init__( If no hyper parameters are found in the checkpoint. ValueError If no data module hyper parameters are found in the checkpoint. - """ super().__init__() @@ -135,11 +135,11 @@ def __init__( # configuration object if isinstance(source, Configuration): self.cfg = source - self.save_hyperparameters(self.cfg.model_dump()) # instantiate model - self.model = CAREamicsKiln(self.cfg.algorithm_config) - self.model.hparams.update(self.cfg.model_dump()) + self.model = CAREamicsKiln( + algorithm_config=self.cfg.algorithm_config, + ) # path to configuration file or model else: @@ -152,12 +152,10 @@ def __init__( # load configuration self.cfg = load_configuration(source) - # save configuration in the working directory - # TODO Ugly, think of a better way to save the configuration - self.save_hyperparameters(self.cfg.model_dump()) - # instantiate model - self.model = CAREamicsKiln(self.cfg.algorithm_config) + self.model = CAREamicsKiln( + algorithm_config=self.cfg.algorithm_config, + ) # attempt loading a pre-trained model else: @@ -166,20 +164,16 @@ def __init__( # define the checkpoint saving callback self.callbacks = self._define_callbacks() - # torch.set_float32_matmul_precision('medium') - # instantiate logger if self.cfg.training_config.has_logger(): if self.cfg.training_config.logger == SupportedLogger.WANDB: self.experiment_logger: LOGGER_TYPES = WandbLogger( name=experiment_name, save_dir=self.work_dir / Path("logs"), - # **self.cfg.logger.model_dump(), ) elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD: self.experiment_logger = TensorBoardLogger( save_dir=self.work_dir / Path("logs"), - # **self.cfg.logger.model_dump(), ) else: self.experiment_logger = None @@ -190,7 +184,6 @@ def __init__( callbacks=self.callbacks, default_root_dir=self.work_dir, logger=self.experiment_logger, - # precision="bf16" ) # change the prediction loop, necessary for tiled prediction @@ -207,6 +200,7 @@ def _define_callbacks(self) -> List[Callback]: """ # checkpoint callback saves checkpoints during training self.callbacks = [ + HyperParametersCallback(self.cfg), ModelCheckpoint( dirpath=self.work_dir / Path("checkpoints"), filename=self.cfg.experiment_name, @@ -223,10 +217,6 @@ def _define_callbacks(self) -> List[Callback]: return self.callbacks - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """Save the configuration in the checkpoint.""" - checkpoint["cfg"] = self.cfg.model_dump() - def train( self, *, @@ -645,9 +635,6 @@ def predict( f"Invalid input. Expected a CAREamicsWood instance, paths or " f"np.ndarray (got {type(source)})." ) - - def export_model( - self, - path: Union[Path, str] - ) -> None: - pass \ No newline at end of file + + def export_model(self, path: Union[Path, str]) -> None: + pass diff --git a/src/careamics/config/__init__.py b/src/careamics/config/__init__.py index a551675cb..9ee70b7ee 100644 --- a/src/careamics/config/__init__.py +++ b/src/careamics/config/__init__.py @@ -15,6 +15,7 @@ "CustomModel", "create_inference_configuration", "clear_custom_models", + "ConfigurationInformation", ] from .algorithm_model import AlgorithmModel diff --git a/src/careamics/config/configuration_model.py b/src/careamics/config/configuration_model.py index 5570ad42b..1e437acd0 100644 --- a/src/careamics/config/configuration_model.py +++ b/src/careamics/config/configuration_model.py @@ -7,10 +7,12 @@ from typing import Dict, List, Literal, Union import yaml +from bioimageio.spec.generic.v0_3 import CiteEntry from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from .algorithm_model import AlgorithmModel from .data_model import DataModel +from .references import N2V2_REF, N2V_REF, STRUCTN2V_REF from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform from .training_model import TrainingModel from .transformations.n2v_manipulate_model import ( @@ -419,6 +421,34 @@ def get_algorithm_description(self) -> str: return "" + def get_algorithm_citations(self) -> List[CiteEntry]: + """ + Return a list of citation entries corresponding to the algorithm + defined in the configuration. + + Returns + ------- + List[CiteEntry] + List of citation entries. + """ + if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = ( + self.data_config.transforms[-1].parameters.struct_mask_axis != "none" + ) + + # return the (struct)N2V(2) references + if use_n2v2 and use_structN2V: + return [N2V_REF, N2V2_REF, STRUCTN2V_REF] + elif use_n2v2: + return [N2V_REF, N2V2_REF] + elif use_structN2V: + return [N2V_REF, STRUCTN2V_REF] + else: + return [N2V_REF] + + raise ValueError("Citation not available for custom algorithm.") + def get_algorithm_references(self) -> str: """ Get the algorithm references. @@ -435,22 +465,9 @@ def get_algorithm_references(self) -> str: ) references = [ - 'Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning ' - 'denoising from single noisy images". In Proceedings of the IEEE/CVF ' - "conference on computer vision and pattern recognition (pp. " - "2129-2137). doi: " - "[10.1109/cvpr.2019.00223](https://doi.org/10.1109/cvpr.2019.00223)\n", - "Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., " - '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified ' - 'sampling strategies and a tweaked network architecture". In European ' - "Conference on Computer Vision (pp. 503-518). doi: " - "[10.1007/978-3-031-25069-9_33](https://doi.org/10.1007/978-3-031-" - "25069-9_33)\n", - "Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020" - '. "Removing structured noise with self-supervised blind-spot ' - 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' - "Imaging (ISBI) (pp. 159-163). doi: [10.1109/isbi45749.2020.9098336](" - "https://doi.org/10.1109/isbi45749.2020.9098336)\n", + N2V_REF.text + " doi: " + N2V_REF.doi, + N2V2_REF.text + " doi: " + N2V2_REF.doi, + STRUCTN2V_REF.text + " doi: " + STRUCTN2V_REF.doi, ] # return the (struct)N2V(2) references diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py index 2672e0b26..ace44bc9e 100644 --- a/src/careamics/dataset/dataset_utils/dataset_utils.py +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -63,7 +63,10 @@ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray: # sanity checks if len(_axes) != len(_x.shape): - raise ValueError(f"Incompatible data shape ({_x.shape}) and axes ({_axes}).") + raise ValueError( + f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes " + f"correct?" + ) # get new x shape new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes) diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py index e353265dc..26df7b6fb 100644 --- a/src/careamics/lightning_module.py +++ b/src/careamics/lightning_module.py @@ -71,8 +71,6 @@ def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None: self.lr_scheduler_name = algorithm_config.lr_scheduler.name self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters - # self.save_hyperparameters(algorithm_config.model_dump()) - def forward(self, x: Any) -> Any: """Forward pass. @@ -239,6 +237,33 @@ def __init__( lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau", lr_scheduler_parameters: Optional[dict] = None, ) -> None: + """ + Wrapper for the CAREamics model, exposing all algorithm configuration arguments. + + Parameters + ---------- + algorithm : Union[SupportedAlgorithm, str] + Algorithm to use for training (see SupportedAlgorithm). + loss : Union[SupportedLoss, str] + Loss function to use for training (see SupportedLoss). + architecture : Union[SupportedArchitecture, str] + Model architecture to use for training (see SupportedArchitecture). + model_parameters : dict, optional + Model parameters to use for training, by default {}. Model parameters are + defined in the relevant `torch.nn.Module` class, or Pyddantic model (see + `careamics.config.architectures`). + optimizer : Union[SupportedOptimizer, str], optional + Optimizer to use for training, by default "Adam" (see SupportedOptimizer). + optimizer_parameters : dict, optional + Optimizer parameters to use for training, as defined in `torch.optim`, by + default {}. + lr_scheduler : Union[SupportedScheduler, str], optional + Learning rate scheduler to use for training, by default "ReduceLROnPlateau" + (see SupportedScheduler). + lr_scheduler_parameters : dict, optional + Learning rate scheduler parameters to use for training, as defined in + `torch.optim`, by default {}. + """ # create a AlgorithmModel compatible dictionary if lr_scheduler_parameters is None: lr_scheduler_parameters = {} @@ -263,7 +288,7 @@ def __init__( # add model parameters to algorithm configuration algorithm_configuration["model"] = model_configuration - # self.save_hyperparameters({**model_configuration, **algorithm_configuration}) + # call the parent init using an AlgorithmModel instance super().__init__(AlgorithmModel(**algorithm_configuration)) diff --git a/src/careamics/model_io/bioimage/__init__.py b/src/careamics/model_io/bioimage/__init__.py index f469548f1..cff07fbcd 100644 --- a/src/careamics/model_io/bioimage/__init__.py +++ b/src/careamics/model_io/bioimage/__init__.py @@ -1 +1,5 @@ """Bioimage Model Zoo format functions.""" + +__all__ = ["create_model_description"] + +from .model_description import create_model_description diff --git a/src/careamics/model_io/bioimage/io.py b/src/careamics/model_io/bioimage/io.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index dbb7c9a25..2e18fc5f2 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -1,39 +1,49 @@ from pathlib import Path from typing import List, Optional, Tuple, Union +import numpy as np from bioimageio.spec.model.v0_5 import ( + ArchitectureFromLibraryDescr, Author, AxisBase, AxisId, BatchAxis, ChannelAxis, + EnvironmentFileDescr, FileDescr, + FixedZeroMeanUnitVarianceDescr, + FixedZeroMeanUnitVarianceKwargs, Identifier, InputTensorDescr, ModelDescr, OutputTensorDescr, - ParameterizedSize, + PytorchStateDictWeightsDescr, SpaceInputAxis, SpaceOutputAxis, TensorId, + Version, + WeightsDescr, ) -from careamics import CAREamist -from careamics.config import DataModel, save_configuration -from careamics.utils import cwd, get_careamics_home +from careamics.config import Configuration, DataModel from .readme_factory import readme_factory def _create_axes( + array: np.ndarray, data_config: DataModel, is_input: bool = True, channel_names: Optional[List[str]] = None, ) -> List[AxisBase]: """Create axes description. + Array shape is expected to be SC(Z)YX. + Parameters ---------- + array : np.ndarray + Array. config : DataModel CAREamics data configuration is_input : bool, optional @@ -65,31 +75,27 @@ def _create_axes( f"{data_config.axes}." ) else: - axes_model.append(ChannelAxis(channel_names=[Identifier("raw")])) + # singleton channel + axes_model.append(ChannelAxis(channel_names=[Identifier("channel")])) # spatial axes - for axes in data_config.axes: + for ind, axes in enumerate(data_config.axes): if axes in ["X", "Y", "Z"]: if is_input: axes_model.append( - SpaceInputAxis( - id=AxisId(axes.lower()), - size=ParameterizedSize( - min=16, step=8 - ), # TODO check the min/step - ) + SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind]) ) else: axes_model.append( - SpaceOutputAxis( - id=AxisId(axes.lower()), size=ParameterizedSize(min=16, step=8) - ) + SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind]) ) return axes_model def _create_inputs_ouputs( + input_array: np.ndarray, + output_array: np.ndarray, data_config: DataModel, input_path: Union[Path, str], output_path: Union[Path, str], @@ -112,26 +118,44 @@ def _create_inputs_ouputs( Tuple[InputTensorDescr, OutputTensorDescr] Input and output tensor descriptions """ - input_axes = _create_axes(data_config) - output_axes = _create_axes(data_config) + input_axes = _create_axes(input_array, data_config) + output_axes = _create_axes(output_array, data_config, is_input=False) input_descr = InputTensorDescr( - id=TensorId("raw"), axes=input_axes, test_tensor=FileDescr(source=input_path) + id=TensorId("input"), + axes=input_axes, + test_tensor=FileDescr(source=input_path), + preprocessing=FixedZeroMeanUnitVarianceDescr( + kwargs=FixedZeroMeanUnitVarianceKwargs( + mean=data_config.mean, std=data_config.std + ) + ), ) output_descr = OutputTensorDescr( - id=TensorId("pred"), axes=output_axes, test_tensor=FileDescr(source=output_path) + id=TensorId("prediction"), + axes=output_axes, + test_tensor=FileDescr(source=output_path), + postprocessing=FixedZeroMeanUnitVarianceDescr( + kwargs=FixedZeroMeanUnitVarianceKwargs( + mean=data_config.mean, std=data_config.std + ) + ), ) return input_descr, output_descr def create_model_description( - careamist: CAREamist, + config: Configuration, name: str, general_description: str, authors: List[Author], inputs: Union[Path, str], outputs: Union[Path, str], - weights: Union[Path, str], + weights_path: Union[Path, str], + torch_version: str, + careamics_version: str, + config_path: Union[Path, str], + env_path: Union[Path, str], data_description: Optional[str] = None, custom_description: Optional[str] = None, ) -> ModelDescr: @@ -151,8 +175,14 @@ def create_model_description( Path to input .npy file. outputs : Union[Path, str] Path to output .npy file. - weights : Union[Path, str] + weights_path : Union[Path, str] Path to model weights. + torch_version : str + Pytorch version. + config_path : Union[Path, str] + Path to model configuration. + env_path : Union[Path, str] + Path to environment file. data_description : Optional[str], optional Description of the data, by default None custom_description : Optional[str], optional @@ -163,22 +193,40 @@ def create_model_description( ModelDescr Model description. """ + # documentation doc = readme_factory( - careamist.cfg, + config, + careamics_version=careamics_version, data_description=data_description, custom_description=custom_description, ) + # inputs, outputs input_descr, output_descr = _create_inputs_ouputs( - careamist.cfg.data_config, + input_array=np.load(inputs), + output_array=np.load(outputs), + data_config=config.data_config, input_path=inputs, output_path=outputs, ) - # export configuration - with cwd(get_careamics_home()): - config_path = save_configuration(careamist.cfg, get_careamics_home()) + # weights description + architecture_descr = ArchitectureFromLibraryDescr( + import_from="careamics.models", + callable=f"{config.algorithm_config.model.architecture}", + kwargs=config.algorithm_config.model.model_dump(), + ) + + weights_descr = WeightsDescr( + pytorch_state_dict=PytorchStateDictWeightsDescr( + source=weights_path, + architecture=architecture_descr, + pytorch_version=Version(torch_version), + dependencies=EnvironmentFileDescr(source=env_path), + ), + ) + # overall model description model = ModelDescr( name=name, authors=authors, @@ -186,15 +234,16 @@ def create_model_description( documentation=doc, inputs=[input_descr], outputs=[output_descr], - tags=careamist.cfg.get_algorithm_keywords(), + tags=config.get_algorithm_keywords(), links=[ "https://github.com/CAREamics/careamics", "https://careamics.github.io/latest/", ], license="BSD-3-Clause", version="0.1.0", - weights=weights, - attachments=[config_path], + weights=weights_descr, + attachments=[FileDescr(source=config_path)], + cite=config.get_algorithm_citations(), ) return model diff --git a/src/careamics/model_io/bioimage/readme_factory.py b/src/careamics/model_io/bioimage/readme_factory.py index 146a10ffb..bf988e28b 100644 --- a/src/careamics/model_io/bioimage/readme_factory.py +++ b/src/careamics/model_io/bioimage/readme_factory.py @@ -1,17 +1,12 @@ from pathlib import Path from typing import Optional -import pkg_resources -import torch import yaml -from bioimageio.spec.model.v0_5 import Version from careamics.config import Configuration from careamics.config.support import SupportedAlgorithm from careamics.utils import cwd, get_careamics_home -pytorch_version = Version(torch.__version__) - def _yaml_block(yaml_str: str) -> str: """Return a markdown code block with a yaml string. @@ -31,6 +26,7 @@ def _yaml_block(yaml_str: str) -> str: def readme_factory( config: Configuration, + careamics_version: str, data_description: Optional[str] = None, custom_description: Optional[str] = None, ) -> Path: @@ -46,6 +42,8 @@ def readme_factory( ---------- config : Configuration CAREamics configuration + careamics_version : str + CAREamics version data_description : Optional[str], optional Description of the data, by default None custom_description : Optional[str], optional @@ -83,7 +81,6 @@ def readme_factory( description.append("\n\n") # algorithm details - careamics_version = pkg_resources.get_distribution("careamics").version description.append( f"{algorithm_flavour} was trained using CAREamics (version " f"{careamics_version}) with the following algorithm " diff --git a/src/careamics/model_io/model_io_utils.py b/src/careamics/model_io/model_io_utils.py index 8b1d1fb65..4962d95c0 100644 --- a/src/careamics/model_io/model_io_utils.py +++ b/src/careamics/model_io/model_io_utils.py @@ -1,12 +1,18 @@ """Utility functions to load pretrained models.""" from pathlib import Path -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union -from torch import load +import pkg_resources +from bioimageio.core import test_model +from bioimageio.spec import ValidationSummary, save_bioimageio_package +from torch import __version__, load, save -from careamics.config import Configuration +from careamics.config import Configuration, save_configuration +from careamics.config.support import SupportedArchitecture from careamics.lightning_module import CAREamicsKiln -from careamics.utils import check_path_exists +from careamics.utils import check_path_exists, get_careamics_home + +from .bioimage import create_model_description def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: @@ -34,27 +40,26 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio if path.suffix == ".ckpt": # load checkpoint - checkpoint = load(path) + checkpoint: dict = load(path) # attempt to load algorithm parameters try: cfg_dict = checkpoint["hyper_parameters"] except KeyError as e: raise ValueError( - "Invalid checkpoint file. No `hyper_parameters` found for the " - "algorithm." + f"Invalid checkpoint file. No `hyper_parameters` found in the " + f"checkpoint: {checkpoint.keys()}" ) from e model = _load_from_checkpoint(path) return model, Configuration(**cfg_dict) - elif path.suffix == "bioimage.io.zip": + elif path.suffix == ".zip": return _load_from_bmz(path) else: raise ValueError( - f"Invalid model format. Expected .ckpt or bioimage.io.zip, " - f"got {path.suffix}." + f"Invalid model format. Expected .ckpt or .zip, " f"got {path.suffix}." ) @@ -123,3 +128,131 @@ def _load_from_bmz( # load BMZ archive # extract model and call _load_from_torch_dict + + +def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path: + """ + Export the model state dictionary to a file. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to export. + path : Union[Path, str] + Path to the file where to save the model state dictionary. + + Returns + ------- + Path + Path to the saved model state dictionary. + """ + path = Path(path) + + # make sure it has the correct suffix + if path.suffix not in ".pth": + path = path.with_suffix(".pth") + + # save model state dictionary + save(model.model.state_dict(), path) + + return path + + +def export_bmz( + model: CAREamicsKiln, + config: Configuration, + path: Union[Path, str], + name: str, + general_description: str, + authors: List[dict], + inputs: Union[Path, str], + outputs: Union[Path, str], + data_description: Optional[str] = None, + custom_description: Optional[str] = None, +) -> None: + """ + Export the model to BioImage Model Zoo format. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to export. + config : Configuration + Model configuration. + path : Union[Path, str] + Path to the output file. + name : str + Model name. + general_description : str + General description of the model. + authors : List[dict] + Authors of the model. + inputs : Union[Path, str] + Path to input .npy file. + outputs : Union[Path, str] + Path to output .npy file. + data_description : Optional[str], optional + Description of the data, by default None + custom_description : Optional[str], optional + Description of the custom algorithm, by default None + """ + path = Path(path) + + # method is not compatible with Custom models + if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM: + raise ValueError( + "Exporting Custom models to BioImage Model Zoo format is not supported." + ) + + # make sure it has the correct suffix + if path.suffix not in ".zip": + path = path.with_suffix(".zip") + + # versions + pytorch_version = __version__ + careamics_version = pkg_resources.get_distribution("careamics").version + + # create environment file + env_path = get_careamics_home() / "environment.yml" + env_path.write_text( + f"name: careamics\n" + f"dependencies:\n" + f" - python=3.8\n" + f" - pytorch={pytorch_version}\n" + f" - torchvision={pytorch_version}\n" + f" - pip\n" + f" - pip:\n" + f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n" + ) + # TODO from pip with package version + + # export configuration + config_path = save_configuration(config, get_careamics_home()) + + # export model state dictionary + weight_path = _export_state_dict(model, get_careamics_home() / "weights.pth") + + # create model description + model_description = create_model_description( + config=config, + name=name, + general_description=general_description, + authors=authors, + inputs=inputs, + outputs=outputs, + weights_path=weight_path, + torch_version=pytorch_version, + careamics_version=careamics_version, + config_path=config_path, + env_path=env_path, + data_description=data_description, + custom_description=custom_description, + ) + + # test model description + summary: ValidationSummary = test_model(model_description) + if summary.status != "success": + raise ValueError(f"Model description test failed: {summary}") + + # save bmz model + save_bioimageio_package(model_description, output_path=path) diff --git a/tests/config/architectures/test_architecture_model.py b/tests/config/architectures/test_architecture_model.py index 8db4fe660..b97ab7e96 100644 --- a/tests/config/architectures/test_architecture_model.py +++ b/tests/config/architectures/test_architecture_model.py @@ -8,4 +8,4 @@ def test_model_dump(): # dump model model_dict = model.model_dump() - assert model_dict == {} \ No newline at end of file + assert model_dict == {} diff --git a/tests/config/architectures/test_custom_model.py b/tests/config/architectures/test_custom_model.py index a6ebd04db..63b6c7760 100644 --- a/tests/config/architectures/test_custom_model.py +++ b/tests/config/architectures/test_custom_model.py @@ -30,16 +30,11 @@ def forward(self, input): def test_any_custom_parameters(): """Test that the custom model can have any fields. - - Note that those fields are validated by instantiating the + + Note that those fields are validated by instantiating the model. """ - CustomModel( - architecture="Custom", - name="linear", - in_features=10, - out_features=5 - ) + CustomModel(architecture="Custom", name="linear", in_features=10, out_features=5) def test_linear_model(): @@ -60,7 +55,7 @@ def test_custom_model(): model_dict = { "architecture": SupportedArchitecture.CUSTOM.value, "name": "linear", - "in_features": 10, + "in_features": 10, "out_features": 5, } diff --git a/tests/conftest.py b/tests/conftest.py index b15e70b67..31b781b4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,7 @@ import pytest import tifffile -from careamics.config import Configuration +from careamics import CAREamist, Configuration from careamics.config.algorithm_model import ( AlgorithmModel, LrSchedulerModel, @@ -312,3 +312,30 @@ def supervised_configuration( } return configuration + + +@pytest.fixture +def pre_trained(tmp_path, minimum_configuration): + """Fixture to create a pre-trained CAREamics model.""" + # training data + train_array = np.arange(32 * 32).reshape((32, 32)) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array) + + # check that it trained + pre_trained_path: Path = tmp_path / "checkpoints" / "last.ckpt" + assert pre_trained_path.exists() + + return pre_trained_path diff --git a/tests/test_careamist.py b/tests/test_careamist.py index aded0ce2b..fda43904b 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -8,33 +8,6 @@ from careamics.config.support import SupportedAlgorithm, SupportedData -@pytest.fixture -def pre_trained(tmp_path, minimum_configuration): - """Fixture to create a pre-trained CAREamics model.""" - # training data - train_array = np.ones((32, 32)) - - # create configuration - config = Configuration(**minimum_configuration) - config.training_config.num_epochs = 1 - config.data_config.axes = "YX" - config.data_config.batch_size = 2 - config.data_config.data_type = SupportedData.ARRAY.value - config.data_config.patch_size = (8, 8) - - # instantiate CAREamist - careamist = CAREamist(source=config, work_dir=tmp_path) - - # train CAREamist - careamist.train(train_source=train_array) - - # check that it trained - pre_trained_path: Path = tmp_path / "checkpoints" / "last.ckpt" - assert pre_trained_path.exists() - - return pre_trained_path - - def test_no_parameters(): """Test that CAREamics cannot be instantiated without parameters.""" with pytest.raises(TypeError): @@ -440,6 +413,8 @@ def test_predict_pretrained(tmp_path, pre_trained): # instantiate CAREamist careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + assert careamist.cfg.data_config.mean is not None + assert careamist.cfg.data_config.std is not None # predict predicted = careamist.predict(train_array, tile_overlap=(4, 4)) @@ -447,36 +422,3 @@ def test_predict_pretrained(tmp_path, pre_trained): # check that it predicted assert predicted is not None assert predicted.squeeze().shape == train_array.shape - - -# TODO move to test_export_bmz -def test_export_bmz(tmp_path, pre_trained): - # training data - train_array = np.ones((32, 32), dtype=np.float32) - - # instantiate CAREamist - careamist = CAREamist(source=pre_trained, work_dir=tmp_path) - - # predict - predicted = careamist.predict(train_array, tile_overlap=(4, 4)) - - # save images - train_path = tmp_path / "train.npy" - np.save(train_path, train_array[np.newaxis, np.newaxis, ...]) - - predicted_path = tmp_path / "predicted.npy" - np.save(tmp_path / "predicted.npy", predicted[np.newaxis, ...]) - - from careamics.model_io.model_io_utils import export_bmz - - # export to BioImage Model Zoo - export_bmz( - model=careamist.model, - config=careamist.cfg, - path=tmp_path / "model.zip", - name="TopModel", - general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}], - inputs=train_path, - outputs=predicted_path, - ) From 4dd6863c8e8e081741095578356dc6c26df54d22 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Thu, 25 Apr 2024 12:40:40 +0200 Subject: [PATCH 03/14] Remove torch from dependencies --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e252a82e7..1951b438c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,13 +37,13 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ - 'pytorch==2.*', + # pytorch should be installed via conda/pip beforehand 'albumentations', - 'bioimageio.core>=0.6.*', + 'bioimageio.core>=0.6.0', 'tifffile', 'psutil', 'pydantic>=2.5', - 'pytorch_lightning', + 'pytorch_lightning>=2.2.0', 'pyyaml', 'scikit-image', 'zarr', From 3ddbb9842560c4908bfc79b91427bc7393253ede Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Thu, 25 Apr 2024 14:05:01 +0200 Subject: [PATCH 04/14] (WIP) BMZ export --- src/careamics/config/configuration_factory.py | 12 ++++-- src/careamics/config/inference_model.py | 35 +---------------- src/careamics/config/references/__init__.py | 13 +++++++ src/careamics/config/references/references.py | 24 ++++++++++++ src/careamics/dataset/in_memory_dataset.py | 38 ++++++++----------- .../model_io/bioimage/model_description.py | 28 ++++++++++---- src/careamics/model_io/model_io_utils.py | 2 +- tests/config/test_inference_model.py | 19 ---------- tests/model_io/test_model_io_utils.py | 34 +++++++++++++++++ 9 files changed, 117 insertions(+), 88 deletions(-) create mode 100644 src/careamics/config/references/__init__.py create mode 100644 src/careamics/config/references/references.py create mode 100644 tests/model_io/test_model_io_utils.py diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index 16b5b66c8..c697b2a40 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -404,8 +404,6 @@ def create_inference_configuration( training_configuration: Configuration, tile_size: Optional[Tuple[int, ...]] = None, tile_overlap: Tuple[int, ...] = (48, 48), - mean: Optional[float] = None, - std: Optional[float] = None, data_type: Optional[Literal["array", "tiff", "custom"]] = None, axes: Optional[str] = None, transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None, @@ -442,6 +440,12 @@ def create_inference_configuration( InferenceConfiguration Configuration for inference with N2V. """ + if ( + training_configuration.data_config.mean is None + or training_configuration.data_config.std is None + ): + raise ValueError("Mean and std must be provided in the training configuration.") + if transforms is None: transforms = [ { @@ -454,8 +458,8 @@ def create_inference_configuration( tile_size=tile_size or training_configuration.data_config.patch_size, tile_overlap=tile_overlap, axes=axes or training_configuration.data_config.axes, - mean=mean or training_configuration.data_config.mean, - std=std or training_configuration.data_config.std, + mean=training_configuration.data_config.mean, + std=training_configuration.data_config.std, transforms=transforms, tta_transforms=tta_transforms, batch_size=batch_size, diff --git a/src/careamics/config/inference_model.py b/src/careamics/config/inference_model.py index cc5243dc4..7ac0ff0de 100644 --- a/src/careamics/config/inference_model.py +++ b/src/careamics/config/inference_model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union from albumentations import Compose from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ -234,39 +234,6 @@ def add_std_and_mean_to_normalize( return pred_model - def _update(self, **kwargs: Any) -> None: - """Update multiple arguments at once.""" - self.__dict__.update(kwargs) - self.__class__.model_validate(self.__dict__) - - def set_mean_and_std(self, mean: float, std: float) -> None: - """ - Set mean and standard deviation of the data. - - This method should be used instead setting the fields directly, as it would - otherwise trigger a validation error. - - Parameters - ---------- - mean : float - Mean of the data. - std : float - Standard deviation of the data. - """ - self._update(mean=mean, std=std) - - # search in the transforms for Normalize and update parameters - if not isinstance(self.transforms, Compose): - for transform in self.transforms: - if transform.name == SupportedTransform.NORMALIZE.value: - transform.parameters.mean = mean - transform.parameters.std = std - else: - raise ValueError( - "Setting mean and std with Compose transforms is not allowed. Add " - "mean and std parameters directly to the transform in the Compose." - ) - def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None: """ Set 3D parameters. diff --git a/src/careamics/config/references/__init__.py b/src/careamics/config/references/__init__.py new file mode 100644 index 000000000..862827e26 --- /dev/null +++ b/src/careamics/config/references/__init__.py @@ -0,0 +1,13 @@ +"""Module containing references to the algorithm used in CAREamics.""" + +__all__ = [ + "N2V2_REF", + "N2V_REF", + "STRUCTN2V_REF", +] + +from .references import ( + N2V2_REF, + N2V_REF, + STRUCTN2V_REF, +) diff --git a/src/careamics/config/references/references.py b/src/careamics/config/references/references.py new file mode 100644 index 000000000..9b2e29211 --- /dev/null +++ b/src/careamics/config/references/references.py @@ -0,0 +1,24 @@ +from bioimageio.spec.generic.v0_3 import CiteEntry + +N2V_REF = CiteEntry( + text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning ' + 'denoising from single noisy images". In Proceedings of the IEEE/CVF ' + "conference on computer vision and pattern recognition (pp. 2129-2137).", + doi="10.1109/cvpr.2019.00223", +) + +N2V2_REF = CiteEntry( + text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., " + '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified ' + 'sampling strategies and a tweaked network architecture". In European ' + "Conference on Computer Vision (pp. 503-518).", + doi="10.1007/978-3-031-25069-9_33", +) + +STRUCTN2V_REF = CiteEntry( + text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020." + '"Removing structured noise with self-supervised blind-spot ' + 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' + "Imaging (ISBI) (pp. 159-163).", + doi="10.1109/isbi45749.2020.9098336", +) diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 620b06345..3cdc3c7c7 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -269,7 +269,7 @@ def split_dataset( return dataset -class InMemoryPredictionDataset(InMemoryDataset): +class InMemoryPredictionDataset(Dataset): """ Dataset storing data in memory and allowing generating patches from it. @@ -311,25 +311,8 @@ def __init__( self.read_source_func = read_source_func # Generate patches - tiles = self._prepare_tiles() - - # Add results to members - self.data, computed_mean, computed_std = tiles - - if not self.pred_config.mean or not self.pred_config.std: - self.mean, self.std = computed_mean, computed_std - logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}") - - # if the transforms are not an instance of Compose - if hasattr(self.pred_config, "has_transform_list"): - if self.pred_config.has_transform_list(): - # update mean and std in configuration - # the object is mutable and should then be recorded in the CAREamist - self.pred_config.set_mean_and_std(self.mean, self.std) - else: - self.pred_config.set_mean_and_std(self.mean, self.std) - else: - self.mean, self.std = self.pred_config.mean, self.pred_config.std + self.data = self._prepare_tiles() + self.mean, self.std = self.pred_config.mean, self.pred_config.std # get transforms self.patch_transform = get_patch_transform( @@ -351,9 +334,20 @@ def _prepare_tiles(self) -> Callable: if self.tiling: return generate_patches_predict( self.input_array, self.axes, self.tile_size, self.tile_overlap - ), self.input_array.mean(), self.input_array.std() + ) else: - return self.input_array, self.input_array.mean(), self.input_array.std() + return self.input_array + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.data) def __getitem__(self, index: int) -> Tuple[np.ndarray, Any, Any, Any, Any]: """ diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index 2e18fc5f2..0a4a596bf 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -120,25 +120,37 @@ def _create_inputs_ouputs( """ input_axes = _create_axes(input_array, data_config) output_axes = _create_axes(output_array, data_config, is_input=False) + + # mean and std + mean = data_config.mean + std = data_config.std + + # and the mean and std required to invert the normalization + inv_mean = -mean / std + inv_std = 1 / std + + # create input/output descriptions input_descr = InputTensorDescr( id=TensorId("input"), axes=input_axes, test_tensor=FileDescr(source=input_path), - preprocessing=FixedZeroMeanUnitVarianceDescr( - kwargs=FixedZeroMeanUnitVarianceKwargs( - mean=data_config.mean, std=data_config.std + preprocessing=[ + FixedZeroMeanUnitVarianceDescr( + kwargs=FixedZeroMeanUnitVarianceKwargs(mean=mean, std=std) ) - ), + ], ) output_descr = OutputTensorDescr( id=TensorId("prediction"), axes=output_axes, test_tensor=FileDescr(source=output_path), - postprocessing=FixedZeroMeanUnitVarianceDescr( - kwargs=FixedZeroMeanUnitVarianceKwargs( - mean=data_config.mean, std=data_config.std + postprocessing=[ + FixedZeroMeanUnitVarianceDescr( + kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization + mean=inv_mean, std=inv_std + ) ) - ), + ], ) return input_descr, output_descr diff --git a/src/careamics/model_io/model_io_utils.py b/src/careamics/model_io/model_io_utils.py index 4962d95c0..88f6fd4f7 100644 --- a/src/careamics/model_io/model_io_utils.py +++ b/src/careamics/model_io/model_io_utils.py @@ -251,7 +251,7 @@ def export_bmz( # test model description summary: ValidationSummary = test_model(model_description) - if summary.status != "success": + if summary.status == "failed": raise ValueError(f"Model description test failed: {summary}") # save bmz model diff --git a/tests/config/test_inference_model.py b/tests/config/test_inference_model.py index 8277869cb..6fa9e068e 100644 --- a/tests/config/test_inference_model.py +++ b/tests/config/test_inference_model.py @@ -51,25 +51,6 @@ def test_mean_std_both_specified_or_none(minimum_inference: dict): InferenceModel(**minimum_inference) -def test_set_mean_and_std(minimum_inference: dict): - """Test that mean and std can be set after initialization.""" - # they can be set both, when they None - mean = 4.07 - std = 14.07 - pred = InferenceModel(**minimum_inference) - pred.set_mean_and_std(mean, std) - assert pred.mean == mean - assert pred.std == std - - # and if they are already set - minimum_inference["mean"] = 10.4 - minimum_inference["std"] = 3.2 - pred = InferenceModel(**minimum_inference) - pred.set_mean_and_std(mean, std) - assert pred.mean == mean - assert pred.std == std - - def test_tile_size(minimum_inference: dict): """Test that non-zero even patch size are accepted.""" # 2D diff --git a/tests/model_io/test_model_io_utils.py b/tests/model_io/test_model_io_utils.py new file mode 100644 index 000000000..46fea7273 --- /dev/null +++ b/tests/model_io/test_model_io_utils.py @@ -0,0 +1,34 @@ +import numpy as np + +from careamics import CAREamist +from careamics.model_io.model_io_utils import export_bmz + + +def test_export_bmz(tmp_path, pre_trained): + # training data + train_array = np.ones((32, 32), dtype=np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict + predicted = careamist.predict(train_array, tta_transforms=False) + + # save images + train_path = tmp_path / "train.npy" + np.save(train_path, train_array[np.newaxis, np.newaxis, ...]) + + predicted_path = tmp_path / "predicted.npy" + np.save(tmp_path / "predicted.npy", predicted[np.newaxis, ...]) + + # export to BioImage Model Zoo + export_bmz( + model=careamist.model, + config=careamist.cfg, + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + inputs=train_path, + outputs=predicted_path, + ) From 3c223304f8e6dd215033c9933adbdec2abdab080 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Thu, 25 Apr 2024 23:22:28 +0200 Subject: [PATCH 05/14] Prediction without tile, tiling information object, patching refactoring --- src/careamics/__init__.py | 6 +- src/careamics/careamist.py | 8 +- src/careamics/config/configuration_factory.py | 6 +- src/careamics/config/inference_model.py | 62 +-- src/careamics/config/tile_information.py | 101 +++++ src/careamics/dataset/in_memory_dataset.py | 79 ++-- src/careamics/dataset/iterable_dataset.py | 84 ++-- src/careamics/dataset/patching/__init__.py | 8 - src/careamics/dataset/patching/patching.py | 202 +-------- .../dataset/patching/tiled_patching.py | 46 +-- src/careamics/dataset/zarr_dataset.py | 299 +++++++------- src/careamics/lightning_datamodule.py | 339 +-------------- src/careamics/lightning_module.py | 6 +- .../lightning_prediction_datamodule.py | 390 ++++++++++++++++++ ...iction.py => lightning_prediction_loop.py} | 59 ++- src/careamics/prediction/prediction_utils.py | 30 +- tests/config/test_inference_model.py | 32 +- tests/config/test_tile_information.py | 50 +++ tests/conftest.py | 4 +- .../dataset/patching/test_random_patching.py | 2 +- tests/dataset/patching/test_tiled_patching.py | 9 +- tests/prediction/test_prediction_utils.py | 13 +- tests/test_careamist.py | 39 +- tests/test_lightning_datamodule.py | 30 +- 24 files changed, 980 insertions(+), 924 deletions(-) create mode 100644 src/careamics/config/tile_information.py create mode 100644 src/careamics/lightning_prediction_datamodule.py rename src/careamics/{lightning_prediction.py => lightning_prediction_loop.py} (54%) create mode 100644 tests/config/test_tile_information.py diff --git a/src/careamics/__init__.py b/src/careamics/__init__.py index 86904d1d3..222aff45e 100644 --- a/src/careamics/__init__.py +++ b/src/careamics/__init__.py @@ -19,8 +19,6 @@ from .careamist import CAREamist from .config import Configuration, load_configuration, save_configuration -from .lightning_datamodule import ( - CAREamicsPredictDataModule, - CAREamicsTrainDataModule, -) +from .lightning_datamodule import CAREamicsTrainDataModule from .lightning_module import CAREamicsModule +from .lightning_prediction_datamodule import CAREamicsPredictDataModule diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 144108d8a..b8d76dec6 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -20,9 +20,10 @@ ) from careamics.config.inference_model import TRANSFORMS_UNION from careamics.config.support import SupportedAlgorithm, SupportedLogger -from careamics.lightning_datamodule import CAREamicsClay, CAREamicsWood +from careamics.lightning_datamodule import CAREamicsWood from careamics.lightning_module import CAREamicsKiln -from careamics.lightning_prediction import CAREamicsPredictionLoop +from careamics.lightning_prediction_datamodule import CAREamicsClay +from careamics.lightning_prediction_loop import CAREamicsPredictionLoop from careamics.model_io import load_pretrained from careamics.utils import check_path_exists, get_logger @@ -635,6 +636,3 @@ def predict( f"Invalid input. Expected a CAREamicsWood instance, paths or " f"np.ndarray (got {type(source)})." ) - - def export_model(self, path: Union[Path, str]) -> None: - pass diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index c697b2a40..028e826b3 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -403,7 +403,7 @@ def create_n2v_configuration( def create_inference_configuration( training_configuration: Configuration, tile_size: Optional[Tuple[int, ...]] = None, - tile_overlap: Tuple[int, ...] = (48, 48), + tile_overlap: Optional[Tuple[int, ...]] = None, data_type: Optional[Literal["array", "tiff", "custom"]] = None, axes: Optional[str] = None, transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None, @@ -413,7 +413,7 @@ def create_inference_configuration( """ Create a configuration for inference with N2V. - If not provided, `data_type`, `tile_size`, and `axes` are taken from the training + If not provided, `data_type` and `axes` are taken from the training configuration. If `transforms` are not provided, only normalization is applied. Parameters @@ -455,7 +455,7 @@ def create_inference_configuration( return InferenceModel( data_type=data_type or training_configuration.data_config.data_type, - tile_size=tile_size or training_configuration.data_config.patch_size, + tile_size=tile_size, tile_overlap=tile_overlap, axes=axes or training_configuration.data_config.axes, mean=training_configuration.data_config.mean, diff --git a/src/careamics/config/inference_model.py b/src/careamics/config/inference_model.py index 7ac0ff0de..78583cecc 100644 --- a/src/careamics/config/inference_model.py +++ b/src/careamics/config/inference_model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Literal, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union from albumentations import Compose from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ -20,16 +20,17 @@ class InferenceModel(BaseModel): # Mandatory fields data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData - tile_size: Union[List[int], Tuple[int]] = Field(..., min_length=2, max_length=3) - tile_overlap: List[int] = Field( - default=[48, 48], min_length=2, max_length=3 - ) # TODO Will be calculated automatically in the future + tile_size: Optional[Union[List[int], Tuple[int]]] = Field( + default=None, min_length=2, max_length=3 + ) + tile_overlap: Optional[Union[List[int], Tuple[int]]] = Field( + default=None, min_length=2, max_length=3 + ) axes: str - # Optional fields - mean: Optional[float] = None - std: Optional[float] = None + mean: float + std: float = Field(..., ge=0.0) transforms: Union[List[TRANSFORMS_UNION], Compose] = Field( default=[ @@ -71,12 +72,15 @@ def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: ValueError If the patch size is not even. """ - for dim in patch_list: - if dim < 1: - raise ValueError(f"Patch size must be non-zero positive (got {dim}).") + if patch_list is not None: + for dim in patch_list: + if dim < 1: + raise ValueError( + f"Patch size must be non-zero positive (got {dim})." + ) - if dim % 2 != 0: - raise ValueError(f"Patch size must be even (got {dim}).") + if dim % 2 != 0: + raise ValueError(f"Patch size must be even (got {dim}).") return patch_list @@ -162,20 +166,23 @@ def validate_dimensions(cls, pred_model: InferenceModel) -> InferenceModel: """ expected_len = 3 if "Z" in pred_model.axes else 2 - if len(pred_model.tile_size) != expected_len: - raise ValueError( - f"Tile size must have {expected_len} dimensions given axes " - f"{pred_model.axes} (got {pred_model.tile_size})." - ) + if pred_model.tile_size is not None and pred_model.tile_overlap is not None: + if len(pred_model.tile_size) != expected_len: + raise ValueError( + f"Tile size must have {expected_len} dimensions given axes " + f"{pred_model.axes} (got {pred_model.tile_size})." + ) - if len(pred_model.tile_overlap) != expected_len: - raise ValueError( - f"Tile overlap must have {expected_len} dimensions given axes " - f"{pred_model.axes} (got {pred_model.tile_overlap})." - ) + if len(pred_model.tile_overlap) != expected_len: + raise ValueError( + f"Tile overlap must have {expected_len} dimensions given axes " + f"{pred_model.axes} (got {pred_model.tile_overlap})." + ) - if any((i >= j) for i, j in zip(pred_model.tile_overlap, pred_model.tile_size)): - raise ValueError("Tile overlap must be smaller than tile size.") + if any( + (i >= j) for i, j in zip(pred_model.tile_overlap, pred_model.tile_size) + ): + raise ValueError("Tile overlap must be smaller than tile size.") return pred_model @@ -234,6 +241,11 @@ def add_std_and_mean_to_normalize( return pred_model + def _update(self, **kwargs: Any) -> None: + """Update multiple arguments at once.""" + self.__dict__.update(kwargs) + self.__class__.model_validate(self.__dict__) + def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None: """ Set 3D parameters. diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py new file mode 100644 index 000000000..6e669af9f --- /dev/null +++ b/src/careamics/config/tile_information.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from typing import Optional, Tuple + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator + + +class TileInformation(BaseModel): + """ + Pydantic model containing tile information. + + This model is used to represent the information required to stitch back a tile into + a larger image. It is used throughout the prediction pipeline of CAREamics. + """ + + model_config = ConfigDict(validate_default=True) + + array_shape: Tuple[int, ...] + tiled: bool = False + last_tile: bool = False + overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) + stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) + + @field_validator("array_shape") + @classmethod + def no_singleton_dimensions(cls, v: Tuple[int, ...]): + """ + Check that the array shape does not have any singleton dimensions. + + Parameters + ---------- + v : Tuple[int, ...] + Array shape to check. + + Returns + ------- + Tuple[int, ...] + The array shape if it does not contain singleton dimensions. + + Raises + ------ + ValueError + If the array shape contains singleton dimensions. + """ + if any(dim == 1 for dim in v): + raise ValueError("Array shape must not contain singleton dimensions.") + return v + + @field_validator("last_tile") + @classmethod + def only_if_tiled(cls, v: bool, values: ValidationInfo): + """ + Check that the last tile flag is only set to `True` if tiling is enabled, + otherwise set it to `False`. + + Parameters + ---------- + v : bool + Last tile flag. + values : ValidationInfo + Validation information. + + Returns + ------- + bool + The last tile flag. + """ + if not values.data["tiled"]: + return False + return v + + @field_validator("overlap_crop_coords", "stitch_coords") + @classmethod + def mandatory_if_tiled(cls, v: Optional[Tuple[int, ...]], values: ValidationInfo): + """ + Check that the coordinates are not `None` if tiling is enabled. + + The method also return `None` if tiling is not enabled. + + Parameters + ---------- + v : Optional[Tuple[int, ...]] + Coordinates to check. + values : ValidationInfo + Validation information. + + Returns + ------- + Optional[Tuple[int, ...]] + The coordinates if tiling is enabled, otherwise `None`. + + Raises + ------ + ValueError + If the coordinates are `None` and tiling is enabled. + """ + if values.data["tiled"]: + if v is None: + raise ValueError("Value must be specified if tiling is enabled.") + else: + return None diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 3cdc3c7c7..becf3eaa3 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -9,16 +9,17 @@ from torch.utils.data import Dataset from ..config import DataModel, InferenceModel +from ..config.tile_information import TileInformation from ..utils.logging import get_logger -from .dataset_utils import read_tiff +from .dataset_utils import read_tiff, reshape_array from .patching.patch_transform import get_patch_transform from .patching.patching import ( - generate_patches_predict, prepare_patches_supervised, prepare_patches_supervised_array, prepare_patches_unsupervised, prepare_patches_unsupervised_array, ) +from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) @@ -302,11 +303,13 @@ def __init__( self.axes = self.pred_config.axes self.tile_size = self.pred_config.tile_size self.tile_overlap = self.pred_config.tile_overlap - self.tiling = self.tile_size and self.tile_overlap self.mean = self.pred_config.mean self.std = self.pred_config.std self.data_target = data_target + # tiling only if both tile size and overlap are provided + self.tiling = self.tile_size is not None and self.tile_overlap is not None + # read function self.read_source_func = read_source_func @@ -320,23 +323,34 @@ def __init__( with_target=self.data_target is not None, ) - def _prepare_tiles(self) -> Callable: + def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: """ Iterate over data source and create an array of patches. - Calls consecutive function for supervised and unsupervised learning. - Returns ------- - np.ndarray - Array of patches. + List[XArrayTile] + List of tiles. """ + # reshape array + reshaped_sample = reshape_array(self.input_array, self.axes) + if self.tiling: - return generate_patches_predict( - self.input_array, self.axes, self.tile_size, self.tile_overlap + # generate patches, which returns a generator + patch_generator = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, ) + patches_list = list(patch_generator) + + if len(patches_list) == 0: + raise ValueError("No tiles generated, ") + + return patches_list else: - return self.input_array + array_shape = reshaped_sample.squeeze().shape + return [(reshaped_sample, TileInformation(array_shape=array_shape))] def __len__(self) -> int: """ @@ -349,7 +363,7 @@ def __len__(self) -> int: """ return len(self.data) - def __getitem__(self, index: int) -> Tuple[np.ndarray, Any, Any, Any, Any]: + def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]: """ Return the patch corresponding to the provided index. @@ -360,37 +374,18 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, Any, Any, Any, Any]: Returns ------- - Tuple[np.ndarray] - Patch. - - Raises - ------ - ValueError - If dataset mean and std are not set. + Tuple[np.ndarray, TileInformation] + Transformed patch. """ - if self.tiling: - ( - tile, - last_tile, - arr_shape, - overlap_crop_coords, - stitch_coords, - ) = self.data[index] + tile_array, tile_info = self.data[index] - # Albumentations requires Channel last - tile = np.moveaxis(tile, 0, -1) + # Albumentations requires channel last, use the XArrayTile array + patch = np.moveaxis(tile_array, 0, -1) - # Apply transforms - transformed_tile = self.patch_transform(image=tile)["image"] - tile = transformed_tile + # Apply transforms + transformed_patch = self.patch_transform(image=patch)["image"] - # move C axes back - tile = np.moveaxis(tile, -1, 0) - - return ( - tile, - last_tile, - arr_shape, - overlap_crop_coords, - stitch_coords, - ) # TODO can we wrap this into an object? + # move C axes back + transformed_patch = np.moveaxis(transformed_patch, -1, 0) + + return transformed_patch, tile_info diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index be8835b46..53df2e618 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -8,15 +8,14 @@ from torch.utils.data import IterableDataset, get_worker_info from ..config import DataModel, InferenceModel -from ..config.support import SupportedExtractionStrategy +from ..config.tile_information import TileInformation from ..utils.logging import get_logger -from .dataset_utils import read_tiff +from .dataset_utils import read_tiff, reshape_array from .patching import ( - generate_patches_predict, - generate_patches_supervised, - generate_patches_unsupervised, get_patch_transform, ) +from .patching.random_patching import extract_patches_random +from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) @@ -175,22 +174,18 @@ def __iter__( # iterate over files for sample_input, sample_target in self._iterate_over_files(): - if self.target_files is not None: - patches = generate_patches_supervised( - data=sample_input, - axes=self.data_config.axes, - patch_extraction_method=SupportedExtractionStrategy.RANDOM, - patch_size=self.data_config.patch_size, - target=sample_target, - ) + reshaped_sample = reshape_array(sample_input, self.data_config.axes) + reshaped_target = ( + None + if sample_target is None + else reshape_array(sample_target, self.data_config.axes) + ) - else: - patches = generate_patches_unsupervised( - data=sample_input, - axes=self.data_config.axes, - patch_extraction_method=SupportedExtractionStrategy.RANDOM, - patch_size=self.data_config.patch_size, - ) + patches = extract_patches_random( + arr=reshaped_sample, + patch_size=self.data_config.patch_size, + target=reshaped_target, + ) # iterate over patches # patches are tuples of (patch, target) if target is available @@ -368,12 +363,8 @@ def __init__( self.tile_overlap = self.prediction_config.tile_overlap self.read_source_func = read_source_func - # check that mean and std are provided - if not self.mean or not self.std: - raise ValueError( - "Mean and std must be provided to the configuration in order to " - " perform prediction." - ) + # tile only if both tile size and overlaps are provided + self.tile = self.tile_size is not None and self.tile_overlap is not None # get tta transforms self.patch_transform = get_patch_transform( @@ -383,7 +374,7 @@ def __init__( def __iter__( self, - ) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]: + ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: """ Iterate over data source and yield single patch. @@ -397,18 +388,29 @@ def __iter__( ), "Mean and std must be provided" for sample, _ in self._iterate_over_files(): - patches = generate_patches_predict( - data=sample, - axes=self.axes, - tile_size=self.tile_size, - tile_overlap=self.tile_overlap, - ) - # TODO AttributeError: 'IterablePredictionDataset' object has no attribute - # 'mean' message appears if the predict func is run more than once - - for patch_data in patches: - # Albumentations expects the channel dimension to be last - transformed = self.patch_transform( - image=np.moveaxis(patch_data[0], 0, -1) + # reshape array + reshaped_sample = reshape_array(sample, self.axes) + + if self.tile: + # generate patches, return a generator + patch_gen = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, ) - yield (np.moveaxis(transformed["image"], -1, 0), *patch_data[1:]) + else: + # just wrap the sample in a generator with default tiling info + array_shape = reshaped_sample.squeeze().shape + patch_gen = ( + (reshaped_sample, TileInformation(array_shape=array_shape)) + for _ in range(1) + ) + + # apply transform to patches + for patch_array, tile_info in patch_gen: + # albumentations expects the channel dimension to be last + patch = np.moveaxis(patch_array, 0, -1) + transformed_patch = self.patch_transform(image=patch) + transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0) + + yield transformed_patch, tile_info diff --git a/src/careamics/dataset/patching/__init__.py b/src/careamics/dataset/patching/__init__.py index f1ee387c0..c684789bf 100644 --- a/src/careamics/dataset/patching/__init__.py +++ b/src/careamics/dataset/patching/__init__.py @@ -2,16 +2,8 @@ __all__ = [ - "generate_patches_predict", - "generate_patches_supervised", - "generate_patches_unsupervised", "get_patch_transform", "get_patch_transform_predict", ] from .patch_transform import get_patch_transform, get_patch_transform_predict -from .patching import ( - generate_patches_predict, - generate_patches_supervised, - generate_patches_unsupervised, -) diff --git a/src/careamics/dataset/patching/patching.py b/src/careamics/dataset/patching/patching.py index 09dae8427..def9ce25e 100644 --- a/src/careamics/dataset/patching/patching.py +++ b/src/careamics/dataset/patching/patching.py @@ -4,29 +4,17 @@ These functions are used to tile images into patches or tiles. """ from pathlib import Path -from typing import Callable, Generator, List, Optional, Tuple, Union +from typing import Callable, List, Tuple, Union import numpy as np -import zarr -from ...config.support.supported_extraction_strategies import ( - SupportedExtractionStrategy, -) from ...utils.logging import get_logger from ..dataset_utils import reshape_array -from .random_patching import extract_patches_random, extract_patches_random_from_chunks from .sequential_patching import extract_patches_sequential -from .tiled_patching import extract_tiles logger = get_logger(__name__) -# TODO: several issues that require refactoring -# - some patching return array, others generator -# - in iterable and in memory, the reshaping happens at different moment -# - return type is not consistent (ndarray, ndarray or ndarray, None or just ndarray) - - # called by in memory dataset def prepare_patches_supervised( train_files: List[Path], @@ -222,191 +210,3 @@ def prepare_patches_unsupervised_array( patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size) return patches, _, mean, std # TODO inelegant, replace by dataclass? - - -# prediction, both in memory and iterable -def generate_patches_predict( - data: np.ndarray, - axes: str, - tile_size: Union[List[int], Tuple[int, ...]], - tile_overlap: Union[List[int], Tuple[int, ...]], -) -> List[Tuple[np.ndarray, bool, Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]]: - """ - Iterate over data source and create an array of patches. - - Returns - ------- - np.ndarray - Array of patches. - """ - # reshape array - reshaped_sample = reshape_array(data, axes) - # generate patches, return a generator - patches = extract_tiles( - arr=reshaped_sample, tile_size=tile_size, overlaps=tile_overlap - ) - patches_list = list(patches) # TODO: refactor to use generator ? - if len(patches_list) == 0: - raise ValueError("No tiles generated") - - return patches_list - -# iterator over files -def generate_patches_supervised( - data: Union[np.ndarray, zarr.Array], - axes: str, - patch_extraction_method: SupportedExtractionStrategy, - patch_size: Union[List[int], Tuple[int, ...]], - patch_overlap: Optional[Union[List[int], Tuple[int, ...]]] = None, - target: Optional[Union[np.ndarray, zarr.Array]] = None, -) -> Generator[np.ndarray, None, None]: - """ - Creates an iterator with patches and corresponding targets from a sample. - - Parameters - ---------- - sample : np.ndarray - Input array. - patch_extraction_method : ExtractionStrategies - Patch extraction method, as defined in extraction_strategy.ExtractionStrategy. - patch_size : Optional[Union[List[int], Tuple[int]]] - Size of the patches along each dimension of the array, except the first. - patch_overlap : Optional[Union[List[int], Tuple[int]]] - Overlap between patches. - - Returns - ------- - Generator[np.ndarray, None, None] - Generator yielding patches/tiles. - - Raises - ------ - ValueError - If overlap is not specified when using tiling. - ValueError - If patches is None. - """ - patches = None - targets = None - - # reshape target - reshaped_sample = reshape_array(data, axes) - reshaped_target = reshape_array(target, axes) - - if patch_size is not None: - patches = None - - if patch_extraction_method == SupportedExtractionStrategy.TILED: - if patch_overlap is None: - raise ValueError( - "Overlaps must be specified when using tiling (got None)." - ) - patches = extract_tiles( - arr=reshaped_sample, tile_size=patch_size, overlaps=patch_overlap - ) - - elif patch_extraction_method == SupportedExtractionStrategy.SEQUENTIAL: - patches, targets = extract_patches_sequential( - arr=reshaped_sample, patch_size=patch_size, target=reshaped_target - ) - - elif patch_extraction_method == SupportedExtractionStrategy.RANDOM: - # Returns a generator of patches and targets(if present) - patches = extract_patches_random( - arr=reshaped_sample, patch_size=patch_size, target=reshaped_target - ) - - elif patch_extraction_method == SupportedExtractionStrategy.RANDOM_ZARR: - # Returns a generator of patches and targets(if present) - patches = extract_patches_random_from_chunks( - reshaped_sample, - patch_size=patch_size, - chunk_size=reshaped_sample.chunks, - ) - - if patch_extraction_method == SupportedExtractionStrategy.SEQUENTIAL: - return patches, targets - else: - return patches - - else: - # no patching - return (reshaped_sample for _ in range(1)), reshaped_target - - -# iterator over files -def generate_patches_unsupervised( - data: Union[np.ndarray, zarr.Array], - axes: str, - patch_extraction_method: SupportedExtractionStrategy, - patch_size: Union[List[int], Tuple[int, ...]], - patch_overlap: Optional[Union[List[int], Tuple[int]]] = None, -) -> Generator[np.ndarray, None, None]: - """ - Creates an iterator over patches from a sample. - - Parameters - ---------- - sample : np.ndarray - Input array. - patch_extraction_method : SupportedExtractionStrategy - Patch extraction methods (see `config.support`). - patch_size : Optional[Union[List[int], Tuple[int]]] - Size of the patches along each dimension of the array, except the first. - patch_overlap : Optional[Union[List[int], Tuple[int]]] - Overlap between patches. - - Returns - ------- - Generator[np.ndarray, None, None] - Generator yielding patches/tiles. - - Raises - ------ - ValueError - If overlap is not specified when using tiling. - ValueError - If patches is None. - """ - # reshape array - reshaped_sample = reshape_array(data, axes) - - # if tiled (patches with overlaps) - if patch_extraction_method == SupportedExtractionStrategy.TILED: - if patch_overlap is None: - patch_overlap = [48] * len(patch_size) # TODO pass overlap instead - - # return a Generator of the following: - # - patch: np.ndarray, dims C(Z)YX - # - last_tile: bool - # - shape: Tuple[int], shape of a tile, excluding the S dimension - # - overlap_crop_coords: coordinates used to crop the patch during stitching - # - stitch_coords: coordinates used to stitch the tiles back to the full image - patches = extract_tiles( - arr=reshaped_sample, tile_size=patch_size, overlaps=patch_overlap - ) - - # random extraction - elif patch_extraction_method == SupportedExtractionStrategy.RANDOM: - # return a Generator that yields the following: - # - patch: np.ndarray, dimension C(Z)YX - # - target_patch: np.ndarray, dimension C(Z)YX, or None - patches = extract_patches_random(reshaped_sample, patch_size=patch_size) - - # zarr specific random extraction - elif patch_extraction_method == SupportedExtractionStrategy.RANDOM_ZARR: - # # Returns a generator of patches and targets(if present) - # patches = extract_patches_random_from_chunks( - # sample, patch_size=patch_size, chunk_size=sample.chunks - # ) - raise NotImplementedError("Random zarr extraction not implemented yet.") - - # no patching, return sample - elif patch_extraction_method == SupportedExtractionStrategy.NONE: - patches = (reshaped_sample for _ in range(1)) - - # no extraction method - else: - raise ValueError("Invalid patch extraction method.") - - return patches diff --git a/src/careamics/dataset/patching/tiled_patching.py b/src/careamics/dataset/patching/tiled_patching.py index c7c4f69e5..bd618daeb 100644 --- a/src/careamics/dataset/patching/tiled_patching.py +++ b/src/careamics/dataset/patching/tiled_patching.py @@ -3,10 +3,12 @@ import numpy as np +from careamics.config.tile_information import TileInformation + def _compute_crop_and_stitch_coords_1d( axis_size: int, tile_size: int, overlap: int -) -> Tuple[List[Tuple[int, int]], ...]: +) -> Tuple[List[Tuple[int, ...]], ...]: """ Compute the coordinates of each tile along an axis, given the overlap. @@ -21,7 +23,7 @@ def _compute_crop_and_stitch_coords_1d( Returns ------- - Tuple[Tuple[int]] + Tuple[Tuple[int, ...], ...] Tuple of all coordinates for given axis. """ # Compute the step between tiles @@ -75,24 +77,16 @@ def extract_tiles( arr: np.ndarray, tile_size: Union[List[int], Tuple[int, ...]], overlaps: Union[List[int], Tuple[int, ...]], -) -> Generator[ - Tuple[np.ndarray, bool, Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]], - None, - None, -]: +) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: """ Generate tiles from the input array with specified overlap. - The tiles cover the whole array. The method returns a generator that yields a tuple - containing the following: + The tiles cover the whole array. The method returns a generator that yields + tuples of array and tile information, the latter includes whether + the tile is the last one, the coordinates of the overlap crop, and the coordinates + of the stitched tile. - - tile: np.ndarray, shape (C, (Z), Y, X). - - last_tile: bool, whether this is the last tile. - - shape: Tuple[int], shape of a tile, excluding the S dimension. - - overlap_crop_coords: Tuple[int], coordinates used to crop the tile during - stitching. - - stitch_coords: Tuple[int], coordinates used to stitch the tiles back to the full - image. + The array has shape C(Z)YX, where C can be a singleton. Parameters ---------- @@ -105,9 +99,8 @@ def extract_tiles( Yields ------ - Generator[Tuple[np.ndarray, bool, Tuple[int], np.ndarray, np.ndarray], None, None] - Generator of tuple containing the tile, last tile boolean, array shape, - overlap_crop_coords, and stitching coords. + Generator[Tuple[np.ndarray, TileInformation], None, None] + Tile generator, yields the tile and additional information. """ # Iterate over num samples (S) for sample_idx in range(arr.shape[0]): @@ -153,10 +146,13 @@ def extract_tiles( else: last_tile = False - yield ( - tile.astype(np.float32), - last_tile, - arr.shape[1:], # TODO is this used anywhere?? - overlap_crop_coords, - stitch_coords, + # create tile information + tile_info = TileInformation( + array_shape=sample.squeeze().shape, + tiled=True, + last_tile=last_tile, + overlap_crop_coords=overlap_crop_coords, + stitch_coords=stitch_coords, ) + + yield tile, tile_info diff --git a/src/careamics/dataset/zarr_dataset.py b/src/careamics/dataset/zarr_dataset.py index 751d8bdd5..ee54fdd26 100644 --- a/src/careamics/dataset/zarr_dataset.py +++ b/src/careamics/dataset/zarr_dataset.py @@ -1,150 +1,149 @@ -from itertools import islice -from typing import Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import zarr - -from careamics.utils import RunningStats -from careamics.utils.logging import get_logger - -from ..config.support.supported_extraction_strategies import SupportedExtractionStrategy -from ..utils import normalize -from .dataset_utils import read_zarr -from .patching.patching import ( - generate_patches_unsupervised, -) - -logger = get_logger(__name__) - - -class ZarrDataset(torch.utils.data.IterableDataset): - """Dataset to extract patches from a zarr storage. - - Parameters - ---------- - data_source : Union[zarr.Group, zarr.Array] - Zarr storage. - axes : str - Description of axes in format STCZYX. - patch_extraction_method : Union[ExtractionStrategies, None] - Patch extraction strategy, as defined in extraction_strategy. - patch_size : Optional[Union[List[int], Tuple[int]]], optional - Size of the patches in each dimension, by default None. - num_patches : Optional[int], optional - Number of patches to extract, by default None. - mean : Optional[float], optional - Expected mean of the dataset, by default None. - std : Optional[float], optional - Expected standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform callable, by default None. - patch_transform_params : Optional[Dict], optional - Patch transform parameters, by default None. - running_stats_window_perc : float, optional - Percentage of the dataset to use for calculating the initial mean and standard - deviation, by default 0.01. - mode : str, optional - train/predict, controls running stats calculation. - """ - - def __init__( - self, - data_source: Union[zarr.Group, zarr.Array], - axes: str, - patch_extraction_method: Union[SupportedExtractionStrategy, None], - patch_size: Optional[Union[List[int], Tuple[int]]] = None, - num_patches: Optional[int] = None, - mean: Optional[float] = None, - std: Optional[float] = None, - patch_transform: Optional[Callable] = None, - patch_transform_params: Optional[Dict] = None, - running_stats_window_perc: float = 0.01, - mode: str = "train", - ) -> None: - self.data_source = data_source - self.axes = axes - self.patch_extraction_method = patch_extraction_method - self.patch_size = patch_size - self.num_patches = num_patches - self.mean = mean - self.std = std - self.patch_transform = patch_transform - self.patch_transform_params = patch_transform_params - self.sample = read_zarr(self.data_source, self.axes) - self.running_stats_window = int( - np.prod(self.sample._cdata_shape) * running_stats_window_perc - ) - self.mode = mode - self.running_stats = RunningStats() - - self._calculate_initial_mean_std() - - def _calculate_initial_mean_std(self): - """Calculate initial mean and std of the dataset.""" - if self.mean is None and self.std is None: - idxs = np.random.randint( - 0, - np.prod(self.sample._cdata_shape), - size=max(1, self.running_stats_window), - ) - random_chunks = self.sample[idxs] - self.running_stats.init(random_chunks.mean(), random_chunks.std()) - - def _generate_patches(self): - """Generate patches from the dataset and calculates running stats. - - Yields - ------ - np.ndarray - Patch. - """ - patches = generate_patches_unsupervised( - self.sample, - self.patch_extraction_method, - self.patch_size, - ) - - # num_patches = np.ceil( - # np.prod(self.sample.chunks) - # / (np.prod(self.patch_size) * self.running_stats_window) - # ).astype(int) - - for idx, patch in enumerate(patches): - if self.mode != "predict": - self.running_stats.update(patch.mean()) - if isinstance(patch, tuple): - normalized_patch = normalize( - img=patch[0], - mean=self.running_stats.avg_mean.value, - std=self.running_stats.avg_std.value, - ) - patch = (normalized_patch, *patch[1:]) - else: - patch = normalize( - img=patch, - mean=self.running_stats.avg_mean.value, - std=self.running_stats.avg_std.value, - ) - - if self.patch_transform is not None: - assert self.patch_transform_params is not None - patch = self.patch_transform(patch, **self.patch_transform_params) - if self.num_patches is not None and idx >= self.num_patches: - return - else: - yield patch - self.mean = self.running_stats.avg_mean.value - self.std = self.running_stats.avg_std.value - - def __iter__(self): - """ - Iterate over data source and yield single patch. - - Yields - ------ - np.ndarray - """ - worker_info = torch.utils.data.get_worker_info() - num_workers = worker_info.num_workers if worker_info is not None else 1 - yield from islice(self._generate_patches(), 0, None, num_workers) +# from itertools import islice +# from typing import Callable, Dict, List, Optional, Tuple, Union + +# import numpy as np +# import torch +# import zarr + +# from careamics.utils import RunningStats +# from careamics.utils.logging import get_logger + +# from ..utils import normalize +# from .dataset_utils import read_zarr +# from .patching.patching import ( +# generate_patches_unsupervised, +# ) + +# logger = get_logger(__name__) + + +# class ZarrDataset(torch.utils.data.IterableDataset): +# """Dataset to extract patches from a zarr storage. + +# Parameters +# ---------- +# data_source : Union[zarr.Group, zarr.Array] +# Zarr storage. +# axes : str +# Description of axes in format STCZYX. +# patch_extraction_method : Union[ExtractionStrategies, None] +# Patch extraction strategy, as defined in extraction_strategy. +# patch_size : Optional[Union[List[int], Tuple[int]]], optional +# Size of the patches in each dimension, by default None. +# num_patches : Optional[int], optional +# Number of patches to extract, by default None. +# mean : Optional[float], optional +# Expected mean of the dataset, by default None. +# std : Optional[float], optional +# Expected standard deviation of the dataset, by default None. +# patch_transform : Optional[Callable], optional +# Patch transform callable, by default None. +# patch_transform_params : Optional[Dict], optional +# Patch transform parameters, by default None. +# running_stats_window_perc : float, optional +# Percentage of the dataset to use for calculating the initial mean and standard +# deviation, by default 0.01. +# mode : str, optional +# train/predict, controls running stats calculation. +# """ + +# def __init__( +# self, +# data_source: Union[zarr.Group, zarr.Array], +# axes: str, +# patch_extraction_method: Union[SupportedExtractionStrategy, None], +# patch_size: Optional[Union[List[int], Tuple[int]]] = None, +# num_patches: Optional[int] = None, +# mean: Optional[float] = None, +# std: Optional[float] = None, +# patch_transform: Optional[Callable] = None, +# patch_transform_params: Optional[Dict] = None, +# running_stats_window_perc: float = 0.01, +# mode: str = "train", +# ) -> None: +# self.data_source = data_source +# self.axes = axes +# self.patch_extraction_method = patch_extraction_method +# self.patch_size = patch_size +# self.num_patches = num_patches +# self.mean = mean +# self.std = std +# self.patch_transform = patch_transform +# self.patch_transform_params = patch_transform_params +# self.sample = read_zarr(self.data_source, self.axes) +# self.running_stats_window = int( +# np.prod(self.sample._cdata_shape) * running_stats_window_perc +# ) +# self.mode = mode +# self.running_stats = RunningStats() + +# self._calculate_initial_mean_std() + +# def _calculate_initial_mean_std(self): +# """Calculate initial mean and std of the dataset.""" +# if self.mean is None and self.std is None: +# idxs = np.random.randint( +# 0, +# np.prod(self.sample._cdata_shape), +# size=max(1, self.running_stats_window), +# ) +# random_chunks = self.sample[idxs] +# self.running_stats.init(random_chunks.mean(), random_chunks.std()) + +# def _generate_patches(self): +# """Generate patches from the dataset and calculates running stats. + +# Yields +# ------ +# np.ndarray +# Patch. +# """ +# patches = generate_patches_unsupervised( +# self.sample, +# self.patch_extraction_method, +# self.patch_size, +# ) + +# # num_patches = np.ceil( +# # np.prod(self.sample.chunks) +# # / (np.prod(self.patch_size) * self.running_stats_window) +# # ).astype(int) + +# for idx, patch in enumerate(patches): +# if self.mode != "predict": +# self.running_stats.update(patch.mean()) +# if isinstance(patch, tuple): +# normalized_patch = normalize( +# img=patch[0], +# mean=self.running_stats.avg_mean.value, +# std=self.running_stats.avg_std.value, +# ) +# patch = (normalized_patch, *patch[1:]) +# else: +# patch = normalize( +# img=patch, +# mean=self.running_stats.avg_mean.value, +# std=self.running_stats.avg_std.value, +# ) + +# if self.patch_transform is not None: +# assert self.patch_transform_params is not None +# patch = self.patch_transform(patch, **self.patch_transform_params) +# if self.num_patches is not None and idx >= self.num_patches: +# return +# else: +# yield patch +# self.mean = self.running_stats.avg_mean.value +# self.std = self.running_stats.avg_std.value + +# def __iter__(self): +# """ +# Iterate over data source and yield single patch. + +# Yields +# ------ +# np.ndarray +# """ +# worker_info = torch.utils.data.get_worker_info() +# num_workers = worker_info.num_workers if worker_info is not None else 1 +# yield from islice(self._generate_patches(), 0, None, num_workers) diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py index eded7fed4..cc5425ecb 100644 --- a/src/careamics/lightning_datamodule.py +++ b/src/careamics/lightning_datamodule.py @@ -6,7 +6,7 @@ from albumentations import Compose from torch.utils.data import DataLoader -from careamics.config import DataModel, InferenceModel +from careamics.config import DataModel from careamics.config.data_model import TRANSFORMS_UNION from careamics.config.support import SupportedData from careamics.dataset.dataset_utils import ( @@ -17,16 +17,13 @@ ) from careamics.dataset.in_memory_dataset import ( InMemoryDataset, - InMemoryPredictionDataset, ) from careamics.dataset.iterable_dataset import ( - IterablePredictionDataset, PathIterableDataset, ) from careamics.utils import get_logger, get_ram_size DatasetType = Union[InMemoryDataset, PathIterableDataset] -PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset] logger = get_logger(__name__) @@ -189,15 +186,14 @@ def prepare_data(self) -> None: Hook used to prepare the data before calling `setup`. Here, we only need to examine the data if it was provided as a str or a Path. - - TODO: from lightning doc: + + TODO: from lightning doc: prepare_data is called from the main process. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you - assign states here then they won’t be available for other processes. + assign states here then they won't be available for other processes. https://lightning.ai/docs/pytorch/stable/data/datamodule.html """ - # if the data is a Path or a str if ( not isinstance(self.train_data, np.ndarray) @@ -357,181 +353,6 @@ def val_dataloader(self) -> Any: ) -class CAREamicsClay(L.LightningDataModule): - """ - LightningDataModule for prediction dataset. - - The data module can be used with Path, str or numpy arrays. The data can be either - a folder containing images or a single file. - - To read custom data types, you can set `data_type` to `custom` in `data_config` - and provide a function that returns a numpy array from a path as - `read_source_func` parameter. The function will receive a Path object and - an axies string as arguments, the axes being derived from the `data_config`. - - You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. - "*.czi") to filter the files extension using `extension_filter`. - - Parameters - ---------- - prediction_config : InferenceModel - Pydantic model for CAREamics prediction configuration. - pred_data : Union[Path, str, np.ndarray] - Prediction data, can be a path to a folder, a file or a numpy array. - read_source_func : Optional[Callable], optional - Function to read custom types, by default None. - extension_filter : str, optional - Filter to filter file extensions for custom types, by default "". - dataloader_params : dict, optional - Dataloader parameters, by default {}. - """ - - def __init__( - self, - prediction_config: InferenceModel, - pred_data: Union[Path, str, np.ndarray], - read_source_func: Optional[Callable] = None, - extension_filter: str = "", - dataloader_params: Optional[dict] = None, - ) -> None: - """ - Constructor. - - The data module can be used with Path, str or numpy arrays. The data can be - either a folder containing images or a single file. - - To read custom data types, you can set `data_type` to `custom` in `data_config` - and provide a function that returns a numpy array from a path as - `read_source_func` parameter. The function will receive a Path object and - an axies string as arguments, the axes being derived from the `data_config`. - - You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. - "*.czi") to filter the files extension using `extension_filter`. - - Parameters - ---------- - prediction_config : InferenceModel - Pydantic model for CAREamics prediction configuration. - pred_data : Union[Path, str, np.ndarray] - Prediction data, can be a path to a folder, a file or a numpy array. - read_source_func : Optional[Callable], optional - Function to read custom types, by default None. - extension_filter : str, optional - Filter to filter file extensions for custom types, by default "". - dataloader_params : dict, optional - Dataloader parameters, by default {}. - - Raises - ------ - ValueError - If the data type is `custom` and no `read_source_func` is provided. - ValueError - If the data type is `array` and the input is not a numpy array. - ValueError - If the data type is `tiff` and the input is neither a Path nor a str. - """ - if dataloader_params is None: - dataloader_params = {} - if dataloader_params is None: - dataloader_params = {} - super().__init__() - - # check that a read source function is provided for custom types - if ( - prediction_config.data_type == SupportedData.CUSTOM - and read_source_func is None - ): - raise ValueError( - f"Data type {SupportedData.CUSTOM} is not allowed without " - f"specifying a `read_source_func`." - ) - - # and that arrays are passed, if array type specified - elif prediction_config.data_type == SupportedData.ARRAY and not isinstance( - pred_data, np.ndarray - ): - raise ValueError( - f"Expected array input (see configuration.data.data_type), but got " - f"{type(pred_data)} instead." - ) - - # and that Path or str are passed, if tiff file type specified - elif prediction_config.data_type == SupportedData.TIFF and not ( - isinstance(pred_data, Path) or isinstance(pred_data, str) - ): - raise ValueError( - f"Expected Path or str input (see configuration.data.data_type), " - f"but got {type(pred_data)} instead." - ) - - # configuration data - self.prediction_config = prediction_config - self.data_type = prediction_config.data_type - self.batch_size = prediction_config.batch_size - self.dataloader_params = dataloader_params - - self.pred_data = pred_data - self.tile_size = prediction_config.tile_size - self.tile_overlap = prediction_config.tile_overlap - - # read source function - if prediction_config.data_type == SupportedData.CUSTOM: - # mypy check - assert read_source_func is not None - - self.read_source_func: Callable = read_source_func - elif prediction_config.data_type != SupportedData.ARRAY: - self.read_source_func = get_read_func(prediction_config.data_type) - - self.extension_filter = extension_filter - - def prepare_data(self) -> None: - """Hook used to prepare the data before calling `setup`.""" - # if the data is a Path or a str - if not isinstance(self.pred_data, np.ndarray): - self.pred_files = list_files( - self.pred_data, self.data_type, self.extension_filter - ) - - def setup(self, stage: Optional[str] = None) -> None: - """ - Hook called at the beginning of predict. - - Parameters - ---------- - stage : Optional[str], optional - Stage, by default None. - """ - # if numpy array - if self.data_type == SupportedData.ARRAY: - # prediction dataset - self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset( - prediction_config=self.prediction_config, - inputs=self.pred_data, - ) - else: - self.predict_dataset = IterablePredictionDataset( - prediction_config=self.prediction_config, - src_files=self.pred_files, - read_source_func=self.read_source_func, - ) - - def predict_dataloader(self) -> DataLoader: - """ - Create a dataloader for prediction. - - Returns - ------- - DataLoader - Prediction dataloader. - """ - return DataLoader( - self.predict_dataset, - batch_size=self.batch_size, - **self.dataloader_params, - ) # TODO check workers are used - - class CAREamicsTrainDataModule(CAREamicsWood): """ LightningDataModule wrapper for training and validation datasets. @@ -839,155 +660,3 @@ def __init__( val_minimum_split=val_minimum_patches, use_in_memory=use_in_memory, ) - - -class CAREamicsPredictDataModule(CAREamicsClay): - """ - LightningDataModule wrapper of an inference dataset. - - Since the lightning datamodule has no access to the model, make sure that the - parameters passed to the datamodule are consistent with the model's requirements - and are coherent. - - The data module can be used with Path, str or numpy arrays. To use array data, set - `data_type` to `array` and pass a numpy array to `train_data`. - - The default transformations applied to the images are defined in - `careamics.config.inference_model`. To use different transformations, pass a list - of transforms or an albumentation `Compose` as `transforms` parameter. See examples - for more details. - - The `mean` and `std` parameters are only used if Normalization is defined either - in the default transformations or in the `transforms` parameter, but not with - a `Compose` object. If you pass a `Normalization` transform in a list as - `transforms`, then the mean and std parameters will be overwritten by those passed - to this method. - - By default, CAREamics only supports types defined in - `careamics.config.support.SupportedData`. To read custom data types, you can set - `data_type` to `custom` and provide a function that returns a numpy array from a - path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression - (e.g. "*.jpeg") to filter the files extension using `extension_filter`. - - In `dataloader_params`, you can pass any parameter accepted by PyTorch - dataloaders, except for `batch_size`, which is set by the `batch_size` - parameter. - - Parameters - ---------- - pred_data : Union[str, Path, np.ndarray] - Prediction data. - data_type : Union[Literal["array", "tiff", "custom"], SupportedData] - Data type, see `SupportedData` for available options. - tile_size : List[int] - Tile size, 2D or 3D tile size. - tile_overlap : List[int] - Tile overlap, 2D or 3D tile overlap. - axes : str - Axes of the data, choosen amongst SCZYX. - batch_size : int - Batch size. - tta_transforms : bool, optional - Use test time augmentation, by default True. - mean : Optional[float], optional - Mean value for normalization, only used if Normalization is defined, by - default None. - std : Optional[float], optional - Standard deviation value for normalization, only used if Normalization is - defined, by default None. - transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional - List of transforms to apply to prediction patches. If None, default - transforms are applied. - read_source_func : Optional[Callable], optional - Function to read the source data, used if `data_type` is `custom`, by - default None. - extension_filter : str, optional - Filter for file extensions, used if `data_type` is `custom`, by default "". - dataloader_params : dict, optional - Pytorch dataloader parameters, by default {}. - """ - - def __init__( - self, - pred_data: Union[str, Path, np.ndarray], - data_type: Union[Literal["array", "tiff", "custom"], SupportedData], - tile_size: List[int], - tile_overlap: List[int] = (48, 48), # TODO replace with calculator - axes: str = "YX", - batch_size: int = 1, - tta_transforms: bool = True, - mean: Optional[float] = None, - std: Optional[float] = None, - transforms: Optional[Union[List, Compose]] = None, - read_source_func: Optional[Callable] = None, - extension_filter: str = "", - dataloader_params: Optional[dict] = None, - ) -> None: - """ - Constructor. - - Parameters - ---------- - pred_data : Union[str, Path, np.ndarray] - Prediction data. - data_type : Union[Literal["array", "tiff", "custom"], SupportedData] - Data type, see `SupportedData` for available options. - tile_size : List[int] - Tile size, 2D or 3D tile size. - tile_overlap : List[int] - Tile overlap, 2D or 3D tile overlap. - axes : str - Axes of the data, choosen amongst SCZYX. - batch_size : int - Batch size. - tta_transforms : bool, optional - Use test time augmentation, by default True. - mean : Optional[float], optional - Mean value for normalization, only used if Normalization is defined, by - default None. - std : Optional[float], optional - Standard deviation value for normalization, only used if Normalization is - defined, by default None. - transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional - List of transforms to apply to prediction patches. If None, default - transforms are applied. - read_source_func : Optional[Callable], optional - Function to read the source data, used if `data_type` is `custom`, by - default None. - extension_filter : str, optional - Filter for file extensions, used if `data_type` is `custom`, by default "". - dataloader_params : dict, optional - Pytorch dataloader parameters, by default {}. - """ - if dataloader_params is None: - dataloader_params = {} - prediction_dict = { - "data_type": data_type, - "tile_size": tile_size, - "tile_overlap": tile_overlap, - "axes": axes, - "mean": mean, - "std": std, - "tta": tta_transforms, - "batch_size": batch_size, - } - - # if transforms are passed (otherwise it will use the default ones) - if transforms is not None: - prediction_dict["transforms"] = transforms - - # validate configuration - self.prediction_config = InferenceModel(**prediction_dict) - - # sanity check on the dataloader parameters - if "batch_size" in dataloader_params: - # remove it - del dataloader_params["batch_size"] - - super().__init__( - prediction_config=self.prediction_config, - pred_data=pred_data, - read_source_func=read_source_func, - extension_filter=extension_filter, - dataloader_params=dataloader_params, - ) diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py index 26df7b6fb..35cbd9417 100644 --- a/src/careamics/lightning_module.py +++ b/src/careamics/lightning_module.py @@ -170,7 +170,11 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: self._trainer.datamodule.predict_dataset.mean, self._trainer.datamodule.predict_dataset.std, ) - return denormalized_output, aux + + if len(aux) > 0: + return denormalized_output, aux + else: + return denormalized_output def configure_optimizers(self) -> Any: """Configure optimizers and learning rate schedulers. diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py new file mode 100644 index 000000000..654d44fdb --- /dev/null +++ b/src/careamics/lightning_prediction_datamodule.py @@ -0,0 +1,390 @@ +from pathlib import Path +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +import numpy as np +import pytorch_lightning as L +from albumentations import Compose +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate + +from careamics.config import InferenceModel +from careamics.config.support import SupportedData +from careamics.config.tile_information import TileInformation +from careamics.dataset.dataset_utils import ( + get_read_func, + list_files, +) +from careamics.dataset.in_memory_dataset import ( + InMemoryPredictionDataset, +) +from careamics.dataset.iterable_dataset import ( + IterablePredictionDataset, +) +from careamics.utils import get_logger + +PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset] + +logger = get_logger(__name__) + + +def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: + """ + Collate tiles received from CAREamics prediction dataloader. + + CAREamics prediction dataloader returns tuples of arrays and TileInformation. In + case of non-tiled data, this function will return the arrays. In case of tiled data, + it will return the arrays, the last tile flag, the overlap crop coordinates and the + stitch coordinates. + + Parameters + ---------- + batch : Tuple[Tuple[np.ndarray, TileInformation], ...] + Batch of tiles. + + Returns + ------- + Any + Collated batch. + """ + first_tile_info: TileInformation = batch[0][1] + # if not tiled, then return arrays + if not first_tile_info.tiled: + arrays, _ = zip(*batch) + + return default_collate(arrays) + # else we explicit the last_tile flag and coordinates + else: + new_batch = [ + (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords) + for tile, t in batch + ] + + return default_collate(new_batch) + + +class CAREamicsClay(L.LightningDataModule): + """ + LightningDataModule for prediction dataset. + + The data module can be used with Path, str or numpy arrays. The data can be either + a folder containing images or a single file. + + To read custom data types, you can set `data_type` to `custom` in `data_config` + and provide a function that returns a numpy array from a path as + `read_source_func` parameter. The function will receive a Path object and + an axies string as arguments, the axes being derived from the `data_config`. + + You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. + "*.czi") to filter the files extension using `extension_filter`. + + Parameters + ---------- + prediction_config : InferenceModel + Pydantic model for CAREamics prediction configuration. + pred_data : Union[Path, str, np.ndarray] + Prediction data, can be a path to a folder, a file or a numpy array. + read_source_func : Optional[Callable], optional + Function to read custom types, by default None. + extension_filter : str, optional + Filter to filter file extensions for custom types, by default "". + dataloader_params : dict, optional + Dataloader parameters, by default {}. + """ + + def __init__( + self, + prediction_config: InferenceModel, + pred_data: Union[Path, str, np.ndarray], + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + dataloader_params: Optional[dict] = None, + ) -> None: + """ + Constructor. + + The data module can be used with Path, str or numpy arrays. The data can be + either a folder containing images or a single file. + + To read custom data types, you can set `data_type` to `custom` in `data_config` + and provide a function that returns a numpy array from a path as + `read_source_func` parameter. The function will receive a Path object and + an axies string as arguments, the axes being derived from the `data_config`. + + You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. + "*.czi") to filter the files extension using `extension_filter`. + + Parameters + ---------- + prediction_config : InferenceModel + Pydantic model for CAREamics prediction configuration. + pred_data : Union[Path, str, np.ndarray] + Prediction data, can be a path to a folder, a file or a numpy array. + read_source_func : Optional[Callable], optional + Function to read custom types, by default None. + extension_filter : str, optional + Filter to filter file extensions for custom types, by default "". + dataloader_params : dict, optional + Dataloader parameters, by default {}. + + Raises + ------ + ValueError + If the data type is `custom` and no `read_source_func` is provided. + ValueError + If the data type is `array` and the input is not a numpy array. + ValueError + If the data type is `tiff` and the input is neither a Path nor a str. + """ + if dataloader_params is None: + dataloader_params = {} + if dataloader_params is None: + dataloader_params = {} + super().__init__() + + # check that a read source function is provided for custom types + if ( + prediction_config.data_type == SupportedData.CUSTOM + and read_source_func is None + ): + raise ValueError( + f"Data type {SupportedData.CUSTOM} is not allowed without " + f"specifying a `read_source_func`." + ) + + # and that arrays are passed, if array type specified + elif prediction_config.data_type == SupportedData.ARRAY and not isinstance( + pred_data, np.ndarray + ): + raise ValueError( + f"Expected array input (see configuration.data.data_type), but got " + f"{type(pred_data)} instead." + ) + + # and that Path or str are passed, if tiff file type specified + elif prediction_config.data_type == SupportedData.TIFF and not ( + isinstance(pred_data, Path) or isinstance(pred_data, str) + ): + raise ValueError( + f"Expected Path or str input (see configuration.data.data_type), " + f"but got {type(pred_data)} instead." + ) + + # configuration data + self.prediction_config = prediction_config + self.data_type = prediction_config.data_type + self.batch_size = prediction_config.batch_size + self.dataloader_params = dataloader_params + + self.pred_data = pred_data + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + + # read source function + if prediction_config.data_type == SupportedData.CUSTOM: + # mypy check + assert read_source_func is not None + + self.read_source_func: Callable = read_source_func + elif prediction_config.data_type != SupportedData.ARRAY: + self.read_source_func = get_read_func(prediction_config.data_type) + + self.extension_filter = extension_filter + + def prepare_data(self) -> None: + """Hook used to prepare the data before calling `setup`.""" + # if the data is a Path or a str + if not isinstance(self.pred_data, np.ndarray): + self.pred_files = list_files( + self.pred_data, self.data_type, self.extension_filter + ) + + def setup(self, stage: Optional[str] = None) -> None: + """ + Hook called at the beginning of predict. + + Parameters + ---------- + stage : Optional[str], optional + Stage, by default None. + """ + # if numpy array + if self.data_type == SupportedData.ARRAY: + # prediction dataset + self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset( + prediction_config=self.prediction_config, + inputs=self.pred_data, + ) + else: + self.predict_dataset = IterablePredictionDataset( + prediction_config=self.prediction_config, + src_files=self.pred_files, + read_source_func=self.read_source_func, + ) + + def predict_dataloader(self) -> DataLoader: + """ + Create a dataloader for prediction. + + Returns + ------- + DataLoader + Prediction dataloader. + """ + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + collate_fn=_collate_tiles, + **self.dataloader_params, + ) # TODO check workers are used + + +class CAREamicsPredictDataModule(CAREamicsClay): + """ + LightningDataModule wrapper of an inference dataset. + + Since the lightning datamodule has no access to the model, make sure that the + parameters passed to the datamodule are consistent with the model's requirements + and are coherent. + + The data module can be used with Path, str or numpy arrays. To use array data, set + `data_type` to `array` and pass a numpy array to `train_data`. + + The default transformations applied to the images are defined in + `careamics.config.inference_model`. To use different transformations, pass a list + of transforms or an albumentation `Compose` as `transforms` parameter. See examples + for more details. + + The `mean` and `std` parameters are only used if Normalization is defined either + in the default transformations or in the `transforms` parameter, but not with + a `Compose` object. If you pass a `Normalization` transform in a list as + `transforms`, then the mean and std parameters will be overwritten by those passed + to this method. + + By default, CAREamics only supports types defined in + `careamics.config.support.SupportedData`. To read custom data types, you can set + `data_type` to `custom` and provide a function that returns a numpy array from a + path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression + (e.g. "*.jpeg") to filter the files extension using `extension_filter`. + + In `dataloader_params`, you can pass any parameter accepted by PyTorch + dataloaders, except for `batch_size`, which is set by the `batch_size` + parameter. + + Parameters + ---------- + pred_data : Union[str, Path, np.ndarray] + Prediction data. + data_type : Union[Literal["array", "tiff", "custom"], SupportedData] + Data type, see `SupportedData` for available options. + mean : float + Mean value for normalization, only used if Normalization is defined in the + transforms. + std : float + Standard deviation value for normalization, only used if Normalization is + defined in the transform. + tile_size : Tuple[int, ...] + Tile size, 2D or 3D tile size. + tile_overlap : Tuple[int, ...] + Tile overlap, 2D or 3D tile overlap. + axes : str + Axes of the data, choosen amongst SCZYX. + batch_size : int + Batch size. + tta_transforms : bool, optional + Use test time augmentation, by default True. + transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional + List of transforms to apply to prediction patches. If None, default + transforms are applied. + read_source_func : Optional[Callable], optional + Function to read the source data, used if `data_type` is `custom`, by + default None. + extension_filter : str, optional + Filter for file extensions, used if `data_type` is `custom`, by default "". + dataloader_params : dict, optional + Pytorch dataloader parameters, by default {}. + """ + + def __init__( + self, + pred_data: Union[str, Path, np.ndarray], + data_type: Union[Literal["array", "tiff", "custom"], SupportedData], + mean: float, + std: float, + tile_size: Optional[Tuple[int, ...]] = None, + tile_overlap: Optional[Tuple[int, ...]] = None, + axes: str = "YX", + batch_size: int = 1, + tta_transforms: bool = True, + transforms: Optional[Union[List, Compose]] = None, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + dataloader_params: Optional[dict] = None, + ) -> None: + """ + Constructor. + + Parameters + ---------- + pred_data : Union[str, Path, np.ndarray] + Prediction data. + data_type : Union[Literal["array", "tiff", "custom"], SupportedData] + Data type, see `SupportedData` for available options. + tile_size : List[int] + Tile size, 2D or 3D tile size. + tile_overlap : List[int] + Tile overlap, 2D or 3D tile overlap. + axes : str + Axes of the data, choosen amongst SCZYX. + batch_size : int + Batch size. + tta_transforms : bool, optional + Use test time augmentation, by default True. + mean : Optional[float], optional + Mean value for normalization, only used if Normalization is defined, by + default None. + std : Optional[float], optional + Standard deviation value for normalization, only used if Normalization is + defined, by default None. + transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional + List of transforms to apply to prediction patches. If None, default + transforms are applied. + read_source_func : Optional[Callable], optional + Function to read the source data, used if `data_type` is `custom`, by + default None. + extension_filter : str, optional + Filter for file extensions, used if `data_type` is `custom`, by default "". + dataloader_params : dict, optional + Pytorch dataloader parameters, by default {}. + """ + if dataloader_params is None: + dataloader_params = {} + prediction_dict = { + "data_type": data_type, + "tile_size": tile_size, + "tile_overlap": tile_overlap, + "axes": axes, + "mean": mean, + "std": std, + "tta": tta_transforms, + "batch_size": batch_size, + } + + # if transforms are passed (otherwise it will use the default ones) + if transforms is not None: + prediction_dict["transforms"] = transforms + + # validate configuration + self.prediction_config = InferenceModel(**prediction_dict) + + # sanity check on the dataloader parameters + if "batch_size" in dataloader_params: + # remove it + del dataloader_params["batch_size"] + + super().__init__( + prediction_config=self.prediction_config, + pred_data=pred_data, + read_source_func=read_source_func, + extension_filter=extension_filter, + dataloader_params=dataloader_params, + ) diff --git a/src/careamics/lightning_prediction.py b/src/careamics/lightning_prediction_loop.py similarity index 54% rename from src/careamics/lightning_prediction.py rename to src/careamics/lightning_prediction_loop.py index 7040d7bf5..4a2f64f7b 100644 --- a/src/careamics/lightning_prediction.py +++ b/src/careamics/lightning_prediction_loop.py @@ -10,33 +10,44 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop): - """Predict loop for tiles-based prediction.""" + """ + CAREamics prediction loop. - # def _predict_step(self, batch, batch_idx, dataloader_idx, dataloader_iter): - # self.model.predict_step(batch, batch_idx) + This class extends the PyTorch Lightning `_PredictionLoop` class to include + the stitching of the tiles into a single prediction result. + """ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: - """Calls ``on_predict_epoch_end`` hook. + """ + Calls `on_predict_epoch_end` hook. + + Adapted from the parent method. Returns ------- the results for all dataloaders - """ trainer = self.trainer call._call_callback_hooks(trainer, "on_predict_epoch_end") call._call_lightning_module_hook(trainer, "on_predict_epoch_end") if self.return_predictions: + ######################################################## + ################ CAREamics specific code ############### if len(self.predicted_array) == 1: return self.predicted_array[0] else: - return self.predicted_array # TODO revisit logic + # TODO revisit logic + return self.predicted_array + ######################################################## return None @_no_grad_context def run(self) -> Optional[_PREDICT_OUTPUT]: - """Runs the prediction loop. + """ + Runs the prediction loop. + + Adapted from the parent method in order to stitch the predictions. Returns ------- @@ -72,17 +83,29 @@ def run(self) -> Optional[_PREDICT_OUTPUT]: # run step hooks self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter) - # Stitching tiles together - last_tile, *data = self.predictions[batch_idx][1] - self.tiles.append(self.predictions[batch_idx][0]) - self.stitching_data.append(data) - if any(last_tile): - predicted_batches = stitch_prediction( - self.tiles, self.stitching_data - ) - self.predicted_array.append(predicted_batches) - self.tiles.clear() - self.stitching_data.clear() + ######################################################## + ################ CAREamics specific code ############### + is_tiled = len(self.predictions[batch_idx]) == 2 + if is_tiled: + # extract the last tile flag and the coordinates (crop and stitch) + last_tile, *stitch_data = self.predictions[batch_idx][1] + + # append the tile and the coordinates to the lists + self.tiles.append(self.predictions[batch_idx][0]) + self.stitching_data.append(stitch_data) + + # if last tile, stitch the tiles and add array to the prediction + if any(last_tile): + predicted_batches = stitch_prediction( + self.tiles, self.stitching_data + ) + self.predicted_array.append(predicted_batches) + self.tiles.clear() + self.stitching_data.clear() + else: + # simply add the prediction to the list + self.predicted_array.append(self.predictions[batch_idx]) + ######################################################## except StopIteration: break finally: diff --git a/src/careamics/prediction/prediction_utils.py b/src/careamics/prediction/prediction_utils.py index 32d8255d8..c9aa751fc 100644 --- a/src/careamics/prediction/prediction_utils.py +++ b/src/careamics/prediction/prediction_utils.py @@ -3,7 +3,7 @@ These functions are used during prediction. """ -from typing import List, Optional +from typing import List import numpy as np import torch @@ -11,8 +11,7 @@ def stitch_prediction( tiles: List[torch.Tensor], - stitching_data: List, - explicit_stitching: Optional[bool] = False, + stitching_data: List[List[torch.Tensor]], ) -> torch.Tensor: """ Stitch tiles back together to form a full image. @@ -21,25 +20,26 @@ def stitch_prediction( ---------- tiles : List[torch.Tensor] Cropped tiles and their respective stitching coordinates. - stitching_data : List - List of coordinates obtained from + stitching_coords : List + List of information and coordinates obtained from `dataset.tiled_patching.extract_tiles`. - explicit_stitching : bool, optional - Whether this function is called explicitly after prediction(Lighting) or inside - the predict function. Removes the first element (last tile indicator). Returns ------- np.ndarray Full image. """ - # Remove first element of stitching_data if explicit_stitching - # TODO revisit, no way around this? - if explicit_stitching: - stitching_data = [d[1:] for d in stitching_data] - - # Get whole sample shape. Unique to get rid of batch dimension. Expects tensors - input_shape = [x.unique() for x in stitching_data[0][0]] #TODO refatcor unique() ? + # retrieve whole array size, there is two cases to consider: + # 1. the tiles are stored in a list + # 2. the tiles are stored in a list with batches along the first dim + if tiles[0].shape[0] > 1: + input_shape = np.array( + [el.numpy() for el in stitching_data[0][0][0]], dtype=int + ).squeeze() + else: + input_shape = np.array( + [el.numpy() for el in stitching_data[0][0]], dtype=int + ).squeeze() predicted_image = np.zeros(input_shape, dtype=np.float32) diff --git a/tests/config/test_inference_model.py b/tests/config/test_inference_model.py index 6fa9e068e..5ceadf456 100644 --- a/tests/config/test_inference_model.py +++ b/tests/config/test_inference_model.py @@ -18,21 +18,13 @@ def test_wrong_extensions(minimum_inference: dict, ext: str): InferenceModel(**minimum_inference) -@pytest.mark.parametrize("mean, std", [(0, 124.5), (12.6, 0.1)]) -def test_mean_std_non_negative(minimum_inference: dict, mean, std): - """Test that non negative mean and std are accepted.""" - minimum_inference["mean"] = mean - minimum_inference["std"] = std - - prediction_model = InferenceModel(**minimum_inference) - assert prediction_model.mean == mean - assert prediction_model.std == std - - def test_mean_std_both_specified_or_none(minimum_inference: dict): - """Test an error is raised if std is specified but mean is None.""" - # No error if both are None - InferenceModel(**minimum_inference) + """Test error raising when setting mean and std.""" + # Errors if both are None + minimum_inference["mean"] = None + minimum_inference["std"] = None + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) # Error if only mean is defined minimum_inference["mean"] = 10.4 @@ -53,8 +45,17 @@ def test_mean_std_both_specified_or_none(minimum_inference: dict): def test_tile_size(minimum_inference: dict): """Test that non-zero even patch size are accepted.""" + # no tiling + prediction_model = InferenceModel(**minimum_inference) + # 2D + minimum_inference["tile_size"] = [12, 12] + minimum_inference["tile_overlap"] = [2, 2] + minimum_inference["axes"] = "YX" + prediction_model = InferenceModel(**minimum_inference) + assert prediction_model.tile_size == [12, 12] + assert prediction_model.tile_overlap == [2, 2] # 3D minimum_inference["tile_size"] = [12, 12, 12] @@ -93,6 +94,9 @@ def test_wrong_tile_overlap(minimum_inference: dict, tile_size, tile_overlap): def test_set_3d(minimum_inference: dict): """Test that 3D can be set.""" + minimum_inference["tile_size"] = [64, 64] + minimum_inference["tile_overlap"] = [32, 32] + pred = InferenceModel(**minimum_inference) assert "Z" not in pred.axes assert len(pred.tile_size) == 2 diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py new file mode 100644 index 000000000..78b24cc80 --- /dev/null +++ b/tests/config/test_tile_information.py @@ -0,0 +1,50 @@ +import numpy as np +import pytest + +from careamics.config.tile_information import TileInformation + + +def test_defaults(): + """Test instantiating time information with defaults.""" + tile_info = TileInformation(array_shape=np.zeros((6, 6)).shape) + + assert tile_info.array_shape == (6, 6) + assert not tile_info.tiled + assert not tile_info.last_tile + assert tile_info.overlap_crop_coords is None + assert tile_info.stitch_coords is None + + +def test_tiled(): + """Test instantiating time information with parameters.""" + tile_info = TileInformation( + array_shape=np.zeros((6, 6)).shape, + tiled=True, + last_tile=True, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) + + assert tile_info.array_shape == (6, 6) + assert tile_info.tiled + assert tile_info.last_tile + assert tile_info.overlap_crop_coords == ((1, 2),) + assert tile_info.stitch_coords == ((3, 4),) + + +def test_validation_last_tile(): + """Test that last tile is only set if tiled is set.""" + tile_info = TileInformation(array_shape=(6, 6), last_tile=True) + assert not tile_info.last_tile + + +def test_error_on_coords(): + """Test than an error is raised if it is tiled but not coordinates are given.""" + with pytest.raises(ValueError): + TileInformation(array_shape=(6, 6), tiled=True) + + +def test_error_on_singleton_dims(): + """Test that an error is raised if the array shape contains singleton dimensions.""" + with pytest.raises(ValueError): + TileInformation(array_shape=(2, 1, 6, 6)) diff --git a/tests/conftest.py b/tests/conftest.py index 31b781b4a..13d553c3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,8 +123,8 @@ def minimum_inference() -> dict: # create dictionary predic = { "data_type": SupportedData.TIFF.value, - "tile_size": [64, 64], - "tile_overlap": [10, 10], + "mean": 0.0, + "std": 1.0, "axes": "SYX", } diff --git a/tests/dataset/patching/test_random_patching.py b/tests/dataset/patching/test_random_patching.py index 6bc85f74f..858cfe2e0 100644 --- a/tests/dataset/patching/test_random_patching.py +++ b/tests/dataset/patching/test_random_patching.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from careamics.dataset.patching.patching import extract_patches_random +from careamics.dataset.patching.random_patching import extract_patches_random @pytest.mark.parametrize( diff --git a/tests/dataset/patching/test_tiled_patching.py b/tests/dataset/patching/test_tiled_patching.py index 6099d72bf..a7e135d4f 100644 --- a/tests/dataset/patching/test_tiled_patching.py +++ b/tests/dataset/patching/test_tiled_patching.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from careamics.config.tile_information import TileInformation from careamics.dataset.patching.tiled_patching import ( _compute_crop_and_stitch_coords_1d, extract_tiles, @@ -17,7 +18,13 @@ def check_extract_tiles(array: np.ndarray, tile_size, overlaps): # Assemble all tiles and their respective coordinates for tile_data in tile_data_generator: - tile, _, _, overlap_crop_coords, stitch_coords = tile_data + tile = tile_data[0] + + tile_info: TileInformation = tile_data[1] + overlap_crop_coords = tile_info.overlap_crop_coords + stitch_coords = tile_info.stitch_coords + + # add data to lists tiles.append(tile) all_overlap_crop_coords.append(overlap_crop_coords) all_stitch_coords.append(stitch_coords) diff --git a/tests/prediction/test_prediction_utils.py b/tests/prediction/test_prediction_utils.py index 041215264..a4e9f395d 100644 --- a/tests/prediction/test_prediction_utils.py +++ b/tests/prediction/test_prediction_utils.py @@ -22,23 +22,20 @@ def test_stitch_prediction(ordered_array, input_shape, tile_size, overlaps): stitching_data = [] # extract tiles - tiling_outputs = extract_tiles(arr, tile_size, overlaps) + tile_generator = extract_tiles(arr, tile_size, overlaps) # Assemble all tiles as it is done during the prediction stage - for tile_data in tiling_outputs: - tile, _, input_shape, overlap_crop_coords, stitch_coords = tile_data - - tiles.append(from_numpy(tile)) # need to convert to torch.Tensor + for tile_data, tile_info in tile_generator: + tiles.append(from_numpy(tile_data)) # need to convert to torch.Tensor stitching_data.append( ( # this is way too wacky [tensor(i) for i in input_shape], # need to convert to torch.Tensor - [[tensor([j]) for j in i] for i in overlap_crop_coords], - [[tensor([j]) for j in i] for i in stitch_coords], + [[tensor([j]) for j in i] for i in tile_info.overlap_crop_coords], + [[tensor([j]) for j in i] for i in tile_info.stitch_coords], ) ) # compute stitching coordinates - result = stitch_prediction(tiles, stitching_data) assert (result == arr).all() diff --git a/tests/test_careamist.py b/tests/test_careamist.py index fda43904b..a180e976c 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -336,8 +336,8 @@ def test_train_tiff_files_supervised(tmp_path, supervised_configuration): @pytest.mark.parametrize("batch_size", [1, 2]) -def test_predict_array(tmp_path, minimum_configuration, batch_size): - """Test that CAREamics can predict with arrays.""" +def test_predict_on_array_tiled(tmp_path, minimum_configuration, batch_size): + """Test that CAREamics can predict on arrays.""" # training data train_array = np.random.rand(32, 32) @@ -357,17 +357,34 @@ def test_predict_array(tmp_path, minimum_configuration, batch_size): # predict CAREamist predicted = careamist.predict( - train_array, batch_size=batch_size, tile_overlap=(4, 4) + train_array, batch_size=batch_size, tile_size=(16, 16), tile_overlap=(4, 4) ) - # check thatmean/std were set properly - assert careamist.cfg.data_config.mean is not None - assert careamist.cfg.data_config.std is not None - assert careamist.cfg.data_config.mean == train_array.mean() - assert careamist.cfg.data_config.std == train_array.std() - # check prediction and its shape@pytest.mark.parametrize("batch_size", [1, 2]) + assert predicted.squeeze().shape == train_array.shape + + +def test_predict_array_no_tiling(tmp_path, minimum_configuration): + """Test that CAREamics can predict on arrays without tiling.""" + # training data + train_array = np.random.rand(4, 32, 32) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "SYX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array) + + # predict CAREamist + predicted = careamist.predict(train_array) - assert predicted is not None assert predicted.squeeze().shape == train_array.shape @@ -401,7 +418,6 @@ def test_predict_path(tmp_path, minimum_configuration, batch_size): ) # check that it predicted - assert predicted is not None assert predicted.squeeze().shape == train_array.shape @@ -420,5 +436,4 @@ def test_predict_pretrained(tmp_path, pre_trained): predicted = careamist.predict(train_array, tile_overlap=(4, 4)) # check that it predicted - assert predicted is not None assert predicted.squeeze().shape == train_array.shape diff --git a/tests/test_lightning_datamodule.py b/tests/test_lightning_datamodule.py index 27bc292df..50f5e370f 100644 --- a/tests/test_lightning_datamodule.py +++ b/tests/test_lightning_datamodule.py @@ -101,40 +101,44 @@ def test_lightning_predict_datamodule_wrong_type(simple_array): CAREamicsPredictDataModule( pred_data=simple_array, data_type="wrong_type", - tile_size=(10, 10), - tile_overlap=(2, 2), + mean=0.5, + std=0.1, axes="YX", batch_size=2, ) -def test_lightning_pred_datamodule_error_no_mean(simple_array): +def test_lightning_pred_datamodule_tiling(simple_array): """Test that the data module is created correctly with an array.""" # create data module - CAREamicsPredictDataModule( + data_module = CAREamicsPredictDataModule( pred_data=simple_array, data_type="array", - tile_size=(10, 10), - tile_overlap=(2, 2), + mean=0.5, + std=0.1, axes="YX", batch_size=2, + tile_overlap=[2, 2], + tile_size=[4, 4], ) + data_module.prepare_data() + data_module.setup() + assert len(list(data_module.predict_dataloader())) == 8 + -def test_lightning_pred_datamodule_array(simple_array): +def test_lightning_pred_datamodule_no_tiling(simple_array): """Test that the data module is created correctly with an array.""" # create data module data_module = CAREamicsPredictDataModule( pred_data=simple_array, data_type="array", - tile_size=(10, 10), - tile_overlap=(2, 2), - axes="YX", - batch_size=2, mean=0.5, std=0.1, + axes="YX", + batch_size=2, ) + data_module.prepare_data() data_module.setup() - - assert len(list(data_module.predict_dataloader())) > 0 + assert len(list(data_module.predict_dataloader())) == 1 From 3dc9d58fdd10ea0ec758fa94a644751be00c0a46 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:59:25 +0200 Subject: [PATCH 06/14] Rename stitching --- src/careamics/prediction/__init__.py | 4 +- ...ediction_utils.py => stitch_prediction.py} | 56 ------------------- 2 files changed, 1 insertion(+), 59 deletions(-) rename src/careamics/prediction/{prediction_utils.py => stitch_prediction.py} (63%) diff --git a/src/careamics/prediction/__init__.py b/src/careamics/prediction/__init__.py index a11cfceb7..852e65de1 100644 --- a/src/careamics/prediction/__init__.py +++ b/src/careamics/prediction/__init__.py @@ -2,8 +2,6 @@ __all__ = [ "stitch_prediction", - "tta_backward", - "tta_forward", ] -from .prediction_utils import stitch_prediction, tta_backward, tta_forward +from .stitch_prediction import stitch_prediction diff --git a/src/careamics/prediction/prediction_utils.py b/src/careamics/prediction/stitch_prediction.py similarity index 63% rename from src/careamics/prediction/prediction_utils.py rename to src/careamics/prediction/stitch_prediction.py index c9aa751fc..f88233a81 100644 --- a/src/careamics/prediction/prediction_utils.py +++ b/src/careamics/prediction/stitch_prediction.py @@ -70,59 +70,3 @@ def stitch_prediction( ] = cropped_tile.to(torch.float32) return predicted_image - - -def tta_forward(x: np.ndarray) -> List: - """ - Augment 8-fold an array. - - The augmentation is performed using all 90 deg rotations and their flipped version, - as well as the original image flipped. - - Parameters - ---------- - x : torch.tensor - Data to augment. - - Returns - ------- - List - Stack of augmented images. - """ - x_aug = [ - x, - torch.rot90(x, 1, dims=(2, 3)), - torch.rot90(x, 2, dims=(2, 3)), - torch.rot90(x, 3, dims=(2, 3)), - ] - x_aug_flip = x_aug.copy() - for x_ in x_aug: - x_aug_flip.append(torch.flip(x_, dims=(1, 3))) - return x_aug_flip - - -def tta_backward(x_aug: List) -> np.ndarray: - """ - Invert `tta_forward` and average the 8 images. - - Parameters - ---------- - x_aug : List - Stack of 8-fold augmented images. - - Returns - ------- - np.ndarray - Average of de-augmented x_aug. - """ - x_deaug = [ - x_aug[0], - np.rot90(x_aug[1], -1), - np.rot90(x_aug[2], -2), - np.rot90(x_aug[3], -3), - np.fliplr(x_aug[4]), - np.rot90(np.fliplr(x_aug[5]), -1), - np.rot90(np.fliplr(x_aug[6]), -2), - np.rot90(np.fliplr(x_aug[7]), -3), - ] - return np.mean(x_deaug, 0) From a05dd4119fc6a25c8ce5a7399c2a3ef4cc4d2665 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 26 Apr 2024 16:22:44 +0200 Subject: [PATCH 07/14] Fix error in TileInformation --- src/careamics/config/tile_information.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py index 6e669af9f..e018e0f16 100644 --- a/src/careamics/config/tile_information.py +++ b/src/careamics/config/tile_information.py @@ -50,8 +50,7 @@ def no_singleton_dimensions(cls, v: Tuple[int, ...]): @classmethod def only_if_tiled(cls, v: bool, values: ValidationInfo): """ - Check that the last tile flag is only set to `True` if tiling is enabled, - otherwise set it to `False`. + Check that the last tile flag is only set if tiling is enabled. Parameters ---------- @@ -71,7 +70,9 @@ def only_if_tiled(cls, v: bool, values: ValidationInfo): @field_validator("overlap_crop_coords", "stitch_coords") @classmethod - def mandatory_if_tiled(cls, v: Optional[Tuple[int, ...]], values: ValidationInfo): + def mandatory_if_tiled( + cls, v: Optional[Tuple[int, ...]], values: ValidationInfo + ) -> Optional[Tuple[int, ...]]: """ Check that the coordinates are not `None` if tiling is enabled. @@ -97,5 +98,7 @@ def mandatory_if_tiled(cls, v: Optional[Tuple[int, ...]], values: ValidationInfo if values.data["tiled"]: if v is None: raise ValueError("Value must be specified if tiling is enabled.") + + return v else: return None From b9de6ffb25b15ecced9c01464168a66b364f839f Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 26 Apr 2024 18:17:52 +0200 Subject: [PATCH 08/14] Normalize/Denormalize transforms following BMZ, parent transform --- pyproject.toml | 13 +- .../config/architectures/custom_model.py | 11 +- src/careamics/config/configuration_factory.py | 16 +- src/careamics/config/data_model.py | 53 +++--- src/careamics/config/inference_model.py | 18 +- .../transformations/n2v_manipulate_model.py | 41 +++-- .../config/transformations/nd_flip_model.py | 34 ++-- .../config/transformations/normalize_model.py | 31 ++-- .../config/transformations/transform_model.py | 35 ++++ .../xy_random_rotate90_model.py | 30 ++-- src/careamics/dataset/patching/__init__.py | 3 +- .../dataset/patching/patch_transform.py | 156 +----------------- src/careamics/lightning_datamodule.py | 3 +- src/careamics/lightning_module.py | 17 +- src/careamics/transforms/__init__.py | 29 ++-- src/careamics/transforms/normalize.py | 97 +++++++++-- src/careamics/transforms/tta.py | 14 +- src/careamics/utils/__init__.py | 4 +- .../{normalization.py => running_stats.py} | 53 +----- tests/config/test_configuration_factory.py | 37 +++-- tests/config/test_configuration_model.py | 10 +- tests/config/test_data_model.py | 61 ++++--- tests/config/test_inference_model.py | 6 +- .../test_n2v_manipulate_model.py | 18 ++ tests/model_io/test_model_io_utils.py | 4 +- tests/prediction/test_prediction_utils.py | 2 +- tests/test_lightning_datamodule.py | 12 +- tests/transforms/test_normalize.py | 30 ++++ 28 files changed, 420 insertions(+), 418 deletions(-) create mode 100644 src/careamics/config/transformations/transform_model.py rename src/careamics/utils/{normalization.py => running_stats.py} (55%) create mode 100644 tests/config/transformations/test_n2v_manipulate_model.py create mode 100644 tests/transforms/test_normalize.py diff --git a/pyproject.toml b/pyproject.toml index 1951b438c..c81af28a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,11 +167,14 @@ ignore = [ [tool.numpydoc_validation] checks = [ "all", # report on all checks, except the below - "EX01", - "SA01", - "ES01", - "GL02", - "GL03", + "EX01", # Example section not found + "SA01", # See Also section not found + "ES01", # Extended Summar not found + "GL01", # Docstring text (summary) should start in the line immediately + # after the opening quotes + "GL02", # Closing quotes should be placed in the line after the last text + # in the docstring + "GL03", # Double line break found ] exclude = [ # don't report on objects that match any of these regex "test_*", diff --git a/src/careamics/config/architectures/custom_model.py b/src/careamics/config/architectures/custom_model.py index 557032a2d..f7540bc38 100644 --- a/src/careamics/config/architectures/custom_model.py +++ b/src/careamics/config/architectures/custom_model.py @@ -1,7 +1,7 @@ from __future__ import annotations from pprint import pformat -from typing import Literal +from typing import Any, Dict, Literal from pydantic import ConfigDict, field_validator, model_validator from torch.nn import Module @@ -128,12 +128,17 @@ def __str__(self) -> str: """ return pformat(self.model_dump()) - def model_dump(self) -> dict: + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: """Dump the model configuration. + Parameters + ---------- + kwargs : Any + Additional keyword arguments from Pydantic BaseModel model_dump method. + Returns ------- - dict + Dict[str, Any] Model configuration. """ model_dict = super().model_dump() diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index 028e826b3..283584aae 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -360,15 +360,13 @@ def create_n2v_configuration( # n2v2 and structn2v nv2_transform = { "name": SupportedTransform.N2V_MANIPULATE.value, - "parameters": { - "strategy": SupportedPixelManipulation.MEDIAN.value - if use_n2v2 - else SupportedPixelManipulation.UNIFORM.value, - "roi_size": roi_size, - "masked_pixel_percentage": masked_pixel_percentage, - "struct_mask_axis": struct_n2v_axis, - "struct_mask_span": struct_n2v_span, - }, + "strategy": SupportedPixelManipulation.MEDIAN.value + if use_n2v2 + else SupportedPixelManipulation.UNIFORM.value, + "roi_size": roi_size, + "masked_pixel_percentage": masked_pixel_percentage, + "struct_mask_axis": struct_n2v_axis, + "struct_mask_span": struct_n2v_span, } transforms.append(nv2_transform) diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index 8e2782cd1..8068565c6 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -34,7 +34,6 @@ ] -# TODO can we check whether N2V manipulate is in a Compose? # TODO does patches need to be multiple of 8 with UNet? class DataModel(BaseModel): """ @@ -87,9 +86,7 @@ class DataModel(BaseModel): # Dataset configuration data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData patch_size: List[int] = Field(..., min_length=2, max_length=3) - batch_size: int = Field( - default=1, ge=1, validate_default=True - ) # TODO Differentiate based on Train/inf ? + batch_size: int = Field(default=1, ge=1, validate_default=True) axes: str # Optional fields @@ -114,7 +111,7 @@ class DataModel(BaseModel): validate_default=True, ) - dataloader_params: Optional[dict] = None # TODO validate ? + dataloader_params: Optional[dict] = None @field_validator("patch_size") @classmethod @@ -188,7 +185,8 @@ def axes_valid(cls, axes: str) -> str: def validate_prediction_transforms( cls, transforms: Union[List[TRANSFORMS_UNION], Compose] ) -> Union[List[TRANSFORMS_UNION], Compose]: - """Validate N2VManipulate transform position in the transform list. + """ + Validate N2VManipulate transform position in the transform list. Parameters ---------- @@ -273,8 +271,8 @@ def add_std_and_mean_to_normalize(cls, data_model: DataModel) -> DataModel: if data_model.has_transform_list(): for transform in data_model.transforms: if transform.name == SupportedTransform.NORMALIZE.value: - transform.parameters.mean = data_model.mean - transform.parameters.std = data_model.std + transform.mean = data_model.mean + transform.std = data_model.std return data_model @@ -308,9 +306,9 @@ def validate_dimensions(cls, data_model: DataModel) -> DataModel: if data_model.has_transform_list(): for transform in data_model.transforms: if transform.name == SupportedTransform.NDFLIP: - transform.parameters.is_3D = True + transform.is_3D = True elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90: - transform.parameters.is_3D = True + transform.is_3D = True else: if len(data_model.patch_size) != 2: @@ -322,14 +320,15 @@ def validate_dimensions(cls, data_model: DataModel) -> DataModel: if data_model.has_transform_list(): for transform in data_model.transforms: if transform.name == SupportedTransform.NDFLIP: - transform.parameters.is_3D = False + transform.is_3D = False elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90: - transform.parameters.is_3D = False + transform.is_3D = False return data_model def __str__(self) -> str: - """Pretty string reprensenting the configuration. + """ + Pretty string reprensenting the configuration. Returns ------- @@ -339,7 +338,14 @@ def __str__(self) -> str: return pformat(self.model_dump()) def _update(self, **kwargs: Any) -> None: - """Update multiple arguments at once.""" + """ + Update multiple arguments at once. + + Parameters + ---------- + kwargs : Any + Keyword arguments to update. + """ self.__dict__.update(kwargs) self.__class__.model_validate(self.__dict__) @@ -442,8 +448,8 @@ def set_mean_and_std(self, mean: float, std: float) -> None: if self.has_transform_list(): for transform in self.transforms: if transform.name == SupportedTransform.NORMALIZE.value: - transform.parameters.mean = mean - transform.parameters.std = std + transform.mean = mean + transform.std = std else: raise ValueError( "Setting mean and std with Compose transforms is not allowed. Add " @@ -464,7 +470,8 @@ def set_3D(self, axes: str, patch_size: List[int]) -> None: self._update(axes=axes, patch_size=patch_size) def set_N2V2(self, use_n2v2: bool) -> None: - """Set N2V2. + """ + Set N2V2. Parameters ---------- @@ -484,7 +491,8 @@ def set_N2V2(self, use_n2v2: bool) -> None: self.set_N2V2_strategy("uniform") def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None: - """Set N2V2 strategy. + """ + Set N2V2 strategy. Parameters ---------- @@ -503,7 +511,7 @@ def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None: for transform in self.transforms: if transform.name == SupportedTransform.N2V_MANIPULATE.value: - transform.parameters.strategy = strategy + transform.strategy = strategy found_n2v = True if not found_n2v: @@ -522,7 +530,8 @@ def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None: def set_structN2V_mask( self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int ) -> None: - """Set structN2V mask parameters. + """ + Set structN2V mask parameters. Setting `mask_axis` to `none` will disable structN2V. @@ -545,8 +554,8 @@ def set_structN2V_mask( for transform in self.transforms: if transform.name == SupportedTransform.N2V_MANIPULATE.value: - transform.parameters.struct_mask_axis = mask_axis - transform.parameters.struct_mask_span = mask_span + transform.struct_mask_axis = mask_axis + transform.struct_mask_span = mask_span found_n2v = True if not found_n2v: diff --git a/src/careamics/config/inference_model.py b/src/careamics/config/inference_model.py index 78583cecc..c5834a4ec 100644 --- a/src/careamics/config/inference_model.py +++ b/src/careamics/config/inference_model.py @@ -122,7 +122,8 @@ def axes_valid(cls, axes: str) -> str: def validate_transforms( cls, transforms: Union[List[TRANSFORMS_UNION], Compose] ) -> Union[List[TRANSFORMS_UNION], Compose]: - """Validate that transforms do not have N2V pixel manipulate transforms. + """ + Validate that transforms do not have N2V pixel manipulate transforms. Parameters ---------- @@ -236,13 +237,20 @@ def add_std_and_mean_to_normalize( if not isinstance(pred_model.transforms, Compose): for transform in pred_model.transforms: if transform.name == SupportedTransform.NORMALIZE.value: - transform.parameters.mean = pred_model.mean - transform.parameters.std = pred_model.std + transform.mean = pred_model.mean + transform.std = pred_model.std return pred_model def _update(self, **kwargs: Any) -> None: - """Update multiple arguments at once.""" + """ + Update multiple arguments at once. + + Parameters + ---------- + kwargs : Any + Key-value pairs of arguments to update. + """ self.__dict__.update(kwargs) self.__class__.model_validate(self.__dict__) @@ -256,5 +264,7 @@ def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> No Axes. tile_size : List[int] Tile size. + tile_overlap : List[int] + Tile overlap. """ self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap) diff --git a/src/careamics/config/transformations/n2v_manipulate_model.py b/src/careamics/config/transformations/n2v_manipulate_model.py index 07929a084..5e4a03587 100644 --- a/src/careamics/config/transformations/n2v_manipulate_model.py +++ b/src/careamics/config/transformations/n2v_manipulate_model.py @@ -1,15 +1,36 @@ +"""Pydantic model for the N2VManipulate transform.""" from typing import Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, field_validator +from .transform_model import TransformModel -class N2VManipulationParameters(BaseModel): - """Pydantic model used to validate N2V manipulation parameters.""" + +class N2VManipulateModel(TransformModel): + """ + Pydantic model used to represent N2V manipulation. + + Attributes + ---------- + name : Literal["N2VManipulate"] + Name of the transformation. + roi_size : int + Size of the masking region, by default 11. + masked_pixel_percentage : float + Percentage of masked pixels, by default 0.2. + strategy : Literal["uniform", "median"] + Strategy pixel value replacement, by default "uniform". + struct_mask_axis : Literal["horizontal", "vertical", "none"] + Axis of the structN2V mask, by default "none". + struct_mask_span : int + Span of the structN2V mask, by default 5. + """ model_config = ConfigDict( validate_assignment=True, ) + name: Literal["N2VManipulate"] roi_size: int = Field(default=11, ge=3, le=21) masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0) strategy: Literal["uniform", "median"] = Field(default="uniform") @@ -19,7 +40,8 @@ class N2VManipulationParameters(BaseModel): @field_validator("roi_size", "struct_mask_span") @classmethod def odd_value(cls, v: int) -> int: - """Validate that the value is odd. + """ + Validate that the value is odd. Parameters ---------- @@ -39,14 +61,3 @@ def odd_value(cls, v: int) -> int: if v % 2 == 0: raise ValueError("Size must be an odd number.") return v - - -class N2VManipulateModel(BaseModel): - """Pydantic model used to represent N2V manipulation.""" - - model_config = ConfigDict( - validate_assignment=True, - ) - - name: Literal["N2VManipulate"] - parameters: N2VManipulationParameters = N2VManipulationParameters() diff --git a/src/careamics/config/transformations/nd_flip_model.py b/src/careamics/config/transformations/nd_flip_model.py index 5c211e9c8..9787a13ac 100644 --- a/src/careamics/config/transformations/nd_flip_model.py +++ b/src/careamics/config/transformations/nd_flip_model.py @@ -1,26 +1,32 @@ +"""Pydantic model for the NDFlip transform.""" from typing import Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field +from .transform_model import TransformModel -class NDFlipParameters(BaseModel): - """Pydantic model used to validate NDFlip parameters.""" - model_config = ConfigDict( - validate_assignment=True, - ) +class NDFlipModel(TransformModel): + """ + Pydantic model used to represent NDFlip transformation. - p: float = Field(default=0.5, ge=0.0, le=1.0) - is_3D: bool = Field(default=False) - flip_z: bool = Field(default=True) - - -class NDFlipModel(BaseModel): - """Pydantic model used to represent NDFlip transformation.""" + Attributes + ---------- + name : Literal["NDFlip"] + Name of the transformation. + p : float + Probability of applying the transformation, by default 0.5. + is_3D : bool + Whether the transformation should be applied in 3D, by default False. + flip_z : bool + Whether to flip the z axis, by default True. + """ model_config = ConfigDict( validate_assignment=True, ) name: Literal["NDFlip"] - parameters: NDFlipParameters = NDFlipParameters() + p: float = Field(default=0.5, ge=0.0, le=1.0) + is_3D: bool = Field(default=False) + flip_z: bool = Field(default=True) diff --git a/src/careamics/config/transformations/normalize_model.py b/src/careamics/config/transformations/normalize_model.py index 6f1de597f..cc156dbf9 100644 --- a/src/careamics/config/transformations/normalize_model.py +++ b/src/careamics/config/transformations/normalize_model.py @@ -1,26 +1,31 @@ +"""Pydantic model for the Normalize transform.""" from typing import Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field +from .transform_model import TransformModel -class NormalizeParameters(BaseModel): - """Pydantic model used to validate Normalize parameters.""" - model_config = ConfigDict( - validate_assignment=True, - ) - - mean: float = Field(default=0.485) # albumentations default - std: float = Field(default=0.229) - max_pixel_value: float = Field(default=1.0, ge=0.0) # TODO explain why +class NormalizeModel(TransformModel): + """ + Pydantic model used to represent Normalize transformation. + The Normalize transform is a zero mean and unit variance transformation. -class NormalizeModel(BaseModel): - """Pydantic model used to represent Normalize transformation.""" + Attributes + ---------- + name : Literal["Normalize"] + Name of the transformation. + mean : float + Mean value for normalization. + std : float + Standard deviation value for normalization. + """ model_config = ConfigDict( validate_assignment=True, ) name: Literal["Normalize"] - parameters: NormalizeParameters = NormalizeParameters() + mean: float = Field(default=0.485) # albumentations defaults + std: float = Field(default=0.229) diff --git a/src/careamics/config/transformations/transform_model.py b/src/careamics/config/transformations/transform_model.py new file mode 100644 index 000000000..698375361 --- /dev/null +++ b/src/careamics/config/transformations/transform_model.py @@ -0,0 +1,35 @@ +"""Parent model for the transforms.""" +from typing import Any, Dict + +from pydantic import BaseModel + + +class TransformModel(BaseModel): + """ + Pydantic model used to represent a transformation. + + The `model_dump` method is overwritten to exclude the name field. + + Attributes + ---------- + name : str + Name of the transformation. + """ + + name: str + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """ + Return the model as a dictionary. + + Parameters + ---------- + **kwargs + Pydantic BaseMode model_dump method keyword arguments. + """ + model_dict = super().model_dump(**kwargs) + + # remove the name field + model_dict.pop("name") + + return model_dict diff --git a/src/careamics/config/transformations/xy_random_rotate90_model.py b/src/careamics/config/transformations/xy_random_rotate90_model.py index d9add88c1..af0cd1422 100644 --- a/src/careamics/config/transformations/xy_random_rotate90_model.py +++ b/src/careamics/config/transformations/xy_random_rotate90_model.py @@ -1,25 +1,29 @@ +"""Pydantic model for the XYRandomRotate90 transform.""" from typing import Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field +from .transform_model import TransformModel -class XYRandomRotate90Parameters(BaseModel): - """Pydantic model used to validate NDFlip parameters.""" - model_config = ConfigDict( - validate_assignment=True, - ) +class XYRandomRotate90Model(TransformModel): + """ + Pydantic model used to represent NDFlip transformation. - p: float = Field(default=0.5, ge=0.0, le=1.0) - is_3D: bool = Field(default=False) - - -class XYRandomRotate90Model(BaseModel): - """Pydantic model used to represent NDFlip transformation.""" + Attributes + ---------- + name : Literal["XYRandomRotate90"] + Name of the transformation. + p : float + Probability of applying the transformation, by default 0.5. + is_3D : bool + Whether the transformation should be applied in 3D, by default False. + """ model_config = ConfigDict( validate_assignment=True, ) name: Literal["XYRandomRotate90"] - parameters: XYRandomRotate90Parameters = XYRandomRotate90Parameters() + p: float = Field(default=0.5, ge=0.0, le=1.0) + is_3D: bool = Field(default=False) diff --git a/src/careamics/dataset/patching/__init__.py b/src/careamics/dataset/patching/__init__.py index c684789bf..c9f5f219e 100644 --- a/src/careamics/dataset/patching/__init__.py +++ b/src/careamics/dataset/patching/__init__.py @@ -3,7 +3,6 @@ __all__ = [ "get_patch_transform", - "get_patch_transform_predict", ] -from .patch_transform import get_patch_transform, get_patch_transform_predict +from .patch_transform import get_patch_transform diff --git a/src/careamics/dataset/patching/patch_transform.py b/src/careamics/dataset/patching/patch_transform.py index 8a092cfc1..15cde203b 100644 --- a/src/careamics/dataset/patching/patch_transform.py +++ b/src/careamics/dataset/patching/patch_transform.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Union +from typing import List, Union import albumentations as Aug @@ -31,7 +31,7 @@ def get_patch_transform( # instantiate all transforms transforms = [ - all_transforms[transform.name](**transform.parameters.model_dump()) + all_transforms[transform.name](**transform.model_dump()) for transform in patch_transforms ] @@ -42,155 +42,3 @@ def get_patch_transform( if (with_target and normalize_mask) # TODO check this else {}, ) - - -# TODO kept here as reference -def _get_patch_transform( - patch_transforms: Union[List[TRANSFORMS_UNION], Aug.Compose], - mean: float, - std: float, - target: bool, - normalize_mask: bool = True, -) -> Aug.Compose: - """Return a pixel manipulation function. - - Used in N2V family of algorithms. - - Parameters - ---------- - patch_transform_type : str - Type of patch transform. - target : bool - Whether the transform is applied to the target(if the target is present). - mode : str - Train or predict mode. - - Returns - ------- - Union[None, Callable] - Patch transform function. - """ - if patch_transforms is None: - return Aug.Compose( - [Aug.NoOp()], - additional_targets={"target": "image"} - if (target and normalize_mask) # TODO why? there is no normalization here? - else {}, - ) - elif isinstance(patch_transforms, list): - patch_transforms[[t["name"] for t in patch_transforms].index("Normalize")][ - "parameters" - ] = { - "mean": mean, - "std": std, - "max_pixel_value": 1, - } - # TODO not very readable - return Aug.Compose( - [ - get_all_transforms()[transform["name"]](**transform["parameters"]) - if "parameters" in transform - else get_all_transforms()[transform["name"]]() - for transform in patch_transforms - ], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - elif isinstance(patch_transforms, Aug.Compose): - return Aug.Compose( - [ - t - for t in patch_transforms.transforms[:-1] - if not isinstance(t, Aug.Normalize) - ] - + [ - Aug.Normalize(mean=mean, std=std, max_pixel_value=1), - patch_transforms.transforms[-1] - if patch_transforms.transforms[-1].__class__.__name__ == "ManipulateN2V" - else Aug.NoOp(), - ], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - else: - raise ValueError( - f"Incorrect patch transform type {patch_transforms}. " - f"Please refer to the documentation." # TODO add link to documentation - ) - - -# TODO add tta -def get_patch_transform_predict( - patch_transforms: Union[List, Aug.Compose, None], - mean: float, - std: float, - target: bool, - normalize_mask: bool = True, -) -> Union[None, Callable]: - """Return a pixel manipulation function. - - Used in N2V family of algorithms. - - Parameters - ---------- - patch_transform_type : str - Type of patch transform. - target : bool - Whether the transform is applied to the target(if the target is present). - mode : str - Train or predict mode. - - Returns - ------- - Union[None, Callable] - Patch transform function. - """ - if patch_transforms is None: - return Aug.Compose( - [Aug.NoOp()], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - elif isinstance(patch_transforms, list): - patch_transforms[[t["name"] for t in patch_transforms].index("Normalize")][ - "parameters" - ] = { - "mean": mean, - "std": std, - "max_pixel_value": 1, - } - # TODO not very readable - return Aug.Compose( - [ - get_all_transforms()[transform["name"]](**transform["parameters"]) - if "parameters" in transform - else get_all_transforms()[transform["name"]]() - for transform in patch_transforms - if transform["name"] != "ManipulateN2V" - ], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - elif isinstance(patch_transforms, Aug.Compose): - return Aug.Compose( - [ - t - for t in patch_transforms.transforms[:-1] - if not isinstance(t, Aug.Normalize) - ] - + [ - Aug.Normalize(mean=mean, std=std, max_pixel_value=1), - ], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - else: - raise ValueError( - f"Incorrect patch transform type {patch_transforms}. " - f"Please refer to the documentation." # TODO add link to documentation - ) diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py index cc5425ecb..7bc65dbba 100644 --- a/src/careamics/lightning_datamodule.py +++ b/src/careamics/lightning_datamodule.py @@ -479,7 +479,8 @@ class CAREamicsTrainDataModule(CAREamicsWood): >>> my_transforms = [ ... { ... "name": SupportedTransform.NORMALIZE.value, - ... "parameters": {"mean": 0, "std": 1}, + ... "mean": 0, + ... "std": 1, ... }, ... { ... "name": SupportedTransform.N2V_MANIPULATE.value, diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py index 35cbd9417..6e3c4f522 100644 --- a/src/careamics/lightning_module.py +++ b/src/careamics/lightning_module.py @@ -13,8 +13,7 @@ ) from careamics.losses import loss_factory from careamics.models.model_factory import model_factory -from careamics.transforms import ImageRestorationTTA -from careamics.utils import denormalize +from careamics.transforms import Denormalize, ImageRestorationTTA from careamics.utils.torch_utils import get_optimizer, get_scheduler @@ -37,11 +36,6 @@ class CAREamicsKiln(L.LightningModule): Optimizer parameters. lr_scheduler_name : str Learning rate scheduler name. - - Parameters - ---------- - algorithm_config : Union[AlgorithmModel, dict] - Algorithm configuration. """ def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None: @@ -164,12 +158,11 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: output = self.model(x) # Denormalize the output - # TODO replace with Albu class - denormalized_output = denormalize( - output, - self._trainer.datamodule.predict_dataset.mean, - self._trainer.datamodule.predict_dataset.std, + denorm = Denormalize( + mean=self._trainer.datamodule.predict_dataset.mean, + std=self._trainer.datamodule.predict_dataset.std, ) + denormalized_output = denorm(image=output)["image"] if len(aux) > 0: return denormalized_output, aux diff --git a/src/careamics/transforms/__init__.py b/src/careamics/transforms/__init__.py index 3f7e37a65..dd4c2604c 100644 --- a/src/careamics/transforms/__init__.py +++ b/src/careamics/transforms/__init__.py @@ -6,36 +6,31 @@ "NDFlip", "XYRandomRotate90", "ImageRestorationTTA", + "Denormalize", + "Normalize", ] -from inspect import getmembers, isclass - -import albumentations as Aug from .n2v_manipulate import N2VManipulate from .nd_flip import NDFlip +from .normalize import Denormalize, Normalize from .tta import ImageRestorationTTA from .xy_random_rotate90 import XYRandomRotate90 -ALL_TRANSFORMS = dict( - getmembers(Aug, isclass) - + [ - ("N2VManipulate", N2VManipulate), - ("NDFlip", NDFlip), - ("XYRandomRotate90", XYRandomRotate90), - ] -) +ALL_TRANSFORMS = { + "Normalize": Normalize, + "N2VManipulate": N2VManipulate, + "NDFlip": NDFlip, + "XYRandomRotate90": XYRandomRotate90, +} def get_all_transforms() -> dict: """Return all the transforms accepted by CAREamics. - This includes all transforms from Albumentations (see https://albumentations.ai/), - and custom transforms implemented in CAREamics. - - Note that while any Albumentations transform can be used in CAREamics, no check are - implemented to verify the compatibility of any other transforms than the ones - officially supported (see SupportedTransforms). + Note that while CAREamics accepts any `Compose` transforms from Albumentations (see + https://albumentations.ai/), only a few transformations are explicitely supported + (see `SupportedTransform`). Returns ------- diff --git a/src/careamics/transforms/normalize.py b/src/careamics/transforms/normalize.py index 3554a16d3..4ec529b0f 100644 --- a/src/careamics/transforms/normalize.py +++ b/src/careamics/transforms/normalize.py @@ -1,42 +1,61 @@ -from typing import Any, Tuple +from typing import Any import numpy as np from albumentations import DualTransform -# TODO unused -class Denormalize(DualTransform): - """Denormalize an image or image patch. +class Normalize(DualTransform): + """ + Normalize an image or image patch. + + Normalization is a zero mean and unit variance. This transform expects (Z)YXC + dimensions. + + Not that an epsilon value of 1e-6 is added to the standard deviation to avoid + division by zero and that it returns a float32 image. - This transform expects (Z)YXC dimensions. + Attributes + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + eps : float + Epsilon value to avoid division by zero. """ def __init__( self, - mean: Tuple[float], - std: Tuple[float], - max_pixel_value: int = 1, - always_apply: bool = False, - p: float = 1.0, + mean: float, + std: float, ): - super().__init__(always_apply=always_apply, p=p) + super().__init__(always_apply=True, p=1) - self.mean = np.array(mean) - self.std = np.array(std) - self.max_pixel_value = max_pixel_value + self.mean = mean + self.std = std + self.eps = 1e-6 def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray: - """Apply the transform to the image. + """ + Apply the transform to the image. Parameters ---------- patch : np.ndarray Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + + Returns + ------- + np.ndarray + Normalized image or image patch. """ - return (patch * self.std + self.mean) * self.max_pixel_value + return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32) def apply_to_mask(self, mask: np.ndarray, **kwargs: Any) -> np.ndarray: - """Apply the transform to the mask. + """ + Apply the transform to the mask. + + The mask is returned as is. Parameters ---------- @@ -44,3 +63,47 @@ def apply_to_mask(self, mask: np.ndarray, **kwargs: Any) -> np.ndarray: Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). """ return mask + + +class Denormalize(DualTransform): + """ + Denormalize an image or image patch. + + Denormalization is performed expecting a zero mean and unit variance input. This + transform expects (Z)YXC dimensions. + + Not that an epsilon value of 1e-6 is added to the standard deviation to avoid + division by zero during the normalization step, which is taken into account during + denormalization. + + Attributes + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + eps : float + Epsilon value to avoid division by zero. + """ + + def __init__( + self, + mean: float, + std: float, + ): + super().__init__(always_apply=True, p=1) + + self.mean = mean + self.std = std + self.eps = 1e-6 + + def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray: + """ + Apply the transform to the image. + + Parameters + ---------- + patch : np.ndarray + Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + """ + return patch * (self.std + self.eps) + self.mean diff --git a/src/careamics/transforms/tta.py b/src/careamics/transforms/tta.py index 55cc9cc13..39e65950e 100644 --- a/src/careamics/transforms/tta.py +++ b/src/careamics/transforms/tta.py @@ -1,3 +1,4 @@ +"""Test-time augmentations.""" from typing import List import numpy as np @@ -6,12 +7,16 @@ # TODO add tests class ImageRestorationTTA: - """Test-time augmentation for image restoration tasks. + """ + Test-time augmentation for image restoration tasks. The augmentation is performed using all 90 deg rotations and their flipped version, as well as the original image flipped. Tensors should be of shape SC(Z)YX + + This transformation is used in the LightningModule in order to perform test-time + agumentation. """ def __init__(self) -> None: @@ -19,16 +24,17 @@ def __init__(self) -> None: pass def forward(self, x: Tensor) -> List[Tensor]: - """Apply test-time augmentation to the input tensor. + """ + Apply test-time augmentation to the input tensor. Parameters ---------- - x : Any + x : Tensor Input tensor, shape SC(Z)YX. Returns ------- - Any + List[Tensor] List of augmented tensors. """ augmented = [ diff --git a/src/careamics/utils/__init__.py b/src/careamics/utils/__init__.py index 91d309e33..b37f4fef2 100644 --- a/src/careamics/utils/__init__.py +++ b/src/careamics/utils/__init__.py @@ -2,8 +2,6 @@ __all__ = [ - "denormalize", - "normalize", "check_axes_validity", "check_tiling_validity", "cwd", @@ -20,9 +18,9 @@ from .base_enum import BaseEnum from .context import cwd, get_careamics_home from .logging import get_logger -from .normalization import RunningStats, denormalize, normalize from .path_utils import check_path_exists from .ram import get_ram_size +from .running_stats import RunningStats from .validators import ( check_axes_validity, check_tiling_validity, diff --git a/src/careamics/utils/normalization.py b/src/careamics/utils/running_stats.py similarity index 55% rename from src/careamics/utils/normalization.py rename to src/careamics/utils/running_stats.py index 86cb8edd5..053f44d50 100644 --- a/src/careamics/utils/normalization.py +++ b/src/careamics/utils/running_stats.py @@ -4,62 +4,11 @@ These methods are used to normalize and denormalize images. """ from multiprocessing import Value -from typing import List, Tuple, Union +from typing import Tuple import numpy as np -def normalize( - img: Union[Tuple, np.ndarray], mean: float, std: float -) -> Union[List, np.ndarray]: - """ - Normalize an image using mean and standard deviation. - - Images are normalised by subtracting the mean and dividing by the standard - deviation. - - Parameters - ---------- - img : np.ndarray - Image to normalize. - mean : float - Mean. - std : float - Standard deviation. - - Returns - ------- - np.ndarray - Normalized array. - """ - zero_mean = img - mean - return zero_mean / std - - -def denormalize(img: np.ndarray, mean: float, std: float) -> np.ndarray: - """ - Denormalize an image using mean and standard deviation. - - Images are denormalised by multiplying by the standard deviation and adding the - mean. - - Parameters - ---------- - img : np.ndarray - Image to denormalize. - mean : float - Mean. - std : float - Standard deviation. - - Returns - ------- - np.ndarray - Denormalized array. - """ - return img * std + mean - - class RunningStats: """Calculates running mean and std.""" diff --git a/tests/config/test_configuration_factory.py b/tests/config/test_configuration_factory.py index 61708b6ba..711c4375d 100644 --- a/tests/config/test_configuration_factory.py +++ b/tests/config/test_configuration_factory.py @@ -18,13 +18,16 @@ def test_n2v_configuration(): batch_size=8, num_epochs=100, ) - assert config.data_config.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + assert ( + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.UNIFORM.value ) - assert not config.data_config.transforms[-2].parameters.is_3D # XY_RANDOM_ROTATE90 - assert not config.data_config.transforms[-3].parameters.is_3D # NDFLIP + assert not config.data_config.transforms[-2].is_3D # XY_RANDOM_ROTATE90 + assert not config.data_config.transforms[-3].is_3D # NDFLIP assert not config.algorithm_config.model.is_3D() @@ -38,13 +41,16 @@ def test_n2v_3d_configuration(): batch_size=8, num_epochs=100, ) - assert config.data_config.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + assert ( + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.UNIFORM.value ) - assert config.data_config.transforms[-2].parameters.is_3D # XY_RANDOM_ROTATE90 - assert config.data_config.transforms[-3].parameters.is_3D # NDFLIP + assert config.data_config.transforms[-2].is_3D # XY_RANDOM_ROTATE90 + assert config.data_config.transforms[-3].is_3D # NDFLIP assert config.algorithm_config.model.is_3D() @@ -157,7 +163,10 @@ def test_n2v_no_aug(): use_augmentations=False, ) assert len(config.data_config.transforms) == 2 - assert config.data_config.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + assert ( + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) assert config.data_config.transforms[-2].name == SupportedTransform.NORMALIZE.value @@ -173,8 +182,8 @@ def test_n2v_augmentation_parameters(): roi_size=17, masked_pixel_percentage=0.5, ) - assert config.data_config.transforms[-1].parameters.roi_size == 17 - assert config.data_config.transforms[-1].parameters.masked_pixel_percentage == 0.5 + assert config.data_config.transforms[-1].roi_size == 17 + assert config.data_config.transforms[-1].masked_pixel_percentage == 0.5 def test_n2v2(): @@ -189,7 +198,7 @@ def test_n2v2(): use_n2v2=True, ) assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.MEDIAN.value ) @@ -207,7 +216,7 @@ def test_structn2v(): struct_n2v_span=7, ) assert ( - config.data_config.transforms[-1].parameters.struct_mask_axis + config.data_config.transforms[-1].struct_mask_axis == SupportedStructAxis.HORIZONTAL.value ) - assert config.data_config.transforms[-1].parameters.struct_mask_span == 7 + assert config.data_config.transforms[-1].struct_mask_span == 7 diff --git a/tests/config/test_configuration_model.py b/tests/config/test_configuration_model.py index 5f731159d..083ee1dda 100644 --- a/tests/config/test_configuration_model.py +++ b/tests/config/test_configuration_model.py @@ -114,7 +114,7 @@ def test_n2v2_and_transforms(minimum_configuration: dict, algorithm, strategy): config.data_config.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value ) - assert config.data_config.transforms[-1].parameters.strategy == expected_strategy + assert config.data_config.transforms[-1].strategy == expected_strategy # passing ManipulateN2V with the wrong strategy minimum_configuration["data_config"]["transforms"] = [ @@ -126,7 +126,7 @@ def test_n2v2_and_transforms(minimum_configuration: dict, algorithm, strategy): } ] config = Configuration(**minimum_configuration) - assert config.data_config.transforms[-1].parameters.strategy == expected_strategy + assert config.data_config.transforms[-1].strategy == expected_strategy def test_setting_n2v2(minimum_configuration: dict): @@ -140,7 +140,7 @@ def test_setting_n2v2(minimum_configuration: dict): assert config.algorithm_config.algorithm == SupportedAlgorithm.N2V.value assert not config.algorithm_config.model.n2v2 assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.UNIFORM.value ) @@ -148,7 +148,7 @@ def test_setting_n2v2(minimum_configuration: dict): config.set_N2V2(True) assert config.algorithm_config.model.n2v2 assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.MEDIAN.value ) @@ -156,7 +156,7 @@ def test_setting_n2v2(minimum_configuration: dict): config.set_N2V2(False) assert not config.algorithm_config.model.n2v2 assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.UNIFORM.value ) diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py index 7acce922e..cd42b55b4 100644 --- a/tests/config/test_data_model.py +++ b/tests/config/test_data_model.py @@ -7,7 +7,10 @@ SupportedStructAxis, SupportedTransform, ) -from careamics.config.transformations.xy_random_rotate90_model import ( +from careamics.config.transformations import ( + N2VManipulateModel, + NDFlipModel, + NormalizeModel, XYRandomRotate90Model, ) from careamics.transforms import get_all_transforms @@ -83,8 +86,8 @@ def test_mean_and_std_in_normalize(minimum_data: dict): {"name": SupportedTransform.NORMALIZE.value}, ] data = DataModel(**minimum_data) - assert data.transforms[0].parameters.mean == 10.4 - assert data.transforms[0].parameters.std == 3.2 + assert data.transforms[0].mean == 10.4 + assert data.transforms[0].std == 3.2 def test_patch_size(minimum_data: dict): @@ -155,7 +158,18 @@ def test_set_3d(minimum_data: dict): def test_passing_supported_transforms(minimum_data: dict, transforms): """Test that list of supported transforms can be passed.""" minimum_data["transforms"] = transforms - DataModel(**minimum_data) + model = DataModel(**minimum_data) + + supported = { + "NDFlip": NDFlipModel, + "XYRandomRotate90": XYRandomRotate90Model, + "Normalize": NormalizeModel, + "N2VManipulate": N2VManipulateModel, + } + + for ind, t in enumerate(transforms): + assert t["name"] == model.transforms[ind].name + assert isinstance(model.transforms[ind], supported[t["name"]]) @pytest.mark.parametrize( @@ -236,25 +250,24 @@ def test_correct_transform_parameters(minimum_data: dict): model = DataModel(**minimum_data) # Normalize - params = model.transforms[0].parameters.model_dump() + params = model.transforms[0].model_dump() assert "mean" in params assert "std" in params - assert "max_pixel_value" in params # NDFlip - params = model.transforms[1].parameters.model_dump() + params = model.transforms[1].model_dump() assert "p" in params assert "is_3D" in params assert "flip_z" in params # XYRandomRotate90 - params = model.transforms[2].parameters.model_dump() + params = model.transforms[2].model_dump() assert "p" in params assert "is_3D" in params assert isinstance(model.transforms[2], XYRandomRotate90Model) # N2VManipulate - params = model.transforms[3].parameters.model_dump() + params = model.transforms[3].model_dump() assert "roi_size" in params assert "masked_pixel_percentage" in params assert "strategy" in params @@ -307,13 +320,13 @@ def test_3D_and_transforms(minimum_data: dict): }, ] data = DataModel(**minimum_data) - assert data.transforms[0].parameters.is_3D is False - assert data.transforms[1].parameters.is_3D is False + assert data.transforms[0].is_3D is False + assert data.transforms[1].is_3D is False # change to 3D data.set_3D("ZYX", [64, 64, 64]) - data.transforms[0].parameters.is_3D = True - data.transforms[1].parameters.is_3D = True + data.transforms[0].is_3D = True + data.transforms[1].is_3D = True def test_set_n2v_strategy(minimum_data: dict): @@ -323,13 +336,13 @@ def test_set_n2v_strategy(minimum_data: dict): data = DataModel(**minimum_data) assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - assert data.transforms[-1].parameters.strategy == uniform + assert data.transforms[-1].strategy == uniform data.set_N2V2_strategy(median) - assert data.transforms[-1].parameters.strategy == median + assert data.transforms[-1].strategy == median data.set_N2V2_strategy(uniform) - assert data.transforms[-1].parameters.strategy == uniform + assert data.transforms[-1].strategy == uniform def test_set_n2v_strategy_wrong_value(minimum_data: dict): @@ -347,20 +360,20 @@ def test_set_struct_mask(minimum_data: dict): data = DataModel(**minimum_data) assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - assert data.transforms[-1].parameters.struct_mask_axis == none - assert data.transforms[-1].parameters.struct_mask_span == 5 + assert data.transforms[-1].struct_mask_axis == none + assert data.transforms[-1].struct_mask_span == 5 data.set_structN2V_mask(vertical, 3) - assert data.transforms[-1].parameters.struct_mask_axis == vertical - assert data.transforms[-1].parameters.struct_mask_span == 3 + assert data.transforms[-1].struct_mask_axis == vertical + assert data.transforms[-1].struct_mask_span == 3 data.set_structN2V_mask(horizontal, 7) - assert data.transforms[-1].parameters.struct_mask_axis == horizontal - assert data.transforms[-1].parameters.struct_mask_span == 7 + assert data.transforms[-1].struct_mask_axis == horizontal + assert data.transforms[-1].struct_mask_span == 7 data.set_structN2V_mask(none, 11) - assert data.transforms[-1].parameters.struct_mask_axis == none - assert data.transforms[-1].parameters.struct_mask_span == 11 + assert data.transforms[-1].struct_mask_axis == none + assert data.transforms[-1].struct_mask_span == 11 def test_set_struct_mask_wrong_value(minimum_data: dict): diff --git a/tests/config/test_inference_model.py b/tests/config/test_inference_model.py index 5ceadf456..bd6f0e98c 100644 --- a/tests/config/test_inference_model.py +++ b/tests/config/test_inference_model.py @@ -165,7 +165,7 @@ def test_passing_compose_transform(minimum_inference: dict): """Test that Compose transform can be passed.""" minimum_inference["transforms"] = Compose( [ - get_all_transforms()[SupportedTransform.NORMALIZE](), + get_all_transforms()[SupportedTransform.NORMALIZE](mean=10.4, std=3.2), get_all_transforms()[SupportedTransform.NDFLIP](), ] ) @@ -181,5 +181,5 @@ def test_mean_and_std_in_normalize(minimum_inference: dict): ] data = InferenceModel(**minimum_inference) - assert data.transforms[0].parameters.mean == 10.4 - assert data.transforms[0].parameters.std == 3.2 + assert data.transforms[0].mean == 10.4 + assert data.transforms[0].std == 3.2 diff --git a/tests/config/transformations/test_n2v_manipulate_model.py b/tests/config/transformations/test_n2v_manipulate_model.py new file mode 100644 index 000000000..966bd4bc0 --- /dev/null +++ b/tests/config/transformations/test_n2v_manipulate_model.py @@ -0,0 +1,18 @@ +import pytest + +from careamics.config.transformations.n2v_manipulate_model import N2VManipulateModel + + +def test_odd_roi_and_mask(): + """Test that errors are thrown if we pass even roi and mask sizes.""" + # no error + model = N2VManipulateModel(name="N2VManipulate", roi_size=3, struct_mask_span=7) + assert model.roi_size == 3 + assert model.struct_mask_span == 7 + + # errors + with pytest.raises(ValueError): + N2VManipulateModel(name="N2VManipulate", roi_size=4, struct_mask_span=7) + + with pytest.raises(ValueError): + N2VManipulateModel(name="N2VManipulate", roi_size=3, struct_mask_span=6) diff --git a/tests/model_io/test_model_io_utils.py b/tests/model_io/test_model_io_utils.py index 46fea7273..3a3af1717 100644 --- a/tests/model_io/test_model_io_utils.py +++ b/tests/model_io/test_model_io_utils.py @@ -11,7 +11,7 @@ def test_export_bmz(tmp_path, pre_trained): # instantiate CAREamist careamist = CAREamist(source=pre_trained, work_dir=tmp_path) - # predict + # predict (no tiling and no tta) predicted = careamist.predict(train_array, tta_transforms=False) # save images @@ -19,7 +19,7 @@ def test_export_bmz(tmp_path, pre_trained): np.save(train_path, train_array[np.newaxis, np.newaxis, ...]) predicted_path = tmp_path / "predicted.npy" - np.save(tmp_path / "predicted.npy", predicted[np.newaxis, ...]) + np.save(tmp_path / "predicted.npy", predicted) # export to BioImage Model Zoo export_bmz( diff --git a/tests/prediction/test_prediction_utils.py b/tests/prediction/test_prediction_utils.py index a4e9f395d..c96920ccd 100644 --- a/tests/prediction/test_prediction_utils.py +++ b/tests/prediction/test_prediction_utils.py @@ -2,7 +2,7 @@ from torch import from_numpy, tensor from careamics.dataset.patching.tiled_patching import extract_tiles -from careamics.prediction.prediction_utils import stitch_prediction +from careamics.prediction.stitch_prediction import stitch_prediction @pytest.mark.parametrize( diff --git a/tests/test_lightning_datamodule.py b/tests/test_lightning_datamodule.py index 50f5e370f..6eeebda9b 100644 --- a/tests/test_lightning_datamodule.py +++ b/tests/test_lightning_datamodule.py @@ -68,7 +68,7 @@ def test_lightning_train_datamodule_n2v2(simple_array, use_n2v2, strategy): batch_size=2, use_n2v2=use_n2v2, ) - assert data_module.data_config.transforms[-1].parameters.strategy == strategy + assert data_module.data_config.transforms[-1].strategy == strategy def test_lightning_train_datamodule_structn2v(simple_array): @@ -85,14 +85,8 @@ def test_lightning_train_datamodule_structn2v(simple_array): struct_n2v_axis=struct_axis, struct_n2v_span=struct_span, ) - assert ( - data_module.data_config.transforms[-1].parameters.struct_mask_axis - == struct_axis - ) - assert ( - data_module.data_config.transforms[-1].parameters.struct_mask_span - == struct_span - ) + assert data_module.data_config.transforms[-1].struct_mask_axis == struct_axis + assert data_module.data_config.transforms[-1].struct_mask_span == struct_span def test_lightning_predict_datamodule_wrong_type(simple_array): diff --git a/tests/transforms/test_normalize.py b/tests/transforms/test_normalize.py new file mode 100644 index 000000000..da61a9119 --- /dev/null +++ b/tests/transforms/test_normalize.py @@ -0,0 +1,30 @@ +import numpy as np + +from careamics.transforms import Denormalize, Normalize + + +def test_normalize_denormalize(): + """Test the Normalize transform.""" + # Create data + array = np.arange(100).reshape((1, 1, 10, 10)) + + # Create the transform + norm = Normalize( + mean=50, + std=25, + ) + + # Apply the transform + normalized: np.array = norm(image=array)["image"] + assert np.abs(normalized.mean()) < 0.02 + assert np.abs(normalized.std() - 1) < 0.2 + + # Create the denormalize transform + denorm = Denormalize( + mean=50, + std=25, + ) + + # Apply the denormalize transform + denormalized: np.array = denorm(image=normalized)["image"] + assert np.isclose(denormalized, array).all() From 6bd1c0f5e8167447d49dd49d81c1d6f3a4d87295 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 26 Apr 2024 18:46:57 +0200 Subject: [PATCH 09/14] Prevent extra parameters in transforms --- src/careamics/config/data_model.py | 9 ++++++--- .../config/transformations/transform_model.py | 11 ++++++++++- tests/config/test_configuration_model.py | 4 +--- tests/config/test_data_model.py | 10 +++------- .../transformations/test_n2v_manipulate_model.py | 8 ++++++++ .../config/transformations/test_normalize_model.py | 13 +++++++++++++ 6 files changed, 41 insertions(+), 14 deletions(-) create mode 100644 tests/config/transformations/test_normalize_model.py diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index 8068565c6..1acead884 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -55,7 +55,7 @@ class DataModel(BaseModel): ... ) To change the mean and std of the data: - >>> data.set_mean_and_std(mean=0., std=1.) + >>> data.set_mean_and_std(mean=214.3, std=84.5) One can pass also a list of transformations, by keyword, using the SupportedTransform or the name of an Albumentation transform: @@ -68,10 +68,13 @@ class DataModel(BaseModel): ... transforms=[ ... { ... "name": SupportedTransform.NORMALIZE.value, + ... "mean": 167.6, + ... "std": 47.2, ... }, ... { ... "name": "NDFlip", - ... "parameters": {"is_3D": True, "flip_Z": True} + ... "is_3D": True, + ... "flip_z": True, ... } ... ] ... ) @@ -343,7 +346,7 @@ def _update(self, **kwargs: Any) -> None: Parameters ---------- - kwargs : Any + **kwargs : Any Keyword arguments to update. """ self.__dict__.update(kwargs) diff --git a/src/careamics/config/transformations/transform_model.py b/src/careamics/config/transformations/transform_model.py index 698375361..ffbc6022d 100644 --- a/src/careamics/config/transformations/transform_model.py +++ b/src/careamics/config/transformations/transform_model.py @@ -1,7 +1,7 @@ """Parent model for the transforms.""" from typing import Any, Dict -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class TransformModel(BaseModel): @@ -16,6 +16,10 @@ class TransformModel(BaseModel): Name of the transformation. """ + model_config = ConfigDict( + extra="forbid", # throw errors if the parameters are not properly passed + ) + name: str def model_dump(self, **kwargs) -> Dict[str, Any]: @@ -26,6 +30,11 @@ def model_dump(self, **kwargs) -> Dict[str, Any]: ---------- **kwargs Pydantic BaseMode model_dump method keyword arguments. + + Returns + ------- + Dict[str, Any] + Dictionary representation of the model. """ model_dict = super().model_dump(**kwargs) diff --git a/tests/config/test_configuration_model.py b/tests/config/test_configuration_model.py index 083ee1dda..353184cf6 100644 --- a/tests/config/test_configuration_model.py +++ b/tests/config/test_configuration_model.py @@ -120,9 +120,7 @@ def test_n2v2_and_transforms(minimum_configuration: dict, algorithm, strategy): minimum_configuration["data_config"]["transforms"] = [ { "name": SupportedTransform.N2V_MANIPULATE.value, - "parameters": { - "strategy": strategy, - }, + "strategy": strategy, } ] config = Configuration(**minimum_configuration) diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py index cd42b55b4..a343da516 100644 --- a/tests/config/test_data_model.py +++ b/tests/config/test_data_model.py @@ -307,16 +307,12 @@ def test_3D_and_transforms(minimum_data: dict): minimum_data["transforms"] = [ { "name": SupportedTransform.NDFLIP.value, - "parameters": { - "is_3D": True, - "flip_z": True, - }, + "is_3D": True, + "flip_z": True, }, { "name": SupportedTransform.XY_RANDOM_ROTATE90.value, - "parameters": { - "is_3D": True, - }, + "is_3D": True, }, ] data = DataModel(**minimum_data) diff --git a/tests/config/transformations/test_n2v_manipulate_model.py b/tests/config/transformations/test_n2v_manipulate_model.py index 966bd4bc0..6939182c8 100644 --- a/tests/config/transformations/test_n2v_manipulate_model.py +++ b/tests/config/transformations/test_n2v_manipulate_model.py @@ -16,3 +16,11 @@ def test_odd_roi_and_mask(): with pytest.raises(ValueError): N2VManipulateModel(name="N2VManipulate", roi_size=3, struct_mask_span=6) + + +def test_extra_parameters(): + """Test that errors are thrown if we pass extra parameters.""" + with pytest.raises(ValueError): + N2VManipulateModel( + name="N2VManipulate", roi_size=3, struct_mask_span=7, extra_param=1 + ) diff --git a/tests/config/transformations/test_normalize_model.py b/tests/config/transformations/test_normalize_model.py new file mode 100644 index 000000000..324b635fe --- /dev/null +++ b/tests/config/transformations/test_normalize_model.py @@ -0,0 +1,13 @@ +from careamics.config.transformations import NormalizeModel + + +def test_setting_mean_std(): + """Test that we can set the mean and std values.""" + model = NormalizeModel(name="Normalize", mean=0.5, std=0.5) + assert model.mean == 0.5 + assert model.std == 0.5 + + model.mean = 0.6 + model.std = 0.6 + assert model.mean == 0.6 + assert model.std == 0.6 From 1d0750870c69dc4cb1ae07714e70cdbbb47f43e6 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Sat, 27 Apr 2024 13:29:47 +0200 Subject: [PATCH 10/14] Refactor validators, add patch min and power of 2 --- .../config/architectures/custom_model.py | 2 +- src/careamics/config/configuration_factory.py | 4 - src/careamics/config/configuration_model.py | 139 ++++++------------ src/careamics/config/data_model.py | 25 +--- src/careamics/config/inference_model.py | 30 +++- src/careamics/config/references/__init__.py | 33 ++++- .../references/algorithm_descriptions.py | 93 ++++++++++++ src/careamics/config/references/references.py | 7 +- src/careamics/config/validators/__init__.py | 5 + .../config/validators/validator_utils.py | 100 +++++++++++++ .../model_io/bioimage/model_description.py | 17 ++- src/careamics/prediction/stitch_prediction.py | 3 +- src/careamics/utils/__init__.py | 6 - src/careamics/utils/validators.py | 114 -------------- src/careamics/utils/wandb.py | 121 --------------- tests/config/test_data_model.py | 6 +- tests/config/test_inference_model.py | 12 +- .../validators/test_validator_utils.py} | 26 +++- tests/conftest.py | 118 +++------------ tests/dataset/test_in_memory_dataset.py | 10 +- tests/dataset/test_iterable_dataset.py | 16 +- tests/model_io/test_model_io_utils.py | 1 + ...ion_utils.py => test_stitch_prediction.py} | 4 +- tests/test_lightning_datamodule.py | 12 +- 24 files changed, 409 insertions(+), 495 deletions(-) create mode 100644 src/careamics/config/references/algorithm_descriptions.py create mode 100644 src/careamics/config/validators/__init__.py create mode 100644 src/careamics/config/validators/validator_utils.py delete mode 100644 src/careamics/utils/validators.py delete mode 100644 src/careamics/utils/wandb.py rename tests/{utils/test_axes.py => config/validators/test_validator_utils.py} (60%) rename tests/prediction/{test_prediction_utils.py => test_stitch_prediction.py} (92%) diff --git a/src/careamics/config/architectures/custom_model.py b/src/careamics/config/architectures/custom_model.py index f7540bc38..56e9b753b 100644 --- a/src/careamics/config/architectures/custom_model.py +++ b/src/careamics/config/architectures/custom_model.py @@ -55,7 +55,7 @@ class CustomModel(ArchitectureModel): >>> # Create a configuration >>> config_dict = { ... "architecture": "Custom", - ... "name": "linear", + ... "name": "my_linear", ... "in_features": 10, ... "out_features": 5, ... } diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index 283584aae..012cd65b2 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -29,10 +29,6 @@ def create_n2n_configuration( use_augmentations: bool = True, use_n2v2: bool = False, n_channels: int = 1, - roi_size: int = 11, - masked_pixel_percentage: float = 0.2, - struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none", - struct_n2v_span: int = 5, logger: Literal["wandb", "tensorboard", "none"] = "none", model_kwargs: Optional[dict] = None, ) -> Configuration: diff --git a/src/careamics/config/configuration_model.py b/src/careamics/config/configuration_model.py index 1e437acd0..1696037d3 100644 --- a/src/careamics/config/configuration_model.py +++ b/src/careamics/config/configuration_model.py @@ -12,7 +12,20 @@ from .algorithm_model import AlgorithmModel from .data_model import DataModel -from .references import N2V2_REF, N2V_REF, STRUCTN2V_REF +from .references import ( + N2V2Ref, + N2VRef, + StructN2VRef, + N2VDescription, + N2V2Description, + StructN2VDescription, + StructN2V2Description, + N2V, + N2V2, + STRUCT_N2V, + STRUCT_N2V2, + CUSTOM +) from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform from .training_model import TrainingModel from .transformations.n2v_manipulate_model import ( @@ -323,18 +336,18 @@ def get_algorithm_flavour(self) -> str: if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 use_structN2V = ( - self.data_config.transforms[-1].parameters.struct_mask_axis != "none" + self.data_config.transforms[-1].struct_mask_axis != "none" ) # return the n2v flavour if use_n2v2 and use_structN2V: - return "StructN2V2" + return STRUCT_N2V2 elif use_n2v2: - return "N2V2" + return N2V2 elif use_structN2V: - return "StructN2V" + return STRUCT_N2V else: - return "Noise2Void" + return N2V return self.algorithm_config.algorithm.capitalize() @@ -342,6 +355,8 @@ def get_algorithm_description(self) -> str: """ Return a description of the algorithm. + This method is used to generate the README of the BioImage Model Zoo export. + Returns ------- str @@ -349,82 +364,25 @@ def get_algorithm_description(self) -> str: """ algorithm_flavour = self.get_algorithm_flavour() - if algorithm_flavour == "Custom": + if algorithm_flavour == CUSTOM: return f"Custom algorithm, named {self.algorithm_config.model.name}" else: # currently only N2V flavours - if algorithm_flavour == "Noise2Void": - return ( - "Noise2Void is a UNet-based self-supervised algorithm that uses " - "blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels)." - ) - elif algorithm_flavour == "N2V2": - return ( - "N2V2 is an iteration of Noise2Void. " - "Noise2Void is a UNet-based self-supervised algorithm that uses " - "blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels). " - "N2V2 introduces blur-pool layers and removed skip connections in " - "the UNet architecture to remove checkboard artefacts, a common " - "artefacts ocurring in Noise2Void." - ) - elif algorithm_flavour == "StructN2V": - return ( - "StructN2V is an iteration of Noise2Void. " - "Noise2Void is a UNet-based self-supervised algorithm that uses " - "blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels). " - "StructN2V uses a linear mask (horizontal or vertical) to replace " - "the pixel values of neighbors of the masked pixels by a random " - "value. Such masking allows removing 1D structured noise from the " - "the images, the main failure case of the original N2V." - ) - elif algorithm_flavour == "StructN2V2": - return ( - "StructN2V2 is an iteration of Noise2Void that uses both " - "structN2V and N2V2 ." - "Noise2Void is a UNet-based self-supervised algorithm that uses " - "blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels). " - "StructN2V uses a linear mask (horizontal or vertical) to replace " - "the pixel values of neighbors of the masked pixels by a random " - "value. Such masking allows removing 1D structured noise from the " - "the images, the main failure case of the original N2V." - "N2V2 introduces blur-pool layers and removed skip connections in " - "the UNet architecture to remove checkboard artefacts, a common " - "artefacts ocurring in Noise2Void." - ) + if algorithm_flavour == N2V: + return N2VDescription().description + elif algorithm_flavour == N2V2: + return N2V2Description().description + elif algorithm_flavour == STRUCT_N2V: + return StructN2VDescription().description + elif algorithm_flavour == STRUCT_N2V2: + return StructN2V2Description().description return "" def get_algorithm_citations(self) -> List[CiteEntry]: """ - Return a list of citation entries corresponding to the algorithm - defined in the configuration. + Return a list of citation entries of the current algorithm. + + This is used to generate the model description for the BioImage Model Zoo. Returns ------- @@ -434,18 +392,18 @@ def get_algorithm_citations(self) -> List[CiteEntry]: if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 use_structN2V = ( - self.data_config.transforms[-1].parameters.struct_mask_axis != "none" + self.data_config.transforms[-1].struct_mask_axis != "none" ) # return the (struct)N2V(2) references if use_n2v2 and use_structN2V: - return [N2V_REF, N2V2_REF, STRUCTN2V_REF] + return [N2VRef, N2V2Ref, StructN2VRef] elif use_n2v2: - return [N2V_REF, N2V2_REF] + return [N2VRef, N2V2Ref] elif use_structN2V: - return [N2V_REF, STRUCTN2V_REF] + return [N2VRef, StructN2VRef] else: - return [N2V_REF] + return [N2VRef] raise ValueError("Citation not available for custom algorithm.") @@ -453,6 +411,8 @@ def get_algorithm_references(self) -> str: """ Get the algorithm references. + This is used to generate the README of the BioImage Model Zoo export. + Returns ------- str @@ -461,13 +421,13 @@ def get_algorithm_references(self) -> str: if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 use_structN2V = ( - self.data_config.transforms[-1].parameters.struct_mask_axis != "none" + self.data_config.transforms[-1].struct_mask_axis != "none" ) references = [ - N2V_REF.text + " doi: " + N2V_REF.doi, - N2V2_REF.text + " doi: " + N2V2_REF.doi, - STRUCTN2V_REF.text + " doi: " + STRUCTN2V_REF.doi, + N2VRef.text + " doi: " + N2VRef.doi, + N2V2Ref.text + " doi: " + N2V2Ref.doi, + StructN2VRef.text + " doi: " + StructN2VRef.doi, ] # return the (struct)N2V(2) references @@ -496,7 +456,7 @@ def get_algorithm_keywords(self) -> List[str]: if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 use_structN2V = ( - self.data_config.transforms[-1].parameters.struct_mask_axis != "none" + self.data_config.transforms[-1].struct_mask_axis != "none" ) keywords = [ @@ -506,13 +466,13 @@ def get_algorithm_keywords(self) -> List[str]: "3D" if "Z" in self.data_config.axes else "2D", "CAREamics", "pytorch", - "Noise2Void", + N2V, ] if use_n2v2: - keywords.append("N2V2") + keywords.append(N2V2) if use_structN2V: - keywords.append("StructN2V2") + keywords.append(STRUCT_N2V) else: keywords = ["CAREamics"] @@ -546,9 +506,6 @@ def model_dump( exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs ) - # change Path into str - # dictionary["working_directory"] = str(dictionary["working_directory"]) - return dictionary diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index 1acead884..c1f7c9a20 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -2,7 +2,7 @@ from __future__ import annotations from pprint import pformat -from typing import Any, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Union, Tuple from albumentations import Compose from pydantic import ( @@ -15,8 +15,7 @@ ) from typing_extensions import Annotated -from careamics.utils import check_axes_validity - +from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 from .support import SupportedTransform from .transformations.n2v_manipulate_model import N2VManipulateModel from .transformations.nd_flip_model import NDFlipModel @@ -34,7 +33,6 @@ ] -# TODO does patches need to be multiple of 8 with UNet? class DataModel(BaseModel): """ Data configuration. @@ -88,7 +86,7 @@ class DataModel(BaseModel): # Dataset configuration data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData - patch_size: List[int] = Field(..., min_length=2, max_length=3) + patch_size: Union[List[int], Tuple[int]] = Field(..., min_length=2, max_length=3) batch_size: int = Field(default=1, ge=1, validate_default=True) axes: str @@ -118,11 +116,11 @@ class DataModel(BaseModel): @field_validator("patch_size") @classmethod - def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: + def all_elements_power_of_2_minimum_8(cls, patch_list: List[int]) -> List[int]: """ Validate patch size. - Patch size must be non-zero, positive and even. + Patch size must be powers of 2 and minimum 8. Parameters ---------- @@ -137,18 +135,11 @@ def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: Raises ------ ValueError - If the patch size is 0. + If the patch size is smaller than 8. ValueError - If the patch size is not even. + If the patch size is not a power of 2. """ - for dim in patch_list: - if dim < 1: - raise ValueError(f"Patch size must be non-zero positive (got {dim}).") - - if dim % 2 != 0: - raise ValueError(f"Patch size must be even (got {dim}).") - - return patch_list + return patch_size_ge_than_8_power_of_2(patch_list) @field_validator("axes") @classmethod diff --git a/src/careamics/config/inference_model.py b/src/careamics/config/inference_model.py index c5834a4ec..967d4f556 100644 --- a/src/careamics/config/inference_model.py +++ b/src/careamics/config/inference_model.py @@ -5,8 +5,7 @@ from albumentations import Compose from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from careamics.utils import check_axes_validity - +from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 from .support import SupportedTransform from .transformations.normalize_model import NormalizeModel @@ -47,7 +46,7 @@ class InferenceModel(BaseModel): # Dataloader parameters batch_size: int = Field(default=1, ge=1) - @field_validator("tile_size", "tile_overlap") + @field_validator("tile_overlap") @classmethod def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: """ @@ -83,6 +82,31 @@ def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: raise ValueError(f"Patch size must be even (got {dim}).") return patch_list + + @field_validator("tile_size") + @classmethod + def tile_min_8_power_of_2(cls, tile_list: List[int]) -> List[int]: + """ + Validate that each entry is greater or equal than 8 and a power of 2. + + Parameters + ---------- + tile_list : List[int] + Patch size. + + Returns + ------- + List[int] + Validated patch size. + + Raises + ------ + ValueError + If the patch size if smaller than 8. + ValueError + If the patch size is not a power of 2. + """ + return patch_size_ge_than_8_power_of_2(tile_list) @field_validator("axes") @classmethod diff --git a/src/careamics/config/references/__init__.py b/src/careamics/config/references/__init__.py index 862827e26..d4ca6fdf4 100644 --- a/src/careamics/config/references/__init__.py +++ b/src/careamics/config/references/__init__.py @@ -1,13 +1,34 @@ """Module containing references to the algorithm used in CAREamics.""" __all__ = [ - "N2V2_REF", - "N2V_REF", - "STRUCTN2V_REF", + "N2V2Ref", + "N2VRef", + "StructN2VRef", + "N2VDescription", + "N2V2Description", + "StructN2VDescription", + "StructN2V2Description", + "N2V", + "N2V2", + "STRUCT_N2V", + "STRUCT_N2V2", + "CUSTOM" ] from .references import ( - N2V2_REF, - N2V_REF, - STRUCTN2V_REF, + N2V2Ref, + N2VRef, + StructN2VRef, +) + +from .algorithm_descriptions import ( + N2VDescription, + N2V2Description, + StructN2VDescription, + StructN2V2Description, + N2V, + N2V2, + STRUCT_N2V, + STRUCT_N2V2, + CUSTOM ) diff --git a/src/careamics/config/references/algorithm_descriptions.py b/src/careamics/config/references/algorithm_descriptions.py new file mode 100644 index 000000000..624cb7928 --- /dev/null +++ b/src/careamics/config/references/algorithm_descriptions.py @@ -0,0 +1,93 @@ +"""Descriptions of the algorithms used in CAREmics.""" +from pydantic import BaseModel + +CUSTOM = "Custom" +N2V = "Noise2Void" +N2V2 = "N2V2" +STRUCT_N2V = "StructN2V" +STRUCT_N2V2 = "StructN2V2" + + +N2V_DESCRIPTION = "Noise2Void is a UNet-based self-supervised algorithm that " \ + "uses blind-spot training to denoise images. In short, in every " \ + "patches during training, random pixels are selected and their " \ + "value replaced by a neighboring pixel value. The network is then " \ + "trained to predict the original pixel value. The algorithm " \ + "relies on the continuity of the signal (neighboring pixels have " \ + "similar values) and the pixel-wise independence of the noise " \ + "(the noise in a pixel is not correlated with the noise in " \ + "neighboring pixels)." + +class AlgorithmDescription(BaseModel): + """Description of an algorithm. + + Attributes + ---------- + description : str + Description of the algorithm. + """ + + description: str + + +class N2VDescription(AlgorithmDescription): + """Description of Noise2Void. + + Attributes + ---------- + description : str + Description of Noise2Void. + """ + + description: str = N2V_DESCRIPTION + + +class N2V2Description(AlgorithmDescription): + """Description of N2V2. + + Attributes + ---------- + description : str + Description of N2V2. + """ + + description: str = "N2V2 is a variant of Noise2Void. " + N2V_DESCRIPTION + \ + "\nN2V2 introduces blur-pool layers and removed skip " \ + "connections in the UNet architecture to remove checkboard " \ + "artefacts, a common artefacts ocurring in Noise2Void." + + +class StructN2VDescription(AlgorithmDescription): + """Description of StructN2V. + + Attributes + ---------- + description : str + Description of StructN2V. + """ + + description: str = "StructN2V is a variant of Noise2Void. " + N2V_DESCRIPTION + \ + "\nStructN2V uses a linear mask (horizontal or vertical) to replace " \ + "the pixel values of neighbors of the masked pixels by a random " \ + "value. Such masking allows removing 1D structured noise from the " \ + "the images, the main failure case of the original N2V." + + +class StructN2V2Description(AlgorithmDescription): + """Description of StructN2V2. + + Attributes + ---------- + description : str + Description of StructN2V2. + """ + + description: str = "StructN2V2 is a a variant of Noise2Void that uses both " \ + "structN2V and N2V2. "+ N2V_DESCRIPTION + \ + "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace " \ + "the pixel values of neighbors of the masked pixels by a random " \ + "value. Such masking allows removing 1D structured noise from the " \ + "the images, the main failure case of the original N2V." \ + "\nN2V2 introduces blur-pool layers and removed skip connections in " \ + "the UNet architecture to remove checkboard artefacts, a common " \ + "artefacts ocurring in Noise2Void." \ No newline at end of file diff --git a/src/careamics/config/references/references.py b/src/careamics/config/references/references.py index 9b2e29211..46a50efc9 100644 --- a/src/careamics/config/references/references.py +++ b/src/careamics/config/references/references.py @@ -1,13 +1,14 @@ +"""References for the CAREamics algorithms.""" from bioimageio.spec.generic.v0_3 import CiteEntry -N2V_REF = CiteEntry( +N2VRef = CiteEntry( text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning ' 'denoising from single noisy images". In Proceedings of the IEEE/CVF ' "conference on computer vision and pattern recognition (pp. 2129-2137).", doi="10.1109/cvpr.2019.00223", ) -N2V2_REF = CiteEntry( +N2V2Ref = CiteEntry( text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., " '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified ' 'sampling strategies and a tweaked network architecture". In European ' @@ -15,7 +16,7 @@ doi="10.1007/978-3-031-25069-9_33", ) -STRUCTN2V_REF = CiteEntry( +StructN2VRef = CiteEntry( text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020." '"Removing structured noise with self-supervised blind-spot ' 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' diff --git a/src/careamics/config/validators/__init__.py b/src/careamics/config/validators/__init__.py new file mode 100644 index 000000000..7b82a6a7e --- /dev/null +++ b/src/careamics/config/validators/__init__.py @@ -0,0 +1,5 @@ +"""Validator utilities.""" + +__all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"] + +from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2 \ No newline at end of file diff --git a/src/careamics/config/validators/validator_utils.py b/src/careamics/config/validators/validator_utils.py new file mode 100644 index 000000000..a63726c45 --- /dev/null +++ b/src/careamics/config/validators/validator_utils.py @@ -0,0 +1,100 @@ +""" +Validator functions. + +These functions are used to validate dimensions and axes of inputs. +""" +from typing import Tuple, Optional, List, Union + +_AXES = "STCZYX" + + +def check_axes_validity(axes: str) -> bool: + """ + Sanity check on axes. + + The constraints on the axes are the following: + - must be a combination of 'STCZYX' + - must not contain duplicates + - must contain at least 2 contiguous axes: X and Y + - must contain at most 4 axes + - cannot contain both S and T axes + + Axes do not need to be in the order 'STCZYX', as this depends on the user data. + + Parameters + ---------- + axes : str + Axes to validate. + + Returns + ------- + bool + True if axes are valid, False otherwise. + """ + _axes = axes.upper() + + # Minimum is 2 (XY) and maximum is 4 (TZYX) + if len(_axes) < 2 or len(_axes) > 6: + raise ValueError( + f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes." + ) + + if "YX" not in _axes and "XY" not in _axes: + raise ValueError( + f"Invalid axes {axes}. Must contain at least X and Y axes consecutively." + ) + + # all characters must be in REF_AXES = 'STCZYX' + if not all(s in _AXES for s in _axes): + raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.") + + # check for repeating characters + for i, s in enumerate(_axes): + if i != _axes.rfind(s): + raise ValueError( + f"Invalid axes {axes}. Cannot contain duplicate axes" + f" (got multiple {axes[i]})." + ) + + return True + + +def patch_size_ge_than_8_power_of_2( + patch_list: Optional[Union[List[int], Tuple[int]]] + ) -> Optional[Union[List[int], Tuple[int]]]: + """ + Validate that each entry is greater or equal than 8 and a power of 2. + + If None is passed, the function will return None. + + Parameters + ---------- + patch_list : Optional[Union[List[int], Tuple[int]]] + Patch size. + + Returns + ------- + Optional[Union[List[int], Tuple[int]]] + Validated patch size. + + Raises + ------ + ValueError + If the patch size if smaller than 8. + ValueError + If the patch size is not a power of 2. + """ + if patch_list is not None: + for dim in patch_list: + if dim < 8: + raise ValueError( + f"Patch size must be non-zero positive (got {dim})." + ) + + if (dim & (dim - 1)) != 0: + raise ValueError( + f"Patch size must be a power of 2 in all dimensions " + f"(got {dim})." + ) + + return patch_list \ No newline at end of file diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index 0a4a596bf..88ffa4684 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -126,8 +126,12 @@ def _create_inputs_ouputs( std = data_config.std # and the mean and std required to invert the normalization - inv_mean = -mean / std - inv_std = 1 / std + # CAREamics denormalization: x = y * (std + eps) + mean + # BMZ normalization : x = (y - mean') / (std' + eps) + # to apply the BMZ normalization as a denormalization step, we need: + eps = 1e-6 + inv_mean = -mean / (std + eps) + inv_std = 1 / (std + eps) - eps # create input/output descriptions input_descr = InputTensorDescr( @@ -256,6 +260,15 @@ def create_model_description( weights=weights_descr, attachments=[FileDescr(source=config_path)], cite=config.get_algorithm_citations(), + config={ # conversion from float32 to float64 creates small differences... + "bioimageio": { + "test_kwargs": { + "pytorch_state_dict": { + "decimals": 2, # ...so we relax the constraints on the decimals + } + } + } + } ) return model diff --git a/src/careamics/prediction/stitch_prediction.py b/src/careamics/prediction/stitch_prediction.py index f88233a81..5e0ee7e11 100644 --- a/src/careamics/prediction/stitch_prediction.py +++ b/src/careamics/prediction/stitch_prediction.py @@ -41,7 +41,8 @@ def stitch_prediction( [el.numpy() for el in stitching_data[0][0]], dtype=int ).squeeze() - predicted_image = np.zeros(input_shape, dtype=np.float32) + # TODO should use torch.zeros instead of np.zeros + predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32)) for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip( tiles, stitching_data diff --git a/src/careamics/utils/__init__.py b/src/careamics/utils/__init__.py index b37f4fef2..ef2929113 100644 --- a/src/careamics/utils/__init__.py +++ b/src/careamics/utils/__init__.py @@ -2,8 +2,6 @@ __all__ = [ - "check_axes_validity", - "check_tiling_validity", "cwd", "MetricTracker", "get_ram_size", @@ -21,7 +19,3 @@ from .path_utils import check_path_exists from .ram import get_ram_size from .running_stats import RunningStats -from .validators import ( - check_axes_validity, - check_tiling_validity, -) diff --git a/src/careamics/utils/validators.py b/src/careamics/utils/validators.py deleted file mode 100644 index 672a5ed0c..000000000 --- a/src/careamics/utils/validators.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -Validator functions. - -These functions are used to validate dimensions and axes of inputs. -""" -from typing import List - -AXES = "STCZYX" - - -def check_axes_validity(axes: str) -> bool: - """ - Sanity check on axes. - - The constraints on the axes are the following: - - must be a combination of 'STCZYX' - - must not contain duplicates - - must contain at least 2 contiguous axes: X and Y - - must contain at most 4 axes - - cannot contain both S and T axes - - Axes do not need to be in the order 'STCZYX', as this depends on the user data. - - Parameters - ---------- - axes : str - Axes to validate. - - Returns - ------- - bool - True if axes are valid, False otherwise. - """ - _axes = axes.upper() - - # Minimum is 2 (XY) and maximum is 4 (TZYX) - if len(_axes) < 2 or len(_axes) > 6: - raise ValueError( - f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes." - ) - - if "YX" not in _axes and "XY" not in _axes: - raise ValueError( - f"Invalid axes {axes}. Must contain at least X and Y axes consecutively." - ) - - # all characters must be in REF_AXES = 'STCZYX' - if not all(s in AXES for s in _axes): - raise ValueError(f"Invalid axes {axes}. Must be a combination of {AXES}.") - - # check for repeating characters - for i, s in enumerate(_axes): - if i != _axes.rfind(s): - raise ValueError( - f"Invalid axes {axes}. Cannot contain duplicate axes" - f" (got multiple {axes[i]})." - ) - - return True - - -def check_tiling_validity(tile_shape: List[int], overlaps: List[int]) -> None: - """ - Check that the tiling parameters are valid. - - Parameters - ---------- - tile_shape : List[int] - Shape of the tiles. - overlaps : List[int] - Overlap between tiles. - - Raises - ------ - ValueError - If one of the parameters is None. - ValueError - If one of the element is zero. - ValueError - If one of the element is non-divisible by 2. - ValueError - If the number of elements in `overlaps` and `tile_shape` is different. - ValueError - If one of the overlaps is larger than the corresponding tile shape. - """ - # cannot be None - if tile_shape is None or overlaps is None: - raise ValueError( - "Cannot use tiling without specifying `tile_shape` and " - "`overlaps`, make sure they have been correctly specified." - ) - - # non-zero and divisible by two - for dims_list in [tile_shape, overlaps]: - for dim in dims_list: - if dim < 0: - raise ValueError(f"Entry must be non-null positive (got {dim}).") - - if dim % 2 != 0: - raise ValueError(f"Entry must be divisible by 2 (got {dim}).") - - # same length - if len(overlaps) != len(tile_shape): - raise ValueError( - f"Overlaps ({len(overlaps)}) and tile shape ({len(tile_shape)}) must " - f"have the same number of dimensions." - ) - - # overlaps smaller than tile shape - for overlap, tile_dim in zip(overlaps, tile_shape): - if overlap >= tile_dim: - raise ValueError( - f"Overlap ({overlap}) must be smaller than tile shape ({tile_dim})." - ) diff --git a/src/careamics/utils/wandb.py b/src/careamics/utils/wandb.py deleted file mode 100644 index 3bae730d7..000000000 --- a/src/careamics/utils/wandb.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -A WandB logger for CAREamics. - -Implements a WandB class for use within the Engine. -""" -import sys -from pathlib import Path -from typing import Dict, Union - -import torch -import wandb - -from ..config import Configuration - - -def is_notebook() -> bool: - """ - Check if the code is executed from a notebook or a qtconsole. - - Returns - ------- - bool - True if the code is executed from a notebooks, False otherwise. - """ - try: - from IPython import get_ipython - - shell = get_ipython().__class__.__name__ - if shell == "ZMQInteractiveShell": - return True # Jupyter notebook or qtconsole - else: - return False - except (NameError, ModuleNotFoundError): - return False - - -class WandBLogging: - """ - WandB logging class. - - Parameters - ---------- - experiment_name : str - Name of the experiment. - log_path : Path - Path in which to save the WandB log. - config : Configuration - Configuration of the model. - model_to_watch : torch.nn.Module - Model. - save_code : bool, optional - Whether to save the code, by default True. - """ - - def __init__( - self, - experiment_name: str, - log_path: Path, - config: Configuration, - model_to_watch: torch.nn.Module, - save_code: bool = True, - ): - """ - Constructor. - - Parameters - ---------- - experiment_name : str - Name of the experiment. - log_path : Path - Path in which to save the WandB log. - config : Configuration - Configuration of the model. - model_to_watch : torch.nn.Module - Model. - save_code : bool, optional - Whether to save the code, by default True. - """ - self.run = wandb.init( - project="careamics-restoration", - dir=log_path, - name=experiment_name, - config=config.model_dump() if config else None, - # save_code=save_code, - ) - if model_to_watch: - wandb.watch(model_to_watch, log="all", log_freq=1) - if save_code: - if is_notebook(): - # Get all sys path and select the root - code_path = Path([p for p in sys.path if "caremics" in p][-1]).parent - else: - code_path = Path("../") - self.log_code(code_path) - - def log_metrics(self, metric_dict: Dict) -> None: - """ - Log metrics to wandb. - - Parameters - ---------- - metric_dict : Dict - New metrics entry. - """ - self.run.log(metric_dict, commit=True) - - def log_code(self, code_path: Union[str, Path]) -> None: - """ - Log code to wandb. - - Parameters - ---------- - code_path : Union[str, Path] - Path to the code. - """ - self.run.log_code( - root=code_path, - include_fn=lambda path: path.endswith(".py") - or path.endswith(".yml") - or path.endswith(".yaml"), - ) diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py index a343da516..4e59d776a 100644 --- a/tests/config/test_data_model.py +++ b/tests/config/test_data_model.py @@ -96,15 +96,15 @@ def test_patch_size(minimum_data: dict): data_model = DataModel(**minimum_data) # 3D - minimum_data["patch_size"] = [12, 12, 12] + minimum_data["patch_size"] = [16, 8, 8] minimum_data["axes"] = "ZYX" data_model = DataModel(**minimum_data) - assert data_model.patch_size == [12, 12, 12] + assert data_model.patch_size == minimum_data["patch_size"] @pytest.mark.parametrize( - "patch_size", [[12], [0, 12, 12], [12, 12, 13], [12, 12, 12, 12]] + "patch_size", [[12], [0, 12, 12], [12, 12, 13], [16, 10, 16], [12, 12, 12, 12]] ) def test_wrong_patch_size(minimum_data: dict, patch_size): """Test that wrong patch sizes are not accepted (zero or odd, dims 1 or > 3).""" diff --git a/tests/config/test_inference_model.py b/tests/config/test_inference_model.py index bd6f0e98c..885a48169 100644 --- a/tests/config/test_inference_model.py +++ b/tests/config/test_inference_model.py @@ -49,22 +49,22 @@ def test_tile_size(minimum_inference: dict): prediction_model = InferenceModel(**minimum_inference) # 2D - minimum_inference["tile_size"] = [12, 12] + minimum_inference["tile_size"] = [16, 8] minimum_inference["tile_overlap"] = [2, 2] minimum_inference["axes"] = "YX" prediction_model = InferenceModel(**minimum_inference) - assert prediction_model.tile_size == [12, 12] - assert prediction_model.tile_overlap == [2, 2] + assert prediction_model.tile_size == minimum_inference["tile_size"] + assert prediction_model.tile_overlap == minimum_inference["tile_overlap"] # 3D - minimum_inference["tile_size"] = [12, 12, 12] + minimum_inference["tile_size"] = [16, 8, 32] minimum_inference["tile_overlap"] = [2, 2, 2] minimum_inference["axes"] = "ZYX" prediction_model = InferenceModel(**minimum_inference) - assert prediction_model.tile_size == [12, 12, 12] - assert prediction_model.tile_overlap == [2, 2, 2] + assert prediction_model.tile_size == minimum_inference["tile_size"] + assert prediction_model.tile_overlap == minimum_inference["tile_overlap"] @pytest.mark.parametrize( diff --git a/tests/utils/test_axes.py b/tests/config/validators/test_validator_utils.py similarity index 60% rename from tests/utils/test_axes.py rename to tests/config/validators/test_validator_utils.py index 0c8e56135..afc7d8210 100644 --- a/tests/utils/test_axes.py +++ b/tests/config/validators/test_validator_utils.py @@ -1,7 +1,8 @@ import pytest -from careamics.utils import check_axes_validity - +from careamics.config.validators import ( + check_axes_validity, patch_size_ge_than_8_power_of_2 +) @pytest.mark.parametrize( "axes, valid", @@ -43,3 +44,24 @@ def test_are_axes_valid(axes, valid): else: with pytest.raises((ValueError, NotImplementedError)): check_axes_validity(axes) + + +@pytest.mark.parametrize("patch_size, error", + [ + ((2, 8, 8), True), + ((10,), True), + ((8, 10, 16), True), + ((8, 13), True), + ((8, 16, 4), True), + ((8,), False), + ((8, 8), False), + ((8, 64, 64), False), + ] +) +def test_patch_size(patch_size, error): + """Test if patch size is valid.""" + if error: + with pytest.raises(ValueError): + patch_size_ge_than_8_power_of_2(patch_size) + else: + patch_size_ge_than_8_power_of_2(patch_size) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 13d553c3a..4bc99fd46 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,11 @@ -import tempfile from pathlib import Path -from typing import Callable, Generator, Tuple +from typing import Callable, Tuple import numpy as np import pytest -import tifffile from careamics import CAREamist, Configuration -from careamics.config.algorithm_model import ( - AlgorithmModel, - LrSchedulerModel, - OptimizerModel, -) -from careamics.config.data_model import DataModel from careamics.config.support import SupportedData -from careamics.config.training_model import TrainingModel # TODO add details about where each of these fixture is used (e.g. smoke test) @@ -41,7 +32,7 @@ def minimum_algorithm_custom() -> dict: # create dictionary algorithm = { "algorithm": "custom", - "loss": "n2v", + "loss": "mae", "model": { "architecture": "UNet", }, @@ -103,9 +94,9 @@ def minimum_data() -> dict: """ # create dictionary data = { - "data_type": SupportedData.TIFF.value, - "patch_size": [64, 64], - "axes": "SYX", + "data_type": SupportedData.ARRAY.value, + "patch_size": [8, 8], + "axes": "YX", } return data @@ -122,10 +113,10 @@ def minimum_inference() -> dict: """ # create dictionary predic = { - "data_type": SupportedData.TIFF.value, - "mean": 0.0, + "data_type": SupportedData.ARRAY.value, + "axes": "YX", + "mean": 2.0, "std": 1.0, - "axes": "SYX", } return predic @@ -142,7 +133,7 @@ def minimum_training() -> dict: """ # create dictionary training = { - "num_epochs": 666, + "num_epochs": 1, } return training @@ -181,6 +172,20 @@ def minimum_configuration( return configuration +@pytest.fixture +def supervised_configuration( + minimum_algorithm_supervised: dict, minimum_data: dict, minimum_training: dict +) -> dict: + configuration = { + "experiment_name": "LevitatingFrog", + "algorithm_config": minimum_algorithm_supervised, + "training_config": minimum_training, + "data_config": minimum_data, + } + + return configuration + + @pytest.fixture def ordered_array() -> Callable: """A function that returns an array with ordered values.""" @@ -227,17 +232,6 @@ def array_3D() -> np.ndarray: return np.arange(2048 * 3).reshape((1, 3, 8, 16, 16)) -@pytest.fixture -def temp_dir() -> Generator[Path, None, None]: - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - -@pytest.fixture -def image_size() -> Tuple[int, int]: - return (128, 128) - - @pytest.fixture def patch_size() -> Tuple[int, int]: return (64, 64) @@ -248,72 +242,6 @@ def overlaps() -> Tuple[int, int]: return (32, 32) -@pytest.fixture -def example_data_path( - temp_dir: Path, image_size: Tuple[int, int], patch_size: Tuple[int, int] -) -> Tuple[Path, Path]: - test_image = np.random.rand(*image_size) - - train_path = temp_dir / "train" - val_path = temp_dir / "val" - test_path = temp_dir / "test" - train_path.mkdir() - val_path.mkdir() - test_path.mkdir() - - tifffile.imwrite(train_path / "train_image.tif", test_image) - tifffile.imwrite(val_path / "val_image.tif", test_image) - tifffile.imwrite(test_path / "test_image.tif", test_image) - - return train_path, val_path, test_path - - -@pytest.fixture -def base_configuration(temp_dir: Path, patch_size) -> Configuration: - configuration = Configuration( - experiment_name="smoke_test", - working_directory=temp_dir, - algorithm_config=AlgorithmModel( - algorithm="n2v", - loss="n2v", - model={"architecture": "UNet"}, - is_3D="False", - transforms={"Flip": None, "ManipulateN2V": None}, - ), - data_config=DataModel( - in_memory=True, - extension="tif", - axes="YX", - ), - training_config=TrainingModel( - num_epochs=1, - patch_size=patch_size, - batch_size=2, - optimizer=OptimizerModel(name="Adam"), - lr_scheduler=LrSchedulerModel(name="ReduceLROnPlateau"), - extraction_strategy="random", - augmentation=True, - num_workers=0, - use_wandb=False, - ), - ) - return configuration - - -@pytest.fixture -def supervised_configuration( - minimum_algorithm_supervised: dict, minimum_data: dict, minimum_training: dict -) -> dict: - configuration = { - "experiment_name": "LevitatingFrog", - "algorithm_config": minimum_algorithm_supervised, - "training_config": minimum_training, - "data_config": minimum_data, - } - - return configuration - - @pytest.fixture def pre_trained(tmp_path, minimum_configuration): """Fixture to create a pre-trained CAREamics model.""" diff --git a/tests/dataset/test_in_memory_dataset.py b/tests/dataset/test_in_memory_dataset.py index 384211e36..8d2c27381 100644 --- a/tests/dataset/test_in_memory_dataset.py +++ b/tests/dataset/test_in_memory_dataset.py @@ -15,7 +15,7 @@ def test_number_of_patches(ordered_array): # create config config_dict = { "data_type": SupportedData.ARRAY.value, - "patch_size": [4, 4], + "patch_size": [8, 8], "axes": "YX", } config = DataModel(**config_dict) @@ -40,12 +40,12 @@ def test_compute_mean_std_transform(ordered_array): def test_extracting_val_array(ordered_array, percentage): """Test extracting a validation set patches from InMemoryDataset.""" # create array - array = ordered_array((20, 20)) + array = ordered_array((32, 32)) # create config config_dict = { "data_type": SupportedData.ARRAY.value, - "patch_size": [4, 4], + "patch_size": [8, 8], "axes": "YX", } config = DataModel(**config_dict) @@ -76,7 +76,7 @@ def test_extracting_val_array(ordered_array, percentage): def test_extracting_val_files(tmp_path, ordered_array, percentage): """Test extracting a validation set patches from InMemoryDataset.""" # create array - array = ordered_array((20, 20)) + array = ordered_array((32, 32)) # save array to file file_path = tmp_path / "array.tif" @@ -85,7 +85,7 @@ def test_extracting_val_files(tmp_path, ordered_array, percentage): # create config config_dict = { "data_type": SupportedData.ARRAY.value, - "patch_size": [4, 4], + "patch_size": [8, 8], "axes": "YX", } config = DataModel(**config_dict) diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py index 351df036b..5347310cf 100644 --- a/tests/dataset/test_iterable_dataset.py +++ b/tests/dataset/test_iterable_dataset.py @@ -12,16 +12,16 @@ "shape", [ # 2D - (20, 20), + (32, 32), # 3D - (20, 20, 20), + (32, 32, 32), ], ) def test_number_of_files(tmp_path, ordered_array, shape): """Test number of files in PathIterableDataset.""" # create array - array_size = 20 - patch_size = 4 + array_size = 32 + patch_size = 8 n_files = 3 factor = len(shape) axes = "YX" if factor == 2 else "ZYX" @@ -63,13 +63,13 @@ def test_read_function(tmp_path, ordered_array): def read_npy(file_path, *args, **kwargs): return np.load(file_path) - array_size = 20 - patch_size = 4 + array_size = 32 + patch_size = 8 n_files = 3 patch_sizes = [patch_size] * 2 # create array - array = ordered_array((n_files, array_size, array_size)) + array: np.ndarray = ordered_array((n_files, array_size, array_size)) # save each plane in a single .npy file files = [] @@ -115,7 +115,7 @@ def test_extracting_val_files(tmp_path, ordered_array, percentage): # create config config_dict = { "data_type": SupportedData.TIFF.value, - "patch_size": [4, 4], + "patch_size": [8, 8], "axes": "YX", } config = DataModel(**config_dict) diff --git a/tests/model_io/test_model_io_utils.py b/tests/model_io/test_model_io_utils.py index 3a3af1717..36cc575ae 100644 --- a/tests/model_io/test_model_io_utils.py +++ b/tests/model_io/test_model_io_utils.py @@ -32,3 +32,4 @@ def test_export_bmz(tmp_path, pre_trained): inputs=train_path, outputs=predicted_path, ) + assert (tmp_path / "model.zip").exists() diff --git a/tests/prediction/test_prediction_utils.py b/tests/prediction/test_stitch_prediction.py similarity index 92% rename from tests/prediction/test_prediction_utils.py rename to tests/prediction/test_stitch_prediction.py index c96920ccd..4908af233 100644 --- a/tests/prediction/test_prediction_utils.py +++ b/tests/prediction/test_stitch_prediction.py @@ -35,7 +35,7 @@ def test_stitch_prediction(ordered_array, input_shape, tile_size, overlaps): ) ) - # compute stitching coordinates + # compute stitching coordinates, it returns a torch.Tensor result = stitch_prediction(tiles, stitching_data) - assert (result == arr).all() + assert (result.numpy() == arr).all() diff --git a/tests/test_lightning_datamodule.py b/tests/test_lightning_datamodule.py index 6eeebda9b..33444971d 100644 --- a/tests/test_lightning_datamodule.py +++ b/tests/test_lightning_datamodule.py @@ -27,9 +27,10 @@ def test_lightning_train_datamodule_array(simple_array): data_module = CAREamicsTrainDataModule( train_data=simple_array, data_type="array", - patch_size=(2, 2), + patch_size=(8, 8), axes="YX", batch_size=2, + val_minimum_patches=2, ) data_module.prepare_data() data_module.setup() @@ -48,6 +49,7 @@ def test_lightning_train_datamodule_supervised_n2v_throws_error(simple_array): axes="YX", batch_size=2, train_target_data=simple_array, + val_minimum_patches=2, ) @@ -63,7 +65,7 @@ def test_lightning_train_datamodule_n2v2(simple_array, use_n2v2, strategy): data_module = CAREamicsTrainDataModule( train_data=simple_array, data_type="array", - patch_size=(10, 10), + patch_size=(16, 16), axes="YX", batch_size=2, use_n2v2=use_n2v2, @@ -79,7 +81,7 @@ def test_lightning_train_datamodule_structn2v(simple_array): data_module = CAREamicsTrainDataModule( train_data=simple_array, data_type="array", - patch_size=(10, 10), + patch_size=(16, 16), axes="YX", batch_size=2, struct_n2v_axis=struct_axis, @@ -113,12 +115,12 @@ def test_lightning_pred_datamodule_tiling(simple_array): axes="YX", batch_size=2, tile_overlap=[2, 2], - tile_size=[4, 4], + tile_size=[8, 8], ) data_module.prepare_data() data_module.setup() - assert len(list(data_module.predict_dataloader())) == 8 + assert len(list(data_module.predict_dataloader())) == 2 def test_lightning_pred_datamodule_no_tiling(simple_array): From 54e902964f116a1d2aa2d6ae370a1b009c1f2e31 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Sat, 27 Apr 2024 13:31:22 +0200 Subject: [PATCH 11/14] Add failing tests for model --- src/careamics/config/configuration_model.py | 29 ++- src/careamics/lightning_datamodule.py | 12 +- src/careamics/lightning_prediction_loop.py | 7 +- tests/test_lightning_module.py | 224 ++++++++++++++++++++ 4 files changed, 259 insertions(+), 13 deletions(-) diff --git a/src/careamics/config/configuration_model.py b/src/careamics/config/configuration_model.py index 1696037d3..abfb3712e 100644 --- a/src/careamics/config/configuration_model.py +++ b/src/careamics/config/configuration_model.py @@ -16,15 +16,21 @@ N2V2Ref, N2VRef, StructN2VRef, + N2NRef, + CARERef, N2VDescription, N2V2Description, StructN2VDescription, StructN2V2Description, + N2NDescription, + CAREDescription, N2V, N2V2, STRUCT_N2V, STRUCT_N2V2, - CUSTOM + CUSTOM, + N2N, + CARE ) from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform from .training_model import TrainingModel @@ -239,8 +245,7 @@ def validate_algorithm_and_data(self: Configuration) -> Configuration: name=SupportedTransform.N2V_MANIPULATE.value, ) ) - # TODO Doesn't validate the parameters of N2VManipulate !! - # make sure that N2V has the correct pixel manipulate strategy + median = SupportedPixelManipulation.MEDIAN.value uniform = SupportedPixelManipulation.UNIFORM.value strategy = median if self.algorithm_config.model.n2v2 else uniform @@ -348,9 +353,13 @@ def get_algorithm_flavour(self) -> str: return STRUCT_N2V else: return N2V - - return self.algorithm_config.algorithm.capitalize() - + elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N: + return N2N + elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE: + return CARE + else: + return CUSTOM + def get_algorithm_description(self) -> str: """ Return a description of the algorithm. @@ -375,6 +384,10 @@ def get_algorithm_description(self) -> str: return StructN2VDescription().description elif algorithm_flavour == STRUCT_N2V2: return StructN2V2Description().description + elif algorithm_flavour == N2N: + return N2NDescription().description + elif algorithm_flavour == CARE: + return CAREDescription().description return "" @@ -404,6 +417,10 @@ def get_algorithm_citations(self) -> List[CiteEntry]: return [N2VRef, StructN2VRef] else: return [N2VRef] + elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N: + return [N2NRef] + elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE: + return [CARERef] raise ValueError("Citation not available for custom algorithm.") diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py index 7bc65dbba..56f48d47a 100644 --- a/src/careamics/lightning_datamodule.py +++ b/src/careamics/lightning_datamodule.py @@ -445,11 +445,11 @@ class CAREamicsTrainDataModule(CAREamicsWood): Create a CAREamicsTrainDataModule with default transforms with a numpy array: >>> import numpy as np >>> from careamics import CAREamicsTrainDataModule - >>> my_array = np.arange(100).reshape(10, 10) + >>> my_array = np.arange(256).reshape(16, 16) >>> data_module = CAREamicsTrainDataModule( ... train_data=my_array, ... data_type="array", - ... patch_size=(2, 2), + ... patch_size=(8, 8), ... axes='YX', ... batch_size=2, ... ) @@ -458,12 +458,14 @@ class CAREamicsTrainDataModule(CAREamicsWood): function and a filter for the files extension: >>> import numpy as np >>> from careamics import CAREamicsTrainDataModule + >>> >>> def read_npy(path): ... return np.load(path) + >>> >>> data_module = CAREamicsTrainDataModule( ... train_data="path/to/data", ... data_type="custom", - ... patch_size=(2, 2), + ... patch_size=(8, 8), ... axes='YX', ... batch_size=2, ... read_source_func=read_npy, @@ -475,7 +477,7 @@ class CAREamicsTrainDataModule(CAREamicsWood): >>> import numpy as np >>> from careamics import CAREamicsTrainDataModule >>> from careamics.config.support import SupportedTransform - >>> my_array = np.arange(100).reshape(10, 10) + >>> my_array = np.arange(256).reshape(16, 16) >>> my_transforms = [ ... { ... "name": SupportedTransform.NORMALIZE.value, @@ -489,7 +491,7 @@ class CAREamicsTrainDataModule(CAREamicsWood): >>> data_module = CAREamicsTrainDataModule( ... train_data=my_array, ... data_type="array", - ... patch_size=(2, 2), + ... patch_size=(8, 8), ... axes='YX', ... batch_size=2, ... transforms=my_transforms, diff --git a/src/careamics/lightning_prediction_loop.py b/src/careamics/lightning_prediction_loop.py index 4a2f64f7b..ca11da96a 100644 --- a/src/careamics/lightning_prediction_loop.py +++ b/src/careamics/lightning_prediction_loop.py @@ -35,10 +35,13 @@ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: ######################################################## ################ CAREamics specific code ############### if len(self.predicted_array) == 1: - return self.predicted_array[0] + # TODO does this make sense to here? (force numpy array) + return self.predicted_array[0].numpy() else: # TODO revisit logic - return self.predicted_array + return [ + element.numpy() for element in self.predicted_array + ] ######################################################## return None diff --git a/tests/test_lightning_module.py b/tests/test_lightning_module.py index 038ad8438..fff2bf522 100644 --- a/tests/test_lightning_module.py +++ b/tests/test_lightning_module.py @@ -1,3 +1,6 @@ +import pytest +import torch + from careamics.config import AlgorithmModel from careamics.lightning_module import CAREamicsKiln, CAREamicsModule @@ -29,3 +32,224 @@ def test_careamics_kiln(minimum_algorithm_n2v): # instantiate CAREamicsKiln CAREamicsKiln(algo_config) + + +@pytest.mark.parametrize("shape", + [ + (8, 8), + (16, 16), + (32, 32), + ] +) +def test_careamics_kiln_unet_2D_depth_2_shape(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": 1, + "num_classes": 1, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("shape", + [ + (8, 8), + (16, 16), + (32, 32), + (64, 64), + (128, 128), + (256, 256), + ] +) +def test_careamics_kiln_unet_2D_depth_3_shape(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": 1, + "num_classes": 1, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("shape", + [ + (8, 32, 16), + (16, 32, 16), + (8, 32, 32), + (32, 64, 64), + ] +) +def test_careamics_kiln_unet_depth_2_3D(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 3, + "in_channels": 1, + "num_classes": 1, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("shape", + [ + (8, 64, 64), + (16, 64, 64), + (16, 128, 128), + (32, 128, 128), + ] +) +def test_careamics_kiln_unet_depth_3_3D(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 3, + "in_channels": 1, + "num_classes": 1, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_2_channels_2D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, n_channels, 32, 32)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +def test_careamics_kiln_unet_depth_3_channels_2D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, n_channels, 64, 64)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_2_channels_3D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, n_channels, 16, 32, 32)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_3_channels_3D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, n_channels, 16, 64, 64)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + From cfedbe8543c8473fe1caf9335d8f3f8e8d21f44b Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Sat, 27 Apr 2024 13:32:47 +0200 Subject: [PATCH 12/14] Add references, careamist test --- src/careamics/careamist.py | 112 +++++++- src/careamics/config/references/__init__.py | 16 +- .../references/algorithm_descriptions.py | 28 +- src/careamics/config/references/references.py | 17 +- src/careamics/model_io/__init__.py | 3 +- .../model_io/bioimage/model_description.py | 28 +- .../model_io/bioimage/readme_factory.py | 20 +- src/careamics/model_io/bmz_export.py | 140 ++++++++++ src/careamics/model_io/model_io_utils.py | 116 +------- ...t_model_io_utils.py => test_export_bmz.py} | 15 +- tests/test_careamist.py | 253 ++++++++++++++++-- 11 files changed, 562 insertions(+), 186 deletions(-) create mode 100644 src/careamics/model_io/bmz_export.py rename tests/model_io/{test_model_io_utils.py => test_export_bmz.py} (66%) diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index b8d76dec6..4df35d2cb 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -19,12 +19,12 @@ load_configuration, ) from careamics.config.inference_model import TRANSFORMS_UNION -from careamics.config.support import SupportedAlgorithm, SupportedLogger +from careamics.config.support import SupportedAlgorithm, SupportedLogger, SupportedData from careamics.lightning_datamodule import CAREamicsWood from careamics.lightning_module import CAREamicsKiln from careamics.lightning_prediction_datamodule import CAREamicsClay from careamics.lightning_prediction_loop import CAREamicsPredictionLoop -from careamics.model_io import load_pretrained +from careamics.model_io import load_pretrained, export_to_bmz from careamics.utils import check_path_exists, get_logger from .callbacks import HyperParametersCallback @@ -35,8 +35,6 @@ # TODO napari callbacks # TODO: how to do AMP? How to continue training? - - class CAREamist: """ Main CAREamics class, allowing training and prediction using various algorithms. @@ -63,6 +61,10 @@ class CAREamist: by default None. experiment_name : str, optional Experiment name used for checkpoints, by default "CAREamics". + train_datamodule : Optional[CAREamicsWood], optional + Training datamodule, by default None. + pred_datamodule : Optional[CAREamicsClay], optional + Prediction datamodule, by default None. """ @overload @@ -190,6 +192,10 @@ def __init__( # change the prediction loop, necessary for tiled prediction self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer) + # place holder for the datamodules + self.train_datamodule: Optional[CAREamicsWood] = None + self.pred_datamodule: Optional[CAREamicsClay] = None + def _define_callbacks(self) -> List[Callback]: """ Define the callbacks for the training loop. @@ -363,6 +369,9 @@ def _train_on_datamodule(self, datamodule: CAREamicsWood) -> None: datamodule : CAREamicsWood Datamodule to train on. """ + # record datamodule + self.train_datamodule = datamodule + self.trainer.fit(self.model, datamodule=datamodule) def _train_on_array( @@ -465,7 +474,10 @@ def _train_on_path( @overload def predict( # numpydoc ignore=GL08 - self, source: CAREamicsClay + self, + source: CAREamicsClay, + *, + checkpoint: Optional[Literal["best", "last"]] = None, ) -> Union[list, np.ndarray]: ... @@ -576,8 +588,12 @@ def predict( If the input is not a CAREamicsClay instance, a path or a numpy array. """ if isinstance(source, CAREamicsClay): - return self.trainer.predict(datamodule=source) + # record datamodule + self.pred_datamodule = source + return self.trainer.predict( + model=self.model, datamodule=source, ckpt_path=checkpoint + ) else: if self.cfg is None: raise ValueError( @@ -614,6 +630,9 @@ def predict( extension_filter=extension_filter, dataloader_params=dataloader_params, ) + + # record datamodule + self.pred_datamodule = datamodule return self.trainer.predict( model=self.model, datamodule=datamodule, ckpt_path=checkpoint @@ -626,6 +645,9 @@ def predict( pred_data=source, dataloader_params=dataloader_params, ) + + # record datamodule + self.pred_datamodule = datamodule return self.trainer.predict( model=self.model, datamodule=datamodule, ckpt_path=checkpoint @@ -636,3 +658,81 @@ def predict( f"Invalid input. Expected a CAREamicsWood instance, paths or " f"np.ndarray (got {type(source)})." ) + + def export_to_bmz( + self, + path: Union[Path, str], + name: str, + authors: List[dict], + input_array: Optional[np.ndarray] = None, + general_description: str = "", + channel_names: Optional[List[str]] = None, + data_description: Optional[str] = None, + ) -> None: + """Export the model to the BioImage Model Zoo format. + + Input array must be of shape SC(Z)YX, with S and C singleton dimensions. + + Parameters + ---------- + path : Union[Path, str] + Path to save the model. + name : str + Name of the model. + authors : List[dict] + List of authors of the model. + input_array : Optional[np.ndarray], optional + Input array for the model, must be of shape SC(Z)YX, by default None. + general_description : str + General description of the model, used in the metadata of the BMZ archive. + channel_names : Optional[List[str]], optional + Channel names, by default None. + data_description : Optional[str], optional + Description of the data, by default None. + """ + if input_array is None: + # generate images, priority is given to the prediction data module + if self.pred_datamodule is not None: + # unpack a batch, ignore masks or targets + input, *_ = next(iter(self.pred_datamodule.predict_dataloader())) + + # convert torch.Tensor to numpy + input_array = input.numpy() + elif self.train_datamodule is not None: + input, *_ = next(iter(self.train_datamodule.train_dataloader())) + input_array = input.numpy() + else: + # create a random input array + input_array = np.random.normal( + loc=self.cfg.data_config.mean, + scale=self.cfg.data_config.std, + size=self.cfg.data_config.patch_size + ).astype(np.float32)[np.newaxis, np.newaxis, ...] # add S & C dimensions + + # if there is a batch dimension + if input_array.shape[0] > 1: + input_array = input_array[0:1, ...] # keep singleton dim + + # axes need to be without S + axes = self.cfg.data_config.axes.replace("S", "") + + # predict output, remove extra dimensions for the purpose of the prediction + output_array = self.predict( + input_array.squeeze(), + data_type=SupportedData.ARRAY.value, + axes=axes, + tta_transforms=False + ) + + export_to_bmz( + model=self.model, + config=self.cfg, + path=path, + name=name, + general_description=general_description, + authors=authors, + input_array=input_array, + output_array=output_array, + channel_names= channel_names, + data_description=data_description + ) \ No newline at end of file diff --git a/src/careamics/config/references/__init__.py b/src/careamics/config/references/__init__.py index d4ca6fdf4..1139297b7 100644 --- a/src/careamics/config/references/__init__.py +++ b/src/careamics/config/references/__init__.py @@ -12,13 +12,21 @@ "N2V2", "STRUCT_N2V", "STRUCT_N2V2", - "CUSTOM" + "CUSTOM", + "N2N", + "CARE", + "CAREDescription", + "N2NDescription", + "CARERef", + "N2NRef", ] from .references import ( N2V2Ref, N2VRef, StructN2VRef, + N2NRef, + CARERef, ) from .algorithm_descriptions import ( @@ -26,9 +34,13 @@ N2V2Description, StructN2VDescription, StructN2V2Description, + N2NDescription, + CAREDescription, N2V, N2V2, STRUCT_N2V, STRUCT_N2V2, - CUSTOM + CUSTOM, + N2N, + CARE ) diff --git a/src/careamics/config/references/algorithm_descriptions.py b/src/careamics/config/references/algorithm_descriptions.py index 624cb7928..5fe2704a6 100644 --- a/src/careamics/config/references/algorithm_descriptions.py +++ b/src/careamics/config/references/algorithm_descriptions.py @@ -6,6 +6,8 @@ N2V2 = "N2V2" STRUCT_N2V = "StructN2V" STRUCT_N2V2 = "StructN2V2" +N2N = "Noise2Noise" +CARE = "CARE" N2V_DESCRIPTION = "Noise2Void is a UNet-based self-supervised algorithm that " \ @@ -90,4 +92,28 @@ class StructN2V2Description(AlgorithmDescription): "the images, the main failure case of the original N2V." \ "\nN2V2 introduces blur-pool layers and removed skip connections in " \ "the UNet architecture to remove checkboard artefacts, a common " \ - "artefacts ocurring in Noise2Void." \ No newline at end of file + "artefacts ocurring in Noise2Void." + + +class N2NDescription(AlgorithmDescription): + """Description of Noise2Noise. + + Attributes + ---------- + description : str + Description of Noise2Noise. + """ + + description: str = "Noise2Noise" # TODO + + +class CAREDescription(AlgorithmDescription): + """Description of CARE. + + Attributes + ---------- + description : str + Description of CARE. + """ + + description: str = "CARE" # TODO \ No newline at end of file diff --git a/src/careamics/config/references/references.py b/src/careamics/config/references/references.py index 46a50efc9..28ad8f4f7 100644 --- a/src/careamics/config/references/references.py +++ b/src/careamics/config/references/references.py @@ -18,8 +18,21 @@ StructN2VRef = CiteEntry( text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020." - '"Removing structured noise with self-supervised blind-spot ' - 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' + '\"Removing structured noise with self-supervised blind-spot ' + 'networks\". In 2020 IEEE 17th International Symposium on Biomedical ' "Imaging (ISBI) (pp. 159-163).", doi="10.1109/isbi45749.2020.9098336", ) + +N2NRef = CiteEntry( + text='Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., ' + 'Aittala, M. and Aila, T., 2018. \"Noise2Noise: Learning image restoration ' + 'without clean data\". arXiv preprint arXiv:1803.04189.', + doi="10.48550/arXiv.1803.04189", +) + +CARERef = CiteEntry( + text='Weigert, Martin, et al. \"Content-aware image restoration: pushing the ' + 'limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.', + doi="10.1038/s41592-018-0216-7", +) \ No newline at end of file diff --git a/src/careamics/model_io/__init__.py b/src/careamics/model_io/__init__.py index 32af8bbc7..9a8f0948a 100644 --- a/src/careamics/model_io/__init__.py +++ b/src/careamics/model_io/__init__.py @@ -1,7 +1,8 @@ """Model I/O utilities.""" -__all__ = ["load_pretrained"] +__all__ = ["load_pretrained", "export_to_bmz"] from .model_io_utils import load_pretrained +from .bmz_export import export_to_bmz diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index 88ffa4684..3863d6d7b 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -33,8 +33,8 @@ def _create_axes( array: np.ndarray, data_config: DataModel, - is_input: bool = True, channel_names: Optional[List[str]] = None, + is_input: bool = True, ) -> List[AxisBase]: """Create axes description. @@ -46,10 +46,10 @@ def _create_axes( Array. config : DataModel CAREamics data configuration - is_input : bool, optional - Whether the axes are input axes, by default True channel_names : Optional[List[str]], optional Channel names, by default None + is_input : bool, optional + Whether the axes are input axes, by default True Returns ------- @@ -61,6 +61,9 @@ def _create_axes( ValueError If channel names are not provided when channel axis is present """ + # axes have to be SC(Z)YX + spatial_axes = data_config.axes.replace("S", "").replace("C", "") + # batch is always present axes_model = [BatchAxis()] @@ -79,7 +82,7 @@ def _create_axes( axes_model.append(ChannelAxis(channel_names=[Identifier("channel")])) # spatial axes - for ind, axes in enumerate(data_config.axes): + for ind, axes in enumerate(spatial_axes): if axes in ["X", "Y", "Z"]: if is_input: axes_model.append( @@ -98,7 +101,8 @@ def _create_inputs_ouputs( output_array: np.ndarray, data_config: DataModel, input_path: Union[Path, str], - output_path: Union[Path, str], + output_path: Union[Path, str], + channel_names: Optional[List[str]] = None, ) -> Tuple[InputTensorDescr, OutputTensorDescr]: """Create input and output tensor description. @@ -118,8 +122,8 @@ def _create_inputs_ouputs( Tuple[InputTensorDescr, OutputTensorDescr] Input and output tensor descriptions """ - input_axes = _create_axes(input_array, data_config) - output_axes = _create_axes(output_array, data_config, is_input=False) + input_axes = _create_axes(input_array, data_config, channel_names) + output_axes = _create_axes(output_array, data_config, channel_names, False) # mean and std mean = data_config.mean @@ -172,8 +176,8 @@ def create_model_description( careamics_version: str, config_path: Union[Path, str], env_path: Union[Path, str], + channel_names: Optional[List[str]] = None, data_description: Optional[str] = None, - custom_description: Optional[str] = None, ) -> ModelDescr: """Create model description. @@ -199,10 +203,10 @@ def create_model_description( Path to model configuration. env_path : Union[Path, str] Path to environment file. + channel_names : Optional[List[str]], optional + Channel names, by default None. data_description : Optional[str], optional - Description of the data, by default None - custom_description : Optional[str], optional - Description of the custom algorithm, by default None + Description of the data, by default None. Returns ------- @@ -214,7 +218,6 @@ def create_model_description( config, careamics_version=careamics_version, data_description=data_description, - custom_description=custom_description, ) # inputs, outputs @@ -224,6 +227,7 @@ def create_model_description( data_config=config.data_config, input_path=inputs, output_path=outputs, + channel_names=channel_names, ) # weights description diff --git a/src/careamics/model_io/bioimage/readme_factory.py b/src/careamics/model_io/bioimage/readme_factory.py index bf988e28b..6bdd4f883 100644 --- a/src/careamics/model_io/bioimage/readme_factory.py +++ b/src/careamics/model_io/bioimage/readme_factory.py @@ -28,26 +28,20 @@ def readme_factory( config: Configuration, careamics_version: str, data_description: Optional[str] = None, - custom_description: Optional[str] = None, ) -> Path: """Create a README file for the model. `data_description` can be used to add more information about the content of the data the model was trained on. - `custom_description` can be used to add a custom description of the algorithm, only - used when the algorithm is set to `custom` in the configuration. - Parameters ---------- config : Configuration - CAREamics configuration + CAREamics configuration. careamics_version : str - CAREamics version + CAREamics version. data_description : Optional[str], optional - Description of the data, by default None - custom_description : Optional[str], optional - Description of custom algorithm, by default None + Description of the data, by default None. Returns ------- @@ -71,13 +65,7 @@ def readme_factory( # algorithm description description.append("Algorithm description:\n\n") - if ( - algorithm.algorithm == SupportedAlgorithm.CUSTOM - and custom_description is not None - ): - description.append(custom_description) - else: - description.append(config.get_algorithm_description()) + description.append(config.get_algorithm_description()) description.append("\n\n") # algorithm details diff --git a/src/careamics/model_io/bmz_export.py b/src/careamics/model_io/bmz_export.py new file mode 100644 index 000000000..6f8df3d9d --- /dev/null +++ b/src/careamics/model_io/bmz_export.py @@ -0,0 +1,140 @@ +"""Function to export to the BioImage Model Zoo format.""" +from pathlib import Path +from typing import List, Optional, Union +import tempfile + +import numpy as np +import pkg_resources +from bioimageio.core import test_model +from bioimageio.spec import ValidationSummary, save_bioimageio_package +from torch import __version__ + +from careamics.config import Configuration, save_configuration +from careamics.config.support import SupportedArchitecture +from careamics.lightning_module import CAREamicsKiln + +from .bioimage import create_model_description +from .model_io_utils import export_state_dict + + +# TODO break down in subfunctions +def export_to_bmz( + model: CAREamicsKiln, + config: Configuration, + path: Union[Path, str], + name: str, + general_description: str, + authors: List[dict], + input_array: np.ndarray, + output_array: np.ndarray, + channel_names: Optional[List[str]] = None, + data_description: Optional[str] = None, +) -> None: + """Export the model to BioImage Model Zoo format. + + Arrays are expected to be SC(Z)YX with singleton dimensions allowed for S and C. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to export. + config : Configuration + Model configuration. + path : Union[Path, str] + Path to the output file. + name : str + Model name. + general_description : str + General description of the model. + authors : List[dict] + Authors of the model. + input_array : np.ndarray + Input array. + output_array : np.ndarray + Output array. + channel_names : Optional[List[str]], optional + Channel names, by default None + data_description : Optional[str], optional + Description of the data, by default None + + Raises + ------ + ValueError + If the model is a Custom model. + """ + path = Path(path) + + # method is not compatible with Custom models + if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM: + raise ValueError( + "Exporting Custom models to BioImage Model Zoo format is not supported." + ) + + # make sure that input and output arrays have the same shape + assert input_array.shape == output_array.shape, \ + f"Input ({input_array.shape}) and output ({output_array.shape}) arrays " \ + f"have different shapes" + + # make sure it has the correct suffix + if path.suffix not in ".zip": + path = path.with_suffix(".zip") + + # versions + pytorch_version = __version__ + careamics_version = pkg_resources.get_distribution("careamics").version + + # save files in temporary folder + with tempfile.TemporaryDirectory() as tmpdirname: + temp_path = Path(tmpdirname) + + # create environment file + # TODO move in bioimage module + env_path = temp_path / "environment.yml" + env_path.write_text( + f"name: careamics\n" + f"dependencies:\n" + f" - python=3.8\n" + f" - pytorch={pytorch_version}\n" + f" - torchvision={pytorch_version}\n" + f" - pip\n" + f" - pip:\n" + f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n" + ) + # TODO from pip with package version + + # export input and ouputs + inputs = temp_path / "inputs.npy" + np.save(inputs, input_array) + outputs = temp_path / "outputs.npy" + np.save(outputs, output_array) + + # export configuration + config_path = save_configuration(config, temp_path) + + # export model state dictionary + weight_path = export_state_dict(model, temp_path / "weights.pth") + + # create model description + model_description = create_model_description( + config=config, + name=name, + general_description=general_description, + authors=authors, + inputs=inputs, + outputs=outputs, + weights_path=weight_path, + torch_version=pytorch_version, + careamics_version=careamics_version, + config_path=config_path, + env_path=env_path, + channel_names=channel_names, + data_description=data_description, + ) + + # test model description + summary: ValidationSummary = test_model(model_description) + if summary.status == "failed": + raise ValueError(f"Model description test failed: {summary}") + + # save bmz model + save_bioimageio_package(model_description, output_path=path) diff --git a/src/careamics/model_io/model_io_utils.py b/src/careamics/model_io/model_io_utils.py index 88f6fd4f7..ab1bc61e3 100644 --- a/src/careamics/model_io/model_io_utils.py +++ b/src/careamics/model_io/model_io_utils.py @@ -1,18 +1,12 @@ """Utility functions to load pretrained models.""" from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Tuple, Union -import pkg_resources -from bioimageio.core import test_model -from bioimageio.spec import ValidationSummary, save_bioimageio_package from torch import __version__, load, save -from careamics.config import Configuration, save_configuration -from careamics.config.support import SupportedArchitecture +from careamics.config import Configuration from careamics.lightning_module import CAREamicsKiln -from careamics.utils import check_path_exists, get_careamics_home - -from .bioimage import create_model_description +from careamics.utils import check_path_exists def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: @@ -130,7 +124,7 @@ def _load_from_bmz( # extract model and call _load_from_torch_dict -def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path: +def export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path: """ Export the model state dictionary to a file. @@ -155,104 +149,4 @@ def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path: # save model state dictionary save(model.model.state_dict(), path) - return path - - -def export_bmz( - model: CAREamicsKiln, - config: Configuration, - path: Union[Path, str], - name: str, - general_description: str, - authors: List[dict], - inputs: Union[Path, str], - outputs: Union[Path, str], - data_description: Optional[str] = None, - custom_description: Optional[str] = None, -) -> None: - """ - Export the model to BioImage Model Zoo format. - - Parameters - ---------- - model : CAREamicsKiln - CAREamics model to export. - config : Configuration - Model configuration. - path : Union[Path, str] - Path to the output file. - name : str - Model name. - general_description : str - General description of the model. - authors : List[dict] - Authors of the model. - inputs : Union[Path, str] - Path to input .npy file. - outputs : Union[Path, str] - Path to output .npy file. - data_description : Optional[str], optional - Description of the data, by default None - custom_description : Optional[str], optional - Description of the custom algorithm, by default None - """ - path = Path(path) - - # method is not compatible with Custom models - if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM: - raise ValueError( - "Exporting Custom models to BioImage Model Zoo format is not supported." - ) - - # make sure it has the correct suffix - if path.suffix not in ".zip": - path = path.with_suffix(".zip") - - # versions - pytorch_version = __version__ - careamics_version = pkg_resources.get_distribution("careamics").version - - # create environment file - env_path = get_careamics_home() / "environment.yml" - env_path.write_text( - f"name: careamics\n" - f"dependencies:\n" - f" - python=3.8\n" - f" - pytorch={pytorch_version}\n" - f" - torchvision={pytorch_version}\n" - f" - pip\n" - f" - pip:\n" - f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n" - ) - # TODO from pip with package version - - # export configuration - config_path = save_configuration(config, get_careamics_home()) - - # export model state dictionary - weight_path = _export_state_dict(model, get_careamics_home() / "weights.pth") - - # create model description - model_description = create_model_description( - config=config, - name=name, - general_description=general_description, - authors=authors, - inputs=inputs, - outputs=outputs, - weights_path=weight_path, - torch_version=pytorch_version, - careamics_version=careamics_version, - config_path=config_path, - env_path=env_path, - data_description=data_description, - custom_description=custom_description, - ) - - # test model description - summary: ValidationSummary = test_model(model_description) - if summary.status == "failed": - raise ValueError(f"Model description test failed: {summary}") - - # save bmz model - save_bioimageio_package(model_description, output_path=path) + return path \ No newline at end of file diff --git a/tests/model_io/test_model_io_utils.py b/tests/model_io/test_export_bmz.py similarity index 66% rename from tests/model_io/test_model_io_utils.py rename to tests/model_io/test_export_bmz.py index 36cc575ae..6eb3850a6 100644 --- a/tests/model_io/test_model_io_utils.py +++ b/tests/model_io/test_export_bmz.py @@ -1,7 +1,7 @@ import numpy as np from careamics import CAREamist -from careamics.model_io.model_io_utils import export_bmz +from careamics.model_io import export_to_bmz def test_export_bmz(tmp_path, pre_trained): @@ -14,22 +14,15 @@ def test_export_bmz(tmp_path, pre_trained): # predict (no tiling and no tta) predicted = careamist.predict(train_array, tta_transforms=False) - # save images - train_path = tmp_path / "train.npy" - np.save(train_path, train_array[np.newaxis, np.newaxis, ...]) - - predicted_path = tmp_path / "predicted.npy" - np.save(tmp_path / "predicted.npy", predicted) - # export to BioImage Model Zoo - export_bmz( + export_to_bmz( model=careamist.model, config=careamist.cfg, path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", authors=[{"name": "Amod", "affiliation": "El"}], - inputs=train_path, - outputs=predicted_path, + input_array=train_array[np.newaxis, np.newaxis, ...], + output_array=predicted, ) assert (tmp_path / "model.zip").exists() diff --git a/tests/test_careamist.py b/tests/test_careamist.py index a180e976c..bbbc60a5c 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -7,7 +7,8 @@ from careamics import CAREamist, Configuration, save_configuration from careamics.config.support import SupportedAlgorithm, SupportedData - +# TODO test 3D and channels + def test_no_parameters(): """Test that CAREamics cannot be instantiated without parameters.""" with pytest.raises(TypeError): @@ -70,7 +71,7 @@ def test_train_error_target_unsupervised_algorithm(tmp_path, minimum_configurati def test_train_single_array_no_val(tmp_path, minimum_configuration): """Test that CAREamics can be trained with arrays.""" # training data - train_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) # create configuration config = Configuration(**minimum_configuration) @@ -89,12 +90,21 @@ def test_train_single_array_no_val(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + def test_train_array(tmp_path, minimum_configuration): - """Test that CAREamics can be trained with arrays.""" + """Test that CAREamics can be trained on arrays.""" # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) # create configuration config = Configuration(**minimum_configuration) @@ -113,11 +123,87 @@ def test_train_array(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_array_channel(tmp_path, minimum_configuration): + """Test that CAREamics can be trained on arrays with channels.""" + # training data + train_array = np.random.rand(32, 32, 3) + val_array = np.random.rand(32, 32, 3) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YXC" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + channel_names=["red", "green", "blue"] + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_array_3d(tmp_path, minimum_configuration): + """Test that CAREamics can be trained on 3D arrays.""" + # training data + train_array = np.random.rand(8, 32, 32) + val_array = np.random.rand(8, 32, 32) + + # create configuration + minimum_configuration["data_config"]["axes"] = "ZYX" + minimum_configuration["data_config"]["patch_size"] = (8, 16, 16) + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + def test_train_tiff_files_in_memory_no_val(tmp_path, minimum_configuration): """Test that CAREamics can be trained with tiff files in memory.""" # training data - train_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) # save files train_file = tmp_path / "train.tiff" @@ -140,12 +226,21 @@ def test_train_tiff_files_in_memory_no_val(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + def test_train_tiff_files_in_memory(tmp_path, minimum_configuration): """Test that CAREamics can be trained with tiff files in memory.""" # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) # save files train_file = tmp_path / "train.tiff" @@ -171,14 +266,23 @@ def test_train_tiff_files_in_memory(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + def test_train_tiff_files(tmp_path, minimum_configuration): """Test that CAREamics can be trained with tiff files by deactivating the in memory dataset. """ # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) # save files train_file = tmp_path / "train.tiff" @@ -204,14 +308,23 @@ def test_train_tiff_files(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + def test_train_array_supervised(tmp_path, supervised_configuration): """Test that CAREamics can be trained with arrays.""" # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) - train_target = np.ones((32, 32)) - val_target = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) # create configuration config = Configuration(**supervised_configuration) @@ -235,14 +348,23 @@ def test_train_array_supervised(tmp_path, supervised_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + def test_train_tiff_files_in_memory_supervised(tmp_path, supervised_configuration): """Test that CAREamics can be trained with tiff files in memory.""" # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) - train_target = np.ones((32, 32)) - val_target = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) # save files images = tmp_path / "images" @@ -283,16 +405,25 @@ def test_train_tiff_files_in_memory_supervised(tmp_path, supervised_configuratio # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + def test_train_tiff_files_supervised(tmp_path, supervised_configuration): """Test that CAREamics can be trained with tiff files by deactivating the in memory dataset. """ # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) - train_target = np.ones((32, 32)) - val_target = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) # save files images = tmp_path / "images" @@ -334,6 +465,15 @@ def test_train_tiff_files_supervised(tmp_path, supervised_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_on_array_tiled(tmp_path, minimum_configuration, batch_size): @@ -362,8 +502,17 @@ def test_predict_on_array_tiled(tmp_path, minimum_configuration, batch_size): assert predicted.squeeze().shape == train_array.shape + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + -def test_predict_array_no_tiling(tmp_path, minimum_configuration): +def test_predict_arrays_no_tiling(tmp_path, minimum_configuration): """Test that CAREamics can predict on arrays without tiling.""" # training data train_array = np.random.rand(4, 32, 32) @@ -387,6 +536,15 @@ def test_predict_array_no_tiling(tmp_path, minimum_configuration): assert predicted.squeeze().shape == train_array.shape + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + @pytest.mark.parametrize("batch_size", [1, 2]) def test_predict_path(tmp_path, minimum_configuration, batch_size): @@ -420,12 +578,21 @@ def test_predict_path(tmp_path, minimum_configuration, batch_size): # check that it predicted assert predicted.squeeze().shape == train_array.shape + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + def test_predict_pretrained(tmp_path, pre_trained): """Test that CAREamics can be instantiated with a pre-trained network and predict on an array.""" # training data - train_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) # instantiate CAREamist careamist = CAREamist(source=pre_trained, work_dir=tmp_path) @@ -437,3 +604,41 @@ def test_predict_pretrained(tmp_path, pre_trained): # check that it predicted assert predicted.squeeze().shape == train_array.shape + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + + +def test_expot_bmz_pretrained(tmp_path, pre_trained): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ.""" + # training data + train_array = np.random.rand(32, 32).astype(np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # export to BMZ (random array created) + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model.zip").exists() + + # pass an array + careamist.export_to_bmz( + path=tmp_path / "model2.zip", + name="TopModel", + input_array=train_array[np.newaxis, np.newaxis, ...], + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}] + ) + assert (tmp_path / "model2.zip").exists() \ No newline at end of file From 7e8da2eb2f1f26e8afffd12095c3e24991606ed5 Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Sat, 27 Apr 2024 16:33:00 +0200 Subject: [PATCH 13/14] Add BMZ import/export from CAREamist --- src/careamics/careamist.py | 119 ++++++++------ src/careamics/config/data_model.py | 21 ++- src/careamics/config/inference_model.py | 23 +-- .../config/validators/validator_utils.py | 18 +-- src/careamics/model_io/__init__.py | 2 +- src/careamics/model_io/bioimage/__init__.py | 10 +- .../{readme_factory.py => _readme_factory.py} | 9 +- .../model_io/bioimage/bioimage_utils.py | 48 ++++++ .../model_io/bioimage/model_description.py | 74 +++++++-- .../model_io/{bmz_export.py => bmz_io.py} | 139 ++++++++++++++--- src/careamics/model_io/model_io_utils.py | 124 ++++----------- src/careamics/utils/__init__.py | 3 - src/careamics/utils/metrics.py | 46 ------ src/careamics/utils/running_stats.py | 89 +++++------ tests/conftest.py | 30 ++++ tests/model_io/test_bmz_io.py | 66 ++++++++ tests/model_io/test_export_bmz.py | 28 ---- tests/test_careamist.py | 147 ++++++++++++------ tests/utils/test_metrics.py | 22 --- 19 files changed, 601 insertions(+), 417 deletions(-) rename src/careamics/model_io/bioimage/{readme_factory.py => _readme_factory.py} (93%) create mode 100644 src/careamics/model_io/bioimage/bioimage_utils.py rename src/careamics/model_io/{bmz_export.py => bmz_io.py} (52%) create mode 100644 tests/model_io/test_bmz_io.py delete mode 100644 tests/model_io/test_export_bmz.py diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 4df35d2cb..fe5f0e4b1 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -1,4 +1,4 @@ -"""Main class to train and predict with CAREamics models.""" +"""A class to train, predict and export models in CAREamics.""" from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload @@ -19,12 +19,12 @@ load_configuration, ) from careamics.config.inference_model import TRANSFORMS_UNION -from careamics.config.support import SupportedAlgorithm, SupportedLogger, SupportedData +from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger from careamics.lightning_datamodule import CAREamicsWood from careamics.lightning_module import CAREamicsKiln from careamics.lightning_prediction_datamodule import CAREamicsClay from careamics.lightning_prediction_loop import CAREamicsPredictionLoop -from careamics.model_io import load_pretrained, export_to_bmz +from careamics.model_io import export_to_bmz, load_pretrained from careamics.utils import check_path_exists, get_logger from .callbacks import HyperParametersCallback @@ -33,11 +33,21 @@ LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]] + # TODO napari callbacks # TODO: how to do AMP? How to continue training? class CAREamist: - """ - Main CAREamics class, allowing training and prediction using various algorithms. + """Main CAREamics class, allowing training and prediction using various algorithms. + + Parameters + ---------- + source : Union[Path, str, Configuration] + Path to a configuration file or a trained model. + work_dir : Optional[str], optional + Path to working directory in which to save checkpoints and logs, + by default None. + experiment_name : str, optional + Experiment name used for checkpoints, by default "CAREamics". Attributes ---------- @@ -51,20 +61,10 @@ class CAREamist: Experiment logger, "wandb" or "tensorboard". work_dir : Path Working directory. - - Parameters - ---------- - source : Union[Path, str, Configuration] - Path to a configuration file or a trained model. - work_dir : Optional[str], optional - Path to working directory in which to save checkpoints and logs, - by default None. - experiment_name : str, optional - Experiment name used for checkpoints, by default "CAREamics". - train_datamodule : Optional[CAREamicsWood], optional - Training datamodule, by default None. - pred_datamodule : Optional[CAREamicsClay], optional - Prediction datamodule, by default None. + train_datamodule : Optional[CAREamicsWood] + Training datamodule. + pred_datamodule : Optional[CAREamicsClay] + Prediction datamodule. """ @overload @@ -474,7 +474,7 @@ def _train_on_path( @overload def predict( # numpydoc ignore=GL08 - self, + self, source: CAREamicsClay, *, checkpoint: Optional[Literal["best", "last"]] = None, @@ -630,7 +630,7 @@ def predict( extension_filter=extension_filter, dataloader_params=dataloader_params, ) - + # record datamodule self.pred_datamodule = datamodule @@ -645,7 +645,7 @@ def predict( pred_data=source, dataloader_params=dataloader_params, ) - + # record datamodule self.pred_datamodule = datamodule @@ -660,15 +660,15 @@ def predict( ) def export_to_bmz( - self, - path: Union[Path, str], - name: str, - authors: List[dict], - input_array: Optional[np.ndarray] = None, - general_description: str = "", - channel_names: Optional[List[str]] = None, - data_description: Optional[str] = None, - ) -> None: + self, + path: Union[Path, str], + name: str, + authors: List[dict], + input_array: Optional[np.ndarray] = None, + general_description: str = "", + channel_names: Optional[List[str]] = None, + data_description: Optional[str] = None, + ) -> None: """Export the model to the BioImage Model Zoo format. Input array must be of shape SC(Z)YX, with S and C singleton dimensions. @@ -694,36 +694,55 @@ def export_to_bmz( # generate images, priority is given to the prediction data module if self.pred_datamodule is not None: # unpack a batch, ignore masks or targets - input, *_ = next(iter(self.pred_datamodule.predict_dataloader())) + input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader())) # convert torch.Tensor to numpy - input_array = input.numpy() + input_patch = input_patch.numpy() elif self.train_datamodule is not None: - input, *_ = next(iter(self.train_datamodule.train_dataloader())) - input_array = input.numpy() + input_patch, *_ = next(iter(self.train_datamodule.train_dataloader())) + input_patch = input_patch.numpy() else: + if ( + self.cfg.data_config.mean is None + or self.cfg.data_config.std is None + ): + raise ValueError( + "Mean and std cannot be None in the configuration in order to" + "export to the BMZ format. Was the model trained?" + ) + # create a random input array - input_array = np.random.normal( - loc=self.cfg.data_config.mean, + input_patch = np.random.normal( + loc=self.cfg.data_config.mean, scale=self.cfg.data_config.std, - size=self.cfg.data_config.patch_size - ).astype(np.float32)[np.newaxis, np.newaxis, ...] # add S & C dimensions + size=self.cfg.data_config.patch_size, + ).astype(np.float32)[ + np.newaxis, np.newaxis, ... + ] # add S & C dimensions + else: + input_patch = input_array # if there is a batch dimension - if input_array.shape[0] > 1: - input_array = input_array[0:1, ...] # keep singleton dim + if input_patch.shape[0] > 1: + input_patch = input_patch[0:1, ...] # keep singleton dim # axes need to be without S axes = self.cfg.data_config.axes.replace("S", "") - + # predict output, remove extra dimensions for the purpose of the prediction - output_array = self.predict( - input_array.squeeze(), + output_patch = self.predict( + input_patch.squeeze(), data_type=SupportedData.ARRAY.value, axes=axes, - tta_transforms=False + tta_transforms=False, ) + if not isinstance(output_patch, np.ndarray): + raise ValueError( + f"Numpy array required for export to BioImage Model Zoo, got " + f"{type(output_patch)}." + ) + export_to_bmz( model=self.model, config=self.cfg, @@ -731,8 +750,8 @@ def export_to_bmz( name=name, general_description=general_description, authors=authors, - input_array=input_array, - output_array=output_array, - channel_names= channel_names, - data_description=data_description - ) \ No newline at end of file + input_array=input_patch, + output_array=output_patch, + channel_names=channel_names, + data_description=data_description, + ) diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index c1f7c9a20..53255d784 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -2,7 +2,7 @@ from __future__ import annotations from pprint import pformat -from typing import Any, List, Literal, Optional, Union, Tuple +from typing import Any, List, Literal, Optional, Tuple, Union from albumentations import Compose from pydantic import ( @@ -15,12 +15,12 @@ ) from typing_extensions import Annotated -from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 from .support import SupportedTransform from .transformations.n2v_manipulate_model import N2VManipulateModel from .transformations.nd_flip_model import NDFlipModel from .transformations.normalize_model import NormalizeModel from .transformations.xy_random_rotate90_model import XYRandomRotate90Model +from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 TRANSFORMS_UNION = Annotated[ Union[ @@ -86,7 +86,9 @@ class DataModel(BaseModel): # Dataset configuration data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData - patch_size: Union[List[int], Tuple[int]] = Field(..., min_length=2, max_length=3) + patch_size: Union[List[int], Tuple[int, ...]] = Field( + ..., min_length=2, max_length=3 + ) batch_size: int = Field(default=1, ge=1, validate_default=True) axes: str @@ -116,7 +118,9 @@ class DataModel(BaseModel): @field_validator("patch_size") @classmethod - def all_elements_power_of_2_minimum_8(cls, patch_list: List[int]) -> List[int]: + def all_elements_power_of_2_minimum_8( + cls, patch_list: Union[List[int], Tuple[int, ...]] + ) -> Union[List[int], Tuple[int, ...]]: """ Validate patch size. @@ -124,12 +128,12 @@ def all_elements_power_of_2_minimum_8(cls, patch_list: List[int]) -> List[int]: Parameters ---------- - patch_list : List[int] + patch_list : Union[List[int], Tuple[int, ...]] Patch size. Returns ------- - List[int] + Union[List[int], Tuple[int, ...]] Validated patch size. Raises @@ -139,7 +143,10 @@ def all_elements_power_of_2_minimum_8(cls, patch_list: List[int]) -> List[int]: ValueError If the patch size is not a power of 2. """ - return patch_size_ge_than_8_power_of_2(patch_list) + patch_validated = patch_size_ge_than_8_power_of_2(patch_list) + assert patch_validated is not None, "Patch cannot be None." + + return patch_validated @field_validator("axes") @classmethod diff --git a/src/careamics/config/inference_model.py b/src/careamics/config/inference_model.py index 967d4f556..4a3a0992f 100644 --- a/src/careamics/config/inference_model.py +++ b/src/careamics/config/inference_model.py @@ -1,3 +1,4 @@ +"""Pydantic model representing CAREamics prediction configuration.""" from __future__ import annotations from typing import Any, List, Literal, Optional, Tuple, Union @@ -5,9 +6,9 @@ from albumentations import Compose from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 from .support import SupportedTransform from .transformations.normalize_model import NormalizeModel +from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 TRANSFORMS_UNION = Union[NormalizeModel] @@ -19,10 +20,10 @@ class InferenceModel(BaseModel): # Mandatory fields data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData - tile_size: Optional[Union[List[int], Tuple[int]]] = Field( + tile_size: Optional[Union[List[int], Tuple[int, ...]]] = Field( default=None, min_length=2, max_length=3 ) - tile_overlap: Optional[Union[List[int], Tuple[int]]] = Field( + tile_overlap: Optional[Union[List[int], Tuple[int, ...]]] = Field( default=None, min_length=2, max_length=3 ) @@ -48,7 +49,9 @@ class InferenceModel(BaseModel): @field_validator("tile_overlap") @classmethod - def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: + def all_elements_non_zero_even( + cls, patch_list: Optional[Union[List[int], Tuple[int, ...]]] + ) -> Optional[Union[List[int], Tuple[int, ...]]]: """ Validate patch size. @@ -56,12 +59,12 @@ def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: Parameters ---------- - patch_list : List[int] + patch_list : Optional[Union[List[int], Tuple[int, ...]]] Patch size. Returns ------- - List[int] + Optional[Union[List[int], Tuple[int, ...]]] Validated patch size. Raises @@ -82,10 +85,12 @@ def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: raise ValueError(f"Patch size must be even (got {dim}).") return patch_list - + @field_validator("tile_size") @classmethod - def tile_min_8_power_of_2(cls, tile_list: List[int]) -> List[int]: + def tile_min_8_power_of_2( + cls, tile_list: Optional[Union[List[int], Tuple[int, ...]]] + ) -> Optional[Union[List[int], Tuple[int, ...]]]: """ Validate that each entry is greater or equal than 8 and a power of 2. @@ -272,7 +277,7 @@ def _update(self, **kwargs: Any) -> None: Parameters ---------- - kwargs : Any + **kwargs : Any Key-value pairs of arguments to update. """ self.__dict__.update(kwargs) diff --git a/src/careamics/config/validators/validator_utils.py b/src/careamics/config/validators/validator_utils.py index a63726c45..cb93b475f 100644 --- a/src/careamics/config/validators/validator_utils.py +++ b/src/careamics/config/validators/validator_utils.py @@ -3,7 +3,7 @@ These functions are used to validate dimensions and axes of inputs. """ -from typing import Tuple, Optional, List, Union +from typing import List, Optional, Tuple, Union _AXES = "STCZYX" @@ -60,8 +60,8 @@ def check_axes_validity(axes: str) -> bool: def patch_size_ge_than_8_power_of_2( - patch_list: Optional[Union[List[int], Tuple[int]]] - ) -> Optional[Union[List[int], Tuple[int]]]: + patch_list: Optional[Union[List[int], Tuple[int, ...]]] +) -> Optional[Union[List[int], Tuple[int, ...]]]: """ Validate that each entry is greater or equal than 8 and a power of 2. @@ -69,12 +69,12 @@ def patch_size_ge_than_8_power_of_2( Parameters ---------- - patch_list : Optional[Union[List[int], Tuple[int]]] + patch_list : Optional[Union[List[int], Tuple[int, ...]]] Patch size. Returns ------- - Optional[Union[List[int], Tuple[int]]] + Optional[Union[List[int], Tuple[int, ...]]] Validated patch size. Raises @@ -87,14 +87,12 @@ def patch_size_ge_than_8_power_of_2( if patch_list is not None: for dim in patch_list: if dim < 8: - raise ValueError( - f"Patch size must be non-zero positive (got {dim})." - ) + raise ValueError(f"Patch size must be non-zero positive (got {dim}).") if (dim & (dim - 1)) != 0: raise ValueError( f"Patch size must be a power of 2 in all dimensions " f"(got {dim})." ) - - return patch_list \ No newline at end of file + + return patch_list diff --git a/src/careamics/model_io/__init__.py b/src/careamics/model_io/__init__.py index 9a8f0948a..0e99771f4 100644 --- a/src/careamics/model_io/__init__.py +++ b/src/careamics/model_io/__init__.py @@ -4,5 +4,5 @@ __all__ = ["load_pretrained", "export_to_bmz"] +from .bmz_io import export_to_bmz from .model_io_utils import load_pretrained -from .bmz_export import export_to_bmz diff --git a/src/careamics/model_io/bioimage/__init__.py b/src/careamics/model_io/bioimage/__init__.py index cff07fbcd..f312bc7eb 100644 --- a/src/careamics/model_io/bioimage/__init__.py +++ b/src/careamics/model_io/bioimage/__init__.py @@ -1,5 +1,11 @@ """Bioimage Model Zoo format functions.""" -__all__ = ["create_model_description"] +__all__ = [ + "create_model_description", + "extract_model_path", + "get_unzip_path", + "create_env_text", +] -from .model_description import create_model_description +from .bioimage_utils import create_env_text, get_unzip_path +from .model_description import create_model_description, extract_model_path diff --git a/src/careamics/model_io/bioimage/readme_factory.py b/src/careamics/model_io/bioimage/_readme_factory.py similarity index 93% rename from src/careamics/model_io/bioimage/readme_factory.py rename to src/careamics/model_io/bioimage/_readme_factory.py index 6bdd4f883..e823f3781 100644 --- a/src/careamics/model_io/bioimage/readme_factory.py +++ b/src/careamics/model_io/bioimage/_readme_factory.py @@ -1,10 +1,10 @@ +"""Functions used to create a README.md file for BMZ export.""" from pathlib import Path from typing import Optional import yaml from careamics.config import Configuration -from careamics.config.support import SupportedAlgorithm from careamics.utils import cwd, get_careamics_home @@ -14,12 +14,12 @@ def _yaml_block(yaml_str: str) -> str: Parameters ---------- yaml_str : str - YAML string + YAML string. Returns ------- str - Markdown code block with the YAML string + Markdown code block with the YAML string. """ return f"```yaml\n{yaml_str}\n```" @@ -46,13 +46,14 @@ def readme_factory( Returns ------- Path - Path to the README file + Path to the README file. """ algorithm = config.algorithm_config training = config.training_config data = config.data_config # create file + # TODO use tempfile as in the bmz_io module with cwd(get_careamics_home()): readme = Path("README.md") readme.touch() diff --git a/src/careamics/model_io/bioimage/bioimage_utils.py b/src/careamics/model_io/bioimage/bioimage_utils.py new file mode 100644 index 000000000..1ce28bfcd --- /dev/null +++ b/src/careamics/model_io/bioimage/bioimage_utils.py @@ -0,0 +1,48 @@ +"""Bioimage.io utils.""" +from pathlib import Path +from typing import Union + + +def get_unzip_path(zip_path: Union[Path, str]) -> Path: + """Generate unzipped folder path from the bioimage.io model path. + + Parameters + ---------- + zip_path : Path + Path to the bioimage.io model. + + Returns + ------- + Path + Path to the unzipped folder. + """ + zip_path = Path(zip_path) + + return zip_path.parent / (str(zip_path.name) + ".unzip") + + +def create_env_text(pytorch_version: str) -> str: + """Create environment text for the bioimage model. + + Parameters + ---------- + pytorch_version : str + Pytorch version. + + Returns + ------- + str + Environment text. + """ + env = ( + f"name: careamics\n" + f"dependencies:\n" + f" - python=3.8\n" + f" - pytorch={pytorch_version}\n" + f" - torchvision={pytorch_version}\n" + f" - pip\n" + f" - pip:\n" + f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n" + ) + # TODO from pip with package version + return env diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index 3863d6d7b..1901b9550 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -1,3 +1,4 @@ +"""Module use to build BMZ model description.""" from pathlib import Path from typing import List, Optional, Tuple, Union @@ -27,7 +28,7 @@ from careamics.config import Configuration, DataModel -from .readme_factory import readme_factory +from ._readme_factory import readme_factory def _create_axes( @@ -44,22 +45,22 @@ def _create_axes( ---------- array : np.ndarray Array. - config : DataModel - CAREamics data configuration + data_config : DataModel + CAREamics data configuration. channel_names : Optional[List[str]], optional - Channel names, by default None + Channel names, by default None. is_input : bool, optional - Whether the axes are input axes, by default True + Whether the axes are input axes, by default True. Returns ------- List[AxisBase] - List of axes description + List of axes description. Raises ------ ValueError - If channel names are not provided when channel axis is present + If channel names are not provided when channel axis is present. """ # axes have to be SC(Z)YX spatial_axes = data_config.axes.replace("S", "").replace("C", "") @@ -101,7 +102,7 @@ def _create_inputs_ouputs( output_array: np.ndarray, data_config: DataModel, input_path: Union[Path, str], - output_path: Union[Path, str], + output_path: Union[Path, str], channel_names: Optional[List[str]] = None, ) -> Tuple[InputTensorDescr, OutputTensorDescr]: """Create input and output tensor description. @@ -110,22 +111,30 @@ def _create_inputs_ouputs( Parameters ---------- + input_array : np.ndarray + Input array. + output_array : np.ndarray + Output array. data_config : DataModel - CAREamics data configuration + CAREamics data configuration. input_path : Union[Path, str] - Path to input .npy file + Path to input .npy file. output_path : Union[Path, str] - Path to output .npy file + Path to output .npy file. + channel_names : Optional[List[str]], optional + Channel names, by default None. Returns ------- Tuple[InputTensorDescr, OutputTensorDescr] - Input and output tensor descriptions + Input and output tensor descriptions. """ input_axes = _create_axes(input_array, data_config, channel_names) output_axes = _create_axes(output_array, data_config, channel_names, False) # mean and std + assert data_config.mean is not None, "Mean cannot be None." + assert data_config.std is not None, "Std cannot be None." mean = data_config.mean std = data_config.std @@ -183,8 +192,8 @@ def create_model_description( Parameters ---------- - careamist : CAREamist - CAREamist instance. + config : Configuration + CAREamics configuration. name : str Name fo the model. general_description : str @@ -199,6 +208,8 @@ def create_model_description( Path to model weights. torch_version : str Pytorch version. + careamics_version : str + CAREamics version. config_path : Union[Path, str] Path to model configuration. env_path : Union[Path, str] @@ -264,15 +275,44 @@ def create_model_description( weights=weights_descr, attachments=[FileDescr(source=config_path)], cite=config.get_algorithm_citations(), - config={ # conversion from float32 to float64 creates small differences... + config={ # conversion from float32 to float64 creates small differences... "bioimageio": { "test_kwargs": { "pytorch_state_dict": { - "decimals": 2, # ...so we relax the constraints on the decimals + "decimals": 2, # ...so we relax the constraints on the decimals } } } - } + }, ) return model + + +def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]: + """Return the relative path to the weights and configuration files. + + Parameters + ---------- + model_desc : ModelDescr + Model description. + + Returns + ------- + Tuple[Path, Path] + Weights and configuration paths. + """ + weights_path = model_desc.weights.pytorch_state_dict.source.path + + if len(model_desc.attachments) == 1: + config_path = model_desc.attachments[0].source.path + else: + for file in model_desc.attachments: + if file.source.path.suffix == ".yml": + config_path = file.source.path + break + + if config_path is None: + raise ValueError("Configuration file not found.") + + return weights_path, config_path diff --git a/src/careamics/model_io/bmz_export.py b/src/careamics/model_io/bmz_io.py similarity index 52% rename from src/careamics/model_io/bmz_export.py rename to src/careamics/model_io/bmz_io.py index 6f8df3d9d..06e5d9440 100644 --- a/src/careamics/model_io/bmz_export.py +++ b/src/careamics/model_io/bmz_io.py @@ -1,20 +1,74 @@ """Function to export to the BioImage Model Zoo format.""" -from pathlib import Path -from typing import List, Optional, Union import tempfile +from pathlib import Path +from typing import List, Optional, Tuple, Union import numpy as np import pkg_resources -from bioimageio.core import test_model +from bioimageio.core import load_description, test_model from bioimageio.spec import ValidationSummary, save_bioimageio_package -from torch import __version__ +from torch import __version__, load, save -from careamics.config import Configuration, save_configuration +from careamics.config import Configuration, load_configuration, save_configuration from careamics.config.support import SupportedArchitecture from careamics.lightning_module import CAREamicsKiln -from .bioimage import create_model_description -from .model_io_utils import export_state_dict +from .bioimage import ( + create_env_text, + create_model_description, + extract_model_path, + get_unzip_path, +) + + +def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path: + """ + Export the model state dictionary to a file. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to export. + path : Union[Path, str] + Path to the file where to save the model state dictionary. + + Returns + ------- + Path + Path to the saved model state dictionary. + """ + path = Path(path) + + # make sure it has the correct suffix + if path.suffix not in ".pth": + path = path.with_suffix(".pth") + + # save model state dictionary + # we save through the torch model itself to avoid the initial "model." in the + # layers naming, which is incompatible with the way the BMZ load torch state dicts + save(model.model.state_dict(), path) + + return path + + +def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None: + """ + Load a model from a state dictionary. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to be updated with the weights. + path : Union[Path, str] + Path to the model state dictionary. + """ + path = Path(path) + + # load model state dictionary + # same as in _export_state_dict, we load through the torch model to be compatible + # witht bioimageio.core expectations for a torch state dict + state_dict = load(path) + model.model.load_state_dict(state_dict) # TODO break down in subfunctions @@ -53,10 +107,10 @@ def export_to_bmz( output_array : np.ndarray Output array. channel_names : Optional[List[str]], optional - Channel names, by default None + Channel names, by default None. data_description : Optional[str], optional - Description of the data, by default None - + Description of the data, by default None. + Raises ------ ValueError @@ -71,9 +125,10 @@ def export_to_bmz( ) # make sure that input and output arrays have the same shape - assert input_array.shape == output_array.shape, \ - f"Input ({input_array.shape}) and output ({output_array.shape}) arrays " \ + assert input_array.shape == output_array.shape, ( + f"Input ({input_array.shape}) and output ({output_array.shape}) arrays " f"have different shapes" + ) # make sure it has the correct suffix if path.suffix not in ".zip": @@ -90,17 +145,7 @@ def export_to_bmz( # create environment file # TODO move in bioimage module env_path = temp_path / "environment.yml" - env_path.write_text( - f"name: careamics\n" - f"dependencies:\n" - f" - python=3.8\n" - f" - pytorch={pytorch_version}\n" - f" - torchvision={pytorch_version}\n" - f" - pip\n" - f" - pip:\n" - f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n" - ) - # TODO from pip with package version + env_path.write_text(create_env_text(pytorch_version)) # export input and ouputs inputs = temp_path / "inputs.npy" @@ -112,7 +157,7 @@ def export_to_bmz( config_path = save_configuration(config, temp_path) # export model state dictionary - weight_path = export_state_dict(model, temp_path / "weights.pth") + weight_path = _export_state_dict(model, temp_path / "weights.pth") # create model description model_description = create_model_description( @@ -138,3 +183,49 @@ def export_to_bmz( # save bmz model save_bioimageio_package(model_description, output_path=path) + + +def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: + """Load a model from a BioImage Model Zoo archive. + + Parameters + ---------- + path : Union[Path, str] + Path to the BioImage Model Zoo archive. + + Returns + ------- + Tuple[CAREamicsKiln, Configuration] + CAREamics model and configuration. + + Raises + ------ + ValueError + If the path is not a zip file. + """ + path = Path(path) + + if path.suffix != ".zip": + raise ValueError(f"Path must be a bioimage.io zip file, got {path}.") + + # load description, this creates an unzipped folder next to the archive + model_desc = load_description(path) + + # extract relative paths + weights_path, config_path = extract_model_path(model_desc) + + # create folder path and absolute paths + unzip_path = get_unzip_path(path) + weights_path = unzip_path / weights_path + config_path = unzip_path / config_path + + # load configuration + config = load_configuration(config_path) + + # create careamics lightning module + model = CAREamicsKiln(algorithm_config=config.algorithm_config) + + # load model state dictionary + _load_state_dict(model, weights_path) + + return model, config diff --git a/src/careamics/model_io/model_io_utils.py b/src/careamics/model_io/model_io_utils.py index ab1bc61e3..720ac49e0 100644 --- a/src/careamics/model_io/model_io_utils.py +++ b/src/careamics/model_io/model_io_utils.py @@ -1,11 +1,13 @@ """Utility functions to load pretrained models.""" + from pathlib import Path from typing import Tuple, Union -from torch import __version__, load, save +from torch import load from careamics.config import Configuration from careamics.lightning_module import CAREamicsKiln +from careamics.model_io.bmz_io import load_from_bmz from careamics.utils import check_path_exists @@ -13,7 +15,7 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio """ Load a pretrained model from a checkpoint or a BioImage Model Zoo model. - Expected formats are .ckpt, .zip, .pth or .pt files. + Expected formats are .ckpt or .zip files. Parameters ---------- @@ -22,8 +24,8 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio Returns ------- - CAREamicsKiln - CAREamics model loaded from the checkpoint. + Tuple[CAREamicsKiln, Configuration] + Tuple of CAREamics model and its configuration. Raises ------ @@ -33,120 +35,46 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio path = check_path_exists(path) if path.suffix == ".ckpt": - # load checkpoint - checkpoint: dict = load(path) - - # attempt to load algorithm parameters - try: - cfg_dict = checkpoint["hyper_parameters"] - except KeyError as e: - raise ValueError( - f"Invalid checkpoint file. No `hyper_parameters` found in the " - f"checkpoint: {checkpoint.keys()}" - ) from e - - model = _load_from_checkpoint(path) - - return model, Configuration(**cfg_dict) - + return _load_checkpoint(path) elif path.suffix == ".zip": - return _load_from_bmz(path) + return load_from_bmz(path) else: raise ValueError( - f"Invalid model format. Expected .ckpt or .zip, " f"got {path.suffix}." + f"Invalid model format. Expected .ckpt or .zip, got {path.suffix}." ) -def _load_from_checkpoint(path: Union[Path, str]) -> CAREamicsKiln: +def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: """ - Load a model from a checkpoint. + Load a model from a checkpoint and return both model and configuration. Parameters ---------- path : Union[Path, str] Path to the checkpoint. - Returns - ------- - CAREamicsKiln - CAREamics model loaded from the checkpoint. - """ - return CAREamicsKiln.load_from_checkpoint(path) - - -def _load_from_torch_dict( - path: Union[Path, str] -) -> Tuple[CAREamicsKiln, Configuration]: - """ - Load a model from a PyTorch dictionary. - - Parameters - ---------- - path : Union[Path, str] - Path to the PyTorch dictionary. - Returns ------- Tuple[CAREamicsKiln, Configuration] - CAREamics model and Configuration loaded from the BioImage Model Zoo. - """ - raise NotImplementedError( - "Loading a model from a PyTorch dictionary is not implemented yet." - ) - - -def _load_from_bmz( - path: Union[Path, str], -) -> Tuple[CAREamicsKiln, Configuration]: - """ - Load a model from BioImage Model Zoo. - - Parameters - ---------- - path : Union[Path, str] - Path to the BioImage Model Zoo model. - - Returns - ------- - Tuple[CAREamicsKiln, Configuration] - CAREamics model and Configuration loaded from the BioImage Model Zoo. + Tuple of CAREamics model and its configuration. Raises ------ - NotImplementedError - If the method is not implemented yet. - """ - raise NotImplementedError( - "Loading a model from BioImage Model Zoo is not implemented yet." - ) - - # load BMZ archive - # extract model and call _load_from_torch_dict - - -def export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path: - """ - Export the model state dictionary to a file. - - Parameters - ---------- - model : CAREamicsKiln - CAREamics model to export. - path : Union[Path, str] - Path to the file where to save the model state dictionary. - - Returns - ------- - Path - Path to the saved model state dictionary. + ValueError + If the checkpoint file does not contain hyper parameters (configuration). """ - path = Path(path) + # load checkpoint + checkpoint: dict = load(path) - # make sure it has the correct suffix - if path.suffix not in ".pth": - path = path.with_suffix(".pth") + # attempt to load configuration + try: + cfg_dict = checkpoint["hyper_parameters"] + except KeyError as e: + raise ValueError( + f"Invalid checkpoint file. No `hyper_parameters` found in the " + f"checkpoint: {checkpoint.keys()}" + ) from e - # save model state dictionary - save(model.model.state_dict(), path) + model = CAREamicsKiln.load_from_checkpoint(path) - return path \ No newline at end of file + return model, Configuration(**cfg_dict) diff --git a/src/careamics/utils/__init__.py b/src/careamics/utils/__init__.py index ef2929113..3207fe2b2 100644 --- a/src/careamics/utils/__init__.py +++ b/src/careamics/utils/__init__.py @@ -3,13 +3,11 @@ __all__ = [ "cwd", - "MetricTracker", "get_ram_size", "check_path_exists", "BaseEnum", "get_logger", "get_careamics_home", - "RunningStats", ] @@ -18,4 +16,3 @@ from .logging import get_logger from .path_utils import check_path_exists from .ram import get_ram_size -from .running_stats import RunningStats diff --git a/src/careamics/utils/metrics.py b/src/careamics/utils/metrics.py index b961fc2b5..23283c3a4 100644 --- a/src/careamics/utils/metrics.py +++ b/src/careamics/utils/metrics.py @@ -112,49 +112,3 @@ def scale_invariant_psnr( range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt) gt_ = _zero_mean(gt) / np.std(gt) return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter) - - -class MetricTracker: - """ - Metric tracker class. - - This class is used to track values, sum, count and average of a metric over time. - - Attributes - ---------- - val : int - Last value of the metric. - avg : torch.Tensor.float - Average value of the metric. - sum : int - Sum of the metric values (times number of values). - count : int - Number of values. - """ - - def __init__(self) -> None: - """Constructor.""" - self.reset() - - def reset(self) -> None: - """Reset the metric tracker state.""" - self.val = 0.0 - self.avg: torch.Tensor.float = 0.0 - self.sum = 0.0 - self.count = 0.0 - - def update(self, value: int, n: int = 1) -> None: - """ - Update the metric tracker state. - - Parameters - ---------- - value : int - Value to update the metric tracker with. - n : int - Number of values, equals to batch size. - """ - self.val = value - self.sum += value * n - self.count += n - self.avg = self.sum / self.count diff --git a/src/careamics/utils/running_stats.py b/src/careamics/utils/running_stats.py index 053f44d50..1268d3e43 100644 --- a/src/careamics/utils/running_stats.py +++ b/src/careamics/utils/running_stats.py @@ -1,46 +1,43 @@ -""" -Normalization submodule. - -These methods are used to normalize and denormalize images. -""" -from multiprocessing import Value -from typing import Tuple - -import numpy as np - - -class RunningStats: - """Calculates running mean and std.""" - - def __init__(self) -> None: - self.reset() - - def reset(self) -> None: - """Reset the running stats.""" - self.avg_mean = Value("d", 0) - self.avg_std = Value("d", 0) - self.m2 = Value("d", 0) - self.count = Value("i", 0) - - def init(self, mean: float, std: float) -> None: - """Initialize running stats.""" - with self.avg_mean.get_lock(): - self.avg_mean.value += mean - with self.avg_std.get_lock(): - self.avg_std.value = std - - def compute_std(self) -> Tuple[float, float]: - """Compute std.""" - if self.count.value >= 2: - self.avg_std.value = np.sqrt(self.m2.value / self.count.value) - - def update(self, value: float) -> None: - """Update running stats.""" - with self.count.get_lock(): - self.count.value += 1 - delta = value - self.avg_mean.value - with self.avg_mean.get_lock(): - self.avg_mean.value += delta / self.count.value - delta2 = value - self.avg_mean.value - with self.m2.get_lock(): - self.m2.value += delta * delta2 +"""Running stats submodule, used in the Zarr dataset.""" + +# from multiprocessing import Value +# from typing import Tuple + +# import numpy as np + + +# class RunningStats: +# """Calculates running mean and std.""" + +# def __init__(self) -> None: +# self.reset() + +# def reset(self) -> None: +# """Reset the running stats.""" +# self.avg_mean = Value("d", 0) +# self.avg_std = Value("d", 0) +# self.m2 = Value("d", 0) +# self.count = Value("i", 0) + +# def init(self, mean: float, std: float) -> None: +# """Initialize running stats.""" +# with self.avg_mean.get_lock(): +# self.avg_mean.value += mean +# with self.avg_std.get_lock(): +# self.avg_std.value = std + +# def compute_std(self) -> Tuple[float, float]: +# """Compute std.""" +# if self.count.value >= 2: +# self.avg_std.value = np.sqrt(self.m2.value / self.count.value) + +# def update(self, value: float) -> None: +# """Update running stats.""" +# with self.count.get_lock(): +# self.count.value += 1 +# delta = value - self.avg_mean.value +# with self.avg_mean.get_lock(): +# self.avg_mean.value += delta / self.count.value +# delta2 = value - self.avg_mean.value +# with self.m2.get_lock(): +# self.m2.value += delta * delta2 diff --git a/tests/conftest.py b/tests/conftest.py index 4bc99fd46..3e41040ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from careamics import CAREamist, Configuration from careamics.config.support import SupportedData +from careamics.model_io import export_to_bmz # TODO add details about where each of these fixture is used (e.g. smoke test) @@ -267,3 +268,32 @@ def pre_trained(tmp_path, minimum_configuration): assert pre_trained_path.exists() return pre_trained_path + + +@pytest.fixture +def pre_trained_bmz(tmp_path, pre_trained) -> Path: + """Fixture to create a BMZ model.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # export to BioImage Model Zoo + path = tmp_path / "model.zip" + export_to_bmz( + model=careamist.model, + config=careamist.cfg, + path=path, + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + input_array=train_array[np.newaxis, np.newaxis, ...], + output_array=predicted, + ) + assert path.exists() + + return path diff --git a/tests/model_io/test_bmz_io.py b/tests/model_io/test_bmz_io.py new file mode 100644 index 000000000..a031aa0d2 --- /dev/null +++ b/tests/model_io/test_bmz_io.py @@ -0,0 +1,66 @@ +import numpy as np +from torch import Tensor + +from careamics import CAREamist +from careamics.model_io import export_to_bmz, load_pretrained +from careamics.model_io.bmz_io import _export_state_dict, _load_state_dict + + +def test_state_dict_io(tmp_path, pre_trained): + """Test exporting and loading a state dict.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + path = tmp_path / "model.pth" + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # save model + _export_state_dict(careamist.model, path) + assert path.exists() + + # load model + _load_state_dict(careamist.model, path) + + # predict (no tiling and no tta) + predicted_loaded = careamist.predict(train_array, tta_transforms=False) + assert (predicted_loaded == predicted).all() + + +def test_bmz_io(tmp_path, pre_trained): + """Test exporting and loading to the BMZ.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # export to BioImage Model Zoo + path = tmp_path / "model.zip" + export_to_bmz( + model=careamist.model, + config=careamist.cfg, + path=path, + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + input_array=train_array[np.newaxis, np.newaxis, ...], + output_array=predicted, + ) + assert path.exists() + + # load model + model, config = load_pretrained(path) + assert config == careamist.cfg + + # compare predictions + torch_array = Tensor(train_array[np.newaxis, np.newaxis, ...]) + predicted = careamist.model.forward(torch_array).detach().numpy().squeeze() + predicted_loaded = model.forward(torch_array).detach().numpy().squeeze() + assert (predicted_loaded == predicted).all() diff --git a/tests/model_io/test_export_bmz.py b/tests/model_io/test_export_bmz.py deleted file mode 100644 index 6eb3850a6..000000000 --- a/tests/model_io/test_export_bmz.py +++ /dev/null @@ -1,28 +0,0 @@ -import numpy as np - -from careamics import CAREamist -from careamics.model_io import export_to_bmz - - -def test_export_bmz(tmp_path, pre_trained): - # training data - train_array = np.ones((32, 32), dtype=np.float32) - - # instantiate CAREamist - careamist = CAREamist(source=pre_trained, work_dir=tmp_path) - - # predict (no tiling and no tta) - predicted = careamist.predict(train_array, tta_transforms=False) - - # export to BioImage Model Zoo - export_to_bmz( - model=careamist.model, - config=careamist.cfg, - path=tmp_path / "model.zip", - name="TopModel", - general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}], - input_array=train_array[np.newaxis, np.newaxis, ...], - output_array=predicted, - ) - assert (tmp_path / "model.zip").exists() diff --git a/tests/test_careamist.py b/tests/test_careamist.py index bbbc60a5c..e53e7e0ee 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -8,14 +8,15 @@ from careamics.config.support import SupportedAlgorithm, SupportedData # TODO test 3D and channels - + + def test_no_parameters(): """Test that CAREamics cannot be instantiated without parameters.""" with pytest.raises(TypeError): CAREamist() -def test_minimum_configuration_via_object(tmp_path, minimum_configuration): +def test_minimum_configuration_via_object(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be instantiated with a minimum configuration object.""" # create configuration config = Configuration(**minimum_configuration) @@ -24,7 +25,7 @@ def test_minimum_configuration_via_object(tmp_path, minimum_configuration): CAREamist(source=config, work_dir=tmp_path) -def test_minimum_configuration_via_path(tmp_path, minimum_configuration): +def test_minimum_configuration_via_path(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be instantiated with a path to a minimum configuration. """ @@ -36,7 +37,9 @@ def test_minimum_configuration_via_path(tmp_path, minimum_configuration): CAREamist(source=path_to_config) -def test_train_error_target_unsupervised_algorithm(tmp_path, minimum_configuration): +def test_train_error_target_unsupervised_algorithm( + tmp_path: Path, minimum_configuration: dict +): """Test that an error is raised when a target is provided for N2V.""" # create configuration config = Configuration(**minimum_configuration) @@ -68,7 +71,7 @@ def test_train_error_target_unsupervised_algorithm(tmp_path, minimum_configurati ) -def test_train_single_array_no_val(tmp_path, minimum_configuration): +def test_train_single_array_no_val(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained with arrays.""" # training data train_array = np.random.rand(32, 32) @@ -95,12 +98,12 @@ def test_train_single_array_no_val(tmp_path, minimum_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_train_array(tmp_path, minimum_configuration): +def test_train_array(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained on arrays.""" # training data train_array = np.random.rand(32, 32) @@ -128,12 +131,12 @@ def test_train_array(tmp_path, minimum_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_train_array_channel(tmp_path, minimum_configuration): +def test_train_array_channel(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained on arrays with channels.""" # training data train_array = np.random.rand(32, 32, 3) @@ -162,12 +165,12 @@ def test_train_array_channel(tmp_path, minimum_configuration): name="TopModel", general_description="A model that just walked in.", authors=[{"name": "Amod", "affiliation": "El"}], - channel_names=["red", "green", "blue"] + channel_names=["red", "green", "blue"], ) assert (tmp_path / "model.zip").exists() -def test_train_array_3d(tmp_path, minimum_configuration): +def test_train_array_3d(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained on 3D arrays.""" # training data train_array = np.random.rand(8, 32, 32) @@ -195,12 +198,12 @@ def test_train_array_3d(tmp_path, minimum_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_train_tiff_files_in_memory_no_val(tmp_path, minimum_configuration): +def test_train_tiff_files_in_memory_no_val(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained with tiff files in memory.""" # training data train_array = np.random.rand(32, 32) @@ -231,12 +234,12 @@ def test_train_tiff_files_in_memory_no_val(tmp_path, minimum_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_train_tiff_files_in_memory(tmp_path, minimum_configuration): +def test_train_tiff_files_in_memory(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained with tiff files in memory.""" # training data train_array = np.random.rand(32, 32) @@ -271,12 +274,12 @@ def test_train_tiff_files_in_memory(tmp_path, minimum_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_train_tiff_files(tmp_path, minimum_configuration): +def test_train_tiff_files(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained with tiff files by deactivating the in memory dataset. """ @@ -313,12 +316,12 @@ def test_train_tiff_files(tmp_path, minimum_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_train_array_supervised(tmp_path, supervised_configuration): +def test_train_array_supervised(tmp_path: Path, supervised_configuration: dict): """Test that CAREamics can be trained with arrays.""" # training data train_array = np.random.rand(32, 32) @@ -353,12 +356,14 @@ def test_train_array_supervised(tmp_path, supervised_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_train_tiff_files_in_memory_supervised(tmp_path, supervised_configuration): +def test_train_tiff_files_in_memory_supervised( + tmp_path: Path, supervised_configuration: dict +): """Test that CAREamics can be trained with tiff files in memory.""" # training data train_array = np.random.rand(32, 32) @@ -410,12 +415,12 @@ def test_train_tiff_files_in_memory_supervised(tmp_path, supervised_configuratio path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_train_tiff_files_supervised(tmp_path, supervised_configuration): +def test_train_tiff_files_supervised(tmp_path: Path, supervised_configuration: dict): """Test that CAREamics can be trained with tiff files by deactivating the in memory dataset. """ @@ -470,13 +475,15 @@ def test_train_tiff_files_supervised(tmp_path, supervised_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() @pytest.mark.parametrize("batch_size", [1, 2]) -def test_predict_on_array_tiled(tmp_path, minimum_configuration, batch_size): +def test_predict_on_array_tiled( + tmp_path: Path, minimum_configuration: dict, batch_size +): """Test that CAREamics can predict on arrays.""" # training data train_array = np.random.rand(32, 32) @@ -507,12 +514,12 @@ def test_predict_on_array_tiled(tmp_path, minimum_configuration, batch_size): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_predict_arrays_no_tiling(tmp_path, minimum_configuration): +def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can predict on arrays without tiling.""" # training data train_array = np.random.rand(4, 32, 32) @@ -541,13 +548,13 @@ def test_predict_arrays_no_tiling(tmp_path, minimum_configuration): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() @pytest.mark.parametrize("batch_size", [1, 2]) -def test_predict_path(tmp_path, minimum_configuration, batch_size): +def test_predict_path(tmp_path: Path, minimum_configuration: dict, batch_size): """Test that CAREamics can predict with tiff files.""" # training data train_array = np.random.rand(32, 32) @@ -571,9 +578,7 @@ def test_predict_path(tmp_path, minimum_configuration, batch_size): careamist.train(train_source=train_file) # predict CAREamist - predicted = careamist.predict( - train_file, batch_size=batch_size, tile_overlap=(4, 4) - ) + predicted = careamist.predict(train_file, batch_size=batch_size) # check that it predicted assert predicted.squeeze().shape == train_array.shape @@ -583,16 +588,16 @@ def test_predict_path(tmp_path, minimum_configuration, batch_size): path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_predict_pretrained(tmp_path, pre_trained): +def test_predict_pretrained_checkpoint(tmp_path: Path, pre_trained: Path): """Test that CAREamics can be instantiated with a pre-trained network and predict on an array.""" - # training data - train_array = np.random.rand(32, 32) + # prediction data + source_array = np.random.rand(32, 32) # instantiate CAREamist careamist = CAREamist(source=pre_trained, work_dir=tmp_path) @@ -600,45 +605,87 @@ def test_predict_pretrained(tmp_path, pre_trained): assert careamist.cfg.data_config.std is not None # predict - predicted = careamist.predict(train_array, tile_overlap=(4, 4)) + predicted = careamist.predict(source_array) # check that it predicted - assert predicted.squeeze().shape == train_array.shape + assert predicted.squeeze().shape == source_array.shape - # export to BMZ + +def test_predict_pretrained_bmz(tmp_path: Path, pre_trained_bmz: Path): + """Test that CAREamics can be instantiated with a BMZ archive and predict.""" + # prediction data + source_array = np.random.rand(32, 32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained_bmz, work_dir=tmp_path) + + # predict + predicted = careamist.predict(source_array) + + # check that it predicted + assert predicted.squeeze().shape == source_array.shape + + +def test_export_bmz_pretrained_prediction(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ after prediction. + + In this case, the careamist extracts the BMZ test data from the prediction + datamodule. + """ + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # prediction data + source_array = np.random.rand(32, 32) + _ = careamist.predict(source_array) + assert len(careamist.pred_datamodule.predict_dataloader()) > 0 + + # export to BMZ (random array created) careamist.export_to_bmz( path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() -def test_expot_bmz_pretrained(tmp_path, pre_trained): +def test_export_bmz_pretrained_random_array(tmp_path: Path, pre_trained: Path): """Test that CAREamics can be instantiated with a pre-trained network and exported - to BMZ.""" - # training data - train_array = np.random.rand(32, 32).astype(np.float32) + to BMZ. + In this case, the careamist creates a random array for the BMZ archive test. + """ # instantiate CAREamist careamist = CAREamist(source=pre_trained, work_dir=tmp_path) - + # export to BMZ (random array created) careamist.export_to_bmz( path=tmp_path / "model.zip", name="TopModel", general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) assert (tmp_path / "model.zip").exists() - # pass an array + +def test_export_bmz_pretrained_with_array(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ. + + In this case, we provide an array to the BMZ archive test. + """ + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # alternatively we can pass an array + array = np.random.rand(32, 32).astype(np.float32) careamist.export_to_bmz( path=tmp_path / "model2.zip", name="TopModel", - input_array=train_array[np.newaxis, np.newaxis, ...], + input_array=array[np.newaxis, np.newaxis, ...], general_description="A model that just walked in.", - authors=[{"name": "Amod", "affiliation": "El"}] + authors=[{"name": "Amod", "affiliation": "El"}], ) - assert (tmp_path / "model2.zip").exists() \ No newline at end of file + assert (tmp_path / "model2.zip").exists() diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py index 4d6c85d68..b91e66909 100644 --- a/tests/utils/test_metrics.py +++ b/tests/utils/test_metrics.py @@ -2,7 +2,6 @@ import pytest from careamics.utils.metrics import ( - MetricTracker, _zero_mean, scale_invariant_psnr, ) @@ -29,24 +28,3 @@ def test_zero_mean(x): ) def test_scale_invariant_psnr(gt, pred, result): assert scale_invariant_psnr(gt, pred) == pytest.approx(result, rel=5e-3) - - -def test_metric_tracker(): - tracker = MetricTracker() - - # check initial state - assert tracker.sum == 0 - assert tracker.count == 0 - assert tracker.avg == 0 - assert tracker.val == 0 - - # run a few updates - n = 5 - for i in range(n): - tracker.update(i, n) - - # check values - assert tracker.sum == n * (n * (n - 1)) / 2 - assert tracker.count == n * n - assert tracker.avg == (n - 1) / 2 - assert tracker.val == n - 1 From a617d95e15373ba7759383a4fc206cc45dc0afd4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 27 Apr 2024 14:42:11 +0000 Subject: [PATCH 14/14] style(pre-commit.ci): auto fixes [...] --- examples/2D/n2n/example_SEM_careamist.ipynb | 10 ++- examples/2D/n2v/example_BSD68_careamist.ipynb | 2 +- examples/2D/n2v/example_SEM_lightning.ipynb | 10 +-- examples/3D/example_flywing_3D.ipynb | 1 + src/careamics/config/configuration_model.py | 48 +++++----- src/careamics/config/references/__init__.py | 33 ++++--- .../references/algorithm_descriptions.py | 88 +++++++++++-------- src/careamics/config/references/references.py | 16 ++-- src/careamics/config/validators/__init__.py | 2 +- src/careamics/conftest.py | 1 - src/careamics/lightning_prediction_loop.py | 6 +- .../config/validators/test_validator_utils.py | 13 +-- tests/models/test_model_factory.py | 2 +- tests/test_lightning_module.py | 21 +++-- 14 files changed, 132 insertions(+), 121 deletions(-) diff --git a/examples/2D/n2n/example_SEM_careamist.ipynb b/examples/2D/n2n/example_SEM_careamist.ipynb index 59eca1857..db7738060 100644 --- a/examples/2D/n2n/example_SEM_careamist.ipynb +++ b/examples/2D/n2n/example_SEM_careamist.ipynb @@ -14,7 +14,7 @@ "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", "\n", - "from careamics import CAREamist\n" + "from careamics import CAREamist" ] }, { @@ -156,8 +156,12 @@ "metadata": {}, "outputs": [], "source": [ - "engine.train(train_source=train_image[0], val_source=train_image[1],\n", - " train_target=train_image[2], val_target=train_image[3],)" + "engine.train(\n", + " train_source=train_image[0],\n", + " val_source=train_image[1],\n", + " train_target=train_image[2],\n", + " val_target=train_image[3],\n", + ")" ] }, { diff --git a/examples/2D/n2v/example_BSD68_careamist.ipynb b/examples/2D/n2v/example_BSD68_careamist.ipynb index df528ed2f..81698274d 100644 --- a/examples/2D/n2v/example_BSD68_careamist.ipynb +++ b/examples/2D/n2v/example_BSD68_careamist.ipynb @@ -205,7 +205,7 @@ "source": [ "# Create a list of ground truth images\n", "\n", - "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]\n" + "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]" ] }, { diff --git a/examples/2D/n2v/example_SEM_lightning.ipynb b/examples/2D/n2v/example_SEM_lightning.ipynb index 1442f0f09..93e4469c2 100644 --- a/examples/2D/n2v/example_SEM_lightning.ipynb +++ b/examples/2D/n2v/example_SEM_lightning.ipynb @@ -6,20 +6,20 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", "import shutil\n", - "import albumentations as Aug\n", + "from pathlib import Path\n", + "\n", "import matplotlib.pyplot as plt\n", "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", "from pytorch_lightning import Trainer\n", "\n", "from careamics import CAREamicsModule\n", - "from careamics.lightning_prediction import CAREamicsPredictionLoop\n", "from careamics.lightning_datamodule import (\n", " CAREamicsPredictDataModule,\n", " CAREamicsTrainDataModule,\n", - ")" + ")\n", + "from careamics.lightning_prediction import CAREamicsPredictionLoop" ] }, { @@ -142,7 +142,7 @@ " model_parameters={\"n2v2\": False},\n", " optimizer_parameters={\"lr\": 1e-3},\n", " lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n", - ")\n" + ")" ] }, { diff --git a/examples/3D/example_flywing_3D.ipynb b/examples/3D/example_flywing_3D.ipynb index 8a3297c73..512f8ef9b 100644 --- a/examples/3D/example_flywing_3D.ipynb +++ b/examples/3D/example_flywing_3D.ipynb @@ -12,6 +12,7 @@ "import numpy as np\n", "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", + "\n", "# from itkwidgets import compare, view # \"pip install itkwidgets \"if necessary\n", "from pytorch_lightning import Trainer\n", "\n", diff --git a/src/careamics/config/configuration_model.py b/src/careamics/config/configuration_model.py index abfb3712e..4c8e5dca8 100644 --- a/src/careamics/config/configuration_model.py +++ b/src/careamics/config/configuration_model.py @@ -13,24 +13,24 @@ from .algorithm_model import AlgorithmModel from .data_model import DataModel from .references import ( - N2V2Ref, - N2VRef, - StructN2VRef, - N2NRef, - CARERef, - N2VDescription, - N2V2Description, - StructN2VDescription, - StructN2V2Description, - N2NDescription, - CAREDescription, + CARE, + CUSTOM, + N2N, N2V, N2V2, STRUCT_N2V, STRUCT_N2V2, - CUSTOM, - N2N, - CARE + CAREDescription, + CARERef, + N2NDescription, + N2NRef, + N2V2Description, + N2V2Ref, + N2VDescription, + N2VRef, + StructN2V2Description, + StructN2VDescription, + StructN2VRef, ) from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform from .training_model import TrainingModel @@ -245,7 +245,7 @@ def validate_algorithm_and_data(self: Configuration) -> Configuration: name=SupportedTransform.N2V_MANIPULATE.value, ) ) - + median = SupportedPixelManipulation.MEDIAN.value uniform = SupportedPixelManipulation.UNIFORM.value strategy = median if self.algorithm_config.model.n2v2 else uniform @@ -340,9 +340,7 @@ def get_algorithm_flavour(self) -> str: """ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = ( - self.data_config.transforms[-1].struct_mask_axis != "none" - ) + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" # return the n2v flavour if use_n2v2 and use_structN2V: @@ -359,7 +357,7 @@ def get_algorithm_flavour(self) -> str: return CARE else: return CUSTOM - + def get_algorithm_description(self) -> str: """ Return a description of the algorithm. @@ -404,9 +402,7 @@ def get_algorithm_citations(self) -> List[CiteEntry]: """ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = ( - self.data_config.transforms[-1].struct_mask_axis != "none" - ) + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" # return the (struct)N2V(2) references if use_n2v2 and use_structN2V: @@ -437,9 +433,7 @@ def get_algorithm_references(self) -> str: """ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = ( - self.data_config.transforms[-1].struct_mask_axis != "none" - ) + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" references = [ N2VRef.text + " doi: " + N2VRef.doi, @@ -472,9 +466,7 @@ def get_algorithm_keywords(self) -> List[str]: """ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = ( - self.data_config.transforms[-1].struct_mask_axis != "none" - ) + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" keywords = [ "denoising", diff --git a/src/careamics/config/references/__init__.py b/src/careamics/config/references/__init__.py index 1139297b7..b314f1c89 100644 --- a/src/careamics/config/references/__init__.py +++ b/src/careamics/config/references/__init__.py @@ -21,26 +21,25 @@ "N2NRef", ] -from .references import ( - N2V2Ref, - N2VRef, - StructN2VRef, - N2NRef, - CARERef, -) - from .algorithm_descriptions import ( - N2VDescription, - N2V2Description, - StructN2VDescription, - StructN2V2Description, - N2NDescription, - CAREDescription, + CARE, + CUSTOM, + N2N, N2V, N2V2, STRUCT_N2V, STRUCT_N2V2, - CUSTOM, - N2N, - CARE + CAREDescription, + N2NDescription, + N2V2Description, + N2VDescription, + StructN2V2Description, + StructN2VDescription, +) +from .references import ( + CARERef, + N2NRef, + N2V2Ref, + N2VRef, + StructN2VRef, ) diff --git a/src/careamics/config/references/algorithm_descriptions.py b/src/careamics/config/references/algorithm_descriptions.py index 5fe2704a6..97e4ba83f 100644 --- a/src/careamics/config/references/algorithm_descriptions.py +++ b/src/careamics/config/references/algorithm_descriptions.py @@ -10,15 +10,18 @@ CARE = "CARE" -N2V_DESCRIPTION = "Noise2Void is a UNet-based self-supervised algorithm that " \ - "uses blind-spot training to denoise images. In short, in every " \ - "patches during training, random pixels are selected and their " \ - "value replaced by a neighboring pixel value. The network is then " \ - "trained to predict the original pixel value. The algorithm " \ - "relies on the continuity of the signal (neighboring pixels have " \ - "similar values) and the pixel-wise independence of the noise " \ - "(the noise in a pixel is not correlated with the noise in " \ - "neighboring pixels)." +N2V_DESCRIPTION = ( + "Noise2Void is a UNet-based self-supervised algorithm that " + "uses blind-spot training to denoise images. In short, in every " + "patches during training, random pixels are selected and their " + "value replaced by a neighboring pixel value. The network is then " + "trained to predict the original pixel value. The algorithm " + "relies on the continuity of the signal (neighboring pixels have " + "similar values) and the pixel-wise independence of the noise " + "(the noise in a pixel is not correlated with the noise in " + "neighboring pixels)." +) + class AlgorithmDescription(BaseModel): """Description of an algorithm. @@ -28,35 +31,38 @@ class AlgorithmDescription(BaseModel): description : str Description of the algorithm. """ - + description: str class N2VDescription(AlgorithmDescription): """Description of Noise2Void. - + Attributes ---------- description : str Description of Noise2Void. """ - + description: str = N2V_DESCRIPTION class N2V2Description(AlgorithmDescription): """Description of N2V2. - + Attributes ---------- description : str Description of N2V2. """ - - description: str = "N2V2 is a variant of Noise2Void. " + N2V_DESCRIPTION + \ - "\nN2V2 introduces blur-pool layers and removed skip " \ - "connections in the UNet architecture to remove checkboard " \ - "artefacts, a common artefacts ocurring in Noise2Void." + + description: str = ( + "N2V2 is a variant of Noise2Void. " + + N2V_DESCRIPTION + + "\nN2V2 introduces blur-pool layers and removed skip " + "connections in the UNet architecture to remove checkboard " + "artefacts, a common artefacts ocurring in Noise2Void." + ) class StructN2VDescription(AlgorithmDescription): @@ -67,12 +73,15 @@ class StructN2VDescription(AlgorithmDescription): description : str Description of StructN2V. """ - - description: str = "StructN2V is a variant of Noise2Void. " + N2V_DESCRIPTION + \ - "\nStructN2V uses a linear mask (horizontal or vertical) to replace " \ - "the pixel values of neighbors of the masked pixels by a random " \ - "value. Such masking allows removing 1D structured noise from the " \ - "the images, the main failure case of the original N2V." + + description: str = ( + "StructN2V is a variant of Noise2Void. " + + N2V_DESCRIPTION + + "\nStructN2V uses a linear mask (horizontal or vertical) to replace " + "the pixel values of neighbors of the masked pixels by a random " + "value. Such masking allows removing 1D structured noise from the " + "the images, the main failure case of the original N2V." + ) class StructN2V2Description(AlgorithmDescription): @@ -84,16 +93,19 @@ class StructN2V2Description(AlgorithmDescription): Description of StructN2V2. """ - description: str = "StructN2V2 is a a variant of Noise2Void that uses both " \ - "structN2V and N2V2. "+ N2V_DESCRIPTION + \ - "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace " \ - "the pixel values of neighbors of the masked pixels by a random " \ - "value. Such masking allows removing 1D structured noise from the " \ - "the images, the main failure case of the original N2V." \ - "\nN2V2 introduces blur-pool layers and removed skip connections in " \ - "the UNet architecture to remove checkboard artefacts, a common " \ - "artefacts ocurring in Noise2Void." - + description: str = ( + "StructN2V2 is a a variant of Noise2Void that uses both " + "structN2V and N2V2. " + + N2V_DESCRIPTION + + "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace " + "the pixel values of neighbors of the masked pixels by a random " + "value. Such masking allows removing 1D structured noise from the " + "the images, the main failure case of the original N2V." + "\nN2V2 introduces blur-pool layers and removed skip connections in " + "the UNet architecture to remove checkboard artefacts, a common " + "artefacts ocurring in Noise2Void." + ) + class N2NDescription(AlgorithmDescription): """Description of Noise2Noise. @@ -103,8 +115,8 @@ class N2NDescription(AlgorithmDescription): description : str Description of Noise2Noise. """ - - description: str = "Noise2Noise" # TODO + + description: str = "Noise2Noise" # TODO class CAREDescription(AlgorithmDescription): @@ -115,5 +127,5 @@ class CAREDescription(AlgorithmDescription): description : str Description of CARE. """ - - description: str = "CARE" # TODO \ No newline at end of file + + description: str = "CARE" # TODO diff --git a/src/careamics/config/references/references.py b/src/careamics/config/references/references.py index 28ad8f4f7..60c8413f9 100644 --- a/src/careamics/config/references/references.py +++ b/src/careamics/config/references/references.py @@ -18,21 +18,21 @@ StructN2VRef = CiteEntry( text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020." - '\"Removing structured noise with self-supervised blind-spot ' - 'networks\". In 2020 IEEE 17th International Symposium on Biomedical ' + '"Removing structured noise with self-supervised blind-spot ' + 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' "Imaging (ISBI) (pp. 159-163).", doi="10.1109/isbi45749.2020.9098336", ) N2NRef = CiteEntry( - text='Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., ' - 'Aittala, M. and Aila, T., 2018. \"Noise2Noise: Learning image restoration ' - 'without clean data\". arXiv preprint arXiv:1803.04189.', + text="Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., " + 'Aittala, M. and Aila, T., 2018. "Noise2Noise: Learning image restoration ' + 'without clean data". arXiv preprint arXiv:1803.04189.', doi="10.48550/arXiv.1803.04189", ) CARERef = CiteEntry( - text='Weigert, Martin, et al. \"Content-aware image restoration: pushing the ' - 'limits of fluorescence microscopy.\" Nature methods 15.12 (2018): 1090-1097.', + text='Weigert, Martin, et al. "Content-aware image restoration: pushing the ' + 'limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.', doi="10.1038/s41592-018-0216-7", -) \ No newline at end of file +) diff --git a/src/careamics/config/validators/__init__.py b/src/careamics/config/validators/__init__.py index 7b82a6a7e..53ddbf8db 100644 --- a/src/careamics/config/validators/__init__.py +++ b/src/careamics/config/validators/__init__.py @@ -2,4 +2,4 @@ __all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"] -from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2 \ No newline at end of file +from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2 diff --git a/src/careamics/conftest.py b/src/careamics/conftest.py index 929182a07..e0a1fae6c 100644 --- a/src/careamics/conftest.py +++ b/src/careamics/conftest.py @@ -16,7 +16,6 @@ def my_path(tmpdir_factory: TempPathFactory) -> Path: return tmpdir_factory.mktemp("my_path") - pytest_collect_file = Sybil( parsers=[ DocTestParser(), diff --git a/src/careamics/lightning_prediction_loop.py b/src/careamics/lightning_prediction_loop.py index ca11da96a..c7e00fd2e 100644 --- a/src/careamics/lightning_prediction_loop.py +++ b/src/careamics/lightning_prediction_loop.py @@ -36,12 +36,10 @@ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: ################ CAREamics specific code ############### if len(self.predicted_array) == 1: # TODO does this make sense to here? (force numpy array) - return self.predicted_array[0].numpy() + return self.predicted_array[0].numpy() else: # TODO revisit logic - return [ - element.numpy() for element in self.predicted_array - ] + return [element.numpy() for element in self.predicted_array] ######################################################## return None diff --git a/tests/config/validators/test_validator_utils.py b/tests/config/validators/test_validator_utils.py index afc7d8210..740f10a51 100644 --- a/tests/config/validators/test_validator_utils.py +++ b/tests/config/validators/test_validator_utils.py @@ -1,9 +1,11 @@ import pytest from careamics.config.validators import ( - check_axes_validity, patch_size_ge_than_8_power_of_2 + check_axes_validity, + patch_size_ge_than_8_power_of_2, ) + @pytest.mark.parametrize( "axes, valid", [ @@ -46,7 +48,8 @@ def test_are_axes_valid(axes, valid): check_axes_validity(axes) -@pytest.mark.parametrize("patch_size, error", +@pytest.mark.parametrize( + "patch_size, error", [ ((2, 8, 8), True), ((10,), True), @@ -56,12 +59,12 @@ def test_are_axes_valid(axes, valid): ((8,), False), ((8, 8), False), ((8, 64, 64), False), - ] -) + ], +) def test_patch_size(patch_size, error): """Test if patch size is valid.""" if error: with pytest.raises(ValueError): patch_size_ge_than_8_power_of_2(patch_size) else: - patch_size_ge_than_8_power_of_2(patch_size) \ No newline at end of file + patch_size_ge_than_8_power_of_2(patch_size) diff --git a/tests/models/test_model_factory.py b/tests/models/test_model_factory.py index 08f7075bb..f4526aaee 100644 --- a/tests/models/test_model_factory.py +++ b/tests/models/test_model_factory.py @@ -42,7 +42,7 @@ def forward(self, input): model_config = { "architecture": SupportedArchitecture.CUSTOM.value, "name": "linear_model", - "in_features": 10, + "in_features": 10, "out_features": 5, } diff --git a/tests/test_lightning_module.py b/tests/test_lightning_module.py index fff2bf522..afe2fab09 100644 --- a/tests/test_lightning_module.py +++ b/tests/test_lightning_module.py @@ -34,12 +34,13 @@ def test_careamics_kiln(minimum_algorithm_n2v): CAREamicsKiln(algo_config) -@pytest.mark.parametrize("shape", +@pytest.mark.parametrize( + "shape", [ (8, 8), (16, 16), (32, 32), - ] + ], ) def test_careamics_kiln_unet_2D_depth_2_shape(shape): algo_dict = { @@ -64,7 +65,8 @@ def test_careamics_kiln_unet_2D_depth_2_shape(shape): assert y.shape == x.shape -@pytest.mark.parametrize("shape", +@pytest.mark.parametrize( + "shape", [ (8, 8), (16, 16), @@ -72,7 +74,7 @@ def test_careamics_kiln_unet_2D_depth_2_shape(shape): (64, 64), (128, 128), (256, 256), - ] + ], ) def test_careamics_kiln_unet_2D_depth_3_shape(shape): algo_dict = { @@ -97,13 +99,14 @@ def test_careamics_kiln_unet_2D_depth_3_shape(shape): assert y.shape == x.shape -@pytest.mark.parametrize("shape", +@pytest.mark.parametrize( + "shape", [ (8, 32, 16), (16, 32, 16), (8, 32, 32), (32, 64, 64), - ] + ], ) def test_careamics_kiln_unet_depth_2_3D(shape): algo_dict = { @@ -128,13 +131,14 @@ def test_careamics_kiln_unet_depth_2_3D(shape): assert y.shape == x.shape -@pytest.mark.parametrize("shape", +@pytest.mark.parametrize( + "shape", [ (8, 64, 64), (16, 64, 64), (16, 128, 128), (32, 128, 128), - ] + ], ) def test_careamics_kiln_unet_depth_3_3D(shape): algo_dict = { @@ -252,4 +256,3 @@ def test_careamics_kiln_unet_depth_3_channels_3D(n_channels): x = torch.rand((1, n_channels, 16, 64, 64)) y: torch.Tensor = model.forward(x) assert y.shape == x.shape -