diff --git a/examples/2D/n2n/example_SEM_careamist.ipynb b/examples/2D/n2n/example_SEM_careamist.ipynb index 59eca1857..db7738060 100644 --- a/examples/2D/n2n/example_SEM_careamist.ipynb +++ b/examples/2D/n2n/example_SEM_careamist.ipynb @@ -14,7 +14,7 @@ "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", "\n", - "from careamics import CAREamist\n" + "from careamics import CAREamist" ] }, { @@ -156,8 +156,12 @@ "metadata": {}, "outputs": [], "source": [ - "engine.train(train_source=train_image[0], val_source=train_image[1],\n", - " train_target=train_image[2], val_target=train_image[3],)" + "engine.train(\n", + " train_source=train_image[0],\n", + " val_source=train_image[1],\n", + " train_target=train_image[2],\n", + " val_target=train_image[3],\n", + ")" ] }, { diff --git a/examples/2D/n2v/example_BSD68_careamist.ipynb b/examples/2D/n2v/example_BSD68_careamist.ipynb index df528ed2f..81698274d 100644 --- a/examples/2D/n2v/example_BSD68_careamist.ipynb +++ b/examples/2D/n2v/example_BSD68_careamist.ipynb @@ -205,7 +205,7 @@ "source": [ "# Create a list of ground truth images\n", "\n", - "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]\n" + "gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]" ] }, { diff --git a/examples/2D/n2v/example_SEM_lightning.ipynb b/examples/2D/n2v/example_SEM_lightning.ipynb index 1442f0f09..93e4469c2 100644 --- a/examples/2D/n2v/example_SEM_lightning.ipynb +++ b/examples/2D/n2v/example_SEM_lightning.ipynb @@ -6,20 +6,20 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", "import shutil\n", - "import albumentations as Aug\n", + "from pathlib import Path\n", + "\n", "import matplotlib.pyplot as plt\n", "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", "from pytorch_lightning import Trainer\n", "\n", "from careamics import CAREamicsModule\n", - "from careamics.lightning_prediction import CAREamicsPredictionLoop\n", "from careamics.lightning_datamodule import (\n", " CAREamicsPredictDataModule,\n", " CAREamicsTrainDataModule,\n", - ")" + ")\n", + "from careamics.lightning_prediction import CAREamicsPredictionLoop" ] }, { @@ -142,7 +142,7 @@ " model_parameters={\"n2v2\": False},\n", " optimizer_parameters={\"lr\": 1e-3},\n", " lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n", - ")\n" + ")" ] }, { diff --git a/examples/3D/example_flywing_3D.ipynb b/examples/3D/example_flywing_3D.ipynb index 8a3297c73..512f8ef9b 100644 --- a/examples/3D/example_flywing_3D.ipynb +++ b/examples/3D/example_flywing_3D.ipynb @@ -12,6 +12,7 @@ "import numpy as np\n", "import tifffile\n", "from careamics_portfolio import PortfolioManager\n", + "\n", "# from itkwidgets import compare, view # \"pip install itkwidgets \"if necessary\n", "from pytorch_lightning import Trainer\n", "\n", diff --git a/pyproject.toml b/pyproject.toml index 70f38a57b..c81af28a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,12 +37,13 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + # pytorch should be installed via conda/pip beforehand 'albumentations', - 'bioimageio.core', + 'bioimageio.core>=0.6.0', 'tifffile', 'psutil', 'pydantic>=2.5', - 'pytorch_lightning', + 'pytorch_lightning>=2.2.0', 'pyyaml', 'scikit-image', 'zarr', @@ -166,11 +167,14 @@ ignore = [ [tool.numpydoc_validation] checks = [ "all", # report on all checks, except the below - "EX01", - "SA01", - "ES01", - "GL02", - "GL03", + "EX01", # Example section not found + "SA01", # See Also section not found + "ES01", # Extended Summar not found + "GL01", # Docstring text (summary) should start in the line immediately + # after the opening quotes + "GL02", # Closing quotes should be placed in the line after the last text + # in the docstring + "GL03", # Double line break found ] exclude = [ # don't report on objects that match any of these regex "test_*", diff --git a/src/careamics/__init__.py b/src/careamics/__init__.py index 86904d1d3..222aff45e 100644 --- a/src/careamics/__init__.py +++ b/src/careamics/__init__.py @@ -19,8 +19,6 @@ from .careamist import CAREamist from .config import Configuration, load_configuration, save_configuration -from .lightning_datamodule import ( - CAREamicsPredictDataModule, - CAREamicsTrainDataModule, -) +from .lightning_datamodule import CAREamicsTrainDataModule from .lightning_module import CAREamicsModule +from .lightning_prediction_datamodule import CAREamicsPredictDataModule diff --git a/src/careamics/callbacks/__init__.py b/src/careamics/callbacks/__init__.py index a7c04f454..74a82442d 100644 --- a/src/careamics/callbacks/__init__.py +++ b/src/careamics/callbacks/__init__.py @@ -1,5 +1,6 @@ -"""Callback module.""" +"""Callbacks module.""" -__all__ = ["ProgressBarCallback"] +__all__ = ["HyperParametersCallback", "ProgressBarCallback"] -from .progress_bar import ProgressBarCallback +from .hyperparameters_callback import HyperParametersCallback +from .progress_bar_callback import ProgressBarCallback diff --git a/src/careamics/callbacks/hyperparameters_callback.py b/src/careamics/callbacks/hyperparameters_callback.py new file mode 100644 index 000000000..d06090770 --- /dev/null +++ b/src/careamics/callbacks/hyperparameters_callback.py @@ -0,0 +1,42 @@ +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import Callback + +from careamics.config import Configuration + + +class HyperParametersCallback(Callback): + """ + Callback allowing saving CAREamics configuration as hyperparameters in the model. + + This allows saving the configuration as dictionnary in the checkpoints, and + loading it subsequently in a CAREamist instance. + + Attributes + ---------- + config : Configuration + CAREamics configuration to be saved as hyperparameter in the model. + """ + + def __init__(self, config: Configuration): + """ + Constructor. + + Parameters + ---------- + config : Configuration + CAREamics configuration to be saved as hyperparameter in the model. + """ + self.config = config + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule): + """ + Update the hyperparameters of the model with the configuration on train start. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer. + pl_module : LightningModule + PyTorch Lightning module. + """ + pl_module.hparams.update(self.config.model_dump()) diff --git a/src/careamics/callbacks/progress_bar.py b/src/careamics/callbacks/progress_bar_callback.py similarity index 93% rename from src/careamics/callbacks/progress_bar.py rename to src/careamics/callbacks/progress_bar_callback.py index 1ac6861aa..d7862091c 100644 --- a/src/careamics/callbacks/progress_bar.py +++ b/src/careamics/callbacks/progress_bar_callback.py @@ -12,7 +12,7 @@ class ProgressBarCallback(TQDMProgressBar): def init_train_tqdm(self) -> tqdm: """Override this to customize the tqdm bar for training.""" bar = tqdm( - desc='Training', + desc="Training", position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -27,12 +27,12 @@ def init_validation_tqdm(self) -> tqdm: # The main progress bar doesn't exist in `trainer.validate()` has_main_bar = self.train_progress_bar is not None bar = tqdm( - desc='Validating', + desc="Validating", position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, - file=sys.stdout + file=sys.stdout, ) return bar @@ -45,7 +45,7 @@ def init_test_tqdm(self) -> tqdm: leave=True, dynamic_ncols=False, ncols=100, - file=sys.stdout + file=sys.stdout, ) return bar @@ -55,4 +55,3 @@ def get_metrics( """Override this to customize the metrics displayed in the progress bar.""" pbar_metrics = trainer.progress_bar_metrics return {**pbar_metrics} - diff --git a/src/careamics/careamist.py b/src/careamics/careamist.py index 2cba131b8..fe5f0e4b1 100644 --- a/src/careamics/careamist.py +++ b/src/careamics/careamist.py @@ -1,10 +1,10 @@ -"""Main class to train and predict with CAREamics models.""" +"""A class to train, predict and export models in CAREamics.""" from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, overload import numpy as np -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ( Callback, EarlyStopping, @@ -19,25 +19,35 @@ load_configuration, ) from careamics.config.inference_model import TRANSFORMS_UNION -from careamics.config.support import SupportedAlgorithm, SupportedLogger -from careamics.model_io import load_pretrained -from careamics.lightning_datamodule import CAREamicsClay, CAREamicsWood +from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger +from careamics.lightning_datamodule import CAREamicsWood from careamics.lightning_module import CAREamicsKiln -from careamics.lightning_prediction import CAREamicsPredictionLoop +from careamics.lightning_prediction_datamodule import CAREamicsClay +from careamics.lightning_prediction_loop import CAREamicsPredictionLoop +from careamics.model_io import export_to_bmz, load_pretrained from careamics.utils import check_path_exists, get_logger +from .callbacks import HyperParametersCallback + logger = get_logger(__name__) LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]] + # TODO napari callbacks -# TODO save as modelzoo, lightning and pytorch_dict # TODO: how to do AMP? How to continue training? +class CAREamist: + """Main CAREamics class, allowing training and prediction using various algorithms. - -class CAREamist(LightningModule): - """ - Main CAREamics class, allowing training and prediction using various algorithms. + Parameters + ---------- + source : Union[Path, str, Configuration] + Path to a configuration file or a trained model. + work_dir : Optional[str], optional + Path to working directory in which to save checkpoints and logs, + by default None. + experiment_name : str, optional + Experiment name used for checkpoints, by default "CAREamics". Attributes ---------- @@ -51,16 +61,10 @@ class CAREamist(LightningModule): Experiment logger, "wandb" or "tensorboard". work_dir : Path Working directory. - - Parameters - ---------- - source : Union[Path, str, Configuration] - Path to a configuration file or a trained model. - work_dir : Optional[str], optional - Path to working directory in which to save checkpoints and logs, - by default None. - experiment_name : str, optional - Experiment name used for checkpoints, by default "CAREamics". + train_datamodule : Optional[CAREamicsWood] + Training datamodule. + pred_datamodule : Optional[CAREamicsClay] + Prediction datamodule. """ @overload @@ -118,7 +122,6 @@ def __init__( If no hyper parameters are found in the checkpoint. ValueError If no data module hyper parameters are found in the checkpoint. - """ super().__init__() @@ -135,11 +138,11 @@ def __init__( # configuration object if isinstance(source, Configuration): self.cfg = source - self.save_hyperparameters(self.cfg.model_dump()) # instantiate model - self.model = CAREamicsKiln(self.cfg.algorithm_config) - self.model.hparams.update(self.cfg.model_dump()) + self.model = CAREamicsKiln( + algorithm_config=self.cfg.algorithm_config, + ) # path to configuration file or model else: @@ -152,12 +155,10 @@ def __init__( # load configuration self.cfg = load_configuration(source) - # save configuration in the working directory - # TODO Ugly, think of a better way to save the configuration - self.save_hyperparameters(self.cfg.model_dump()) - # instantiate model - self.model = CAREamicsKiln(self.cfg.algorithm_config) + self.model = CAREamicsKiln( + algorithm_config=self.cfg.algorithm_config, + ) # attempt loading a pre-trained model else: @@ -166,20 +167,16 @@ def __init__( # define the checkpoint saving callback self.callbacks = self._define_callbacks() - # torch.set_float32_matmul_precision('medium') - # instantiate logger if self.cfg.training_config.has_logger(): if self.cfg.training_config.logger == SupportedLogger.WANDB: self.experiment_logger: LOGGER_TYPES = WandbLogger( name=experiment_name, save_dir=self.work_dir / Path("logs"), - # **self.cfg.logger.model_dump(), ) elif self.cfg.training_config.logger == SupportedLogger.TENSORBOARD: self.experiment_logger = TensorBoardLogger( save_dir=self.work_dir / Path("logs"), - # **self.cfg.logger.model_dump(), ) else: self.experiment_logger = None @@ -190,12 +187,15 @@ def __init__( callbacks=self.callbacks, default_root_dir=self.work_dir, logger=self.experiment_logger, - # precision="bf16" ) # change the prediction loop, necessary for tiled prediction self.trainer.predict_loop = CAREamicsPredictionLoop(self.trainer) + # place holder for the datamodules + self.train_datamodule: Optional[CAREamicsWood] = None + self.pred_datamodule: Optional[CAREamicsClay] = None + def _define_callbacks(self) -> List[Callback]: """ Define the callbacks for the training loop. @@ -207,6 +207,7 @@ def _define_callbacks(self) -> List[Callback]: """ # checkpoint callback saves checkpoints during training self.callbacks = [ + HyperParametersCallback(self.cfg), ModelCheckpoint( dirpath=self.work_dir / Path("checkpoints"), filename=self.cfg.experiment_name, @@ -223,10 +224,6 @@ def _define_callbacks(self) -> List[Callback]: return self.callbacks - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """Save the configuration in the checkpoint.""" - checkpoint["cfg"] = self.cfg.model_dump() - def train( self, *, @@ -372,6 +369,9 @@ def _train_on_datamodule(self, datamodule: CAREamicsWood) -> None: datamodule : CAREamicsWood Datamodule to train on. """ + # record datamodule + self.train_datamodule = datamodule + self.trainer.fit(self.model, datamodule=datamodule) def _train_on_array( @@ -474,7 +474,10 @@ def _train_on_path( @overload def predict( # numpydoc ignore=GL08 - self, source: CAREamicsClay + self, + source: CAREamicsClay, + *, + checkpoint: Optional[Literal["best", "last"]] = None, ) -> Union[list, np.ndarray]: ... @@ -585,8 +588,12 @@ def predict( If the input is not a CAREamicsClay instance, a path or a numpy array. """ if isinstance(source, CAREamicsClay): - return self.trainer.predict(datamodule=source) + # record datamodule + self.pred_datamodule = source + return self.trainer.predict( + model=self.model, datamodule=source, ckpt_path=checkpoint + ) else: if self.cfg is None: raise ValueError( @@ -624,6 +631,9 @@ def predict( dataloader_params=dataloader_params, ) + # record datamodule + self.pred_datamodule = datamodule + return self.trainer.predict( model=self.model, datamodule=datamodule, ckpt_path=checkpoint ) @@ -636,6 +646,9 @@ def predict( dataloader_params=dataloader_params, ) + # record datamodule + self.pred_datamodule = datamodule + return self.trainer.predict( model=self.model, datamodule=datamodule, ckpt_path=checkpoint ) @@ -646,32 +659,99 @@ def predict( f"np.ndarray (got {type(source)})." ) - def export_model( - self, path: Union[Path, str], type: Literal["bmz", "script"] = "bmz" + def export_to_bmz( + self, + path: Union[Path, str], + name: str, + authors: List[dict], + input_array: Optional[np.ndarray] = None, + general_description: str = "", + channel_names: Optional[List[str]] = None, + data_description: Optional[str] = None, ) -> None: - """ - Export the model to the BioImage Model Zoo or torchscript format. + """Export the model to the BioImage Model Zoo format. + + Input array must be of shape SC(Z)YX, with S and C singleton dimensions. Parameters ---------- path : Union[Path, str] Path to save the model. - type : Literal["bmz", "script"], optional - Export format, by default "bmz". - - Raises - ------ - NotImplementedError - If the export format is not implemented yet. + name : str + Name of the model. + authors : List[dict] + List of authors of the model. + input_array : Optional[np.ndarray], optional + Input array for the model, must be of shape SC(Z)YX, by default None. + general_description : str + General description of the model, used in the metadata of the BMZ archive. + channel_names : Optional[List[str]], optional + Channel names, by default None. + data_description : Optional[str], optional + Description of the data, by default None. """ - path = Path(path) - if type == "bmz": - raise NotImplementedError( - "Exporting a model to BioImage Model Zoo is not implemented yet." - ) - elif type == "script": - self.model.to_torchscript(path) + if input_array is None: + # generate images, priority is given to the prediction data module + if self.pred_datamodule is not None: + # unpack a batch, ignore masks or targets + input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader())) + + # convert torch.Tensor to numpy + input_patch = input_patch.numpy() + elif self.train_datamodule is not None: + input_patch, *_ = next(iter(self.train_datamodule.train_dataloader())) + input_patch = input_patch.numpy() + else: + if ( + self.cfg.data_config.mean is None + or self.cfg.data_config.std is None + ): + raise ValueError( + "Mean and std cannot be None in the configuration in order to" + "export to the BMZ format. Was the model trained?" + ) + + # create a random input array + input_patch = np.random.normal( + loc=self.cfg.data_config.mean, + scale=self.cfg.data_config.std, + size=self.cfg.data_config.patch_size, + ).astype(np.float32)[ + np.newaxis, np.newaxis, ... + ] # add S & C dimensions else: + input_patch = input_array + + # if there is a batch dimension + if input_patch.shape[0] > 1: + input_patch = input_patch[0:1, ...] # keep singleton dim + + # axes need to be without S + axes = self.cfg.data_config.axes.replace("S", "") + + # predict output, remove extra dimensions for the purpose of the prediction + output_patch = self.predict( + input_patch.squeeze(), + data_type=SupportedData.ARRAY.value, + axes=axes, + tta_transforms=False, + ) + + if not isinstance(output_patch, np.ndarray): raise ValueError( - f"Invalid export format. Expected 'bmz' or 'script', got {type}." + f"Numpy array required for export to BioImage Model Zoo, got " + f"{type(output_patch)}." ) + + export_to_bmz( + model=self.model, + config=self.cfg, + path=path, + name=name, + general_description=general_description, + authors=authors, + input_array=input_patch, + output_array=output_patch, + channel_names=channel_names, + data_description=data_description, + ) diff --git a/src/careamics/config/__init__.py b/src/careamics/config/__init__.py index a551675cb..9ee70b7ee 100644 --- a/src/careamics/config/__init__.py +++ b/src/careamics/config/__init__.py @@ -15,6 +15,7 @@ "CustomModel", "create_inference_configuration", "clear_custom_models", + "ConfigurationInformation", ] from .algorithm_model import AlgorithmModel diff --git a/src/careamics/config/algorithm_model.py b/src/careamics/config/algorithm_model.py index 01d120513..b77670095 100644 --- a/src/careamics/config/algorithm_model.py +++ b/src/careamics/config/algorithm_model.py @@ -75,10 +75,8 @@ class AlgorithmModel(BaseModel): ... "model": { ... "architecture": "Custom", ... "name": "linear_model", - ... "parameters": { - ... "in_features": 10, - ... "out_features": 5, - ... }, + ... "in_features": 10, + ... "out_features": 5, ... } ... } >>> config = AlgorithmModel(**config_dict) diff --git a/src/careamics/config/architectures/__init__.py b/src/careamics/config/architectures/__init__.py index 11cab73eb..c65d97f09 100644 --- a/src/careamics/config/architectures/__init__.py +++ b/src/careamics/config/architectures/__init__.py @@ -1,6 +1,7 @@ """Deep-learning model configurations.""" __all__ = [ + "ArchitectureModel", "CustomModel", "UNetModel", "VAEModel", @@ -9,6 +10,7 @@ "register_model", ] +from .architecture_model import ArchitectureModel from .custom_model import CustomModel from .register_model import clear_custom_models, get_custom_model, register_model from .unet_model import UNetModel diff --git a/src/careamics/config/architectures/architecture_model.py b/src/careamics/config/architectures/architecture_model.py new file mode 100644 index 000000000..28113112c --- /dev/null +++ b/src/careamics/config/architectures/architecture_model.py @@ -0,0 +1,29 @@ +from typing import Any, Dict + +from pydantic import BaseModel + + +class ArchitectureModel(BaseModel): + """ + Base Pydantic model for all model architectures. + + The `model_dump` method allows removing the `architecture` key from the model. + """ + + architecture: str + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + """ + Dump the model as a dictionary, ignoring the architecture keyword. + + Returns + ------- + dict[str, Any] + Model as a dictionnary. + """ + model_dict = super().model_dump(**kwargs) + + # remove the architecture key + model_dict.pop("architecture") + + return model_dict diff --git a/src/careamics/config/architectures/custom_model.py b/src/careamics/config/architectures/custom_model.py index 647175e3c..56e9b753b 100644 --- a/src/careamics/config/architectures/custom_model.py +++ b/src/careamics/config/architectures/custom_model.py @@ -1,21 +1,16 @@ from __future__ import annotations from pprint import pformat -from typing import Literal +from typing import Any, Dict, Literal -from pydantic import BaseModel, ConfigDict, field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from torch.nn import Module +from .architecture_model import ArchitectureModel from .register_model import get_custom_model -class CustomParametersModel(BaseModel): - """A Pydantic model that allows any parameter.""" - - model_config = ConfigDict(extra="allow") - - -class CustomModel(BaseModel): +class CustomModel(ArchitectureModel): """Custom model configuration. This Pydantic model allows storing parameters for a custom model. In order for the @@ -60,17 +55,17 @@ class CustomModel(BaseModel): >>> # Create a configuration >>> config_dict = { ... "architecture": "Custom", - ... "name": "linear", - ... "parameters": { - ... "in_features": 10, - ... "out_features": 5, - ... }, + ... "name": "my_linear", + ... "in_features": 10, + ... "out_features": 5, ... } >>> config = CustomModel(**config_dict) """ # pydantic model config - model_config = ConfigDict(validate_assignment=True) + model_config = ConfigDict( + extra="allow", + ) # discriminator used for choosing the pydantic model in Model architecture: Literal["Custom"] @@ -78,9 +73,6 @@ class CustomModel(BaseModel): # name of the custom model name: str - # parameters - parameters: CustomParametersModel - @field_validator("name") @classmethod def custom_model_is_known(cls, value: str) -> str: @@ -115,12 +107,13 @@ def check_parameters(self: CustomModel) -> CustomModel: """ # instantiate model try: - get_custom_model(self.name)(**self.parameters.model_dump()) + get_custom_model(self.name)(**self.model_dump()) except Exception as e: raise ValueError( - f"error while passing parameters to the model: {e}. Verify that all " - f"mandatory parameters are provided, and that either the model accepts " - f"*args and **kwargs, or that no additional parameter is provided." + f"error while passing parameters to the model {e}. Verify that all " + f"mandatory parameters are provided, and that either the {e} accepts " + f"*args and **kwargs in its __init__() method, or that no additional" + f"parameter is provided." ) from None return self @@ -134,3 +127,23 @@ def __str__(self) -> str: Pretty string. """ return pformat(self.model_dump()) + + def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + """Dump the model configuration. + + Parameters + ---------- + kwargs : Any + Additional keyword arguments from Pydantic BaseModel model_dump method. + + Returns + ------- + Dict[str, Any] + Model configuration. + """ + model_dict = super().model_dump() + + # remove the name key + model_dict.pop("name") + + return model_dict diff --git a/src/careamics/config/architectures/register_model.py b/src/careamics/config/architectures/register_model.py index 6e62cbce1..f35b0b88c 100644 --- a/src/careamics/config/architectures/register_model.py +++ b/src/careamics/config/architectures/register_model.py @@ -38,6 +38,9 @@ def forward(self, input): return (input @ self.weight) + self.bias ``` """ + if name is None or name == "": + raise ValueError("Model name cannot be empty.") + if name in CUSTOM_MODELS: raise ValueError( f"Model {name} already exists. Choose a different name or run " diff --git a/src/careamics/config/architectures/unet_model.py b/src/careamics/config/architectures/unet_model.py index 13178731d..9a032e2b1 100644 --- a/src/careamics/config/architectures/unet_model.py +++ b/src/careamics/config/architectures/unet_model.py @@ -2,12 +2,14 @@ from typing import Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, field_validator + +from .architecture_model import ArchitectureModel # TODO tests activation <-> pydantic model, test the literals! # TODO annotations for the json schema? -class UNetModel(BaseModel): +class UNetModel(ArchitectureModel): """ Pydantic model for a N2V(2)-compatible UNet. diff --git a/src/careamics/config/architectures/vae_model.py b/src/careamics/config/architectures/vae_model.py index 608bf4f06..03c7eb60b 100644 --- a/src/careamics/config/architectures/vae_model.py +++ b/src/careamics/config/architectures/vae_model.py @@ -1,12 +1,13 @@ from typing import Literal from pydantic import ( - BaseModel, ConfigDict, ) +from .architecture_model import ArchitectureModel -class VAEModel(BaseModel): + +class VAEModel(ArchitectureModel): """VAE model placeholder.""" model_config = ConfigDict( diff --git a/src/careamics/config/configuration_factory.py b/src/careamics/config/configuration_factory.py index 16b5b66c8..012cd65b2 100644 --- a/src/careamics/config/configuration_factory.py +++ b/src/careamics/config/configuration_factory.py @@ -29,10 +29,6 @@ def create_n2n_configuration( use_augmentations: bool = True, use_n2v2: bool = False, n_channels: int = 1, - roi_size: int = 11, - masked_pixel_percentage: float = 0.2, - struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none", - struct_n2v_span: int = 5, logger: Literal["wandb", "tensorboard", "none"] = "none", model_kwargs: Optional[dict] = None, ) -> Configuration: @@ -360,15 +356,13 @@ def create_n2v_configuration( # n2v2 and structn2v nv2_transform = { "name": SupportedTransform.N2V_MANIPULATE.value, - "parameters": { - "strategy": SupportedPixelManipulation.MEDIAN.value - if use_n2v2 - else SupportedPixelManipulation.UNIFORM.value, - "roi_size": roi_size, - "masked_pixel_percentage": masked_pixel_percentage, - "struct_mask_axis": struct_n2v_axis, - "struct_mask_span": struct_n2v_span, - }, + "strategy": SupportedPixelManipulation.MEDIAN.value + if use_n2v2 + else SupportedPixelManipulation.UNIFORM.value, + "roi_size": roi_size, + "masked_pixel_percentage": masked_pixel_percentage, + "struct_mask_axis": struct_n2v_axis, + "struct_mask_span": struct_n2v_span, } transforms.append(nv2_transform) @@ -403,9 +397,7 @@ def create_n2v_configuration( def create_inference_configuration( training_configuration: Configuration, tile_size: Optional[Tuple[int, ...]] = None, - tile_overlap: Tuple[int, ...] = (48, 48), - mean: Optional[float] = None, - std: Optional[float] = None, + tile_overlap: Optional[Tuple[int, ...]] = None, data_type: Optional[Literal["array", "tiff", "custom"]] = None, axes: Optional[str] = None, transforms: Optional[Union[List[Dict[str, Any]], Compose]] = None, @@ -415,7 +407,7 @@ def create_inference_configuration( """ Create a configuration for inference with N2V. - If not provided, `data_type`, `tile_size`, and `axes` are taken from the training + If not provided, `data_type` and `axes` are taken from the training configuration. If `transforms` are not provided, only normalization is applied. Parameters @@ -442,6 +434,12 @@ def create_inference_configuration( InferenceConfiguration Configuration for inference with N2V. """ + if ( + training_configuration.data_config.mean is None + or training_configuration.data_config.std is None + ): + raise ValueError("Mean and std must be provided in the training configuration.") + if transforms is None: transforms = [ { @@ -451,11 +449,11 @@ def create_inference_configuration( return InferenceModel( data_type=data_type or training_configuration.data_config.data_type, - tile_size=tile_size or training_configuration.data_config.patch_size, + tile_size=tile_size, tile_overlap=tile_overlap, axes=axes or training_configuration.data_config.axes, - mean=mean or training_configuration.data_config.mean, - std=std or training_configuration.data_config.std, + mean=training_configuration.data_config.mean, + std=training_configuration.data_config.std, transforms=transforms, tta_transforms=tta_transforms, batch_size=batch_size, diff --git a/src/careamics/config/configuration_model.py b/src/careamics/config/configuration_model.py index 5570ad42b..4c8e5dca8 100644 --- a/src/careamics/config/configuration_model.py +++ b/src/careamics/config/configuration_model.py @@ -7,10 +7,31 @@ from typing import Dict, List, Literal, Union import yaml +from bioimageio.spec.generic.v0_3 import CiteEntry from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from .algorithm_model import AlgorithmModel from .data_model import DataModel +from .references import ( + CARE, + CUSTOM, + N2N, + N2V, + N2V2, + STRUCT_N2V, + STRUCT_N2V2, + CAREDescription, + CARERef, + N2NDescription, + N2NRef, + N2V2Description, + N2V2Ref, + N2VDescription, + N2VRef, + StructN2V2Description, + StructN2VDescription, + StructN2VRef, +) from .support import SupportedAlgorithm, SupportedPixelManipulation, SupportedTransform from .training_model import TrainingModel from .transformations.n2v_manipulate_model import ( @@ -224,8 +245,7 @@ def validate_algorithm_and_data(self: Configuration) -> Configuration: name=SupportedTransform.N2V_MANIPULATE.value, ) ) - # TODO Doesn't validate the parameters of N2VManipulate !! - # make sure that N2V has the correct pixel manipulate strategy + median = SupportedPixelManipulation.MEDIAN.value uniform = SupportedPixelManipulation.UNIFORM.value strategy = median if self.algorithm_config.model.n2v2 else uniform @@ -320,26 +340,30 @@ def get_algorithm_flavour(self) -> str: """ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = ( - self.data_config.transforms[-1].parameters.struct_mask_axis != "none" - ) + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" # return the n2v flavour if use_n2v2 and use_structN2V: - return "StructN2V2" + return STRUCT_N2V2 elif use_n2v2: - return "N2V2" + return N2V2 elif use_structN2V: - return "StructN2V" + return STRUCT_N2V else: - return "Noise2Void" - - return self.algorithm_config.algorithm.capitalize() + return N2V + elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N: + return N2N + elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE: + return CARE + else: + return CUSTOM def get_algorithm_description(self) -> str: """ Return a description of the algorithm. + This method is used to generate the README of the BioImage Model Zoo export. + Returns ------- str @@ -347,82 +371,61 @@ def get_algorithm_description(self) -> str: """ algorithm_flavour = self.get_algorithm_flavour() - if algorithm_flavour == "Custom": + if algorithm_flavour == CUSTOM: return f"Custom algorithm, named {self.algorithm_config.model.name}" else: # currently only N2V flavours - if algorithm_flavour == "Noise2Void": - return ( - "Noise2Void is a UNet-based self-supervised algorithm that uses " - "blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels)." - ) - elif algorithm_flavour == "N2V2": - return ( - "N2V2 is an iteration of Noise2Void. " - "Noise2Void is a UNet-based self-supervised algorithm that uses " - "blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels). " - "N2V2 introduces blur-pool layers and removed skip connections in " - "the UNet architecture to remove checkboard artefacts, a common " - "artefacts ocurring in Noise2Void." - ) - elif algorithm_flavour == "StructN2V": - return ( - "StructN2V is an iteration of Noise2Void. " - "Noise2Void is a UNet-based self-supervised algorithm that uses " - "blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels). " - "StructN2V uses a linear mask (horizontal or vertical) to replace " - "the pixel values of neighbors of the masked pixels by a random " - "value. Such masking allows removing 1D structured noise from the " - "the images, the main failure case of the original N2V." - ) - elif algorithm_flavour == "StructN2V2": - return ( - "StructN2V2 is an iteration of Noise2Void that uses both " - "structN2V and N2V2 ." - "Noise2Void is a UNet-based self-supervised algorithm that uses " - "blind-spot training to denoise images. In short, in every " - "patches during training, random pixels are selected and their " - "value replaced by a neighboring pixel value. The network is then " - "trained to predict the original pixel value. The algorithm " - "relies on the continuity of the signal (neighboring pixels have " - "similar values) and the pixel-wise independence of the noise " - "(the noise in a pixel is not correlated with the noise in " - "neighboring pixels). " - "StructN2V uses a linear mask (horizontal or vertical) to replace " - "the pixel values of neighbors of the masked pixels by a random " - "value. Such masking allows removing 1D structured noise from the " - "the images, the main failure case of the original N2V." - "N2V2 introduces blur-pool layers and removed skip connections in " - "the UNet architecture to remove checkboard artefacts, a common " - "artefacts ocurring in Noise2Void." - ) + if algorithm_flavour == N2V: + return N2VDescription().description + elif algorithm_flavour == N2V2: + return N2V2Description().description + elif algorithm_flavour == STRUCT_N2V: + return StructN2VDescription().description + elif algorithm_flavour == STRUCT_N2V2: + return StructN2V2Description().description + elif algorithm_flavour == N2N: + return N2NDescription().description + elif algorithm_flavour == CARE: + return CAREDescription().description return "" + def get_algorithm_citations(self) -> List[CiteEntry]: + """ + Return a list of citation entries of the current algorithm. + + This is used to generate the model description for the BioImage Model Zoo. + + Returns + ------- + List[CiteEntry] + List of citation entries. + """ + if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: + use_n2v2 = self.algorithm_config.model.n2v2 + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" + + # return the (struct)N2V(2) references + if use_n2v2 and use_structN2V: + return [N2VRef, N2V2Ref, StructN2VRef] + elif use_n2v2: + return [N2VRef, N2V2Ref] + elif use_structN2V: + return [N2VRef, StructN2VRef] + else: + return [N2VRef] + elif self.algorithm_config.algorithm == SupportedAlgorithm.N2N: + return [N2NRef] + elif self.algorithm_config.algorithm == SupportedAlgorithm.CARE: + return [CARERef] + + raise ValueError("Citation not available for custom algorithm.") + def get_algorithm_references(self) -> str: """ Get the algorithm references. + This is used to generate the README of the BioImage Model Zoo export. + Returns ------- str @@ -430,27 +433,12 @@ def get_algorithm_references(self) -> str: """ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = ( - self.data_config.transforms[-1].parameters.struct_mask_axis != "none" - ) + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" references = [ - 'Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning ' - 'denoising from single noisy images". In Proceedings of the IEEE/CVF ' - "conference on computer vision and pattern recognition (pp. " - "2129-2137). doi: " - "[10.1109/cvpr.2019.00223](https://doi.org/10.1109/cvpr.2019.00223)\n", - "Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., " - '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified ' - 'sampling strategies and a tweaked network architecture". In European ' - "Conference on Computer Vision (pp. 503-518). doi: " - "[10.1007/978-3-031-25069-9_33](https://doi.org/10.1007/978-3-031-" - "25069-9_33)\n", - "Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020" - '. "Removing structured noise with self-supervised blind-spot ' - 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' - "Imaging (ISBI) (pp. 159-163). doi: [10.1109/isbi45749.2020.9098336](" - "https://doi.org/10.1109/isbi45749.2020.9098336)\n", + N2VRef.text + " doi: " + N2VRef.doi, + N2V2Ref.text + " doi: " + N2V2Ref.doi, + StructN2VRef.text + " doi: " + StructN2VRef.doi, ] # return the (struct)N2V(2) references @@ -478,9 +466,7 @@ def get_algorithm_keywords(self) -> List[str]: """ if self.algorithm_config.algorithm == SupportedAlgorithm.N2V: use_n2v2 = self.algorithm_config.model.n2v2 - use_structN2V = ( - self.data_config.transforms[-1].parameters.struct_mask_axis != "none" - ) + use_structN2V = self.data_config.transforms[-1].struct_mask_axis != "none" keywords = [ "denoising", @@ -489,13 +475,13 @@ def get_algorithm_keywords(self) -> List[str]: "3D" if "Z" in self.data_config.axes else "2D", "CAREamics", "pytorch", - "Noise2Void", + N2V, ] if use_n2v2: - keywords.append("N2V2") + keywords.append(N2V2) if use_structN2V: - keywords.append("StructN2V2") + keywords.append(STRUCT_N2V) else: keywords = ["CAREamics"] @@ -529,9 +515,6 @@ def model_dump( exclude_none=exclude_none, exclude_defaults=exclude_defaults, **kwargs ) - # change Path into str - # dictionary["working_directory"] = str(dictionary["working_directory"]) - return dictionary diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index 8e2782cd1..53255d784 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -2,7 +2,7 @@ from __future__ import annotations from pprint import pformat -from typing import Any, List, Literal, Optional, Union +from typing import Any, List, Literal, Optional, Tuple, Union from albumentations import Compose from pydantic import ( @@ -15,13 +15,12 @@ ) from typing_extensions import Annotated -from careamics.utils import check_axes_validity - from .support import SupportedTransform from .transformations.n2v_manipulate_model import N2VManipulateModel from .transformations.nd_flip_model import NDFlipModel from .transformations.normalize_model import NormalizeModel from .transformations.xy_random_rotate90_model import XYRandomRotate90Model +from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 TRANSFORMS_UNION = Annotated[ Union[ @@ -34,8 +33,6 @@ ] -# TODO can we check whether N2V manipulate is in a Compose? -# TODO does patches need to be multiple of 8 with UNet? class DataModel(BaseModel): """ Data configuration. @@ -56,7 +53,7 @@ class DataModel(BaseModel): ... ) To change the mean and std of the data: - >>> data.set_mean_and_std(mean=0., std=1.) + >>> data.set_mean_and_std(mean=214.3, std=84.5) One can pass also a list of transformations, by keyword, using the SupportedTransform or the name of an Albumentation transform: @@ -69,10 +66,13 @@ class DataModel(BaseModel): ... transforms=[ ... { ... "name": SupportedTransform.NORMALIZE.value, + ... "mean": 167.6, + ... "std": 47.2, ... }, ... { ... "name": "NDFlip", - ... "parameters": {"is_3D": True, "flip_Z": True} + ... "is_3D": True, + ... "flip_z": True, ... } ... ] ... ) @@ -86,10 +86,10 @@ class DataModel(BaseModel): # Dataset configuration data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData - patch_size: List[int] = Field(..., min_length=2, max_length=3) - batch_size: int = Field( - default=1, ge=1, validate_default=True - ) # TODO Differentiate based on Train/inf ? + patch_size: Union[List[int], Tuple[int, ...]] = Field( + ..., min_length=2, max_length=3 + ) + batch_size: int = Field(default=1, ge=1, validate_default=True) axes: str # Optional fields @@ -114,41 +114,39 @@ class DataModel(BaseModel): validate_default=True, ) - dataloader_params: Optional[dict] = None # TODO validate ? + dataloader_params: Optional[dict] = None @field_validator("patch_size") @classmethod - def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: + def all_elements_power_of_2_minimum_8( + cls, patch_list: Union[List[int], Tuple[int, ...]] + ) -> Union[List[int], Tuple[int, ...]]: """ Validate patch size. - Patch size must be non-zero, positive and even. + Patch size must be powers of 2 and minimum 8. Parameters ---------- - patch_list : List[int] + patch_list : Union[List[int], Tuple[int, ...]] Patch size. Returns ------- - List[int] + Union[List[int], Tuple[int, ...]] Validated patch size. Raises ------ ValueError - If the patch size is 0. + If the patch size is smaller than 8. ValueError - If the patch size is not even. + If the patch size is not a power of 2. """ - for dim in patch_list: - if dim < 1: - raise ValueError(f"Patch size must be non-zero positive (got {dim}).") + patch_validated = patch_size_ge_than_8_power_of_2(patch_list) + assert patch_validated is not None, "Patch cannot be None." - if dim % 2 != 0: - raise ValueError(f"Patch size must be even (got {dim}).") - - return patch_list + return patch_validated @field_validator("axes") @classmethod @@ -188,7 +186,8 @@ def axes_valid(cls, axes: str) -> str: def validate_prediction_transforms( cls, transforms: Union[List[TRANSFORMS_UNION], Compose] ) -> Union[List[TRANSFORMS_UNION], Compose]: - """Validate N2VManipulate transform position in the transform list. + """ + Validate N2VManipulate transform position in the transform list. Parameters ---------- @@ -273,8 +272,8 @@ def add_std_and_mean_to_normalize(cls, data_model: DataModel) -> DataModel: if data_model.has_transform_list(): for transform in data_model.transforms: if transform.name == SupportedTransform.NORMALIZE.value: - transform.parameters.mean = data_model.mean - transform.parameters.std = data_model.std + transform.mean = data_model.mean + transform.std = data_model.std return data_model @@ -308,9 +307,9 @@ def validate_dimensions(cls, data_model: DataModel) -> DataModel: if data_model.has_transform_list(): for transform in data_model.transforms: if transform.name == SupportedTransform.NDFLIP: - transform.parameters.is_3D = True + transform.is_3D = True elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90: - transform.parameters.is_3D = True + transform.is_3D = True else: if len(data_model.patch_size) != 2: @@ -322,14 +321,15 @@ def validate_dimensions(cls, data_model: DataModel) -> DataModel: if data_model.has_transform_list(): for transform in data_model.transforms: if transform.name == SupportedTransform.NDFLIP: - transform.parameters.is_3D = False + transform.is_3D = False elif transform.name == SupportedTransform.XY_RANDOM_ROTATE90: - transform.parameters.is_3D = False + transform.is_3D = False return data_model def __str__(self) -> str: - """Pretty string reprensenting the configuration. + """ + Pretty string reprensenting the configuration. Returns ------- @@ -339,7 +339,14 @@ def __str__(self) -> str: return pformat(self.model_dump()) def _update(self, **kwargs: Any) -> None: - """Update multiple arguments at once.""" + """ + Update multiple arguments at once. + + Parameters + ---------- + **kwargs : Any + Keyword arguments to update. + """ self.__dict__.update(kwargs) self.__class__.model_validate(self.__dict__) @@ -442,8 +449,8 @@ def set_mean_and_std(self, mean: float, std: float) -> None: if self.has_transform_list(): for transform in self.transforms: if transform.name == SupportedTransform.NORMALIZE.value: - transform.parameters.mean = mean - transform.parameters.std = std + transform.mean = mean + transform.std = std else: raise ValueError( "Setting mean and std with Compose transforms is not allowed. Add " @@ -464,7 +471,8 @@ def set_3D(self, axes: str, patch_size: List[int]) -> None: self._update(axes=axes, patch_size=patch_size) def set_N2V2(self, use_n2v2: bool) -> None: - """Set N2V2. + """ + Set N2V2. Parameters ---------- @@ -484,7 +492,8 @@ def set_N2V2(self, use_n2v2: bool) -> None: self.set_N2V2_strategy("uniform") def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None: - """Set N2V2 strategy. + """ + Set N2V2 strategy. Parameters ---------- @@ -503,7 +512,7 @@ def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None: for transform in self.transforms: if transform.name == SupportedTransform.N2V_MANIPULATE.value: - transform.parameters.strategy = strategy + transform.strategy = strategy found_n2v = True if not found_n2v: @@ -522,7 +531,8 @@ def set_N2V2_strategy(self, strategy: Literal["uniform", "median"]) -> None: def set_structN2V_mask( self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int ) -> None: - """Set structN2V mask parameters. + """ + Set structN2V mask parameters. Setting `mask_axis` to `none` will disable structN2V. @@ -545,8 +555,8 @@ def set_structN2V_mask( for transform in self.transforms: if transform.name == SupportedTransform.N2V_MANIPULATE.value: - transform.parameters.struct_mask_axis = mask_axis - transform.parameters.struct_mask_span = mask_span + transform.struct_mask_axis = mask_axis + transform.struct_mask_span = mask_span found_n2v = True if not found_n2v: diff --git a/src/careamics/config/inference_model.py b/src/careamics/config/inference_model.py index cc5243dc4..4a3a0992f 100644 --- a/src/careamics/config/inference_model.py +++ b/src/careamics/config/inference_model.py @@ -1,3 +1,4 @@ +"""Pydantic model representing CAREamics prediction configuration.""" from __future__ import annotations from typing import Any, List, Literal, Optional, Tuple, Union @@ -5,10 +6,9 @@ from albumentations import Compose from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from careamics.utils import check_axes_validity - from .support import SupportedTransform from .transformations.normalize_model import NormalizeModel +from .validators import check_axes_validity, patch_size_ge_than_8_power_of_2 TRANSFORMS_UNION = Union[NormalizeModel] @@ -20,16 +20,17 @@ class InferenceModel(BaseModel): # Mandatory fields data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData - tile_size: Union[List[int], Tuple[int]] = Field(..., min_length=2, max_length=3) - tile_overlap: List[int] = Field( - default=[48, 48], min_length=2, max_length=3 - ) # TODO Will be calculated automatically in the future + tile_size: Optional[Union[List[int], Tuple[int, ...]]] = Field( + default=None, min_length=2, max_length=3 + ) + tile_overlap: Optional[Union[List[int], Tuple[int, ...]]] = Field( + default=None, min_length=2, max_length=3 + ) axes: str - # Optional fields - mean: Optional[float] = None - std: Optional[float] = None + mean: float + std: float = Field(..., ge=0.0) transforms: Union[List[TRANSFORMS_UNION], Compose] = Field( default=[ @@ -46,9 +47,11 @@ class InferenceModel(BaseModel): # Dataloader parameters batch_size: int = Field(default=1, ge=1) - @field_validator("tile_size", "tile_overlap") + @field_validator("tile_overlap") @classmethod - def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: + def all_elements_non_zero_even( + cls, patch_list: Optional[Union[List[int], Tuple[int, ...]]] + ) -> Optional[Union[List[int], Tuple[int, ...]]]: """ Validate patch size. @@ -56,12 +59,12 @@ def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: Parameters ---------- - patch_list : List[int] + patch_list : Optional[Union[List[int], Tuple[int, ...]]] Patch size. Returns ------- - List[int] + Optional[Union[List[int], Tuple[int, ...]]] Validated patch size. Raises @@ -71,15 +74,45 @@ def all_elements_non_zero_even(cls, patch_list: List[int]) -> List[int]: ValueError If the patch size is not even. """ - for dim in patch_list: - if dim < 1: - raise ValueError(f"Patch size must be non-zero positive (got {dim}).") + if patch_list is not None: + for dim in patch_list: + if dim < 1: + raise ValueError( + f"Patch size must be non-zero positive (got {dim})." + ) - if dim % 2 != 0: - raise ValueError(f"Patch size must be even (got {dim}).") + if dim % 2 != 0: + raise ValueError(f"Patch size must be even (got {dim}).") return patch_list + @field_validator("tile_size") + @classmethod + def tile_min_8_power_of_2( + cls, tile_list: Optional[Union[List[int], Tuple[int, ...]]] + ) -> Optional[Union[List[int], Tuple[int, ...]]]: + """ + Validate that each entry is greater or equal than 8 and a power of 2. + + Parameters + ---------- + tile_list : List[int] + Patch size. + + Returns + ------- + List[int] + Validated patch size. + + Raises + ------ + ValueError + If the patch size if smaller than 8. + ValueError + If the patch size is not a power of 2. + """ + return patch_size_ge_than_8_power_of_2(tile_list) + @field_validator("axes") @classmethod def axes_valid(cls, axes: str) -> str: @@ -118,7 +151,8 @@ def axes_valid(cls, axes: str) -> str: def validate_transforms( cls, transforms: Union[List[TRANSFORMS_UNION], Compose] ) -> Union[List[TRANSFORMS_UNION], Compose]: - """Validate that transforms do not have N2V pixel manipulate transforms. + """ + Validate that transforms do not have N2V pixel manipulate transforms. Parameters ---------- @@ -162,20 +196,23 @@ def validate_dimensions(cls, pred_model: InferenceModel) -> InferenceModel: """ expected_len = 3 if "Z" in pred_model.axes else 2 - if len(pred_model.tile_size) != expected_len: - raise ValueError( - f"Tile size must have {expected_len} dimensions given axes " - f"{pred_model.axes} (got {pred_model.tile_size})." - ) + if pred_model.tile_size is not None and pred_model.tile_overlap is not None: + if len(pred_model.tile_size) != expected_len: + raise ValueError( + f"Tile size must have {expected_len} dimensions given axes " + f"{pred_model.axes} (got {pred_model.tile_size})." + ) - if len(pred_model.tile_overlap) != expected_len: - raise ValueError( - f"Tile overlap must have {expected_len} dimensions given axes " - f"{pred_model.axes} (got {pred_model.tile_overlap})." - ) + if len(pred_model.tile_overlap) != expected_len: + raise ValueError( + f"Tile overlap must have {expected_len} dimensions given axes " + f"{pred_model.axes} (got {pred_model.tile_overlap})." + ) - if any((i >= j) for i, j in zip(pred_model.tile_overlap, pred_model.tile_size)): - raise ValueError("Tile overlap must be smaller than tile size.") + if any( + (i >= j) for i, j in zip(pred_model.tile_overlap, pred_model.tile_size) + ): + raise ValueError("Tile overlap must be smaller than tile size.") return pred_model @@ -229,43 +266,22 @@ def add_std_and_mean_to_normalize( if not isinstance(pred_model.transforms, Compose): for transform in pred_model.transforms: if transform.name == SupportedTransform.NORMALIZE.value: - transform.parameters.mean = pred_model.mean - transform.parameters.std = pred_model.std + transform.mean = pred_model.mean + transform.std = pred_model.std return pred_model def _update(self, **kwargs: Any) -> None: - """Update multiple arguments at once.""" - self.__dict__.update(kwargs) - self.__class__.model_validate(self.__dict__) - - def set_mean_and_std(self, mean: float, std: float) -> None: """ - Set mean and standard deviation of the data. - - This method should be used instead setting the fields directly, as it would - otherwise trigger a validation error. + Update multiple arguments at once. Parameters ---------- - mean : float - Mean of the data. - std : float - Standard deviation of the data. + **kwargs : Any + Key-value pairs of arguments to update. """ - self._update(mean=mean, std=std) - - # search in the transforms for Normalize and update parameters - if not isinstance(self.transforms, Compose): - for transform in self.transforms: - if transform.name == SupportedTransform.NORMALIZE.value: - transform.parameters.mean = mean - transform.parameters.std = std - else: - raise ValueError( - "Setting mean and std with Compose transforms is not allowed. Add " - "mean and std parameters directly to the transform in the Compose." - ) + self.__dict__.update(kwargs) + self.__class__.model_validate(self.__dict__) def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> None: """ @@ -277,5 +293,7 @@ def set_3D(self, axes: str, tile_size: List[int], tile_overlap: List[int]) -> No Axes. tile_size : List[int] Tile size. + tile_overlap : List[int] + Tile overlap. """ self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap) diff --git a/src/careamics/config/references/__init__.py b/src/careamics/config/references/__init__.py new file mode 100644 index 000000000..b314f1c89 --- /dev/null +++ b/src/careamics/config/references/__init__.py @@ -0,0 +1,45 @@ +"""Module containing references to the algorithm used in CAREamics.""" + +__all__ = [ + "N2V2Ref", + "N2VRef", + "StructN2VRef", + "N2VDescription", + "N2V2Description", + "StructN2VDescription", + "StructN2V2Description", + "N2V", + "N2V2", + "STRUCT_N2V", + "STRUCT_N2V2", + "CUSTOM", + "N2N", + "CARE", + "CAREDescription", + "N2NDescription", + "CARERef", + "N2NRef", +] + +from .algorithm_descriptions import ( + CARE, + CUSTOM, + N2N, + N2V, + N2V2, + STRUCT_N2V, + STRUCT_N2V2, + CAREDescription, + N2NDescription, + N2V2Description, + N2VDescription, + StructN2V2Description, + StructN2VDescription, +) +from .references import ( + CARERef, + N2NRef, + N2V2Ref, + N2VRef, + StructN2VRef, +) diff --git a/src/careamics/config/references/algorithm_descriptions.py b/src/careamics/config/references/algorithm_descriptions.py new file mode 100644 index 000000000..97e4ba83f --- /dev/null +++ b/src/careamics/config/references/algorithm_descriptions.py @@ -0,0 +1,131 @@ +"""Descriptions of the algorithms used in CAREmics.""" +from pydantic import BaseModel + +CUSTOM = "Custom" +N2V = "Noise2Void" +N2V2 = "N2V2" +STRUCT_N2V = "StructN2V" +STRUCT_N2V2 = "StructN2V2" +N2N = "Noise2Noise" +CARE = "CARE" + + +N2V_DESCRIPTION = ( + "Noise2Void is a UNet-based self-supervised algorithm that " + "uses blind-spot training to denoise images. In short, in every " + "patches during training, random pixels are selected and their " + "value replaced by a neighboring pixel value. The network is then " + "trained to predict the original pixel value. The algorithm " + "relies on the continuity of the signal (neighboring pixels have " + "similar values) and the pixel-wise independence of the noise " + "(the noise in a pixel is not correlated with the noise in " + "neighboring pixels)." +) + + +class AlgorithmDescription(BaseModel): + """Description of an algorithm. + + Attributes + ---------- + description : str + Description of the algorithm. + """ + + description: str + + +class N2VDescription(AlgorithmDescription): + """Description of Noise2Void. + + Attributes + ---------- + description : str + Description of Noise2Void. + """ + + description: str = N2V_DESCRIPTION + + +class N2V2Description(AlgorithmDescription): + """Description of N2V2. + + Attributes + ---------- + description : str + Description of N2V2. + """ + + description: str = ( + "N2V2 is a variant of Noise2Void. " + + N2V_DESCRIPTION + + "\nN2V2 introduces blur-pool layers and removed skip " + "connections in the UNet architecture to remove checkboard " + "artefacts, a common artefacts ocurring in Noise2Void." + ) + + +class StructN2VDescription(AlgorithmDescription): + """Description of StructN2V. + + Attributes + ---------- + description : str + Description of StructN2V. + """ + + description: str = ( + "StructN2V is a variant of Noise2Void. " + + N2V_DESCRIPTION + + "\nStructN2V uses a linear mask (horizontal or vertical) to replace " + "the pixel values of neighbors of the masked pixels by a random " + "value. Such masking allows removing 1D structured noise from the " + "the images, the main failure case of the original N2V." + ) + + +class StructN2V2Description(AlgorithmDescription): + """Description of StructN2V2. + + Attributes + ---------- + description : str + Description of StructN2V2. + """ + + description: str = ( + "StructN2V2 is a a variant of Noise2Void that uses both " + "structN2V and N2V2. " + + N2V_DESCRIPTION + + "\nStructN2V2 uses a linear mask (horizontal or vertical) to replace " + "the pixel values of neighbors of the masked pixels by a random " + "value. Such masking allows removing 1D structured noise from the " + "the images, the main failure case of the original N2V." + "\nN2V2 introduces blur-pool layers and removed skip connections in " + "the UNet architecture to remove checkboard artefacts, a common " + "artefacts ocurring in Noise2Void." + ) + + +class N2NDescription(AlgorithmDescription): + """Description of Noise2Noise. + + Attributes + ---------- + description : str + Description of Noise2Noise. + """ + + description: str = "Noise2Noise" # TODO + + +class CAREDescription(AlgorithmDescription): + """Description of CARE. + + Attributes + ---------- + description : str + Description of CARE. + """ + + description: str = "CARE" # TODO diff --git a/src/careamics/config/references/references.py b/src/careamics/config/references/references.py new file mode 100644 index 000000000..60c8413f9 --- /dev/null +++ b/src/careamics/config/references/references.py @@ -0,0 +1,38 @@ +"""References for the CAREamics algorithms.""" +from bioimageio.spec.generic.v0_3 import CiteEntry + +N2VRef = CiteEntry( + text='Krull, A., Buchholz, T.O. and Jug, F., 2019. "Noise2Void - Learning ' + 'denoising from single noisy images". In Proceedings of the IEEE/CVF ' + "conference on computer vision and pattern recognition (pp. 2129-2137).", + doi="10.1109/cvpr.2019.00223", +) + +N2V2Ref = CiteEntry( + text="Höck, E., Buchholz, T.O., Brachmann, A., Jug, F. and Freytag, A., " + '2022. "N2V2 - Fixing Noise2Void checkerboard artifacts with modified ' + 'sampling strategies and a tweaked network architecture". In European ' + "Conference on Computer Vision (pp. 503-518).", + doi="10.1007/978-3-031-25069-9_33", +) + +StructN2VRef = CiteEntry( + text="Broaddus, C., Krull, A., Weigert, M., Schmidt, U. and Myers, G., 2020." + '"Removing structured noise with self-supervised blind-spot ' + 'networks". In 2020 IEEE 17th International Symposium on Biomedical ' + "Imaging (ISBI) (pp. 159-163).", + doi="10.1109/isbi45749.2020.9098336", +) + +N2NRef = CiteEntry( + text="Lehtinen, J., Munkberg, J., Hasselgren, J., Laine, S., Karras, T., " + 'Aittala, M. and Aila, T., 2018. "Noise2Noise: Learning image restoration ' + 'without clean data". arXiv preprint arXiv:1803.04189.', + doi="10.48550/arXiv.1803.04189", +) + +CARERef = CiteEntry( + text='Weigert, Martin, et al. "Content-aware image restoration: pushing the ' + 'limits of fluorescence microscopy." Nature methods 15.12 (2018): 1090-1097.', + doi="10.1038/s41592-018-0216-7", +) diff --git a/src/careamics/config/tile_information.py b/src/careamics/config/tile_information.py new file mode 100644 index 000000000..e018e0f16 --- /dev/null +++ b/src/careamics/config/tile_information.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import Optional, Tuple + +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator + + +class TileInformation(BaseModel): + """ + Pydantic model containing tile information. + + This model is used to represent the information required to stitch back a tile into + a larger image. It is used throughout the prediction pipeline of CAREamics. + """ + + model_config = ConfigDict(validate_default=True) + + array_shape: Tuple[int, ...] + tiled: bool = False + last_tile: bool = False + overlap_crop_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) + stitch_coords: Optional[Tuple[Tuple[int, ...], ...]] = Field(default=None) + + @field_validator("array_shape") + @classmethod + def no_singleton_dimensions(cls, v: Tuple[int, ...]): + """ + Check that the array shape does not have any singleton dimensions. + + Parameters + ---------- + v : Tuple[int, ...] + Array shape to check. + + Returns + ------- + Tuple[int, ...] + The array shape if it does not contain singleton dimensions. + + Raises + ------ + ValueError + If the array shape contains singleton dimensions. + """ + if any(dim == 1 for dim in v): + raise ValueError("Array shape must not contain singleton dimensions.") + return v + + @field_validator("last_tile") + @classmethod + def only_if_tiled(cls, v: bool, values: ValidationInfo): + """ + Check that the last tile flag is only set if tiling is enabled. + + Parameters + ---------- + v : bool + Last tile flag. + values : ValidationInfo + Validation information. + + Returns + ------- + bool + The last tile flag. + """ + if not values.data["tiled"]: + return False + return v + + @field_validator("overlap_crop_coords", "stitch_coords") + @classmethod + def mandatory_if_tiled( + cls, v: Optional[Tuple[int, ...]], values: ValidationInfo + ) -> Optional[Tuple[int, ...]]: + """ + Check that the coordinates are not `None` if tiling is enabled. + + The method also return `None` if tiling is not enabled. + + Parameters + ---------- + v : Optional[Tuple[int, ...]] + Coordinates to check. + values : ValidationInfo + Validation information. + + Returns + ------- + Optional[Tuple[int, ...]] + The coordinates if tiling is enabled, otherwise `None`. + + Raises + ------ + ValueError + If the coordinates are `None` and tiling is enabled. + """ + if values.data["tiled"]: + if v is None: + raise ValueError("Value must be specified if tiling is enabled.") + + return v + else: + return None diff --git a/src/careamics/config/transformations/n2v_manipulate_model.py b/src/careamics/config/transformations/n2v_manipulate_model.py index 07929a084..5e4a03587 100644 --- a/src/careamics/config/transformations/n2v_manipulate_model.py +++ b/src/careamics/config/transformations/n2v_manipulate_model.py @@ -1,15 +1,36 @@ +"""Pydantic model for the N2VManipulate transform.""" from typing import Literal -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import ConfigDict, Field, field_validator +from .transform_model import TransformModel -class N2VManipulationParameters(BaseModel): - """Pydantic model used to validate N2V manipulation parameters.""" + +class N2VManipulateModel(TransformModel): + """ + Pydantic model used to represent N2V manipulation. + + Attributes + ---------- + name : Literal["N2VManipulate"] + Name of the transformation. + roi_size : int + Size of the masking region, by default 11. + masked_pixel_percentage : float + Percentage of masked pixels, by default 0.2. + strategy : Literal["uniform", "median"] + Strategy pixel value replacement, by default "uniform". + struct_mask_axis : Literal["horizontal", "vertical", "none"] + Axis of the structN2V mask, by default "none". + struct_mask_span : int + Span of the structN2V mask, by default 5. + """ model_config = ConfigDict( validate_assignment=True, ) + name: Literal["N2VManipulate"] roi_size: int = Field(default=11, ge=3, le=21) masked_pixel_percentage: float = Field(default=0.2, ge=0.05, le=1.0) strategy: Literal["uniform", "median"] = Field(default="uniform") @@ -19,7 +40,8 @@ class N2VManipulationParameters(BaseModel): @field_validator("roi_size", "struct_mask_span") @classmethod def odd_value(cls, v: int) -> int: - """Validate that the value is odd. + """ + Validate that the value is odd. Parameters ---------- @@ -39,14 +61,3 @@ def odd_value(cls, v: int) -> int: if v % 2 == 0: raise ValueError("Size must be an odd number.") return v - - -class N2VManipulateModel(BaseModel): - """Pydantic model used to represent N2V manipulation.""" - - model_config = ConfigDict( - validate_assignment=True, - ) - - name: Literal["N2VManipulate"] - parameters: N2VManipulationParameters = N2VManipulationParameters() diff --git a/src/careamics/config/transformations/nd_flip_model.py b/src/careamics/config/transformations/nd_flip_model.py index 5c211e9c8..9787a13ac 100644 --- a/src/careamics/config/transformations/nd_flip_model.py +++ b/src/careamics/config/transformations/nd_flip_model.py @@ -1,26 +1,32 @@ +"""Pydantic model for the NDFlip transform.""" from typing import Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field +from .transform_model import TransformModel -class NDFlipParameters(BaseModel): - """Pydantic model used to validate NDFlip parameters.""" - model_config = ConfigDict( - validate_assignment=True, - ) +class NDFlipModel(TransformModel): + """ + Pydantic model used to represent NDFlip transformation. - p: float = Field(default=0.5, ge=0.0, le=1.0) - is_3D: bool = Field(default=False) - flip_z: bool = Field(default=True) - - -class NDFlipModel(BaseModel): - """Pydantic model used to represent NDFlip transformation.""" + Attributes + ---------- + name : Literal["NDFlip"] + Name of the transformation. + p : float + Probability of applying the transformation, by default 0.5. + is_3D : bool + Whether the transformation should be applied in 3D, by default False. + flip_z : bool + Whether to flip the z axis, by default True. + """ model_config = ConfigDict( validate_assignment=True, ) name: Literal["NDFlip"] - parameters: NDFlipParameters = NDFlipParameters() + p: float = Field(default=0.5, ge=0.0, le=1.0) + is_3D: bool = Field(default=False) + flip_z: bool = Field(default=True) diff --git a/src/careamics/config/transformations/normalize_model.py b/src/careamics/config/transformations/normalize_model.py index 6f1de597f..cc156dbf9 100644 --- a/src/careamics/config/transformations/normalize_model.py +++ b/src/careamics/config/transformations/normalize_model.py @@ -1,26 +1,31 @@ +"""Pydantic model for the Normalize transform.""" from typing import Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field +from .transform_model import TransformModel -class NormalizeParameters(BaseModel): - """Pydantic model used to validate Normalize parameters.""" - model_config = ConfigDict( - validate_assignment=True, - ) - - mean: float = Field(default=0.485) # albumentations default - std: float = Field(default=0.229) - max_pixel_value: float = Field(default=1.0, ge=0.0) # TODO explain why +class NormalizeModel(TransformModel): + """ + Pydantic model used to represent Normalize transformation. + The Normalize transform is a zero mean and unit variance transformation. -class NormalizeModel(BaseModel): - """Pydantic model used to represent Normalize transformation.""" + Attributes + ---------- + name : Literal["Normalize"] + Name of the transformation. + mean : float + Mean value for normalization. + std : float + Standard deviation value for normalization. + """ model_config = ConfigDict( validate_assignment=True, ) name: Literal["Normalize"] - parameters: NormalizeParameters = NormalizeParameters() + mean: float = Field(default=0.485) # albumentations defaults + std: float = Field(default=0.229) diff --git a/src/careamics/config/transformations/transform_model.py b/src/careamics/config/transformations/transform_model.py new file mode 100644 index 000000000..ffbc6022d --- /dev/null +++ b/src/careamics/config/transformations/transform_model.py @@ -0,0 +1,44 @@ +"""Parent model for the transforms.""" +from typing import Any, Dict + +from pydantic import BaseModel, ConfigDict + + +class TransformModel(BaseModel): + """ + Pydantic model used to represent a transformation. + + The `model_dump` method is overwritten to exclude the name field. + + Attributes + ---------- + name : str + Name of the transformation. + """ + + model_config = ConfigDict( + extra="forbid", # throw errors if the parameters are not properly passed + ) + + name: str + + def model_dump(self, **kwargs) -> Dict[str, Any]: + """ + Return the model as a dictionary. + + Parameters + ---------- + **kwargs + Pydantic BaseMode model_dump method keyword arguments. + + Returns + ------- + Dict[str, Any] + Dictionary representation of the model. + """ + model_dict = super().model_dump(**kwargs) + + # remove the name field + model_dict.pop("name") + + return model_dict diff --git a/src/careamics/config/transformations/xy_random_rotate90_model.py b/src/careamics/config/transformations/xy_random_rotate90_model.py index d9add88c1..af0cd1422 100644 --- a/src/careamics/config/transformations/xy_random_rotate90_model.py +++ b/src/careamics/config/transformations/xy_random_rotate90_model.py @@ -1,25 +1,29 @@ +"""Pydantic model for the XYRandomRotate90 transform.""" from typing import Literal -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field +from .transform_model import TransformModel -class XYRandomRotate90Parameters(BaseModel): - """Pydantic model used to validate NDFlip parameters.""" - model_config = ConfigDict( - validate_assignment=True, - ) +class XYRandomRotate90Model(TransformModel): + """ + Pydantic model used to represent NDFlip transformation. - p: float = Field(default=0.5, ge=0.0, le=1.0) - is_3D: bool = Field(default=False) - - -class XYRandomRotate90Model(BaseModel): - """Pydantic model used to represent NDFlip transformation.""" + Attributes + ---------- + name : Literal["XYRandomRotate90"] + Name of the transformation. + p : float + Probability of applying the transformation, by default 0.5. + is_3D : bool + Whether the transformation should be applied in 3D, by default False. + """ model_config = ConfigDict( validate_assignment=True, ) name: Literal["XYRandomRotate90"] - parameters: XYRandomRotate90Parameters = XYRandomRotate90Parameters() + p: float = Field(default=0.5, ge=0.0, le=1.0) + is_3D: bool = Field(default=False) diff --git a/src/careamics/config/validators/__init__.py b/src/careamics/config/validators/__init__.py new file mode 100644 index 000000000..53ddbf8db --- /dev/null +++ b/src/careamics/config/validators/__init__.py @@ -0,0 +1,5 @@ +"""Validator utilities.""" + +__all__ = ["check_axes_validity", "patch_size_ge_than_8_power_of_2"] + +from .validator_utils import check_axes_validity, patch_size_ge_than_8_power_of_2 diff --git a/src/careamics/config/validators/validator_utils.py b/src/careamics/config/validators/validator_utils.py new file mode 100644 index 000000000..cb93b475f --- /dev/null +++ b/src/careamics/config/validators/validator_utils.py @@ -0,0 +1,98 @@ +""" +Validator functions. + +These functions are used to validate dimensions and axes of inputs. +""" +from typing import List, Optional, Tuple, Union + +_AXES = "STCZYX" + + +def check_axes_validity(axes: str) -> bool: + """ + Sanity check on axes. + + The constraints on the axes are the following: + - must be a combination of 'STCZYX' + - must not contain duplicates + - must contain at least 2 contiguous axes: X and Y + - must contain at most 4 axes + - cannot contain both S and T axes + + Axes do not need to be in the order 'STCZYX', as this depends on the user data. + + Parameters + ---------- + axes : str + Axes to validate. + + Returns + ------- + bool + True if axes are valid, False otherwise. + """ + _axes = axes.upper() + + # Minimum is 2 (XY) and maximum is 4 (TZYX) + if len(_axes) < 2 or len(_axes) > 6: + raise ValueError( + f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes." + ) + + if "YX" not in _axes and "XY" not in _axes: + raise ValueError( + f"Invalid axes {axes}. Must contain at least X and Y axes consecutively." + ) + + # all characters must be in REF_AXES = 'STCZYX' + if not all(s in _AXES for s in _axes): + raise ValueError(f"Invalid axes {axes}. Must be a combination of {_AXES}.") + + # check for repeating characters + for i, s in enumerate(_axes): + if i != _axes.rfind(s): + raise ValueError( + f"Invalid axes {axes}. Cannot contain duplicate axes" + f" (got multiple {axes[i]})." + ) + + return True + + +def patch_size_ge_than_8_power_of_2( + patch_list: Optional[Union[List[int], Tuple[int, ...]]] +) -> Optional[Union[List[int], Tuple[int, ...]]]: + """ + Validate that each entry is greater or equal than 8 and a power of 2. + + If None is passed, the function will return None. + + Parameters + ---------- + patch_list : Optional[Union[List[int], Tuple[int, ...]]] + Patch size. + + Returns + ------- + Optional[Union[List[int], Tuple[int, ...]]] + Validated patch size. + + Raises + ------ + ValueError + If the patch size if smaller than 8. + ValueError + If the patch size is not a power of 2. + """ + if patch_list is not None: + for dim in patch_list: + if dim < 8: + raise ValueError(f"Patch size must be non-zero positive (got {dim}).") + + if (dim & (dim - 1)) != 0: + raise ValueError( + f"Patch size must be a power of 2 in all dimensions " + f"(got {dim})." + ) + + return patch_list diff --git a/src/careamics/conftest.py b/src/careamics/conftest.py index 929182a07..e0a1fae6c 100644 --- a/src/careamics/conftest.py +++ b/src/careamics/conftest.py @@ -16,7 +16,6 @@ def my_path(tmpdir_factory: TempPathFactory) -> Path: return tmpdir_factory.mktemp("my_path") - pytest_collect_file = Sybil( parsers=[ DocTestParser(), diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py index 2672e0b26..ace44bc9e 100644 --- a/src/careamics/dataset/dataset_utils/dataset_utils.py +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -63,7 +63,10 @@ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray: # sanity checks if len(_axes) != len(_x.shape): - raise ValueError(f"Incompatible data shape ({_x.shape}) and axes ({_axes}).") + raise ValueError( + f"Incompatible data shape ({_x.shape}) and axes ({_axes}). Are the axes " + f"correct?" + ) # get new x shape new_x_shape, new_axes, indices = _get_shape_order(_x.shape, _axes) diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index 620b06345..becf3eaa3 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -9,16 +9,17 @@ from torch.utils.data import Dataset from ..config import DataModel, InferenceModel +from ..config.tile_information import TileInformation from ..utils.logging import get_logger -from .dataset_utils import read_tiff +from .dataset_utils import read_tiff, reshape_array from .patching.patch_transform import get_patch_transform from .patching.patching import ( - generate_patches_predict, prepare_patches_supervised, prepare_patches_supervised_array, prepare_patches_unsupervised, prepare_patches_unsupervised_array, ) +from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) @@ -269,7 +270,7 @@ def split_dataset( return dataset -class InMemoryPredictionDataset(InMemoryDataset): +class InMemoryPredictionDataset(Dataset): """ Dataset storing data in memory and allowing generating patches from it. @@ -302,34 +303,19 @@ def __init__( self.axes = self.pred_config.axes self.tile_size = self.pred_config.tile_size self.tile_overlap = self.pred_config.tile_overlap - self.tiling = self.tile_size and self.tile_overlap self.mean = self.pred_config.mean self.std = self.pred_config.std self.data_target = data_target + # tiling only if both tile size and overlap are provided + self.tiling = self.tile_size is not None and self.tile_overlap is not None + # read function self.read_source_func = read_source_func # Generate patches - tiles = self._prepare_tiles() - - # Add results to members - self.data, computed_mean, computed_std = tiles - - if not self.pred_config.mean or not self.pred_config.std: - self.mean, self.std = computed_mean, computed_std - logger.info(f"Computed dataset mean: {self.mean}, std: {self.std}") - - # if the transforms are not an instance of Compose - if hasattr(self.pred_config, "has_transform_list"): - if self.pred_config.has_transform_list(): - # update mean and std in configuration - # the object is mutable and should then be recorded in the CAREamist - self.pred_config.set_mean_and_std(self.mean, self.std) - else: - self.pred_config.set_mean_and_std(self.mean, self.std) - else: - self.mean, self.std = self.pred_config.mean, self.pred_config.std + self.data = self._prepare_tiles() + self.mean, self.std = self.pred_config.mean, self.pred_config.std # get transforms self.patch_transform = get_patch_transform( @@ -337,25 +323,47 @@ def __init__( with_target=self.data_target is not None, ) - def _prepare_tiles(self) -> Callable: + def _prepare_tiles(self) -> List[Tuple[np.ndarray, TileInformation]]: """ Iterate over data source and create an array of patches. - Calls consecutive function for supervised and unsupervised learning. - Returns ------- - np.ndarray - Array of patches. + List[XArrayTile] + List of tiles. """ + # reshape array + reshaped_sample = reshape_array(self.input_array, self.axes) + if self.tiling: - return generate_patches_predict( - self.input_array, self.axes, self.tile_size, self.tile_overlap - ), self.input_array.mean(), self.input_array.std() + # generate patches, which returns a generator + patch_generator = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, + ) + patches_list = list(patch_generator) + + if len(patches_list) == 0: + raise ValueError("No tiles generated, ") + + return patches_list else: - return self.input_array, self.input_array.mean(), self.input_array.std() + array_shape = reshaped_sample.squeeze().shape + return [(reshaped_sample, TileInformation(array_shape=array_shape))] + + def __len__(self) -> int: + """ + Return the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ + return len(self.data) - def __getitem__(self, index: int) -> Tuple[np.ndarray, Any, Any, Any, Any]: + def __getitem__(self, index: int) -> Tuple[np.ndarray, TileInformation]: """ Return the patch corresponding to the provided index. @@ -366,37 +374,18 @@ def __getitem__(self, index: int) -> Tuple[np.ndarray, Any, Any, Any, Any]: Returns ------- - Tuple[np.ndarray] - Patch. - - Raises - ------ - ValueError - If dataset mean and std are not set. + Tuple[np.ndarray, TileInformation] + Transformed patch. """ - if self.tiling: - ( - tile, - last_tile, - arr_shape, - overlap_crop_coords, - stitch_coords, - ) = self.data[index] + tile_array, tile_info = self.data[index] - # Albumentations requires Channel last - tile = np.moveaxis(tile, 0, -1) + # Albumentations requires channel last, use the XArrayTile array + patch = np.moveaxis(tile_array, 0, -1) - # Apply transforms - transformed_tile = self.patch_transform(image=tile)["image"] - tile = transformed_tile + # Apply transforms + transformed_patch = self.patch_transform(image=patch)["image"] - # move C axes back - tile = np.moveaxis(tile, -1, 0) - - return ( - tile, - last_tile, - arr_shape, - overlap_crop_coords, - stitch_coords, - ) # TODO can we wrap this into an object? + # move C axes back + transformed_patch = np.moveaxis(transformed_patch, -1, 0) + + return transformed_patch, tile_info diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index be8835b46..53df2e618 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -8,15 +8,14 @@ from torch.utils.data import IterableDataset, get_worker_info from ..config import DataModel, InferenceModel -from ..config.support import SupportedExtractionStrategy +from ..config.tile_information import TileInformation from ..utils.logging import get_logger -from .dataset_utils import read_tiff +from .dataset_utils import read_tiff, reshape_array from .patching import ( - generate_patches_predict, - generate_patches_supervised, - generate_patches_unsupervised, get_patch_transform, ) +from .patching.random_patching import extract_patches_random +from .patching.tiled_patching import extract_tiles logger = get_logger(__name__) @@ -175,22 +174,18 @@ def __iter__( # iterate over files for sample_input, sample_target in self._iterate_over_files(): - if self.target_files is not None: - patches = generate_patches_supervised( - data=sample_input, - axes=self.data_config.axes, - patch_extraction_method=SupportedExtractionStrategy.RANDOM, - patch_size=self.data_config.patch_size, - target=sample_target, - ) + reshaped_sample = reshape_array(sample_input, self.data_config.axes) + reshaped_target = ( + None + if sample_target is None + else reshape_array(sample_target, self.data_config.axes) + ) - else: - patches = generate_patches_unsupervised( - data=sample_input, - axes=self.data_config.axes, - patch_extraction_method=SupportedExtractionStrategy.RANDOM, - patch_size=self.data_config.patch_size, - ) + patches = extract_patches_random( + arr=reshaped_sample, + patch_size=self.data_config.patch_size, + target=reshaped_target, + ) # iterate over patches # patches are tuples of (patch, target) if target is available @@ -368,12 +363,8 @@ def __init__( self.tile_overlap = self.prediction_config.tile_overlap self.read_source_func = read_source_func - # check that mean and std are provided - if not self.mean or not self.std: - raise ValueError( - "Mean and std must be provided to the configuration in order to " - " perform prediction." - ) + # tile only if both tile size and overlaps are provided + self.tile = self.tile_size is not None and self.tile_overlap is not None # get tta transforms self.patch_transform = get_patch_transform( @@ -383,7 +374,7 @@ def __init__( def __iter__( self, - ) -> Generator[Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]], None, None]: + ) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: """ Iterate over data source and yield single patch. @@ -397,18 +388,29 @@ def __iter__( ), "Mean and std must be provided" for sample, _ in self._iterate_over_files(): - patches = generate_patches_predict( - data=sample, - axes=self.axes, - tile_size=self.tile_size, - tile_overlap=self.tile_overlap, - ) - # TODO AttributeError: 'IterablePredictionDataset' object has no attribute - # 'mean' message appears if the predict func is run more than once - - for patch_data in patches: - # Albumentations expects the channel dimension to be last - transformed = self.patch_transform( - image=np.moveaxis(patch_data[0], 0, -1) + # reshape array + reshaped_sample = reshape_array(sample, self.axes) + + if self.tile: + # generate patches, return a generator + patch_gen = extract_tiles( + arr=reshaped_sample, + tile_size=self.tile_size, + overlaps=self.tile_overlap, ) - yield (np.moveaxis(transformed["image"], -1, 0), *patch_data[1:]) + else: + # just wrap the sample in a generator with default tiling info + array_shape = reshaped_sample.squeeze().shape + patch_gen = ( + (reshaped_sample, TileInformation(array_shape=array_shape)) + for _ in range(1) + ) + + # apply transform to patches + for patch_array, tile_info in patch_gen: + # albumentations expects the channel dimension to be last + patch = np.moveaxis(patch_array, 0, -1) + transformed_patch = self.patch_transform(image=patch) + transformed_patch = np.moveaxis(transformed_patch["image"], -1, 0) + + yield transformed_patch, tile_info diff --git a/src/careamics/dataset/patching/__init__.py b/src/careamics/dataset/patching/__init__.py index f1ee387c0..c9f5f219e 100644 --- a/src/careamics/dataset/patching/__init__.py +++ b/src/careamics/dataset/patching/__init__.py @@ -2,16 +2,7 @@ __all__ = [ - "generate_patches_predict", - "generate_patches_supervised", - "generate_patches_unsupervised", "get_patch_transform", - "get_patch_transform_predict", ] -from .patch_transform import get_patch_transform, get_patch_transform_predict -from .patching import ( - generate_patches_predict, - generate_patches_supervised, - generate_patches_unsupervised, -) +from .patch_transform import get_patch_transform diff --git a/src/careamics/dataset/patching/patch_transform.py b/src/careamics/dataset/patching/patch_transform.py index 8a092cfc1..15cde203b 100644 --- a/src/careamics/dataset/patching/patch_transform.py +++ b/src/careamics/dataset/patching/patch_transform.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Union +from typing import List, Union import albumentations as Aug @@ -31,7 +31,7 @@ def get_patch_transform( # instantiate all transforms transforms = [ - all_transforms[transform.name](**transform.parameters.model_dump()) + all_transforms[transform.name](**transform.model_dump()) for transform in patch_transforms ] @@ -42,155 +42,3 @@ def get_patch_transform( if (with_target and normalize_mask) # TODO check this else {}, ) - - -# TODO kept here as reference -def _get_patch_transform( - patch_transforms: Union[List[TRANSFORMS_UNION], Aug.Compose], - mean: float, - std: float, - target: bool, - normalize_mask: bool = True, -) -> Aug.Compose: - """Return a pixel manipulation function. - - Used in N2V family of algorithms. - - Parameters - ---------- - patch_transform_type : str - Type of patch transform. - target : bool - Whether the transform is applied to the target(if the target is present). - mode : str - Train or predict mode. - - Returns - ------- - Union[None, Callable] - Patch transform function. - """ - if patch_transforms is None: - return Aug.Compose( - [Aug.NoOp()], - additional_targets={"target": "image"} - if (target and normalize_mask) # TODO why? there is no normalization here? - else {}, - ) - elif isinstance(patch_transforms, list): - patch_transforms[[t["name"] for t in patch_transforms].index("Normalize")][ - "parameters" - ] = { - "mean": mean, - "std": std, - "max_pixel_value": 1, - } - # TODO not very readable - return Aug.Compose( - [ - get_all_transforms()[transform["name"]](**transform["parameters"]) - if "parameters" in transform - else get_all_transforms()[transform["name"]]() - for transform in patch_transforms - ], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - elif isinstance(patch_transforms, Aug.Compose): - return Aug.Compose( - [ - t - for t in patch_transforms.transforms[:-1] - if not isinstance(t, Aug.Normalize) - ] - + [ - Aug.Normalize(mean=mean, std=std, max_pixel_value=1), - patch_transforms.transforms[-1] - if patch_transforms.transforms[-1].__class__.__name__ == "ManipulateN2V" - else Aug.NoOp(), - ], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - else: - raise ValueError( - f"Incorrect patch transform type {patch_transforms}. " - f"Please refer to the documentation." # TODO add link to documentation - ) - - -# TODO add tta -def get_patch_transform_predict( - patch_transforms: Union[List, Aug.Compose, None], - mean: float, - std: float, - target: bool, - normalize_mask: bool = True, -) -> Union[None, Callable]: - """Return a pixel manipulation function. - - Used in N2V family of algorithms. - - Parameters - ---------- - patch_transform_type : str - Type of patch transform. - target : bool - Whether the transform is applied to the target(if the target is present). - mode : str - Train or predict mode. - - Returns - ------- - Union[None, Callable] - Patch transform function. - """ - if patch_transforms is None: - return Aug.Compose( - [Aug.NoOp()], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - elif isinstance(patch_transforms, list): - patch_transforms[[t["name"] for t in patch_transforms].index("Normalize")][ - "parameters" - ] = { - "mean": mean, - "std": std, - "max_pixel_value": 1, - } - # TODO not very readable - return Aug.Compose( - [ - get_all_transforms()[transform["name"]](**transform["parameters"]) - if "parameters" in transform - else get_all_transforms()[transform["name"]]() - for transform in patch_transforms - if transform["name"] != "ManipulateN2V" - ], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - elif isinstance(patch_transforms, Aug.Compose): - return Aug.Compose( - [ - t - for t in patch_transforms.transforms[:-1] - if not isinstance(t, Aug.Normalize) - ] - + [ - Aug.Normalize(mean=mean, std=std, max_pixel_value=1), - ], - additional_targets={"target": "image"} - if (target and normalize_mask) - else {}, - ) - else: - raise ValueError( - f"Incorrect patch transform type {patch_transforms}. " - f"Please refer to the documentation." # TODO add link to documentation - ) diff --git a/src/careamics/dataset/patching/patching.py b/src/careamics/dataset/patching/patching.py index 09dae8427..def9ce25e 100644 --- a/src/careamics/dataset/patching/patching.py +++ b/src/careamics/dataset/patching/patching.py @@ -4,29 +4,17 @@ These functions are used to tile images into patches or tiles. """ from pathlib import Path -from typing import Callable, Generator, List, Optional, Tuple, Union +from typing import Callable, List, Tuple, Union import numpy as np -import zarr -from ...config.support.supported_extraction_strategies import ( - SupportedExtractionStrategy, -) from ...utils.logging import get_logger from ..dataset_utils import reshape_array -from .random_patching import extract_patches_random, extract_patches_random_from_chunks from .sequential_patching import extract_patches_sequential -from .tiled_patching import extract_tiles logger = get_logger(__name__) -# TODO: several issues that require refactoring -# - some patching return array, others generator -# - in iterable and in memory, the reshaping happens at different moment -# - return type is not consistent (ndarray, ndarray or ndarray, None or just ndarray) - - # called by in memory dataset def prepare_patches_supervised( train_files: List[Path], @@ -222,191 +210,3 @@ def prepare_patches_unsupervised_array( patches, _ = extract_patches_sequential(reshaped_sample, patch_size=patch_size) return patches, _, mean, std # TODO inelegant, replace by dataclass? - - -# prediction, both in memory and iterable -def generate_patches_predict( - data: np.ndarray, - axes: str, - tile_size: Union[List[int], Tuple[int, ...]], - tile_overlap: Union[List[int], Tuple[int, ...]], -) -> List[Tuple[np.ndarray, bool, Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]]: - """ - Iterate over data source and create an array of patches. - - Returns - ------- - np.ndarray - Array of patches. - """ - # reshape array - reshaped_sample = reshape_array(data, axes) - # generate patches, return a generator - patches = extract_tiles( - arr=reshaped_sample, tile_size=tile_size, overlaps=tile_overlap - ) - patches_list = list(patches) # TODO: refactor to use generator ? - if len(patches_list) == 0: - raise ValueError("No tiles generated") - - return patches_list - -# iterator over files -def generate_patches_supervised( - data: Union[np.ndarray, zarr.Array], - axes: str, - patch_extraction_method: SupportedExtractionStrategy, - patch_size: Union[List[int], Tuple[int, ...]], - patch_overlap: Optional[Union[List[int], Tuple[int, ...]]] = None, - target: Optional[Union[np.ndarray, zarr.Array]] = None, -) -> Generator[np.ndarray, None, None]: - """ - Creates an iterator with patches and corresponding targets from a sample. - - Parameters - ---------- - sample : np.ndarray - Input array. - patch_extraction_method : ExtractionStrategies - Patch extraction method, as defined in extraction_strategy.ExtractionStrategy. - patch_size : Optional[Union[List[int], Tuple[int]]] - Size of the patches along each dimension of the array, except the first. - patch_overlap : Optional[Union[List[int], Tuple[int]]] - Overlap between patches. - - Returns - ------- - Generator[np.ndarray, None, None] - Generator yielding patches/tiles. - - Raises - ------ - ValueError - If overlap is not specified when using tiling. - ValueError - If patches is None. - """ - patches = None - targets = None - - # reshape target - reshaped_sample = reshape_array(data, axes) - reshaped_target = reshape_array(target, axes) - - if patch_size is not None: - patches = None - - if patch_extraction_method == SupportedExtractionStrategy.TILED: - if patch_overlap is None: - raise ValueError( - "Overlaps must be specified when using tiling (got None)." - ) - patches = extract_tiles( - arr=reshaped_sample, tile_size=patch_size, overlaps=patch_overlap - ) - - elif patch_extraction_method == SupportedExtractionStrategy.SEQUENTIAL: - patches, targets = extract_patches_sequential( - arr=reshaped_sample, patch_size=patch_size, target=reshaped_target - ) - - elif patch_extraction_method == SupportedExtractionStrategy.RANDOM: - # Returns a generator of patches and targets(if present) - patches = extract_patches_random( - arr=reshaped_sample, patch_size=patch_size, target=reshaped_target - ) - - elif patch_extraction_method == SupportedExtractionStrategy.RANDOM_ZARR: - # Returns a generator of patches and targets(if present) - patches = extract_patches_random_from_chunks( - reshaped_sample, - patch_size=patch_size, - chunk_size=reshaped_sample.chunks, - ) - - if patch_extraction_method == SupportedExtractionStrategy.SEQUENTIAL: - return patches, targets - else: - return patches - - else: - # no patching - return (reshaped_sample for _ in range(1)), reshaped_target - - -# iterator over files -def generate_patches_unsupervised( - data: Union[np.ndarray, zarr.Array], - axes: str, - patch_extraction_method: SupportedExtractionStrategy, - patch_size: Union[List[int], Tuple[int, ...]], - patch_overlap: Optional[Union[List[int], Tuple[int]]] = None, -) -> Generator[np.ndarray, None, None]: - """ - Creates an iterator over patches from a sample. - - Parameters - ---------- - sample : np.ndarray - Input array. - patch_extraction_method : SupportedExtractionStrategy - Patch extraction methods (see `config.support`). - patch_size : Optional[Union[List[int], Tuple[int]]] - Size of the patches along each dimension of the array, except the first. - patch_overlap : Optional[Union[List[int], Tuple[int]]] - Overlap between patches. - - Returns - ------- - Generator[np.ndarray, None, None] - Generator yielding patches/tiles. - - Raises - ------ - ValueError - If overlap is not specified when using tiling. - ValueError - If patches is None. - """ - # reshape array - reshaped_sample = reshape_array(data, axes) - - # if tiled (patches with overlaps) - if patch_extraction_method == SupportedExtractionStrategy.TILED: - if patch_overlap is None: - patch_overlap = [48] * len(patch_size) # TODO pass overlap instead - - # return a Generator of the following: - # - patch: np.ndarray, dims C(Z)YX - # - last_tile: bool - # - shape: Tuple[int], shape of a tile, excluding the S dimension - # - overlap_crop_coords: coordinates used to crop the patch during stitching - # - stitch_coords: coordinates used to stitch the tiles back to the full image - patches = extract_tiles( - arr=reshaped_sample, tile_size=patch_size, overlaps=patch_overlap - ) - - # random extraction - elif patch_extraction_method == SupportedExtractionStrategy.RANDOM: - # return a Generator that yields the following: - # - patch: np.ndarray, dimension C(Z)YX - # - target_patch: np.ndarray, dimension C(Z)YX, or None - patches = extract_patches_random(reshaped_sample, patch_size=patch_size) - - # zarr specific random extraction - elif patch_extraction_method == SupportedExtractionStrategy.RANDOM_ZARR: - # # Returns a generator of patches and targets(if present) - # patches = extract_patches_random_from_chunks( - # sample, patch_size=patch_size, chunk_size=sample.chunks - # ) - raise NotImplementedError("Random zarr extraction not implemented yet.") - - # no patching, return sample - elif patch_extraction_method == SupportedExtractionStrategy.NONE: - patches = (reshaped_sample for _ in range(1)) - - # no extraction method - else: - raise ValueError("Invalid patch extraction method.") - - return patches diff --git a/src/careamics/dataset/patching/tiled_patching.py b/src/careamics/dataset/patching/tiled_patching.py index c7c4f69e5..bd618daeb 100644 --- a/src/careamics/dataset/patching/tiled_patching.py +++ b/src/careamics/dataset/patching/tiled_patching.py @@ -3,10 +3,12 @@ import numpy as np +from careamics.config.tile_information import TileInformation + def _compute_crop_and_stitch_coords_1d( axis_size: int, tile_size: int, overlap: int -) -> Tuple[List[Tuple[int, int]], ...]: +) -> Tuple[List[Tuple[int, ...]], ...]: """ Compute the coordinates of each tile along an axis, given the overlap. @@ -21,7 +23,7 @@ def _compute_crop_and_stitch_coords_1d( Returns ------- - Tuple[Tuple[int]] + Tuple[Tuple[int, ...], ...] Tuple of all coordinates for given axis. """ # Compute the step between tiles @@ -75,24 +77,16 @@ def extract_tiles( arr: np.ndarray, tile_size: Union[List[int], Tuple[int, ...]], overlaps: Union[List[int], Tuple[int, ...]], -) -> Generator[ - Tuple[np.ndarray, bool, Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]], - None, - None, -]: +) -> Generator[Tuple[np.ndarray, TileInformation], None, None]: """ Generate tiles from the input array with specified overlap. - The tiles cover the whole array. The method returns a generator that yields a tuple - containing the following: + The tiles cover the whole array. The method returns a generator that yields + tuples of array and tile information, the latter includes whether + the tile is the last one, the coordinates of the overlap crop, and the coordinates + of the stitched tile. - - tile: np.ndarray, shape (C, (Z), Y, X). - - last_tile: bool, whether this is the last tile. - - shape: Tuple[int], shape of a tile, excluding the S dimension. - - overlap_crop_coords: Tuple[int], coordinates used to crop the tile during - stitching. - - stitch_coords: Tuple[int], coordinates used to stitch the tiles back to the full - image. + The array has shape C(Z)YX, where C can be a singleton. Parameters ---------- @@ -105,9 +99,8 @@ def extract_tiles( Yields ------ - Generator[Tuple[np.ndarray, bool, Tuple[int], np.ndarray, np.ndarray], None, None] - Generator of tuple containing the tile, last tile boolean, array shape, - overlap_crop_coords, and stitching coords. + Generator[Tuple[np.ndarray, TileInformation], None, None] + Tile generator, yields the tile and additional information. """ # Iterate over num samples (S) for sample_idx in range(arr.shape[0]): @@ -153,10 +146,13 @@ def extract_tiles( else: last_tile = False - yield ( - tile.astype(np.float32), - last_tile, - arr.shape[1:], # TODO is this used anywhere?? - overlap_crop_coords, - stitch_coords, + # create tile information + tile_info = TileInformation( + array_shape=sample.squeeze().shape, + tiled=True, + last_tile=last_tile, + overlap_crop_coords=overlap_crop_coords, + stitch_coords=stitch_coords, ) + + yield tile, tile_info diff --git a/src/careamics/dataset/zarr_dataset.py b/src/careamics/dataset/zarr_dataset.py index 751d8bdd5..ee54fdd26 100644 --- a/src/careamics/dataset/zarr_dataset.py +++ b/src/careamics/dataset/zarr_dataset.py @@ -1,150 +1,149 @@ -from itertools import islice -from typing import Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import zarr - -from careamics.utils import RunningStats -from careamics.utils.logging import get_logger - -from ..config.support.supported_extraction_strategies import SupportedExtractionStrategy -from ..utils import normalize -from .dataset_utils import read_zarr -from .patching.patching import ( - generate_patches_unsupervised, -) - -logger = get_logger(__name__) - - -class ZarrDataset(torch.utils.data.IterableDataset): - """Dataset to extract patches from a zarr storage. - - Parameters - ---------- - data_source : Union[zarr.Group, zarr.Array] - Zarr storage. - axes : str - Description of axes in format STCZYX. - patch_extraction_method : Union[ExtractionStrategies, None] - Patch extraction strategy, as defined in extraction_strategy. - patch_size : Optional[Union[List[int], Tuple[int]]], optional - Size of the patches in each dimension, by default None. - num_patches : Optional[int], optional - Number of patches to extract, by default None. - mean : Optional[float], optional - Expected mean of the dataset, by default None. - std : Optional[float], optional - Expected standard deviation of the dataset, by default None. - patch_transform : Optional[Callable], optional - Patch transform callable, by default None. - patch_transform_params : Optional[Dict], optional - Patch transform parameters, by default None. - running_stats_window_perc : float, optional - Percentage of the dataset to use for calculating the initial mean and standard - deviation, by default 0.01. - mode : str, optional - train/predict, controls running stats calculation. - """ - - def __init__( - self, - data_source: Union[zarr.Group, zarr.Array], - axes: str, - patch_extraction_method: Union[SupportedExtractionStrategy, None], - patch_size: Optional[Union[List[int], Tuple[int]]] = None, - num_patches: Optional[int] = None, - mean: Optional[float] = None, - std: Optional[float] = None, - patch_transform: Optional[Callable] = None, - patch_transform_params: Optional[Dict] = None, - running_stats_window_perc: float = 0.01, - mode: str = "train", - ) -> None: - self.data_source = data_source - self.axes = axes - self.patch_extraction_method = patch_extraction_method - self.patch_size = patch_size - self.num_patches = num_patches - self.mean = mean - self.std = std - self.patch_transform = patch_transform - self.patch_transform_params = patch_transform_params - self.sample = read_zarr(self.data_source, self.axes) - self.running_stats_window = int( - np.prod(self.sample._cdata_shape) * running_stats_window_perc - ) - self.mode = mode - self.running_stats = RunningStats() - - self._calculate_initial_mean_std() - - def _calculate_initial_mean_std(self): - """Calculate initial mean and std of the dataset.""" - if self.mean is None and self.std is None: - idxs = np.random.randint( - 0, - np.prod(self.sample._cdata_shape), - size=max(1, self.running_stats_window), - ) - random_chunks = self.sample[idxs] - self.running_stats.init(random_chunks.mean(), random_chunks.std()) - - def _generate_patches(self): - """Generate patches from the dataset and calculates running stats. - - Yields - ------ - np.ndarray - Patch. - """ - patches = generate_patches_unsupervised( - self.sample, - self.patch_extraction_method, - self.patch_size, - ) - - # num_patches = np.ceil( - # np.prod(self.sample.chunks) - # / (np.prod(self.patch_size) * self.running_stats_window) - # ).astype(int) - - for idx, patch in enumerate(patches): - if self.mode != "predict": - self.running_stats.update(patch.mean()) - if isinstance(patch, tuple): - normalized_patch = normalize( - img=patch[0], - mean=self.running_stats.avg_mean.value, - std=self.running_stats.avg_std.value, - ) - patch = (normalized_patch, *patch[1:]) - else: - patch = normalize( - img=patch, - mean=self.running_stats.avg_mean.value, - std=self.running_stats.avg_std.value, - ) - - if self.patch_transform is not None: - assert self.patch_transform_params is not None - patch = self.patch_transform(patch, **self.patch_transform_params) - if self.num_patches is not None and idx >= self.num_patches: - return - else: - yield patch - self.mean = self.running_stats.avg_mean.value - self.std = self.running_stats.avg_std.value - - def __iter__(self): - """ - Iterate over data source and yield single patch. - - Yields - ------ - np.ndarray - """ - worker_info = torch.utils.data.get_worker_info() - num_workers = worker_info.num_workers if worker_info is not None else 1 - yield from islice(self._generate_patches(), 0, None, num_workers) +# from itertools import islice +# from typing import Callable, Dict, List, Optional, Tuple, Union + +# import numpy as np +# import torch +# import zarr + +# from careamics.utils import RunningStats +# from careamics.utils.logging import get_logger + +# from ..utils import normalize +# from .dataset_utils import read_zarr +# from .patching.patching import ( +# generate_patches_unsupervised, +# ) + +# logger = get_logger(__name__) + + +# class ZarrDataset(torch.utils.data.IterableDataset): +# """Dataset to extract patches from a zarr storage. + +# Parameters +# ---------- +# data_source : Union[zarr.Group, zarr.Array] +# Zarr storage. +# axes : str +# Description of axes in format STCZYX. +# patch_extraction_method : Union[ExtractionStrategies, None] +# Patch extraction strategy, as defined in extraction_strategy. +# patch_size : Optional[Union[List[int], Tuple[int]]], optional +# Size of the patches in each dimension, by default None. +# num_patches : Optional[int], optional +# Number of patches to extract, by default None. +# mean : Optional[float], optional +# Expected mean of the dataset, by default None. +# std : Optional[float], optional +# Expected standard deviation of the dataset, by default None. +# patch_transform : Optional[Callable], optional +# Patch transform callable, by default None. +# patch_transform_params : Optional[Dict], optional +# Patch transform parameters, by default None. +# running_stats_window_perc : float, optional +# Percentage of the dataset to use for calculating the initial mean and standard +# deviation, by default 0.01. +# mode : str, optional +# train/predict, controls running stats calculation. +# """ + +# def __init__( +# self, +# data_source: Union[zarr.Group, zarr.Array], +# axes: str, +# patch_extraction_method: Union[SupportedExtractionStrategy, None], +# patch_size: Optional[Union[List[int], Tuple[int]]] = None, +# num_patches: Optional[int] = None, +# mean: Optional[float] = None, +# std: Optional[float] = None, +# patch_transform: Optional[Callable] = None, +# patch_transform_params: Optional[Dict] = None, +# running_stats_window_perc: float = 0.01, +# mode: str = "train", +# ) -> None: +# self.data_source = data_source +# self.axes = axes +# self.patch_extraction_method = patch_extraction_method +# self.patch_size = patch_size +# self.num_patches = num_patches +# self.mean = mean +# self.std = std +# self.patch_transform = patch_transform +# self.patch_transform_params = patch_transform_params +# self.sample = read_zarr(self.data_source, self.axes) +# self.running_stats_window = int( +# np.prod(self.sample._cdata_shape) * running_stats_window_perc +# ) +# self.mode = mode +# self.running_stats = RunningStats() + +# self._calculate_initial_mean_std() + +# def _calculate_initial_mean_std(self): +# """Calculate initial mean and std of the dataset.""" +# if self.mean is None and self.std is None: +# idxs = np.random.randint( +# 0, +# np.prod(self.sample._cdata_shape), +# size=max(1, self.running_stats_window), +# ) +# random_chunks = self.sample[idxs] +# self.running_stats.init(random_chunks.mean(), random_chunks.std()) + +# def _generate_patches(self): +# """Generate patches from the dataset and calculates running stats. + +# Yields +# ------ +# np.ndarray +# Patch. +# """ +# patches = generate_patches_unsupervised( +# self.sample, +# self.patch_extraction_method, +# self.patch_size, +# ) + +# # num_patches = np.ceil( +# # np.prod(self.sample.chunks) +# # / (np.prod(self.patch_size) * self.running_stats_window) +# # ).astype(int) + +# for idx, patch in enumerate(patches): +# if self.mode != "predict": +# self.running_stats.update(patch.mean()) +# if isinstance(patch, tuple): +# normalized_patch = normalize( +# img=patch[0], +# mean=self.running_stats.avg_mean.value, +# std=self.running_stats.avg_std.value, +# ) +# patch = (normalized_patch, *patch[1:]) +# else: +# patch = normalize( +# img=patch, +# mean=self.running_stats.avg_mean.value, +# std=self.running_stats.avg_std.value, +# ) + +# if self.patch_transform is not None: +# assert self.patch_transform_params is not None +# patch = self.patch_transform(patch, **self.patch_transform_params) +# if self.num_patches is not None and idx >= self.num_patches: +# return +# else: +# yield patch +# self.mean = self.running_stats.avg_mean.value +# self.std = self.running_stats.avg_std.value + +# def __iter__(self): +# """ +# Iterate over data source and yield single patch. + +# Yields +# ------ +# np.ndarray +# """ +# worker_info = torch.utils.data.get_worker_info() +# num_workers = worker_info.num_workers if worker_info is not None else 1 +# yield from islice(self._generate_patches(), 0, None, num_workers) diff --git a/src/careamics/lightning_datamodule.py b/src/careamics/lightning_datamodule.py index eded7fed4..56f48d47a 100644 --- a/src/careamics/lightning_datamodule.py +++ b/src/careamics/lightning_datamodule.py @@ -6,7 +6,7 @@ from albumentations import Compose from torch.utils.data import DataLoader -from careamics.config import DataModel, InferenceModel +from careamics.config import DataModel from careamics.config.data_model import TRANSFORMS_UNION from careamics.config.support import SupportedData from careamics.dataset.dataset_utils import ( @@ -17,16 +17,13 @@ ) from careamics.dataset.in_memory_dataset import ( InMemoryDataset, - InMemoryPredictionDataset, ) from careamics.dataset.iterable_dataset import ( - IterablePredictionDataset, PathIterableDataset, ) from careamics.utils import get_logger, get_ram_size DatasetType = Union[InMemoryDataset, PathIterableDataset] -PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset] logger = get_logger(__name__) @@ -189,15 +186,14 @@ def prepare_data(self) -> None: Hook used to prepare the data before calling `setup`. Here, we only need to examine the data if it was provided as a str or a Path. - - TODO: from lightning doc: + + TODO: from lightning doc: prepare_data is called from the main process. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you - assign states here then they won’t be available for other processes. + assign states here then they won't be available for other processes. https://lightning.ai/docs/pytorch/stable/data/datamodule.html """ - # if the data is a Path or a str if ( not isinstance(self.train_data, np.ndarray) @@ -357,181 +353,6 @@ def val_dataloader(self) -> Any: ) -class CAREamicsClay(L.LightningDataModule): - """ - LightningDataModule for prediction dataset. - - The data module can be used with Path, str or numpy arrays. The data can be either - a folder containing images or a single file. - - To read custom data types, you can set `data_type` to `custom` in `data_config` - and provide a function that returns a numpy array from a path as - `read_source_func` parameter. The function will receive a Path object and - an axies string as arguments, the axes being derived from the `data_config`. - - You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. - "*.czi") to filter the files extension using `extension_filter`. - - Parameters - ---------- - prediction_config : InferenceModel - Pydantic model for CAREamics prediction configuration. - pred_data : Union[Path, str, np.ndarray] - Prediction data, can be a path to a folder, a file or a numpy array. - read_source_func : Optional[Callable], optional - Function to read custom types, by default None. - extension_filter : str, optional - Filter to filter file extensions for custom types, by default "". - dataloader_params : dict, optional - Dataloader parameters, by default {}. - """ - - def __init__( - self, - prediction_config: InferenceModel, - pred_data: Union[Path, str, np.ndarray], - read_source_func: Optional[Callable] = None, - extension_filter: str = "", - dataloader_params: Optional[dict] = None, - ) -> None: - """ - Constructor. - - The data module can be used with Path, str or numpy arrays. The data can be - either a folder containing images or a single file. - - To read custom data types, you can set `data_type` to `custom` in `data_config` - and provide a function that returns a numpy array from a path as - `read_source_func` parameter. The function will receive a Path object and - an axies string as arguments, the axes being derived from the `data_config`. - - You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. - "*.czi") to filter the files extension using `extension_filter`. - - Parameters - ---------- - prediction_config : InferenceModel - Pydantic model for CAREamics prediction configuration. - pred_data : Union[Path, str, np.ndarray] - Prediction data, can be a path to a folder, a file or a numpy array. - read_source_func : Optional[Callable], optional - Function to read custom types, by default None. - extension_filter : str, optional - Filter to filter file extensions for custom types, by default "". - dataloader_params : dict, optional - Dataloader parameters, by default {}. - - Raises - ------ - ValueError - If the data type is `custom` and no `read_source_func` is provided. - ValueError - If the data type is `array` and the input is not a numpy array. - ValueError - If the data type is `tiff` and the input is neither a Path nor a str. - """ - if dataloader_params is None: - dataloader_params = {} - if dataloader_params is None: - dataloader_params = {} - super().__init__() - - # check that a read source function is provided for custom types - if ( - prediction_config.data_type == SupportedData.CUSTOM - and read_source_func is None - ): - raise ValueError( - f"Data type {SupportedData.CUSTOM} is not allowed without " - f"specifying a `read_source_func`." - ) - - # and that arrays are passed, if array type specified - elif prediction_config.data_type == SupportedData.ARRAY and not isinstance( - pred_data, np.ndarray - ): - raise ValueError( - f"Expected array input (see configuration.data.data_type), but got " - f"{type(pred_data)} instead." - ) - - # and that Path or str are passed, if tiff file type specified - elif prediction_config.data_type == SupportedData.TIFF and not ( - isinstance(pred_data, Path) or isinstance(pred_data, str) - ): - raise ValueError( - f"Expected Path or str input (see configuration.data.data_type), " - f"but got {type(pred_data)} instead." - ) - - # configuration data - self.prediction_config = prediction_config - self.data_type = prediction_config.data_type - self.batch_size = prediction_config.batch_size - self.dataloader_params = dataloader_params - - self.pred_data = pred_data - self.tile_size = prediction_config.tile_size - self.tile_overlap = prediction_config.tile_overlap - - # read source function - if prediction_config.data_type == SupportedData.CUSTOM: - # mypy check - assert read_source_func is not None - - self.read_source_func: Callable = read_source_func - elif prediction_config.data_type != SupportedData.ARRAY: - self.read_source_func = get_read_func(prediction_config.data_type) - - self.extension_filter = extension_filter - - def prepare_data(self) -> None: - """Hook used to prepare the data before calling `setup`.""" - # if the data is a Path or a str - if not isinstance(self.pred_data, np.ndarray): - self.pred_files = list_files( - self.pred_data, self.data_type, self.extension_filter - ) - - def setup(self, stage: Optional[str] = None) -> None: - """ - Hook called at the beginning of predict. - - Parameters - ---------- - stage : Optional[str], optional - Stage, by default None. - """ - # if numpy array - if self.data_type == SupportedData.ARRAY: - # prediction dataset - self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset( - prediction_config=self.prediction_config, - inputs=self.pred_data, - ) - else: - self.predict_dataset = IterablePredictionDataset( - prediction_config=self.prediction_config, - src_files=self.pred_files, - read_source_func=self.read_source_func, - ) - - def predict_dataloader(self) -> DataLoader: - """ - Create a dataloader for prediction. - - Returns - ------- - DataLoader - Prediction dataloader. - """ - return DataLoader( - self.predict_dataset, - batch_size=self.batch_size, - **self.dataloader_params, - ) # TODO check workers are used - - class CAREamicsTrainDataModule(CAREamicsWood): """ LightningDataModule wrapper for training and validation datasets. @@ -624,11 +445,11 @@ class CAREamicsTrainDataModule(CAREamicsWood): Create a CAREamicsTrainDataModule with default transforms with a numpy array: >>> import numpy as np >>> from careamics import CAREamicsTrainDataModule - >>> my_array = np.arange(100).reshape(10, 10) + >>> my_array = np.arange(256).reshape(16, 16) >>> data_module = CAREamicsTrainDataModule( ... train_data=my_array, ... data_type="array", - ... patch_size=(2, 2), + ... patch_size=(8, 8), ... axes='YX', ... batch_size=2, ... ) @@ -637,12 +458,14 @@ class CAREamicsTrainDataModule(CAREamicsWood): function and a filter for the files extension: >>> import numpy as np >>> from careamics import CAREamicsTrainDataModule + >>> >>> def read_npy(path): ... return np.load(path) + >>> >>> data_module = CAREamicsTrainDataModule( ... train_data="path/to/data", ... data_type="custom", - ... patch_size=(2, 2), + ... patch_size=(8, 8), ... axes='YX', ... batch_size=2, ... read_source_func=read_npy, @@ -654,11 +477,12 @@ class CAREamicsTrainDataModule(CAREamicsWood): >>> import numpy as np >>> from careamics import CAREamicsTrainDataModule >>> from careamics.config.support import SupportedTransform - >>> my_array = np.arange(100).reshape(10, 10) + >>> my_array = np.arange(256).reshape(16, 16) >>> my_transforms = [ ... { ... "name": SupportedTransform.NORMALIZE.value, - ... "parameters": {"mean": 0, "std": 1}, + ... "mean": 0, + ... "std": 1, ... }, ... { ... "name": SupportedTransform.N2V_MANIPULATE.value, @@ -667,7 +491,7 @@ class CAREamicsTrainDataModule(CAREamicsWood): >>> data_module = CAREamicsTrainDataModule( ... train_data=my_array, ... data_type="array", - ... patch_size=(2, 2), + ... patch_size=(8, 8), ... axes='YX', ... batch_size=2, ... transforms=my_transforms, @@ -839,155 +663,3 @@ def __init__( val_minimum_split=val_minimum_patches, use_in_memory=use_in_memory, ) - - -class CAREamicsPredictDataModule(CAREamicsClay): - """ - LightningDataModule wrapper of an inference dataset. - - Since the lightning datamodule has no access to the model, make sure that the - parameters passed to the datamodule are consistent with the model's requirements - and are coherent. - - The data module can be used with Path, str or numpy arrays. To use array data, set - `data_type` to `array` and pass a numpy array to `train_data`. - - The default transformations applied to the images are defined in - `careamics.config.inference_model`. To use different transformations, pass a list - of transforms or an albumentation `Compose` as `transforms` parameter. See examples - for more details. - - The `mean` and `std` parameters are only used if Normalization is defined either - in the default transformations or in the `transforms` parameter, but not with - a `Compose` object. If you pass a `Normalization` transform in a list as - `transforms`, then the mean and std parameters will be overwritten by those passed - to this method. - - By default, CAREamics only supports types defined in - `careamics.config.support.SupportedData`. To read custom data types, you can set - `data_type` to `custom` and provide a function that returns a numpy array from a - path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression - (e.g. "*.jpeg") to filter the files extension using `extension_filter`. - - In `dataloader_params`, you can pass any parameter accepted by PyTorch - dataloaders, except for `batch_size`, which is set by the `batch_size` - parameter. - - Parameters - ---------- - pred_data : Union[str, Path, np.ndarray] - Prediction data. - data_type : Union[Literal["array", "tiff", "custom"], SupportedData] - Data type, see `SupportedData` for available options. - tile_size : List[int] - Tile size, 2D or 3D tile size. - tile_overlap : List[int] - Tile overlap, 2D or 3D tile overlap. - axes : str - Axes of the data, choosen amongst SCZYX. - batch_size : int - Batch size. - tta_transforms : bool, optional - Use test time augmentation, by default True. - mean : Optional[float], optional - Mean value for normalization, only used if Normalization is defined, by - default None. - std : Optional[float], optional - Standard deviation value for normalization, only used if Normalization is - defined, by default None. - transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional - List of transforms to apply to prediction patches. If None, default - transforms are applied. - read_source_func : Optional[Callable], optional - Function to read the source data, used if `data_type` is `custom`, by - default None. - extension_filter : str, optional - Filter for file extensions, used if `data_type` is `custom`, by default "". - dataloader_params : dict, optional - Pytorch dataloader parameters, by default {}. - """ - - def __init__( - self, - pred_data: Union[str, Path, np.ndarray], - data_type: Union[Literal["array", "tiff", "custom"], SupportedData], - tile_size: List[int], - tile_overlap: List[int] = (48, 48), # TODO replace with calculator - axes: str = "YX", - batch_size: int = 1, - tta_transforms: bool = True, - mean: Optional[float] = None, - std: Optional[float] = None, - transforms: Optional[Union[List, Compose]] = None, - read_source_func: Optional[Callable] = None, - extension_filter: str = "", - dataloader_params: Optional[dict] = None, - ) -> None: - """ - Constructor. - - Parameters - ---------- - pred_data : Union[str, Path, np.ndarray] - Prediction data. - data_type : Union[Literal["array", "tiff", "custom"], SupportedData] - Data type, see `SupportedData` for available options. - tile_size : List[int] - Tile size, 2D or 3D tile size. - tile_overlap : List[int] - Tile overlap, 2D or 3D tile overlap. - axes : str - Axes of the data, choosen amongst SCZYX. - batch_size : int - Batch size. - tta_transforms : bool, optional - Use test time augmentation, by default True. - mean : Optional[float], optional - Mean value for normalization, only used if Normalization is defined, by - default None. - std : Optional[float], optional - Standard deviation value for normalization, only used if Normalization is - defined, by default None. - transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional - List of transforms to apply to prediction patches. If None, default - transforms are applied. - read_source_func : Optional[Callable], optional - Function to read the source data, used if `data_type` is `custom`, by - default None. - extension_filter : str, optional - Filter for file extensions, used if `data_type` is `custom`, by default "". - dataloader_params : dict, optional - Pytorch dataloader parameters, by default {}. - """ - if dataloader_params is None: - dataloader_params = {} - prediction_dict = { - "data_type": data_type, - "tile_size": tile_size, - "tile_overlap": tile_overlap, - "axes": axes, - "mean": mean, - "std": std, - "tta": tta_transforms, - "batch_size": batch_size, - } - - # if transforms are passed (otherwise it will use the default ones) - if transforms is not None: - prediction_dict["transforms"] = transforms - - # validate configuration - self.prediction_config = InferenceModel(**prediction_dict) - - # sanity check on the dataloader parameters - if "batch_size" in dataloader_params: - # remove it - del dataloader_params["batch_size"] - - super().__init__( - prediction_config=self.prediction_config, - pred_data=pred_data, - read_source_func=read_source_func, - extension_filter=extension_filter, - dataloader_params=dataloader_params, - ) diff --git a/src/careamics/lightning_module.py b/src/careamics/lightning_module.py index e353265dc..6e3c4f522 100644 --- a/src/careamics/lightning_module.py +++ b/src/careamics/lightning_module.py @@ -13,8 +13,7 @@ ) from careamics.losses import loss_factory from careamics.models.model_factory import model_factory -from careamics.transforms import ImageRestorationTTA -from careamics.utils import denormalize +from careamics.transforms import Denormalize, ImageRestorationTTA from careamics.utils.torch_utils import get_optimizer, get_scheduler @@ -37,11 +36,6 @@ class CAREamicsKiln(L.LightningModule): Optimizer parameters. lr_scheduler_name : str Learning rate scheduler name. - - Parameters - ---------- - algorithm_config : Union[AlgorithmModel, dict] - Algorithm configuration. """ def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None: @@ -71,8 +65,6 @@ def __init__(self, algorithm_config: Union[AlgorithmModel, dict]) -> None: self.lr_scheduler_name = algorithm_config.lr_scheduler.name self.lr_scheduler_params = algorithm_config.lr_scheduler.parameters - # self.save_hyperparameters(algorithm_config.model_dump()) - def forward(self, x: Any) -> Any: """Forward pass. @@ -166,13 +158,16 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: output = self.model(x) # Denormalize the output - # TODO replace with Albu class - denormalized_output = denormalize( - output, - self._trainer.datamodule.predict_dataset.mean, - self._trainer.datamodule.predict_dataset.std, + denorm = Denormalize( + mean=self._trainer.datamodule.predict_dataset.mean, + std=self._trainer.datamodule.predict_dataset.std, ) - return denormalized_output, aux + denormalized_output = denorm(image=output)["image"] + + if len(aux) > 0: + return denormalized_output, aux + else: + return denormalized_output def configure_optimizers(self) -> Any: """Configure optimizers and learning rate schedulers. @@ -239,6 +234,33 @@ def __init__( lr_scheduler: Union[SupportedScheduler, str] = "ReduceLROnPlateau", lr_scheduler_parameters: Optional[dict] = None, ) -> None: + """ + Wrapper for the CAREamics model, exposing all algorithm configuration arguments. + + Parameters + ---------- + algorithm : Union[SupportedAlgorithm, str] + Algorithm to use for training (see SupportedAlgorithm). + loss : Union[SupportedLoss, str] + Loss function to use for training (see SupportedLoss). + architecture : Union[SupportedArchitecture, str] + Model architecture to use for training (see SupportedArchitecture). + model_parameters : dict, optional + Model parameters to use for training, by default {}. Model parameters are + defined in the relevant `torch.nn.Module` class, or Pyddantic model (see + `careamics.config.architectures`). + optimizer : Union[SupportedOptimizer, str], optional + Optimizer to use for training, by default "Adam" (see SupportedOptimizer). + optimizer_parameters : dict, optional + Optimizer parameters to use for training, as defined in `torch.optim`, by + default {}. + lr_scheduler : Union[SupportedScheduler, str], optional + Learning rate scheduler to use for training, by default "ReduceLROnPlateau" + (see SupportedScheduler). + lr_scheduler_parameters : dict, optional + Learning rate scheduler parameters to use for training, as defined in + `torch.optim`, by default {}. + """ # create a AlgorithmModel compatible dictionary if lr_scheduler_parameters is None: lr_scheduler_parameters = {} @@ -263,7 +285,7 @@ def __init__( # add model parameters to algorithm configuration algorithm_configuration["model"] = model_configuration - # self.save_hyperparameters({**model_configuration, **algorithm_configuration}) + # call the parent init using an AlgorithmModel instance super().__init__(AlgorithmModel(**algorithm_configuration)) diff --git a/src/careamics/lightning_prediction_datamodule.py b/src/careamics/lightning_prediction_datamodule.py new file mode 100644 index 000000000..654d44fdb --- /dev/null +++ b/src/careamics/lightning_prediction_datamodule.py @@ -0,0 +1,390 @@ +from pathlib import Path +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +import numpy as np +import pytorch_lightning as L +from albumentations import Compose +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate + +from careamics.config import InferenceModel +from careamics.config.support import SupportedData +from careamics.config.tile_information import TileInformation +from careamics.dataset.dataset_utils import ( + get_read_func, + list_files, +) +from careamics.dataset.in_memory_dataset import ( + InMemoryPredictionDataset, +) +from careamics.dataset.iterable_dataset import ( + IterablePredictionDataset, +) +from careamics.utils import get_logger + +PredictDatasetType = Union[InMemoryPredictionDataset, IterablePredictionDataset] + +logger = get_logger(__name__) + + +def _collate_tiles(batch: List[Tuple[np.ndarray, TileInformation]]) -> Any: + """ + Collate tiles received from CAREamics prediction dataloader. + + CAREamics prediction dataloader returns tuples of arrays and TileInformation. In + case of non-tiled data, this function will return the arrays. In case of tiled data, + it will return the arrays, the last tile flag, the overlap crop coordinates and the + stitch coordinates. + + Parameters + ---------- + batch : Tuple[Tuple[np.ndarray, TileInformation], ...] + Batch of tiles. + + Returns + ------- + Any + Collated batch. + """ + first_tile_info: TileInformation = batch[0][1] + # if not tiled, then return arrays + if not first_tile_info.tiled: + arrays, _ = zip(*batch) + + return default_collate(arrays) + # else we explicit the last_tile flag and coordinates + else: + new_batch = [ + (tile, t.last_tile, t.array_shape, t.overlap_crop_coords, t.stitch_coords) + for tile, t in batch + ] + + return default_collate(new_batch) + + +class CAREamicsClay(L.LightningDataModule): + """ + LightningDataModule for prediction dataset. + + The data module can be used with Path, str or numpy arrays. The data can be either + a folder containing images or a single file. + + To read custom data types, you can set `data_type` to `custom` in `data_config` + and provide a function that returns a numpy array from a path as + `read_source_func` parameter. The function will receive a Path object and + an axies string as arguments, the axes being derived from the `data_config`. + + You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. + "*.czi") to filter the files extension using `extension_filter`. + + Parameters + ---------- + prediction_config : InferenceModel + Pydantic model for CAREamics prediction configuration. + pred_data : Union[Path, str, np.ndarray] + Prediction data, can be a path to a folder, a file or a numpy array. + read_source_func : Optional[Callable], optional + Function to read custom types, by default None. + extension_filter : str, optional + Filter to filter file extensions for custom types, by default "". + dataloader_params : dict, optional + Dataloader parameters, by default {}. + """ + + def __init__( + self, + prediction_config: InferenceModel, + pred_data: Union[Path, str, np.ndarray], + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + dataloader_params: Optional[dict] = None, + ) -> None: + """ + Constructor. + + The data module can be used with Path, str or numpy arrays. The data can be + either a folder containing images or a single file. + + To read custom data types, you can set `data_type` to `custom` in `data_config` + and provide a function that returns a numpy array from a path as + `read_source_func` parameter. The function will receive a Path object and + an axies string as arguments, the axes being derived from the `data_config`. + + You can also provide a `fnmatch` and `Path.rglob` compatible expression (e.g. + "*.czi") to filter the files extension using `extension_filter`. + + Parameters + ---------- + prediction_config : InferenceModel + Pydantic model for CAREamics prediction configuration. + pred_data : Union[Path, str, np.ndarray] + Prediction data, can be a path to a folder, a file or a numpy array. + read_source_func : Optional[Callable], optional + Function to read custom types, by default None. + extension_filter : str, optional + Filter to filter file extensions for custom types, by default "". + dataloader_params : dict, optional + Dataloader parameters, by default {}. + + Raises + ------ + ValueError + If the data type is `custom` and no `read_source_func` is provided. + ValueError + If the data type is `array` and the input is not a numpy array. + ValueError + If the data type is `tiff` and the input is neither a Path nor a str. + """ + if dataloader_params is None: + dataloader_params = {} + if dataloader_params is None: + dataloader_params = {} + super().__init__() + + # check that a read source function is provided for custom types + if ( + prediction_config.data_type == SupportedData.CUSTOM + and read_source_func is None + ): + raise ValueError( + f"Data type {SupportedData.CUSTOM} is not allowed without " + f"specifying a `read_source_func`." + ) + + # and that arrays are passed, if array type specified + elif prediction_config.data_type == SupportedData.ARRAY and not isinstance( + pred_data, np.ndarray + ): + raise ValueError( + f"Expected array input (see configuration.data.data_type), but got " + f"{type(pred_data)} instead." + ) + + # and that Path or str are passed, if tiff file type specified + elif prediction_config.data_type == SupportedData.TIFF and not ( + isinstance(pred_data, Path) or isinstance(pred_data, str) + ): + raise ValueError( + f"Expected Path or str input (see configuration.data.data_type), " + f"but got {type(pred_data)} instead." + ) + + # configuration data + self.prediction_config = prediction_config + self.data_type = prediction_config.data_type + self.batch_size = prediction_config.batch_size + self.dataloader_params = dataloader_params + + self.pred_data = pred_data + self.tile_size = prediction_config.tile_size + self.tile_overlap = prediction_config.tile_overlap + + # read source function + if prediction_config.data_type == SupportedData.CUSTOM: + # mypy check + assert read_source_func is not None + + self.read_source_func: Callable = read_source_func + elif prediction_config.data_type != SupportedData.ARRAY: + self.read_source_func = get_read_func(prediction_config.data_type) + + self.extension_filter = extension_filter + + def prepare_data(self) -> None: + """Hook used to prepare the data before calling `setup`.""" + # if the data is a Path or a str + if not isinstance(self.pred_data, np.ndarray): + self.pred_files = list_files( + self.pred_data, self.data_type, self.extension_filter + ) + + def setup(self, stage: Optional[str] = None) -> None: + """ + Hook called at the beginning of predict. + + Parameters + ---------- + stage : Optional[str], optional + Stage, by default None. + """ + # if numpy array + if self.data_type == SupportedData.ARRAY: + # prediction dataset + self.predict_dataset: PredictDatasetType = InMemoryPredictionDataset( + prediction_config=self.prediction_config, + inputs=self.pred_data, + ) + else: + self.predict_dataset = IterablePredictionDataset( + prediction_config=self.prediction_config, + src_files=self.pred_files, + read_source_func=self.read_source_func, + ) + + def predict_dataloader(self) -> DataLoader: + """ + Create a dataloader for prediction. + + Returns + ------- + DataLoader + Prediction dataloader. + """ + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + collate_fn=_collate_tiles, + **self.dataloader_params, + ) # TODO check workers are used + + +class CAREamicsPredictDataModule(CAREamicsClay): + """ + LightningDataModule wrapper of an inference dataset. + + Since the lightning datamodule has no access to the model, make sure that the + parameters passed to the datamodule are consistent with the model's requirements + and are coherent. + + The data module can be used with Path, str or numpy arrays. To use array data, set + `data_type` to `array` and pass a numpy array to `train_data`. + + The default transformations applied to the images are defined in + `careamics.config.inference_model`. To use different transformations, pass a list + of transforms or an albumentation `Compose` as `transforms` parameter. See examples + for more details. + + The `mean` and `std` parameters are only used if Normalization is defined either + in the default transformations or in the `transforms` parameter, but not with + a `Compose` object. If you pass a `Normalization` transform in a list as + `transforms`, then the mean and std parameters will be overwritten by those passed + to this method. + + By default, CAREamics only supports types defined in + `careamics.config.support.SupportedData`. To read custom data types, you can set + `data_type` to `custom` and provide a function that returns a numpy array from a + path. Additionally, pass a `fnmatch` and `Path.rglob` compatible expression + (e.g. "*.jpeg") to filter the files extension using `extension_filter`. + + In `dataloader_params`, you can pass any parameter accepted by PyTorch + dataloaders, except for `batch_size`, which is set by the `batch_size` + parameter. + + Parameters + ---------- + pred_data : Union[str, Path, np.ndarray] + Prediction data. + data_type : Union[Literal["array", "tiff", "custom"], SupportedData] + Data type, see `SupportedData` for available options. + mean : float + Mean value for normalization, only used if Normalization is defined in the + transforms. + std : float + Standard deviation value for normalization, only used if Normalization is + defined in the transform. + tile_size : Tuple[int, ...] + Tile size, 2D or 3D tile size. + tile_overlap : Tuple[int, ...] + Tile overlap, 2D or 3D tile overlap. + axes : str + Axes of the data, choosen amongst SCZYX. + batch_size : int + Batch size. + tta_transforms : bool, optional + Use test time augmentation, by default True. + transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional + List of transforms to apply to prediction patches. If None, default + transforms are applied. + read_source_func : Optional[Callable], optional + Function to read the source data, used if `data_type` is `custom`, by + default None. + extension_filter : str, optional + Filter for file extensions, used if `data_type` is `custom`, by default "". + dataloader_params : dict, optional + Pytorch dataloader parameters, by default {}. + """ + + def __init__( + self, + pred_data: Union[str, Path, np.ndarray], + data_type: Union[Literal["array", "tiff", "custom"], SupportedData], + mean: float, + std: float, + tile_size: Optional[Tuple[int, ...]] = None, + tile_overlap: Optional[Tuple[int, ...]] = None, + axes: str = "YX", + batch_size: int = 1, + tta_transforms: bool = True, + transforms: Optional[Union[List, Compose]] = None, + read_source_func: Optional[Callable] = None, + extension_filter: str = "", + dataloader_params: Optional[dict] = None, + ) -> None: + """ + Constructor. + + Parameters + ---------- + pred_data : Union[str, Path, np.ndarray] + Prediction data. + data_type : Union[Literal["array", "tiff", "custom"], SupportedData] + Data type, see `SupportedData` for available options. + tile_size : List[int] + Tile size, 2D or 3D tile size. + tile_overlap : List[int] + Tile overlap, 2D or 3D tile overlap. + axes : str + Axes of the data, choosen amongst SCZYX. + batch_size : int + Batch size. + tta_transforms : bool, optional + Use test time augmentation, by default True. + mean : Optional[float], optional + Mean value for normalization, only used if Normalization is defined, by + default None. + std : Optional[float], optional + Standard deviation value for normalization, only used if Normalization is + defined, by default None. + transforms : Optional[Union[List[TRANSFORMS_UNION], Compose]], optional + List of transforms to apply to prediction patches. If None, default + transforms are applied. + read_source_func : Optional[Callable], optional + Function to read the source data, used if `data_type` is `custom`, by + default None. + extension_filter : str, optional + Filter for file extensions, used if `data_type` is `custom`, by default "". + dataloader_params : dict, optional + Pytorch dataloader parameters, by default {}. + """ + if dataloader_params is None: + dataloader_params = {} + prediction_dict = { + "data_type": data_type, + "tile_size": tile_size, + "tile_overlap": tile_overlap, + "axes": axes, + "mean": mean, + "std": std, + "tta": tta_transforms, + "batch_size": batch_size, + } + + # if transforms are passed (otherwise it will use the default ones) + if transforms is not None: + prediction_dict["transforms"] = transforms + + # validate configuration + self.prediction_config = InferenceModel(**prediction_dict) + + # sanity check on the dataloader parameters + if "batch_size" in dataloader_params: + # remove it + del dataloader_params["batch_size"] + + super().__init__( + prediction_config=self.prediction_config, + pred_data=pred_data, + read_source_func=read_source_func, + extension_filter=extension_filter, + dataloader_params=dataloader_params, + ) diff --git a/src/careamics/lightning_prediction.py b/src/careamics/lightning_prediction_loop.py similarity index 52% rename from src/careamics/lightning_prediction.py rename to src/careamics/lightning_prediction_loop.py index 7040d7bf5..c7e00fd2e 100644 --- a/src/careamics/lightning_prediction.py +++ b/src/careamics/lightning_prediction_loop.py @@ -10,33 +10,45 @@ class CAREamicsPredictionLoop(L.loops._PredictionLoop): - """Predict loop for tiles-based prediction.""" + """ + CAREamics prediction loop. - # def _predict_step(self, batch, batch_idx, dataloader_idx, dataloader_iter): - # self.model.predict_step(batch, batch_idx) + This class extends the PyTorch Lightning `_PredictionLoop` class to include + the stitching of the tiles into a single prediction result. + """ def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: - """Calls ``on_predict_epoch_end`` hook. + """ + Calls `on_predict_epoch_end` hook. + + Adapted from the parent method. Returns ------- the results for all dataloaders - """ trainer = self.trainer call._call_callback_hooks(trainer, "on_predict_epoch_end") call._call_lightning_module_hook(trainer, "on_predict_epoch_end") if self.return_predictions: + ######################################################## + ################ CAREamics specific code ############### if len(self.predicted_array) == 1: - return self.predicted_array[0] + # TODO does this make sense to here? (force numpy array) + return self.predicted_array[0].numpy() else: - return self.predicted_array # TODO revisit logic + # TODO revisit logic + return [element.numpy() for element in self.predicted_array] + ######################################################## return None @_no_grad_context def run(self) -> Optional[_PREDICT_OUTPUT]: - """Runs the prediction loop. + """ + Runs the prediction loop. + + Adapted from the parent method in order to stitch the predictions. Returns ------- @@ -72,17 +84,29 @@ def run(self) -> Optional[_PREDICT_OUTPUT]: # run step hooks self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter) - # Stitching tiles together - last_tile, *data = self.predictions[batch_idx][1] - self.tiles.append(self.predictions[batch_idx][0]) - self.stitching_data.append(data) - if any(last_tile): - predicted_batches = stitch_prediction( - self.tiles, self.stitching_data - ) - self.predicted_array.append(predicted_batches) - self.tiles.clear() - self.stitching_data.clear() + ######################################################## + ################ CAREamics specific code ############### + is_tiled = len(self.predictions[batch_idx]) == 2 + if is_tiled: + # extract the last tile flag and the coordinates (crop and stitch) + last_tile, *stitch_data = self.predictions[batch_idx][1] + + # append the tile and the coordinates to the lists + self.tiles.append(self.predictions[batch_idx][0]) + self.stitching_data.append(stitch_data) + + # if last tile, stitch the tiles and add array to the prediction + if any(last_tile): + predicted_batches = stitch_prediction( + self.tiles, self.stitching_data + ) + self.predicted_array.append(predicted_batches) + self.tiles.clear() + self.stitching_data.clear() + else: + # simply add the prediction to the list + self.predicted_array.append(self.predictions[batch_idx]) + ######################################################## except StopIteration: break finally: diff --git a/src/careamics/model_io/__init__.py b/src/careamics/model_io/__init__.py index 32af8bbc7..0e99771f4 100644 --- a/src/careamics/model_io/__init__.py +++ b/src/careamics/model_io/__init__.py @@ -1,7 +1,8 @@ """Model I/O utilities.""" -__all__ = ["load_pretrained"] +__all__ = ["load_pretrained", "export_to_bmz"] +from .bmz_io import export_to_bmz from .model_io_utils import load_pretrained diff --git a/src/careamics/model_io/bioimage/__init__.py b/src/careamics/model_io/bioimage/__init__.py index f469548f1..f312bc7eb 100644 --- a/src/careamics/model_io/bioimage/__init__.py +++ b/src/careamics/model_io/bioimage/__init__.py @@ -1 +1,11 @@ """Bioimage Model Zoo format functions.""" + +__all__ = [ + "create_model_description", + "extract_model_path", + "get_unzip_path", + "create_env_text", +] + +from .bioimage_utils import create_env_text, get_unzip_path +from .model_description import create_model_description, extract_model_path diff --git a/src/careamics/model_io/bioimage/readme_factory.py b/src/careamics/model_io/bioimage/_readme_factory.py similarity index 74% rename from src/careamics/model_io/bioimage/readme_factory.py rename to src/careamics/model_io/bioimage/_readme_factory.py index 146a10ffb..e823f3781 100644 --- a/src/careamics/model_io/bioimage/readme_factory.py +++ b/src/careamics/model_io/bioimage/_readme_factory.py @@ -1,17 +1,12 @@ +"""Functions used to create a README.md file for BMZ export.""" from pathlib import Path from typing import Optional -import pkg_resources -import torch import yaml -from bioimageio.spec.model.v0_5 import Version from careamics.config import Configuration -from careamics.config.support import SupportedAlgorithm from careamics.utils import cwd, get_careamics_home -pytorch_version = Version(torch.__version__) - def _yaml_block(yaml_str: str) -> str: """Return a markdown code block with a yaml string. @@ -19,48 +14,46 @@ def _yaml_block(yaml_str: str) -> str: Parameters ---------- yaml_str : str - YAML string + YAML string. Returns ------- str - Markdown code block with the YAML string + Markdown code block with the YAML string. """ return f"```yaml\n{yaml_str}\n```" def readme_factory( config: Configuration, + careamics_version: str, data_description: Optional[str] = None, - custom_description: Optional[str] = None, ) -> Path: """Create a README file for the model. `data_description` can be used to add more information about the content of the data the model was trained on. - `custom_description` can be used to add a custom description of the algorithm, only - used when the algorithm is set to `custom` in the configuration. - Parameters ---------- config : Configuration - CAREamics configuration + CAREamics configuration. + careamics_version : str + CAREamics version. data_description : Optional[str], optional - Description of the data, by default None - custom_description : Optional[str], optional - Description of custom algorithm, by default None + Description of the data, by default None. Returns ------- Path - Path to the README file + Path to the README file. """ algorithm = config.algorithm_config training = config.training_config data = config.data_config # create file + # TODO use tempfile as in the bmz_io module with cwd(get_careamics_home()): readme = Path("README.md") readme.touch() @@ -73,17 +66,10 @@ def readme_factory( # algorithm description description.append("Algorithm description:\n\n") - if ( - algorithm.algorithm == SupportedAlgorithm.CUSTOM - and custom_description is not None - ): - description.append(custom_description) - else: - description.append(config.get_algorithm_description()) + description.append(config.get_algorithm_description()) description.append("\n\n") # algorithm details - careamics_version = pkg_resources.get_distribution("careamics").version description.append( f"{algorithm_flavour} was trained using CAREamics (version " f"{careamics_version}) with the following algorithm " diff --git a/src/careamics/model_io/bioimage/bioimage_utils.py b/src/careamics/model_io/bioimage/bioimage_utils.py new file mode 100644 index 000000000..1ce28bfcd --- /dev/null +++ b/src/careamics/model_io/bioimage/bioimage_utils.py @@ -0,0 +1,48 @@ +"""Bioimage.io utils.""" +from pathlib import Path +from typing import Union + + +def get_unzip_path(zip_path: Union[Path, str]) -> Path: + """Generate unzipped folder path from the bioimage.io model path. + + Parameters + ---------- + zip_path : Path + Path to the bioimage.io model. + + Returns + ------- + Path + Path to the unzipped folder. + """ + zip_path = Path(zip_path) + + return zip_path.parent / (str(zip_path.name) + ".unzip") + + +def create_env_text(pytorch_version: str) -> str: + """Create environment text for the bioimage model. + + Parameters + ---------- + pytorch_version : str + Pytorch version. + + Returns + ------- + str + Environment text. + """ + env = ( + f"name: careamics\n" + f"dependencies:\n" + f" - python=3.8\n" + f" - pytorch={pytorch_version}\n" + f" - torchvision={pytorch_version}\n" + f" - pip\n" + f" - pip:\n" + f" - git+https://github.com/CAREamics/careamics.git@dl4mia\n" + ) + # TODO from pip with package version + return env diff --git a/src/careamics/model_io/bioimage/io.py b/src/careamics/model_io/bioimage/io.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index dbb7c9a25..1901b9550 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -1,56 +1,70 @@ +"""Module use to build BMZ model description.""" from pathlib import Path from typing import List, Optional, Tuple, Union +import numpy as np from bioimageio.spec.model.v0_5 import ( + ArchitectureFromLibraryDescr, Author, AxisBase, AxisId, BatchAxis, ChannelAxis, + EnvironmentFileDescr, FileDescr, + FixedZeroMeanUnitVarianceDescr, + FixedZeroMeanUnitVarianceKwargs, Identifier, InputTensorDescr, ModelDescr, OutputTensorDescr, - ParameterizedSize, + PytorchStateDictWeightsDescr, SpaceInputAxis, SpaceOutputAxis, TensorId, + Version, + WeightsDescr, ) -from careamics import CAREamist -from careamics.config import DataModel, save_configuration -from careamics.utils import cwd, get_careamics_home +from careamics.config import Configuration, DataModel -from .readme_factory import readme_factory +from ._readme_factory import readme_factory def _create_axes( + array: np.ndarray, data_config: DataModel, - is_input: bool = True, channel_names: Optional[List[str]] = None, + is_input: bool = True, ) -> List[AxisBase]: """Create axes description. + Array shape is expected to be SC(Z)YX. + Parameters ---------- - config : DataModel - CAREamics data configuration - is_input : bool, optional - Whether the axes are input axes, by default True + array : np.ndarray + Array. + data_config : DataModel + CAREamics data configuration. channel_names : Optional[List[str]], optional - Channel names, by default None + Channel names, by default None. + is_input : bool, optional + Whether the axes are input axes, by default True. Returns ------- List[AxisBase] - List of axes description + List of axes description. Raises ------ ValueError - If channel names are not provided when channel axis is present + If channel names are not provided when channel axis is present. """ + # axes have to be SC(Z)YX + spatial_axes = data_config.axes.replace("S", "").replace("C", "") + # batch is always present axes_model = [BatchAxis()] @@ -65,34 +79,31 @@ def _create_axes( f"{data_config.axes}." ) else: - axes_model.append(ChannelAxis(channel_names=[Identifier("raw")])) + # singleton channel + axes_model.append(ChannelAxis(channel_names=[Identifier("channel")])) # spatial axes - for axes in data_config.axes: + for ind, axes in enumerate(spatial_axes): if axes in ["X", "Y", "Z"]: if is_input: axes_model.append( - SpaceInputAxis( - id=AxisId(axes.lower()), - size=ParameterizedSize( - min=16, step=8 - ), # TODO check the min/step - ) + SpaceInputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind]) ) else: axes_model.append( - SpaceOutputAxis( - id=AxisId(axes.lower()), size=ParameterizedSize(min=16, step=8) - ) + SpaceOutputAxis(id=AxisId(axes.lower()), size=array.shape[2 + ind]) ) return axes_model def _create_inputs_ouputs( + input_array: np.ndarray, + output_array: np.ndarray, data_config: DataModel, input_path: Union[Path, str], output_path: Union[Path, str], + channel_names: Optional[List[str]] = None, ) -> Tuple[InputTensorDescr, OutputTensorDescr]: """Create input and output tensor description. @@ -100,47 +111,89 @@ def _create_inputs_ouputs( Parameters ---------- + input_array : np.ndarray + Input array. + output_array : np.ndarray + Output array. data_config : DataModel - CAREamics data configuration + CAREamics data configuration. input_path : Union[Path, str] - Path to input .npy file + Path to input .npy file. output_path : Union[Path, str] - Path to output .npy file + Path to output .npy file. + channel_names : Optional[List[str]], optional + Channel names, by default None. Returns ------- Tuple[InputTensorDescr, OutputTensorDescr] - Input and output tensor descriptions + Input and output tensor descriptions. """ - input_axes = _create_axes(data_config) - output_axes = _create_axes(data_config) + input_axes = _create_axes(input_array, data_config, channel_names) + output_axes = _create_axes(output_array, data_config, channel_names, False) + + # mean and std + assert data_config.mean is not None, "Mean cannot be None." + assert data_config.std is not None, "Std cannot be None." + mean = data_config.mean + std = data_config.std + + # and the mean and std required to invert the normalization + # CAREamics denormalization: x = y * (std + eps) + mean + # BMZ normalization : x = (y - mean') / (std' + eps) + # to apply the BMZ normalization as a denormalization step, we need: + eps = 1e-6 + inv_mean = -mean / (std + eps) + inv_std = 1 / (std + eps) - eps + + # create input/output descriptions input_descr = InputTensorDescr( - id=TensorId("raw"), axes=input_axes, test_tensor=FileDescr(source=input_path) + id=TensorId("input"), + axes=input_axes, + test_tensor=FileDescr(source=input_path), + preprocessing=[ + FixedZeroMeanUnitVarianceDescr( + kwargs=FixedZeroMeanUnitVarianceKwargs(mean=mean, std=std) + ) + ], ) output_descr = OutputTensorDescr( - id=TensorId("pred"), axes=output_axes, test_tensor=FileDescr(source=output_path) + id=TensorId("prediction"), + axes=output_axes, + test_tensor=FileDescr(source=output_path), + postprocessing=[ + FixedZeroMeanUnitVarianceDescr( + kwargs=FixedZeroMeanUnitVarianceKwargs( # invert normalization + mean=inv_mean, std=inv_std + ) + ) + ], ) return input_descr, output_descr def create_model_description( - careamist: CAREamist, + config: Configuration, name: str, general_description: str, authors: List[Author], inputs: Union[Path, str], outputs: Union[Path, str], - weights: Union[Path, str], + weights_path: Union[Path, str], + torch_version: str, + careamics_version: str, + config_path: Union[Path, str], + env_path: Union[Path, str], + channel_names: Optional[List[str]] = None, data_description: Optional[str] = None, - custom_description: Optional[str] = None, ) -> ModelDescr: """Create model description. Parameters ---------- - careamist : CAREamist - CAREamist instance. + config : Configuration + CAREamics configuration. name : str Name fo the model. general_description : str @@ -151,34 +204,60 @@ def create_model_description( Path to input .npy file. outputs : Union[Path, str] Path to output .npy file. - weights : Union[Path, str] + weights_path : Union[Path, str] Path to model weights. + torch_version : str + Pytorch version. + careamics_version : str + CAREamics version. + config_path : Union[Path, str] + Path to model configuration. + env_path : Union[Path, str] + Path to environment file. + channel_names : Optional[List[str]], optional + Channel names, by default None. data_description : Optional[str], optional - Description of the data, by default None - custom_description : Optional[str], optional - Description of the custom algorithm, by default None + Description of the data, by default None. Returns ------- ModelDescr Model description. """ + # documentation doc = readme_factory( - careamist.cfg, + config, + careamics_version=careamics_version, data_description=data_description, - custom_description=custom_description, ) + # inputs, outputs input_descr, output_descr = _create_inputs_ouputs( - careamist.cfg.data_config, + input_array=np.load(inputs), + output_array=np.load(outputs), + data_config=config.data_config, input_path=inputs, output_path=outputs, + channel_names=channel_names, ) - # export configuration - with cwd(get_careamics_home()): - config_path = save_configuration(careamist.cfg, get_careamics_home()) + # weights description + architecture_descr = ArchitectureFromLibraryDescr( + import_from="careamics.models", + callable=f"{config.algorithm_config.model.architecture}", + kwargs=config.algorithm_config.model.model_dump(), + ) + + weights_descr = WeightsDescr( + pytorch_state_dict=PytorchStateDictWeightsDescr( + source=weights_path, + architecture=architecture_descr, + pytorch_version=Version(torch_version), + dependencies=EnvironmentFileDescr(source=env_path), + ), + ) + # overall model description model = ModelDescr( name=name, authors=authors, @@ -186,15 +265,54 @@ def create_model_description( documentation=doc, inputs=[input_descr], outputs=[output_descr], - tags=careamist.cfg.get_algorithm_keywords(), + tags=config.get_algorithm_keywords(), links=[ "https://github.com/CAREamics/careamics", "https://careamics.github.io/latest/", ], license="BSD-3-Clause", version="0.1.0", - weights=weights, - attachments=[config_path], + weights=weights_descr, + attachments=[FileDescr(source=config_path)], + cite=config.get_algorithm_citations(), + config={ # conversion from float32 to float64 creates small differences... + "bioimageio": { + "test_kwargs": { + "pytorch_state_dict": { + "decimals": 2, # ...so we relax the constraints on the decimals + } + } + } + }, ) return model + + +def extract_model_path(model_desc: ModelDescr) -> Tuple[Path, Path]: + """Return the relative path to the weights and configuration files. + + Parameters + ---------- + model_desc : ModelDescr + Model description. + + Returns + ------- + Tuple[Path, Path] + Weights and configuration paths. + """ + weights_path = model_desc.weights.pytorch_state_dict.source.path + + if len(model_desc.attachments) == 1: + config_path = model_desc.attachments[0].source.path + else: + for file in model_desc.attachments: + if file.source.path.suffix == ".yml": + config_path = file.source.path + break + + if config_path is None: + raise ValueError("Configuration file not found.") + + return weights_path, config_path diff --git a/src/careamics/model_io/bmz_io.py b/src/careamics/model_io/bmz_io.py new file mode 100644 index 000000000..06e5d9440 --- /dev/null +++ b/src/careamics/model_io/bmz_io.py @@ -0,0 +1,231 @@ +"""Function to export to the BioImage Model Zoo format.""" +import tempfile +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import numpy as np +import pkg_resources +from bioimageio.core import load_description, test_model +from bioimageio.spec import ValidationSummary, save_bioimageio_package +from torch import __version__, load, save + +from careamics.config import Configuration, load_configuration, save_configuration +from careamics.config.support import SupportedArchitecture +from careamics.lightning_module import CAREamicsKiln + +from .bioimage import ( + create_env_text, + create_model_description, + extract_model_path, + get_unzip_path, +) + + +def _export_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> Path: + """ + Export the model state dictionary to a file. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to export. + path : Union[Path, str] + Path to the file where to save the model state dictionary. + + Returns + ------- + Path + Path to the saved model state dictionary. + """ + path = Path(path) + + # make sure it has the correct suffix + if path.suffix not in ".pth": + path = path.with_suffix(".pth") + + # save model state dictionary + # we save through the torch model itself to avoid the initial "model." in the + # layers naming, which is incompatible with the way the BMZ load torch state dicts + save(model.model.state_dict(), path) + + return path + + +def _load_state_dict(model: CAREamicsKiln, path: Union[Path, str]) -> None: + """ + Load a model from a state dictionary. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to be updated with the weights. + path : Union[Path, str] + Path to the model state dictionary. + """ + path = Path(path) + + # load model state dictionary + # same as in _export_state_dict, we load through the torch model to be compatible + # witht bioimageio.core expectations for a torch state dict + state_dict = load(path) + model.model.load_state_dict(state_dict) + + +# TODO break down in subfunctions +def export_to_bmz( + model: CAREamicsKiln, + config: Configuration, + path: Union[Path, str], + name: str, + general_description: str, + authors: List[dict], + input_array: np.ndarray, + output_array: np.ndarray, + channel_names: Optional[List[str]] = None, + data_description: Optional[str] = None, +) -> None: + """Export the model to BioImage Model Zoo format. + + Arrays are expected to be SC(Z)YX with singleton dimensions allowed for S and C. + + Parameters + ---------- + model : CAREamicsKiln + CAREamics model to export. + config : Configuration + Model configuration. + path : Union[Path, str] + Path to the output file. + name : str + Model name. + general_description : str + General description of the model. + authors : List[dict] + Authors of the model. + input_array : np.ndarray + Input array. + output_array : np.ndarray + Output array. + channel_names : Optional[List[str]], optional + Channel names, by default None. + data_description : Optional[str], optional + Description of the data, by default None. + + Raises + ------ + ValueError + If the model is a Custom model. + """ + path = Path(path) + + # method is not compatible with Custom models + if config.algorithm_config.model.architecture == SupportedArchitecture.CUSTOM: + raise ValueError( + "Exporting Custom models to BioImage Model Zoo format is not supported." + ) + + # make sure that input and output arrays have the same shape + assert input_array.shape == output_array.shape, ( + f"Input ({input_array.shape}) and output ({output_array.shape}) arrays " + f"have different shapes" + ) + + # make sure it has the correct suffix + if path.suffix not in ".zip": + path = path.with_suffix(".zip") + + # versions + pytorch_version = __version__ + careamics_version = pkg_resources.get_distribution("careamics").version + + # save files in temporary folder + with tempfile.TemporaryDirectory() as tmpdirname: + temp_path = Path(tmpdirname) + + # create environment file + # TODO move in bioimage module + env_path = temp_path / "environment.yml" + env_path.write_text(create_env_text(pytorch_version)) + + # export input and ouputs + inputs = temp_path / "inputs.npy" + np.save(inputs, input_array) + outputs = temp_path / "outputs.npy" + np.save(outputs, output_array) + + # export configuration + config_path = save_configuration(config, temp_path) + + # export model state dictionary + weight_path = _export_state_dict(model, temp_path / "weights.pth") + + # create model description + model_description = create_model_description( + config=config, + name=name, + general_description=general_description, + authors=authors, + inputs=inputs, + outputs=outputs, + weights_path=weight_path, + torch_version=pytorch_version, + careamics_version=careamics_version, + config_path=config_path, + env_path=env_path, + channel_names=channel_names, + data_description=data_description, + ) + + # test model description + summary: ValidationSummary = test_model(model_description) + if summary.status == "failed": + raise ValueError(f"Model description test failed: {summary}") + + # save bmz model + save_bioimageio_package(model_description, output_path=path) + + +def load_from_bmz(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: + """Load a model from a BioImage Model Zoo archive. + + Parameters + ---------- + path : Union[Path, str] + Path to the BioImage Model Zoo archive. + + Returns + ------- + Tuple[CAREamicsKiln, Configuration] + CAREamics model and configuration. + + Raises + ------ + ValueError + If the path is not a zip file. + """ + path = Path(path) + + if path.suffix != ".zip": + raise ValueError(f"Path must be a bioimage.io zip file, got {path}.") + + # load description, this creates an unzipped folder next to the archive + model_desc = load_description(path) + + # extract relative paths + weights_path, config_path = extract_model_path(model_desc) + + # create folder path and absolute paths + unzip_path = get_unzip_path(path) + weights_path = unzip_path / weights_path + config_path = unzip_path / config_path + + # load configuration + config = load_configuration(config_path) + + # create careamics lightning module + model = CAREamicsKiln(algorithm_config=config.algorithm_config) + + # load model state dictionary + _load_state_dict(model, weights_path) + + return model, config diff --git a/src/careamics/model_io/model_io_utils.py b/src/careamics/model_io/model_io_utils.py index 8b1d1fb65..720ac49e0 100644 --- a/src/careamics/model_io/model_io_utils.py +++ b/src/careamics/model_io/model_io_utils.py @@ -1,4 +1,5 @@ """Utility functions to load pretrained models.""" + from pathlib import Path from typing import Tuple, Union @@ -6,6 +7,7 @@ from careamics.config import Configuration from careamics.lightning_module import CAREamicsKiln +from careamics.model_io.bmz_io import load_from_bmz from careamics.utils import check_path_exists @@ -13,7 +15,7 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio """ Load a pretrained model from a checkpoint or a BioImage Model Zoo model. - Expected formats are .ckpt, .zip, .pth or .pt files. + Expected formats are .ckpt or .zip files. Parameters ---------- @@ -22,8 +24,8 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio Returns ------- - CAREamicsKiln - CAREamics model loaded from the checkpoint. + Tuple[CAREamicsKiln, Configuration] + Tuple of CAREamics model and its configuration. Raises ------ @@ -33,93 +35,46 @@ def load_pretrained(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuratio path = check_path_exists(path) if path.suffix == ".ckpt": - # load checkpoint - checkpoint = load(path) - - # attempt to load algorithm parameters - try: - cfg_dict = checkpoint["hyper_parameters"] - except KeyError as e: - raise ValueError( - "Invalid checkpoint file. No `hyper_parameters` found for the " - "algorithm." - ) from e - - model = _load_from_checkpoint(path) - - return model, Configuration(**cfg_dict) - - elif path.suffix == "bioimage.io.zip": - return _load_from_bmz(path) + return _load_checkpoint(path) + elif path.suffix == ".zip": + return load_from_bmz(path) else: raise ValueError( - f"Invalid model format. Expected .ckpt or bioimage.io.zip, " - f"got {path.suffix}." + f"Invalid model format. Expected .ckpt or .zip, got {path.suffix}." ) -def _load_from_checkpoint(path: Union[Path, str]) -> CAREamicsKiln: +def _load_checkpoint(path: Union[Path, str]) -> Tuple[CAREamicsKiln, Configuration]: """ - Load a model from a checkpoint. + Load a model from a checkpoint and return both model and configuration. Parameters ---------- path : Union[Path, str] Path to the checkpoint. - Returns - ------- - CAREamicsKiln - CAREamics model loaded from the checkpoint. - """ - return CAREamicsKiln.load_from_checkpoint(path) - - -def _load_from_torch_dict( - path: Union[Path, str] -) -> Tuple[CAREamicsKiln, Configuration]: - """ - Load a model from a PyTorch dictionary. - - Parameters - ---------- - path : Union[Path, str] - Path to the PyTorch dictionary. - Returns ------- Tuple[CAREamicsKiln, Configuration] - CAREamics model and Configuration loaded from the BioImage Model Zoo. - """ - raise NotImplementedError( - "Loading a model from a PyTorch dictionary is not implemented yet." - ) - - -def _load_from_bmz( - path: Union[Path, str], -) -> Tuple[CAREamicsKiln, Configuration]: - """ - Load a model from BioImage Model Zoo. - - Parameters - ---------- - path : Union[Path, str] - Path to the BioImage Model Zoo model. - - Returns - ------- - Tuple[CAREamicsKiln, Configuration] - CAREamics model and Configuration loaded from the BioImage Model Zoo. + Tuple of CAREamics model and its configuration. Raises ------ - NotImplementedError - If the method is not implemented yet. + ValueError + If the checkpoint file does not contain hyper parameters (configuration). """ - raise NotImplementedError( - "Loading a model from BioImage Model Zoo is not implemented yet." - ) + # load checkpoint + checkpoint: dict = load(path) + + # attempt to load configuration + try: + cfg_dict = checkpoint["hyper_parameters"] + except KeyError as e: + raise ValueError( + f"Invalid checkpoint file. No `hyper_parameters` found in the " + f"checkpoint: {checkpoint.keys()}" + ) from e + + model = CAREamicsKiln.load_from_checkpoint(path) - # load BMZ archive - # extract model and call _load_from_torch_dict + return model, Configuration(**cfg_dict) diff --git a/src/careamics/models/model_factory.py b/src/careamics/models/model_factory.py index 3ca69a0b3..5b93622b3 100644 --- a/src/careamics/models/model_factory.py +++ b/src/careamics/models/model_factory.py @@ -44,7 +44,7 @@ def model_factory( assert isinstance(model_configuration, CustomModel) model = get_custom_model(model_configuration.name) - return model(**model_configuration.parameters.model_dump()) + return model(**model_configuration.model_dump()) else: raise NotImplementedError( f"Model {model_configuration.architecture} is not implemented or unknown." diff --git a/src/careamics/prediction/__init__.py b/src/careamics/prediction/__init__.py index a11cfceb7..852e65de1 100644 --- a/src/careamics/prediction/__init__.py +++ b/src/careamics/prediction/__init__.py @@ -2,8 +2,6 @@ __all__ = [ "stitch_prediction", - "tta_backward", - "tta_forward", ] -from .prediction_utils import stitch_prediction, tta_backward, tta_forward +from .stitch_prediction import stitch_prediction diff --git a/src/careamics/prediction/prediction_utils.py b/src/careamics/prediction/prediction_utils.py deleted file mode 100644 index 32d8255d8..000000000 --- a/src/careamics/prediction/prediction_utils.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -Prediction convenience functions. - -These functions are used during prediction. -""" -from typing import List, Optional - -import numpy as np -import torch - - -def stitch_prediction( - tiles: List[torch.Tensor], - stitching_data: List, - explicit_stitching: Optional[bool] = False, -) -> torch.Tensor: - """ - Stitch tiles back together to form a full image. - - Parameters - ---------- - tiles : List[torch.Tensor] - Cropped tiles and their respective stitching coordinates. - stitching_data : List - List of coordinates obtained from - `dataset.tiled_patching.extract_tiles`. - explicit_stitching : bool, optional - Whether this function is called explicitly after prediction(Lighting) or inside - the predict function. Removes the first element (last tile indicator). - - Returns - ------- - np.ndarray - Full image. - """ - # Remove first element of stitching_data if explicit_stitching - # TODO revisit, no way around this? - if explicit_stitching: - stitching_data = [d[1:] for d in stitching_data] - - # Get whole sample shape. Unique to get rid of batch dimension. Expects tensors - input_shape = [x.unique() for x in stitching_data[0][0]] #TODO refatcor unique() ? - - predicted_image = np.zeros(input_shape, dtype=np.float32) - - for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip( - tiles, stitching_data - ): - for batch_idx in range(tile_batch.shape[0]): - # Compute coordinates for cropping predicted tile - slices = tuple( - [ - slice(c[0][batch_idx], c[1][batch_idx]) - for c in overlap_crop_coords_batch - ] - ) - - # Crop predited tile according to overlap coordinates - cropped_tile = tile_batch[batch_idx].squeeze()[slices] - - # Insert cropped tile into predicted image using stitch coordinates - predicted_image[ - ( - ..., - *[ - slice(c[0][batch_idx], c[1][batch_idx]) - for c in stitch_coords_batch - ], - ) - ] = cropped_tile.to(torch.float32) - - return predicted_image - - -def tta_forward(x: np.ndarray) -> List: - """ - Augment 8-fold an array. - - The augmentation is performed using all 90 deg rotations and their flipped version, - as well as the original image flipped. - - Parameters - ---------- - x : torch.tensor - Data to augment. - - Returns - ------- - List - Stack of augmented images. - """ - x_aug = [ - x, - torch.rot90(x, 1, dims=(2, 3)), - torch.rot90(x, 2, dims=(2, 3)), - torch.rot90(x, 3, dims=(2, 3)), - ] - x_aug_flip = x_aug.copy() - for x_ in x_aug: - x_aug_flip.append(torch.flip(x_, dims=(1, 3))) - return x_aug_flip - - -def tta_backward(x_aug: List) -> np.ndarray: - """ - Invert `tta_forward` and average the 8 images. - - Parameters - ---------- - x_aug : List - Stack of 8-fold augmented images. - - Returns - ------- - np.ndarray - Average of de-augmented x_aug. - """ - x_deaug = [ - x_aug[0], - np.rot90(x_aug[1], -1), - np.rot90(x_aug[2], -2), - np.rot90(x_aug[3], -3), - np.fliplr(x_aug[4]), - np.rot90(np.fliplr(x_aug[5]), -1), - np.rot90(np.fliplr(x_aug[6]), -2), - np.rot90(np.fliplr(x_aug[7]), -3), - ] - return np.mean(x_deaug, 0) diff --git a/src/careamics/prediction/stitch_prediction.py b/src/careamics/prediction/stitch_prediction.py new file mode 100644 index 000000000..5e0ee7e11 --- /dev/null +++ b/src/careamics/prediction/stitch_prediction.py @@ -0,0 +1,73 @@ +""" +Prediction convenience functions. + +These functions are used during prediction. +""" +from typing import List + +import numpy as np +import torch + + +def stitch_prediction( + tiles: List[torch.Tensor], + stitching_data: List[List[torch.Tensor]], +) -> torch.Tensor: + """ + Stitch tiles back together to form a full image. + + Parameters + ---------- + tiles : List[torch.Tensor] + Cropped tiles and their respective stitching coordinates. + stitching_coords : List + List of information and coordinates obtained from + `dataset.tiled_patching.extract_tiles`. + + Returns + ------- + np.ndarray + Full image. + """ + # retrieve whole array size, there is two cases to consider: + # 1. the tiles are stored in a list + # 2. the tiles are stored in a list with batches along the first dim + if tiles[0].shape[0] > 1: + input_shape = np.array( + [el.numpy() for el in stitching_data[0][0][0]], dtype=int + ).squeeze() + else: + input_shape = np.array( + [el.numpy() for el in stitching_data[0][0]], dtype=int + ).squeeze() + + # TODO should use torch.zeros instead of np.zeros + predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32)) + + for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip( + tiles, stitching_data + ): + for batch_idx in range(tile_batch.shape[0]): + # Compute coordinates for cropping predicted tile + slices = tuple( + [ + slice(c[0][batch_idx], c[1][batch_idx]) + for c in overlap_crop_coords_batch + ] + ) + + # Crop predited tile according to overlap coordinates + cropped_tile = tile_batch[batch_idx].squeeze()[slices] + + # Insert cropped tile into predicted image using stitch coordinates + predicted_image[ + ( + ..., + *[ + slice(c[0][batch_idx], c[1][batch_idx]) + for c in stitch_coords_batch + ], + ) + ] = cropped_tile.to(torch.float32) + + return predicted_image diff --git a/src/careamics/transforms/__init__.py b/src/careamics/transforms/__init__.py index 3f7e37a65..dd4c2604c 100644 --- a/src/careamics/transforms/__init__.py +++ b/src/careamics/transforms/__init__.py @@ -6,36 +6,31 @@ "NDFlip", "XYRandomRotate90", "ImageRestorationTTA", + "Denormalize", + "Normalize", ] -from inspect import getmembers, isclass - -import albumentations as Aug from .n2v_manipulate import N2VManipulate from .nd_flip import NDFlip +from .normalize import Denormalize, Normalize from .tta import ImageRestorationTTA from .xy_random_rotate90 import XYRandomRotate90 -ALL_TRANSFORMS = dict( - getmembers(Aug, isclass) - + [ - ("N2VManipulate", N2VManipulate), - ("NDFlip", NDFlip), - ("XYRandomRotate90", XYRandomRotate90), - ] -) +ALL_TRANSFORMS = { + "Normalize": Normalize, + "N2VManipulate": N2VManipulate, + "NDFlip": NDFlip, + "XYRandomRotate90": XYRandomRotate90, +} def get_all_transforms() -> dict: """Return all the transforms accepted by CAREamics. - This includes all transforms from Albumentations (see https://albumentations.ai/), - and custom transforms implemented in CAREamics. - - Note that while any Albumentations transform can be used in CAREamics, no check are - implemented to verify the compatibility of any other transforms than the ones - officially supported (see SupportedTransforms). + Note that while CAREamics accepts any `Compose` transforms from Albumentations (see + https://albumentations.ai/), only a few transformations are explicitely supported + (see `SupportedTransform`). Returns ------- diff --git a/src/careamics/transforms/normalize.py b/src/careamics/transforms/normalize.py index 3554a16d3..4ec529b0f 100644 --- a/src/careamics/transforms/normalize.py +++ b/src/careamics/transforms/normalize.py @@ -1,42 +1,61 @@ -from typing import Any, Tuple +from typing import Any import numpy as np from albumentations import DualTransform -# TODO unused -class Denormalize(DualTransform): - """Denormalize an image or image patch. +class Normalize(DualTransform): + """ + Normalize an image or image patch. + + Normalization is a zero mean and unit variance. This transform expects (Z)YXC + dimensions. + + Not that an epsilon value of 1e-6 is added to the standard deviation to avoid + division by zero and that it returns a float32 image. - This transform expects (Z)YXC dimensions. + Attributes + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + eps : float + Epsilon value to avoid division by zero. """ def __init__( self, - mean: Tuple[float], - std: Tuple[float], - max_pixel_value: int = 1, - always_apply: bool = False, - p: float = 1.0, + mean: float, + std: float, ): - super().__init__(always_apply=always_apply, p=p) + super().__init__(always_apply=True, p=1) - self.mean = np.array(mean) - self.std = np.array(std) - self.max_pixel_value = max_pixel_value + self.mean = mean + self.std = std + self.eps = 1e-6 def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray: - """Apply the transform to the image. + """ + Apply the transform to the image. Parameters ---------- patch : np.ndarray Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + + Returns + ------- + np.ndarray + Normalized image or image patch. """ - return (patch * self.std + self.mean) * self.max_pixel_value + return ((patch - self.mean) / (self.std + self.eps)).astype(np.float32) def apply_to_mask(self, mask: np.ndarray, **kwargs: Any) -> np.ndarray: - """Apply the transform to the mask. + """ + Apply the transform to the mask. + + The mask is returned as is. Parameters ---------- @@ -44,3 +63,47 @@ def apply_to_mask(self, mask: np.ndarray, **kwargs: Any) -> np.ndarray: Mask or mask patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). """ return mask + + +class Denormalize(DualTransform): + """ + Denormalize an image or image patch. + + Denormalization is performed expecting a zero mean and unit variance input. This + transform expects (Z)YXC dimensions. + + Not that an epsilon value of 1e-6 is added to the standard deviation to avoid + division by zero during the normalization step, which is taken into account during + denormalization. + + Attributes + ---------- + mean : float + Mean value. + std : float + Standard deviation value. + eps : float + Epsilon value to avoid division by zero. + """ + + def __init__( + self, + mean: float, + std: float, + ): + super().__init__(always_apply=True, p=1) + + self.mean = mean + self.std = std + self.eps = 1e-6 + + def apply(self, patch: np.ndarray, **kwargs: Any) -> np.ndarray: + """ + Apply the transform to the image. + + Parameters + ---------- + patch : np.ndarray + Image or image patch, 2D or 3D, shape (y, x, c) or (z, y, x, c). + """ + return patch * (self.std + self.eps) + self.mean diff --git a/src/careamics/transforms/tta.py b/src/careamics/transforms/tta.py index 55cc9cc13..39e65950e 100644 --- a/src/careamics/transforms/tta.py +++ b/src/careamics/transforms/tta.py @@ -1,3 +1,4 @@ +"""Test-time augmentations.""" from typing import List import numpy as np @@ -6,12 +7,16 @@ # TODO add tests class ImageRestorationTTA: - """Test-time augmentation for image restoration tasks. + """ + Test-time augmentation for image restoration tasks. The augmentation is performed using all 90 deg rotations and their flipped version, as well as the original image flipped. Tensors should be of shape SC(Z)YX + + This transformation is used in the LightningModule in order to perform test-time + agumentation. """ def __init__(self) -> None: @@ -19,16 +24,17 @@ def __init__(self) -> None: pass def forward(self, x: Tensor) -> List[Tensor]: - """Apply test-time augmentation to the input tensor. + """ + Apply test-time augmentation to the input tensor. Parameters ---------- - x : Any + x : Tensor Input tensor, shape SC(Z)YX. Returns ------- - Any + List[Tensor] List of augmented tensors. """ augmented = [ diff --git a/src/careamics/utils/__init__.py b/src/careamics/utils/__init__.py index 91d309e33..3207fe2b2 100644 --- a/src/careamics/utils/__init__.py +++ b/src/careamics/utils/__init__.py @@ -2,28 +2,17 @@ __all__ = [ - "denormalize", - "normalize", - "check_axes_validity", - "check_tiling_validity", "cwd", - "MetricTracker", "get_ram_size", "check_path_exists", "BaseEnum", "get_logger", "get_careamics_home", - "RunningStats", ] from .base_enum import BaseEnum from .context import cwd, get_careamics_home from .logging import get_logger -from .normalization import RunningStats, denormalize, normalize from .path_utils import check_path_exists from .ram import get_ram_size -from .validators import ( - check_axes_validity, - check_tiling_validity, -) diff --git a/src/careamics/utils/metrics.py b/src/careamics/utils/metrics.py index b961fc2b5..23283c3a4 100644 --- a/src/careamics/utils/metrics.py +++ b/src/careamics/utils/metrics.py @@ -112,49 +112,3 @@ def scale_invariant_psnr( range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt) gt_ = _zero_mean(gt) / np.std(gt) return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter) - - -class MetricTracker: - """ - Metric tracker class. - - This class is used to track values, sum, count and average of a metric over time. - - Attributes - ---------- - val : int - Last value of the metric. - avg : torch.Tensor.float - Average value of the metric. - sum : int - Sum of the metric values (times number of values). - count : int - Number of values. - """ - - def __init__(self) -> None: - """Constructor.""" - self.reset() - - def reset(self) -> None: - """Reset the metric tracker state.""" - self.val = 0.0 - self.avg: torch.Tensor.float = 0.0 - self.sum = 0.0 - self.count = 0.0 - - def update(self, value: int, n: int = 1) -> None: - """ - Update the metric tracker state. - - Parameters - ---------- - value : int - Value to update the metric tracker with. - n : int - Number of values, equals to batch size. - """ - self.val = value - self.sum += value * n - self.count += n - self.avg = self.sum / self.count diff --git a/src/careamics/utils/normalization.py b/src/careamics/utils/normalization.py deleted file mode 100644 index 86cb8edd5..000000000 --- a/src/careamics/utils/normalization.py +++ /dev/null @@ -1,97 +0,0 @@ -""" -Normalization submodule. - -These methods are used to normalize and denormalize images. -""" -from multiprocessing import Value -from typing import List, Tuple, Union - -import numpy as np - - -def normalize( - img: Union[Tuple, np.ndarray], mean: float, std: float -) -> Union[List, np.ndarray]: - """ - Normalize an image using mean and standard deviation. - - Images are normalised by subtracting the mean and dividing by the standard - deviation. - - Parameters - ---------- - img : np.ndarray - Image to normalize. - mean : float - Mean. - std : float - Standard deviation. - - Returns - ------- - np.ndarray - Normalized array. - """ - zero_mean = img - mean - return zero_mean / std - - -def denormalize(img: np.ndarray, mean: float, std: float) -> np.ndarray: - """ - Denormalize an image using mean and standard deviation. - - Images are denormalised by multiplying by the standard deviation and adding the - mean. - - Parameters - ---------- - img : np.ndarray - Image to denormalize. - mean : float - Mean. - std : float - Standard deviation. - - Returns - ------- - np.ndarray - Denormalized array. - """ - return img * std + mean - - -class RunningStats: - """Calculates running mean and std.""" - - def __init__(self) -> None: - self.reset() - - def reset(self) -> None: - """Reset the running stats.""" - self.avg_mean = Value("d", 0) - self.avg_std = Value("d", 0) - self.m2 = Value("d", 0) - self.count = Value("i", 0) - - def init(self, mean: float, std: float) -> None: - """Initialize running stats.""" - with self.avg_mean.get_lock(): - self.avg_mean.value += mean - with self.avg_std.get_lock(): - self.avg_std.value = std - - def compute_std(self) -> Tuple[float, float]: - """Compute std.""" - if self.count.value >= 2: - self.avg_std.value = np.sqrt(self.m2.value / self.count.value) - - def update(self, value: float) -> None: - """Update running stats.""" - with self.count.get_lock(): - self.count.value += 1 - delta = value - self.avg_mean.value - with self.avg_mean.get_lock(): - self.avg_mean.value += delta / self.count.value - delta2 = value - self.avg_mean.value - with self.m2.get_lock(): - self.m2.value += delta * delta2 diff --git a/src/careamics/utils/running_stats.py b/src/careamics/utils/running_stats.py new file mode 100644 index 000000000..1268d3e43 --- /dev/null +++ b/src/careamics/utils/running_stats.py @@ -0,0 +1,43 @@ +"""Running stats submodule, used in the Zarr dataset.""" + +# from multiprocessing import Value +# from typing import Tuple + +# import numpy as np + + +# class RunningStats: +# """Calculates running mean and std.""" + +# def __init__(self) -> None: +# self.reset() + +# def reset(self) -> None: +# """Reset the running stats.""" +# self.avg_mean = Value("d", 0) +# self.avg_std = Value("d", 0) +# self.m2 = Value("d", 0) +# self.count = Value("i", 0) + +# def init(self, mean: float, std: float) -> None: +# """Initialize running stats.""" +# with self.avg_mean.get_lock(): +# self.avg_mean.value += mean +# with self.avg_std.get_lock(): +# self.avg_std.value = std + +# def compute_std(self) -> Tuple[float, float]: +# """Compute std.""" +# if self.count.value >= 2: +# self.avg_std.value = np.sqrt(self.m2.value / self.count.value) + +# def update(self, value: float) -> None: +# """Update running stats.""" +# with self.count.get_lock(): +# self.count.value += 1 +# delta = value - self.avg_mean.value +# with self.avg_mean.get_lock(): +# self.avg_mean.value += delta / self.count.value +# delta2 = value - self.avg_mean.value +# with self.m2.get_lock(): +# self.m2.value += delta * delta2 diff --git a/src/careamics/utils/validators.py b/src/careamics/utils/validators.py deleted file mode 100644 index 672a5ed0c..000000000 --- a/src/careamics/utils/validators.py +++ /dev/null @@ -1,114 +0,0 @@ -""" -Validator functions. - -These functions are used to validate dimensions and axes of inputs. -""" -from typing import List - -AXES = "STCZYX" - - -def check_axes_validity(axes: str) -> bool: - """ - Sanity check on axes. - - The constraints on the axes are the following: - - must be a combination of 'STCZYX' - - must not contain duplicates - - must contain at least 2 contiguous axes: X and Y - - must contain at most 4 axes - - cannot contain both S and T axes - - Axes do not need to be in the order 'STCZYX', as this depends on the user data. - - Parameters - ---------- - axes : str - Axes to validate. - - Returns - ------- - bool - True if axes are valid, False otherwise. - """ - _axes = axes.upper() - - # Minimum is 2 (XY) and maximum is 4 (TZYX) - if len(_axes) < 2 or len(_axes) > 6: - raise ValueError( - f"Invalid axes {axes}. Must contain at least 2 and at most 6 axes." - ) - - if "YX" not in _axes and "XY" not in _axes: - raise ValueError( - f"Invalid axes {axes}. Must contain at least X and Y axes consecutively." - ) - - # all characters must be in REF_AXES = 'STCZYX' - if not all(s in AXES for s in _axes): - raise ValueError(f"Invalid axes {axes}. Must be a combination of {AXES}.") - - # check for repeating characters - for i, s in enumerate(_axes): - if i != _axes.rfind(s): - raise ValueError( - f"Invalid axes {axes}. Cannot contain duplicate axes" - f" (got multiple {axes[i]})." - ) - - return True - - -def check_tiling_validity(tile_shape: List[int], overlaps: List[int]) -> None: - """ - Check that the tiling parameters are valid. - - Parameters - ---------- - tile_shape : List[int] - Shape of the tiles. - overlaps : List[int] - Overlap between tiles. - - Raises - ------ - ValueError - If one of the parameters is None. - ValueError - If one of the element is zero. - ValueError - If one of the element is non-divisible by 2. - ValueError - If the number of elements in `overlaps` and `tile_shape` is different. - ValueError - If one of the overlaps is larger than the corresponding tile shape. - """ - # cannot be None - if tile_shape is None or overlaps is None: - raise ValueError( - "Cannot use tiling without specifying `tile_shape` and " - "`overlaps`, make sure they have been correctly specified." - ) - - # non-zero and divisible by two - for dims_list in [tile_shape, overlaps]: - for dim in dims_list: - if dim < 0: - raise ValueError(f"Entry must be non-null positive (got {dim}).") - - if dim % 2 != 0: - raise ValueError(f"Entry must be divisible by 2 (got {dim}).") - - # same length - if len(overlaps) != len(tile_shape): - raise ValueError( - f"Overlaps ({len(overlaps)}) and tile shape ({len(tile_shape)}) must " - f"have the same number of dimensions." - ) - - # overlaps smaller than tile shape - for overlap, tile_dim in zip(overlaps, tile_shape): - if overlap >= tile_dim: - raise ValueError( - f"Overlap ({overlap}) must be smaller than tile shape ({tile_dim})." - ) diff --git a/src/careamics/utils/wandb.py b/src/careamics/utils/wandb.py deleted file mode 100644 index 3bae730d7..000000000 --- a/src/careamics/utils/wandb.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -A WandB logger for CAREamics. - -Implements a WandB class for use within the Engine. -""" -import sys -from pathlib import Path -from typing import Dict, Union - -import torch -import wandb - -from ..config import Configuration - - -def is_notebook() -> bool: - """ - Check if the code is executed from a notebook or a qtconsole. - - Returns - ------- - bool - True if the code is executed from a notebooks, False otherwise. - """ - try: - from IPython import get_ipython - - shell = get_ipython().__class__.__name__ - if shell == "ZMQInteractiveShell": - return True # Jupyter notebook or qtconsole - else: - return False - except (NameError, ModuleNotFoundError): - return False - - -class WandBLogging: - """ - WandB logging class. - - Parameters - ---------- - experiment_name : str - Name of the experiment. - log_path : Path - Path in which to save the WandB log. - config : Configuration - Configuration of the model. - model_to_watch : torch.nn.Module - Model. - save_code : bool, optional - Whether to save the code, by default True. - """ - - def __init__( - self, - experiment_name: str, - log_path: Path, - config: Configuration, - model_to_watch: torch.nn.Module, - save_code: bool = True, - ): - """ - Constructor. - - Parameters - ---------- - experiment_name : str - Name of the experiment. - log_path : Path - Path in which to save the WandB log. - config : Configuration - Configuration of the model. - model_to_watch : torch.nn.Module - Model. - save_code : bool, optional - Whether to save the code, by default True. - """ - self.run = wandb.init( - project="careamics-restoration", - dir=log_path, - name=experiment_name, - config=config.model_dump() if config else None, - # save_code=save_code, - ) - if model_to_watch: - wandb.watch(model_to_watch, log="all", log_freq=1) - if save_code: - if is_notebook(): - # Get all sys path and select the root - code_path = Path([p for p in sys.path if "caremics" in p][-1]).parent - else: - code_path = Path("../") - self.log_code(code_path) - - def log_metrics(self, metric_dict: Dict) -> None: - """ - Log metrics to wandb. - - Parameters - ---------- - metric_dict : Dict - New metrics entry. - """ - self.run.log(metric_dict, commit=True) - - def log_code(self, code_path: Union[str, Path]) -> None: - """ - Log code to wandb. - - Parameters - ---------- - code_path : Union[str, Path] - Path to the code. - """ - self.run.log_code( - root=code_path, - include_fn=lambda path: path.endswith(".py") - or path.endswith(".yml") - or path.endswith(".yaml"), - ) diff --git a/tests/config/architectures/test_architecture_model.py b/tests/config/architectures/test_architecture_model.py new file mode 100644 index 000000000..b97ab7e96 --- /dev/null +++ b/tests/config/architectures/test_architecture_model.py @@ -0,0 +1,11 @@ +from careamics.config.architectures import ArchitectureModel + + +def test_model_dump(): + """Test that architecture keyword is removed from the model dump.""" + model_params = {"architecture": "LeCorbusier"} + model = ArchitectureModel(**model_params) + + # dump model + model_dict = model.model_dump() + assert model_dict == {} diff --git a/tests/config/architectures/test_custom_model.py b/tests/config/architectures/test_custom_model.py index 6b8983937..63b6c7760 100644 --- a/tests/config/architectures/test_custom_model.py +++ b/tests/config/architectures/test_custom_model.py @@ -2,7 +2,6 @@ from torch import nn, ones from careamics.config.architectures import CustomModel, get_custom_model, register_model -from careamics.config.architectures.custom_model import CustomParametersModel from careamics.config.support import SupportedArchitecture @@ -29,14 +28,13 @@ def forward(self, input): return input -def test_empty_parameters(): - """Test that the custom model parameters does not require any fields.""" - CustomParametersModel() - - def test_any_custom_parameters(): - """Test that the custom model parameters can have any fields.""" - CustomParametersModel(id=3, some_param={"a": 1, "b": 2}, t="test") + """Test that the custom model can have any fields. + + Note that those fields are validated by instantiating the + model. + """ + CustomModel(architecture="Custom", name="linear", in_features=10, out_features=5) def test_linear_model(): @@ -57,7 +55,8 @@ def test_custom_model(): model_dict = { "architecture": SupportedArchitecture.CUSTOM.value, "name": "linear", - "parameters": {"in_features": 10, "out_features": 5}, + "in_features": 10, + "out_features": 5, } # create Pydantic model @@ -65,7 +64,7 @@ def test_custom_model(): # instantiate model model_class = get_custom_model(pydantic_model.name) - model = model_class(**pydantic_model.parameters.model_dump()) + model = model_class(**pydantic_model.model_dump()) assert isinstance(model, LinearModel) assert model.in_features == 10 diff --git a/tests/config/architectures/test_unet_model.py b/tests/config/architectures/test_unet_model.py index 299223ee3..0f41a91b0 100644 --- a/tests/config/architectures/test_unet_model.py +++ b/tests/config/architectures/test_unet_model.py @@ -113,9 +113,9 @@ def test_model_dump(): model_dict = model.model_dump(exclude_defaults=True) # check that default values are excluded except the architecture - assert "architecture" in model_dict - assert len(model_dict) == 3 + assert "architecture" not in model_dict + assert len(model_dict) == 2 # check that we get all the optional values with the exclude_defaults flag model_dict = model.model_dump(exclude_defaults=False) - assert len(model_dict) == len(dict(model)) + assert len(model_dict) == len(dict(model)) - 1 diff --git a/tests/config/test_configuration_factory.py b/tests/config/test_configuration_factory.py index 61708b6ba..711c4375d 100644 --- a/tests/config/test_configuration_factory.py +++ b/tests/config/test_configuration_factory.py @@ -18,13 +18,16 @@ def test_n2v_configuration(): batch_size=8, num_epochs=100, ) - assert config.data_config.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + assert ( + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.UNIFORM.value ) - assert not config.data_config.transforms[-2].parameters.is_3D # XY_RANDOM_ROTATE90 - assert not config.data_config.transforms[-3].parameters.is_3D # NDFLIP + assert not config.data_config.transforms[-2].is_3D # XY_RANDOM_ROTATE90 + assert not config.data_config.transforms[-3].is_3D # NDFLIP assert not config.algorithm_config.model.is_3D() @@ -38,13 +41,16 @@ def test_n2v_3d_configuration(): batch_size=8, num_epochs=100, ) - assert config.data_config.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) + assert ( + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.UNIFORM.value ) - assert config.data_config.transforms[-2].parameters.is_3D # XY_RANDOM_ROTATE90 - assert config.data_config.transforms[-3].parameters.is_3D # NDFLIP + assert config.data_config.transforms[-2].is_3D # XY_RANDOM_ROTATE90 + assert config.data_config.transforms[-3].is_3D # NDFLIP assert config.algorithm_config.model.is_3D() @@ -157,7 +163,10 @@ def test_n2v_no_aug(): use_augmentations=False, ) assert len(config.data_config.transforms) == 2 - assert config.data_config.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value + assert ( + config.data_config.transforms[-1].name + == SupportedTransform.N2V_MANIPULATE.value + ) assert config.data_config.transforms[-2].name == SupportedTransform.NORMALIZE.value @@ -173,8 +182,8 @@ def test_n2v_augmentation_parameters(): roi_size=17, masked_pixel_percentage=0.5, ) - assert config.data_config.transforms[-1].parameters.roi_size == 17 - assert config.data_config.transforms[-1].parameters.masked_pixel_percentage == 0.5 + assert config.data_config.transforms[-1].roi_size == 17 + assert config.data_config.transforms[-1].masked_pixel_percentage == 0.5 def test_n2v2(): @@ -189,7 +198,7 @@ def test_n2v2(): use_n2v2=True, ) assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.MEDIAN.value ) @@ -207,7 +216,7 @@ def test_structn2v(): struct_n2v_span=7, ) assert ( - config.data_config.transforms[-1].parameters.struct_mask_axis + config.data_config.transforms[-1].struct_mask_axis == SupportedStructAxis.HORIZONTAL.value ) - assert config.data_config.transforms[-1].parameters.struct_mask_span == 7 + assert config.data_config.transforms[-1].struct_mask_span == 7 diff --git a/tests/config/test_configuration_model.py b/tests/config/test_configuration_model.py index 5f731159d..353184cf6 100644 --- a/tests/config/test_configuration_model.py +++ b/tests/config/test_configuration_model.py @@ -114,19 +114,17 @@ def test_n2v2_and_transforms(minimum_configuration: dict, algorithm, strategy): config.data_config.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value ) - assert config.data_config.transforms[-1].parameters.strategy == expected_strategy + assert config.data_config.transforms[-1].strategy == expected_strategy # passing ManipulateN2V with the wrong strategy minimum_configuration["data_config"]["transforms"] = [ { "name": SupportedTransform.N2V_MANIPULATE.value, - "parameters": { - "strategy": strategy, - }, + "strategy": strategy, } ] config = Configuration(**minimum_configuration) - assert config.data_config.transforms[-1].parameters.strategy == expected_strategy + assert config.data_config.transforms[-1].strategy == expected_strategy def test_setting_n2v2(minimum_configuration: dict): @@ -140,7 +138,7 @@ def test_setting_n2v2(minimum_configuration: dict): assert config.algorithm_config.algorithm == SupportedAlgorithm.N2V.value assert not config.algorithm_config.model.n2v2 assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.UNIFORM.value ) @@ -148,7 +146,7 @@ def test_setting_n2v2(minimum_configuration: dict): config.set_N2V2(True) assert config.algorithm_config.model.n2v2 assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.MEDIAN.value ) @@ -156,7 +154,7 @@ def test_setting_n2v2(minimum_configuration: dict): config.set_N2V2(False) assert not config.algorithm_config.model.n2v2 assert ( - config.data_config.transforms[-1].parameters.strategy + config.data_config.transforms[-1].strategy == SupportedPixelManipulation.UNIFORM.value ) diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py index 7acce922e..4e59d776a 100644 --- a/tests/config/test_data_model.py +++ b/tests/config/test_data_model.py @@ -7,7 +7,10 @@ SupportedStructAxis, SupportedTransform, ) -from careamics.config.transformations.xy_random_rotate90_model import ( +from careamics.config.transformations import ( + N2VManipulateModel, + NDFlipModel, + NormalizeModel, XYRandomRotate90Model, ) from careamics.transforms import get_all_transforms @@ -83,8 +86,8 @@ def test_mean_and_std_in_normalize(minimum_data: dict): {"name": SupportedTransform.NORMALIZE.value}, ] data = DataModel(**minimum_data) - assert data.transforms[0].parameters.mean == 10.4 - assert data.transforms[0].parameters.std == 3.2 + assert data.transforms[0].mean == 10.4 + assert data.transforms[0].std == 3.2 def test_patch_size(minimum_data: dict): @@ -93,15 +96,15 @@ def test_patch_size(minimum_data: dict): data_model = DataModel(**minimum_data) # 3D - minimum_data["patch_size"] = [12, 12, 12] + minimum_data["patch_size"] = [16, 8, 8] minimum_data["axes"] = "ZYX" data_model = DataModel(**minimum_data) - assert data_model.patch_size == [12, 12, 12] + assert data_model.patch_size == minimum_data["patch_size"] @pytest.mark.parametrize( - "patch_size", [[12], [0, 12, 12], [12, 12, 13], [12, 12, 12, 12]] + "patch_size", [[12], [0, 12, 12], [12, 12, 13], [16, 10, 16], [12, 12, 12, 12]] ) def test_wrong_patch_size(minimum_data: dict, patch_size): """Test that wrong patch sizes are not accepted (zero or odd, dims 1 or > 3).""" @@ -155,7 +158,18 @@ def test_set_3d(minimum_data: dict): def test_passing_supported_transforms(minimum_data: dict, transforms): """Test that list of supported transforms can be passed.""" minimum_data["transforms"] = transforms - DataModel(**minimum_data) + model = DataModel(**minimum_data) + + supported = { + "NDFlip": NDFlipModel, + "XYRandomRotate90": XYRandomRotate90Model, + "Normalize": NormalizeModel, + "N2VManipulate": N2VManipulateModel, + } + + for ind, t in enumerate(transforms): + assert t["name"] == model.transforms[ind].name + assert isinstance(model.transforms[ind], supported[t["name"]]) @pytest.mark.parametrize( @@ -236,25 +250,24 @@ def test_correct_transform_parameters(minimum_data: dict): model = DataModel(**minimum_data) # Normalize - params = model.transforms[0].parameters.model_dump() + params = model.transforms[0].model_dump() assert "mean" in params assert "std" in params - assert "max_pixel_value" in params # NDFlip - params = model.transforms[1].parameters.model_dump() + params = model.transforms[1].model_dump() assert "p" in params assert "is_3D" in params assert "flip_z" in params # XYRandomRotate90 - params = model.transforms[2].parameters.model_dump() + params = model.transforms[2].model_dump() assert "p" in params assert "is_3D" in params assert isinstance(model.transforms[2], XYRandomRotate90Model) # N2VManipulate - params = model.transforms[3].parameters.model_dump() + params = model.transforms[3].model_dump() assert "roi_size" in params assert "masked_pixel_percentage" in params assert "strategy" in params @@ -294,26 +307,22 @@ def test_3D_and_transforms(minimum_data: dict): minimum_data["transforms"] = [ { "name": SupportedTransform.NDFLIP.value, - "parameters": { - "is_3D": True, - "flip_z": True, - }, + "is_3D": True, + "flip_z": True, }, { "name": SupportedTransform.XY_RANDOM_ROTATE90.value, - "parameters": { - "is_3D": True, - }, + "is_3D": True, }, ] data = DataModel(**minimum_data) - assert data.transforms[0].parameters.is_3D is False - assert data.transforms[1].parameters.is_3D is False + assert data.transforms[0].is_3D is False + assert data.transforms[1].is_3D is False # change to 3D data.set_3D("ZYX", [64, 64, 64]) - data.transforms[0].parameters.is_3D = True - data.transforms[1].parameters.is_3D = True + data.transforms[0].is_3D = True + data.transforms[1].is_3D = True def test_set_n2v_strategy(minimum_data: dict): @@ -323,13 +332,13 @@ def test_set_n2v_strategy(minimum_data: dict): data = DataModel(**minimum_data) assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - assert data.transforms[-1].parameters.strategy == uniform + assert data.transforms[-1].strategy == uniform data.set_N2V2_strategy(median) - assert data.transforms[-1].parameters.strategy == median + assert data.transforms[-1].strategy == median data.set_N2V2_strategy(uniform) - assert data.transforms[-1].parameters.strategy == uniform + assert data.transforms[-1].strategy == uniform def test_set_n2v_strategy_wrong_value(minimum_data: dict): @@ -347,20 +356,20 @@ def test_set_struct_mask(minimum_data: dict): data = DataModel(**minimum_data) assert data.transforms[-1].name == SupportedTransform.N2V_MANIPULATE.value - assert data.transforms[-1].parameters.struct_mask_axis == none - assert data.transforms[-1].parameters.struct_mask_span == 5 + assert data.transforms[-1].struct_mask_axis == none + assert data.transforms[-1].struct_mask_span == 5 data.set_structN2V_mask(vertical, 3) - assert data.transforms[-1].parameters.struct_mask_axis == vertical - assert data.transforms[-1].parameters.struct_mask_span == 3 + assert data.transforms[-1].struct_mask_axis == vertical + assert data.transforms[-1].struct_mask_span == 3 data.set_structN2V_mask(horizontal, 7) - assert data.transforms[-1].parameters.struct_mask_axis == horizontal - assert data.transforms[-1].parameters.struct_mask_span == 7 + assert data.transforms[-1].struct_mask_axis == horizontal + assert data.transforms[-1].struct_mask_span == 7 data.set_structN2V_mask(none, 11) - assert data.transforms[-1].parameters.struct_mask_axis == none - assert data.transforms[-1].parameters.struct_mask_span == 11 + assert data.transforms[-1].struct_mask_axis == none + assert data.transforms[-1].struct_mask_span == 11 def test_set_struct_mask_wrong_value(minimum_data: dict): diff --git a/tests/config/test_inference_model.py b/tests/config/test_inference_model.py index 8277869cb..885a48169 100644 --- a/tests/config/test_inference_model.py +++ b/tests/config/test_inference_model.py @@ -18,21 +18,13 @@ def test_wrong_extensions(minimum_inference: dict, ext: str): InferenceModel(**minimum_inference) -@pytest.mark.parametrize("mean, std", [(0, 124.5), (12.6, 0.1)]) -def test_mean_std_non_negative(minimum_inference: dict, mean, std): - """Test that non negative mean and std are accepted.""" - minimum_inference["mean"] = mean - minimum_inference["std"] = std - - prediction_model = InferenceModel(**minimum_inference) - assert prediction_model.mean == mean - assert prediction_model.std == std - - def test_mean_std_both_specified_or_none(minimum_inference: dict): - """Test an error is raised if std is specified but mean is None.""" - # No error if both are None - InferenceModel(**minimum_inference) + """Test error raising when setting mean and std.""" + # Errors if both are None + minimum_inference["mean"] = None + minimum_inference["std"] = None + with pytest.raises(ValueError): + InferenceModel(**minimum_inference) # Error if only mean is defined minimum_inference["mean"] = 10.4 @@ -51,38 +43,28 @@ def test_mean_std_both_specified_or_none(minimum_inference: dict): InferenceModel(**minimum_inference) -def test_set_mean_and_std(minimum_inference: dict): - """Test that mean and std can be set after initialization.""" - # they can be set both, when they None - mean = 4.07 - std = 14.07 - pred = InferenceModel(**minimum_inference) - pred.set_mean_and_std(mean, std) - assert pred.mean == mean - assert pred.std == std - - # and if they are already set - minimum_inference["mean"] = 10.4 - minimum_inference["std"] = 3.2 - pred = InferenceModel(**minimum_inference) - pred.set_mean_and_std(mean, std) - assert pred.mean == mean - assert pred.std == std - - def test_tile_size(minimum_inference: dict): """Test that non-zero even patch size are accepted.""" + # no tiling + prediction_model = InferenceModel(**minimum_inference) + # 2D + minimum_inference["tile_size"] = [16, 8] + minimum_inference["tile_overlap"] = [2, 2] + minimum_inference["axes"] = "YX" + prediction_model = InferenceModel(**minimum_inference) + assert prediction_model.tile_size == minimum_inference["tile_size"] + assert prediction_model.tile_overlap == minimum_inference["tile_overlap"] # 3D - minimum_inference["tile_size"] = [12, 12, 12] + minimum_inference["tile_size"] = [16, 8, 32] minimum_inference["tile_overlap"] = [2, 2, 2] minimum_inference["axes"] = "ZYX" prediction_model = InferenceModel(**minimum_inference) - assert prediction_model.tile_size == [12, 12, 12] - assert prediction_model.tile_overlap == [2, 2, 2] + assert prediction_model.tile_size == minimum_inference["tile_size"] + assert prediction_model.tile_overlap == minimum_inference["tile_overlap"] @pytest.mark.parametrize( @@ -112,6 +94,9 @@ def test_wrong_tile_overlap(minimum_inference: dict, tile_size, tile_overlap): def test_set_3d(minimum_inference: dict): """Test that 3D can be set.""" + minimum_inference["tile_size"] = [64, 64] + minimum_inference["tile_overlap"] = [32, 32] + pred = InferenceModel(**minimum_inference) assert "Z" not in pred.axes assert len(pred.tile_size) == 2 @@ -180,7 +165,7 @@ def test_passing_compose_transform(minimum_inference: dict): """Test that Compose transform can be passed.""" minimum_inference["transforms"] = Compose( [ - get_all_transforms()[SupportedTransform.NORMALIZE](), + get_all_transforms()[SupportedTransform.NORMALIZE](mean=10.4, std=3.2), get_all_transforms()[SupportedTransform.NDFLIP](), ] ) @@ -196,5 +181,5 @@ def test_mean_and_std_in_normalize(minimum_inference: dict): ] data = InferenceModel(**minimum_inference) - assert data.transforms[0].parameters.mean == 10.4 - assert data.transforms[0].parameters.std == 3.2 + assert data.transforms[0].mean == 10.4 + assert data.transforms[0].std == 3.2 diff --git a/tests/config/test_tile_information.py b/tests/config/test_tile_information.py new file mode 100644 index 000000000..78b24cc80 --- /dev/null +++ b/tests/config/test_tile_information.py @@ -0,0 +1,50 @@ +import numpy as np +import pytest + +from careamics.config.tile_information import TileInformation + + +def test_defaults(): + """Test instantiating time information with defaults.""" + tile_info = TileInformation(array_shape=np.zeros((6, 6)).shape) + + assert tile_info.array_shape == (6, 6) + assert not tile_info.tiled + assert not tile_info.last_tile + assert tile_info.overlap_crop_coords is None + assert tile_info.stitch_coords is None + + +def test_tiled(): + """Test instantiating time information with parameters.""" + tile_info = TileInformation( + array_shape=np.zeros((6, 6)).shape, + tiled=True, + last_tile=True, + overlap_crop_coords=((1, 2),), + stitch_coords=((3, 4),), + ) + + assert tile_info.array_shape == (6, 6) + assert tile_info.tiled + assert tile_info.last_tile + assert tile_info.overlap_crop_coords == ((1, 2),) + assert tile_info.stitch_coords == ((3, 4),) + + +def test_validation_last_tile(): + """Test that last tile is only set if tiled is set.""" + tile_info = TileInformation(array_shape=(6, 6), last_tile=True) + assert not tile_info.last_tile + + +def test_error_on_coords(): + """Test than an error is raised if it is tiled but not coordinates are given.""" + with pytest.raises(ValueError): + TileInformation(array_shape=(6, 6), tiled=True) + + +def test_error_on_singleton_dims(): + """Test that an error is raised if the array shape contains singleton dimensions.""" + with pytest.raises(ValueError): + TileInformation(array_shape=(2, 1, 6, 6)) diff --git a/tests/config/transformations/test_n2v_manipulate_model.py b/tests/config/transformations/test_n2v_manipulate_model.py new file mode 100644 index 000000000..6939182c8 --- /dev/null +++ b/tests/config/transformations/test_n2v_manipulate_model.py @@ -0,0 +1,26 @@ +import pytest + +from careamics.config.transformations.n2v_manipulate_model import N2VManipulateModel + + +def test_odd_roi_and_mask(): + """Test that errors are thrown if we pass even roi and mask sizes.""" + # no error + model = N2VManipulateModel(name="N2VManipulate", roi_size=3, struct_mask_span=7) + assert model.roi_size == 3 + assert model.struct_mask_span == 7 + + # errors + with pytest.raises(ValueError): + N2VManipulateModel(name="N2VManipulate", roi_size=4, struct_mask_span=7) + + with pytest.raises(ValueError): + N2VManipulateModel(name="N2VManipulate", roi_size=3, struct_mask_span=6) + + +def test_extra_parameters(): + """Test that errors are thrown if we pass extra parameters.""" + with pytest.raises(ValueError): + N2VManipulateModel( + name="N2VManipulate", roi_size=3, struct_mask_span=7, extra_param=1 + ) diff --git a/tests/config/transformations/test_normalize_model.py b/tests/config/transformations/test_normalize_model.py new file mode 100644 index 000000000..324b635fe --- /dev/null +++ b/tests/config/transformations/test_normalize_model.py @@ -0,0 +1,13 @@ +from careamics.config.transformations import NormalizeModel + + +def test_setting_mean_std(): + """Test that we can set the mean and std values.""" + model = NormalizeModel(name="Normalize", mean=0.5, std=0.5) + assert model.mean == 0.5 + assert model.std == 0.5 + + model.mean = 0.6 + model.std = 0.6 + assert model.mean == 0.6 + assert model.std == 0.6 diff --git a/tests/utils/test_axes.py b/tests/config/validators/test_validator_utils.py similarity index 60% rename from tests/utils/test_axes.py rename to tests/config/validators/test_validator_utils.py index 0c8e56135..740f10a51 100644 --- a/tests/utils/test_axes.py +++ b/tests/config/validators/test_validator_utils.py @@ -1,6 +1,9 @@ import pytest -from careamics.utils import check_axes_validity +from careamics.config.validators import ( + check_axes_validity, + patch_size_ge_than_8_power_of_2, +) @pytest.mark.parametrize( @@ -43,3 +46,25 @@ def test_are_axes_valid(axes, valid): else: with pytest.raises((ValueError, NotImplementedError)): check_axes_validity(axes) + + +@pytest.mark.parametrize( + "patch_size, error", + [ + ((2, 8, 8), True), + ((10,), True), + ((8, 10, 16), True), + ((8, 13), True), + ((8, 16, 4), True), + ((8,), False), + ((8, 8), False), + ((8, 64, 64), False), + ], +) +def test_patch_size(patch_size, error): + """Test if patch size is valid.""" + if error: + with pytest.raises(ValueError): + patch_size_ge_than_8_power_of_2(patch_size) + else: + patch_size_ge_than_8_power_of_2(patch_size) diff --git a/tests/conftest.py b/tests/conftest.py index b15e70b67..3e41040ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,12 @@ -import tempfile from pathlib import Path -from typing import Callable, Generator, Tuple +from typing import Callable, Tuple import numpy as np import pytest -import tifffile - -from careamics.config import Configuration -from careamics.config.algorithm_model import ( - AlgorithmModel, - LrSchedulerModel, - OptimizerModel, -) -from careamics.config.data_model import DataModel + +from careamics import CAREamist, Configuration from careamics.config.support import SupportedData -from careamics.config.training_model import TrainingModel +from careamics.model_io import export_to_bmz # TODO add details about where each of these fixture is used (e.g. smoke test) @@ -41,7 +33,7 @@ def minimum_algorithm_custom() -> dict: # create dictionary algorithm = { "algorithm": "custom", - "loss": "n2v", + "loss": "mae", "model": { "architecture": "UNet", }, @@ -103,9 +95,9 @@ def minimum_data() -> dict: """ # create dictionary data = { - "data_type": SupportedData.TIFF.value, - "patch_size": [64, 64], - "axes": "SYX", + "data_type": SupportedData.ARRAY.value, + "patch_size": [8, 8], + "axes": "YX", } return data @@ -122,10 +114,10 @@ def minimum_inference() -> dict: """ # create dictionary predic = { - "data_type": SupportedData.TIFF.value, - "tile_size": [64, 64], - "tile_overlap": [10, 10], - "axes": "SYX", + "data_type": SupportedData.ARRAY.value, + "axes": "YX", + "mean": 2.0, + "std": 1.0, } return predic @@ -142,7 +134,7 @@ def minimum_training() -> dict: """ # create dictionary training = { - "num_epochs": 666, + "num_epochs": 1, } return training @@ -181,6 +173,20 @@ def minimum_configuration( return configuration +@pytest.fixture +def supervised_configuration( + minimum_algorithm_supervised: dict, minimum_data: dict, minimum_training: dict +) -> dict: + configuration = { + "experiment_name": "LevitatingFrog", + "algorithm_config": minimum_algorithm_supervised, + "training_config": minimum_training, + "data_config": minimum_data, + } + + return configuration + + @pytest.fixture def ordered_array() -> Callable: """A function that returns an array with ordered values.""" @@ -227,17 +233,6 @@ def array_3D() -> np.ndarray: return np.arange(2048 * 3).reshape((1, 3, 8, 16, 16)) -@pytest.fixture -def temp_dir() -> Generator[Path, None, None]: - with tempfile.TemporaryDirectory() as temp_dir: - yield Path(temp_dir) - - -@pytest.fixture -def image_size() -> Tuple[int, int]: - return (128, 128) - - @pytest.fixture def patch_size() -> Tuple[int, int]: return (64, 64) @@ -249,66 +244,56 @@ def overlaps() -> Tuple[int, int]: @pytest.fixture -def example_data_path( - temp_dir: Path, image_size: Tuple[int, int], patch_size: Tuple[int, int] -) -> Tuple[Path, Path]: - test_image = np.random.rand(*image_size) +def pre_trained(tmp_path, minimum_configuration): + """Fixture to create a pre-trained CAREamics model.""" + # training data + train_array = np.arange(32 * 32).reshape((32, 32)) - train_path = temp_dir / "train" - val_path = temp_dir / "val" - test_path = temp_dir / "test" - train_path.mkdir() - val_path.mkdir() - test_path.mkdir() + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) - tifffile.imwrite(train_path / "train_image.tif", test_image) - tifffile.imwrite(val_path / "val_image.tif", test_image) - tifffile.imwrite(test_path / "test_image.tif", test_image) + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) - return train_path, val_path, test_path + # train CAREamist + careamist.train(train_source=train_array) + # check that it trained + pre_trained_path: Path = tmp_path / "checkpoints" / "last.ckpt" + assert pre_trained_path.exists() -@pytest.fixture -def base_configuration(temp_dir: Path, patch_size) -> Configuration: - configuration = Configuration( - experiment_name="smoke_test", - working_directory=temp_dir, - algorithm_config=AlgorithmModel( - algorithm="n2v", - loss="n2v", - model={"architecture": "UNet"}, - is_3D="False", - transforms={"Flip": None, "ManipulateN2V": None}, - ), - data_config=DataModel( - in_memory=True, - extension="tif", - axes="YX", - ), - training_config=TrainingModel( - num_epochs=1, - patch_size=patch_size, - batch_size=2, - optimizer=OptimizerModel(name="Adam"), - lr_scheduler=LrSchedulerModel(name="ReduceLROnPlateau"), - extraction_strategy="random", - augmentation=True, - num_workers=0, - use_wandb=False, - ), - ) - return configuration + return pre_trained_path @pytest.fixture -def supervised_configuration( - minimum_algorithm_supervised: dict, minimum_data: dict, minimum_training: dict -) -> dict: - configuration = { - "experiment_name": "LevitatingFrog", - "algorithm_config": minimum_algorithm_supervised, - "training_config": minimum_training, - "data_config": minimum_data, - } +def pre_trained_bmz(tmp_path, pre_trained) -> Path: + """Fixture to create a BMZ model.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # export to BioImage Model Zoo + path = tmp_path / "model.zip" + export_to_bmz( + model=careamist.model, + config=careamist.cfg, + path=path, + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + input_array=train_array[np.newaxis, np.newaxis, ...], + output_array=predicted, + ) + assert path.exists() - return configuration + return path diff --git a/tests/dataset/patching/test_random_patching.py b/tests/dataset/patching/test_random_patching.py index 6bc85f74f..858cfe2e0 100644 --- a/tests/dataset/patching/test_random_patching.py +++ b/tests/dataset/patching/test_random_patching.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from careamics.dataset.patching.patching import extract_patches_random +from careamics.dataset.patching.random_patching import extract_patches_random @pytest.mark.parametrize( diff --git a/tests/dataset/patching/test_tiled_patching.py b/tests/dataset/patching/test_tiled_patching.py index 6099d72bf..a7e135d4f 100644 --- a/tests/dataset/patching/test_tiled_patching.py +++ b/tests/dataset/patching/test_tiled_patching.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from careamics.config.tile_information import TileInformation from careamics.dataset.patching.tiled_patching import ( _compute_crop_and_stitch_coords_1d, extract_tiles, @@ -17,7 +18,13 @@ def check_extract_tiles(array: np.ndarray, tile_size, overlaps): # Assemble all tiles and their respective coordinates for tile_data in tile_data_generator: - tile, _, _, overlap_crop_coords, stitch_coords = tile_data + tile = tile_data[0] + + tile_info: TileInformation = tile_data[1] + overlap_crop_coords = tile_info.overlap_crop_coords + stitch_coords = tile_info.stitch_coords + + # add data to lists tiles.append(tile) all_overlap_crop_coords.append(overlap_crop_coords) all_stitch_coords.append(stitch_coords) diff --git a/tests/dataset/test_in_memory_dataset.py b/tests/dataset/test_in_memory_dataset.py index 384211e36..8d2c27381 100644 --- a/tests/dataset/test_in_memory_dataset.py +++ b/tests/dataset/test_in_memory_dataset.py @@ -15,7 +15,7 @@ def test_number_of_patches(ordered_array): # create config config_dict = { "data_type": SupportedData.ARRAY.value, - "patch_size": [4, 4], + "patch_size": [8, 8], "axes": "YX", } config = DataModel(**config_dict) @@ -40,12 +40,12 @@ def test_compute_mean_std_transform(ordered_array): def test_extracting_val_array(ordered_array, percentage): """Test extracting a validation set patches from InMemoryDataset.""" # create array - array = ordered_array((20, 20)) + array = ordered_array((32, 32)) # create config config_dict = { "data_type": SupportedData.ARRAY.value, - "patch_size": [4, 4], + "patch_size": [8, 8], "axes": "YX", } config = DataModel(**config_dict) @@ -76,7 +76,7 @@ def test_extracting_val_array(ordered_array, percentage): def test_extracting_val_files(tmp_path, ordered_array, percentage): """Test extracting a validation set patches from InMemoryDataset.""" # create array - array = ordered_array((20, 20)) + array = ordered_array((32, 32)) # save array to file file_path = tmp_path / "array.tif" @@ -85,7 +85,7 @@ def test_extracting_val_files(tmp_path, ordered_array, percentage): # create config config_dict = { "data_type": SupportedData.ARRAY.value, - "patch_size": [4, 4], + "patch_size": [8, 8], "axes": "YX", } config = DataModel(**config_dict) diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py index 351df036b..5347310cf 100644 --- a/tests/dataset/test_iterable_dataset.py +++ b/tests/dataset/test_iterable_dataset.py @@ -12,16 +12,16 @@ "shape", [ # 2D - (20, 20), + (32, 32), # 3D - (20, 20, 20), + (32, 32, 32), ], ) def test_number_of_files(tmp_path, ordered_array, shape): """Test number of files in PathIterableDataset.""" # create array - array_size = 20 - patch_size = 4 + array_size = 32 + patch_size = 8 n_files = 3 factor = len(shape) axes = "YX" if factor == 2 else "ZYX" @@ -63,13 +63,13 @@ def test_read_function(tmp_path, ordered_array): def read_npy(file_path, *args, **kwargs): return np.load(file_path) - array_size = 20 - patch_size = 4 + array_size = 32 + patch_size = 8 n_files = 3 patch_sizes = [patch_size] * 2 # create array - array = ordered_array((n_files, array_size, array_size)) + array: np.ndarray = ordered_array((n_files, array_size, array_size)) # save each plane in a single .npy file files = [] @@ -115,7 +115,7 @@ def test_extracting_val_files(tmp_path, ordered_array, percentage): # create config config_dict = { "data_type": SupportedData.TIFF.value, - "patch_size": [4, 4], + "patch_size": [8, 8], "axes": "YX", } config = DataModel(**config_dict) diff --git a/tests/model_io/test_bmz_io.py b/tests/model_io/test_bmz_io.py new file mode 100644 index 000000000..a031aa0d2 --- /dev/null +++ b/tests/model_io/test_bmz_io.py @@ -0,0 +1,66 @@ +import numpy as np +from torch import Tensor + +from careamics import CAREamist +from careamics.model_io import export_to_bmz, load_pretrained +from careamics.model_io.bmz_io import _export_state_dict, _load_state_dict + + +def test_state_dict_io(tmp_path, pre_trained): + """Test exporting and loading a state dict.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + path = tmp_path / "model.pth" + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # save model + _export_state_dict(careamist.model, path) + assert path.exists() + + # load model + _load_state_dict(careamist.model, path) + + # predict (no tiling and no tta) + predicted_loaded = careamist.predict(train_array, tta_transforms=False) + assert (predicted_loaded == predicted).all() + + +def test_bmz_io(tmp_path, pre_trained): + """Test exporting and loading to the BMZ.""" + # training data + train_array = np.ones((32, 32), dtype=np.float32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # predict (no tiling and no tta) + predicted = careamist.predict(train_array, tta_transforms=False) + + # export to BioImage Model Zoo + path = tmp_path / "model.zip" + export_to_bmz( + model=careamist.model, + config=careamist.cfg, + path=path, + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + input_array=train_array[np.newaxis, np.newaxis, ...], + output_array=predicted, + ) + assert path.exists() + + # load model + model, config = load_pretrained(path) + assert config == careamist.cfg + + # compare predictions + torch_array = Tensor(train_array[np.newaxis, np.newaxis, ...]) + predicted = careamist.model.forward(torch_array).detach().numpy().squeeze() + predicted_loaded = model.forward(torch_array).detach().numpy().squeeze() + assert (predicted_loaded == predicted).all() diff --git a/tests/models/test_model_factory.py b/tests/models/test_model_factory.py index 749524398..f4526aaee 100644 --- a/tests/models/test_model_factory.py +++ b/tests/models/test_model_factory.py @@ -42,7 +42,8 @@ def forward(self, input): model_config = { "architecture": SupportedArchitecture.CUSTOM.value, "name": "linear_model", - "parameters": {"in_features": 10, "out_features": 5}, + "in_features": 10, + "out_features": 5, } # instantiate model diff --git a/tests/prediction/test_prediction_utils.py b/tests/prediction/test_stitch_prediction.py similarity index 64% rename from tests/prediction/test_prediction_utils.py rename to tests/prediction/test_stitch_prediction.py index 041215264..4908af233 100644 --- a/tests/prediction/test_prediction_utils.py +++ b/tests/prediction/test_stitch_prediction.py @@ -2,7 +2,7 @@ from torch import from_numpy, tensor from careamics.dataset.patching.tiled_patching import extract_tiles -from careamics.prediction.prediction_utils import stitch_prediction +from careamics.prediction.stitch_prediction import stitch_prediction @pytest.mark.parametrize( @@ -22,23 +22,20 @@ def test_stitch_prediction(ordered_array, input_shape, tile_size, overlaps): stitching_data = [] # extract tiles - tiling_outputs = extract_tiles(arr, tile_size, overlaps) + tile_generator = extract_tiles(arr, tile_size, overlaps) # Assemble all tiles as it is done during the prediction stage - for tile_data in tiling_outputs: - tile, _, input_shape, overlap_crop_coords, stitch_coords = tile_data - - tiles.append(from_numpy(tile)) # need to convert to torch.Tensor + for tile_data, tile_info in tile_generator: + tiles.append(from_numpy(tile_data)) # need to convert to torch.Tensor stitching_data.append( ( # this is way too wacky [tensor(i) for i in input_shape], # need to convert to torch.Tensor - [[tensor([j]) for j in i] for i in overlap_crop_coords], - [[tensor([j]) for j in i] for i in stitch_coords], + [[tensor([j]) for j in i] for i in tile_info.overlap_crop_coords], + [[tensor([j]) for j in i] for i in tile_info.stitch_coords], ) ) - # compute stitching coordinates - + # compute stitching coordinates, it returns a torch.Tensor result = stitch_prediction(tiles, stitching_data) - assert (result == arr).all() + assert (result.numpy() == arr).all() diff --git a/tests/test_careamist.py b/tests/test_careamist.py index 07a768102..e53e7e0ee 100644 --- a/tests/test_careamist.py +++ b/tests/test_careamist.py @@ -7,32 +7,7 @@ from careamics import CAREamist, Configuration, save_configuration from careamics.config.support import SupportedAlgorithm, SupportedData - -@pytest.fixture -def pre_trained(tmp_path, minimum_configuration): - """Fixture to create a pre-trained CAREamics model.""" - # training data - train_array = np.ones((32, 32)) - - # create configuration - config = Configuration(**minimum_configuration) - config.training_config.num_epochs = 1 - config.data_config.axes = "YX" - config.data_config.batch_size = 2 - config.data_config.data_type = SupportedData.ARRAY.value - config.data_config.patch_size = (8, 8) - - # instantiate CAREamist - careamist = CAREamist(source=config, work_dir=tmp_path) - - # train CAREamist - careamist.train(train_source=train_array) - - # check that it trained - pre_trained_path: Path = tmp_path / "checkpoints" / "last.ckpt" - assert pre_trained_path.exists() - - return pre_trained_path +# TODO test 3D and channels def test_no_parameters(): @@ -41,7 +16,7 @@ def test_no_parameters(): CAREamist() -def test_minimum_configuration_via_object(tmp_path, minimum_configuration): +def test_minimum_configuration_via_object(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be instantiated with a minimum configuration object.""" # create configuration config = Configuration(**minimum_configuration) @@ -50,7 +25,7 @@ def test_minimum_configuration_via_object(tmp_path, minimum_configuration): CAREamist(source=config, work_dir=tmp_path) -def test_minimum_configuration_via_path(tmp_path, minimum_configuration): +def test_minimum_configuration_via_path(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be instantiated with a path to a minimum configuration. """ @@ -62,7 +37,9 @@ def test_minimum_configuration_via_path(tmp_path, minimum_configuration): CAREamist(source=path_to_config) -def test_train_error_target_unsupervised_algorithm(tmp_path, minimum_configuration): +def test_train_error_target_unsupervised_algorithm( + tmp_path: Path, minimum_configuration: dict +): """Test that an error is raised when a target is provided for N2V.""" # create configuration config = Configuration(**minimum_configuration) @@ -94,10 +71,10 @@ def test_train_error_target_unsupervised_algorithm(tmp_path, minimum_configurati ) -def test_train_single_array_no_val(tmp_path, minimum_configuration): +def test_train_single_array_no_val(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained with arrays.""" # training data - train_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) # create configuration config = Configuration(**minimum_configuration) @@ -116,12 +93,21 @@ def test_train_single_array_no_val(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() -def test_train_array(tmp_path, minimum_configuration): - """Test that CAREamics can be trained with arrays.""" + +def test_train_array(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained on arrays.""" # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) # create configuration config = Configuration(**minimum_configuration) @@ -140,11 +126,87 @@ def test_train_array(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_array_channel(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained on arrays with channels.""" + # training data + train_array = np.random.rand(32, 32, 3) + val_array = np.random.rand(32, 32, 3) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "YXC" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) -def test_train_tiff_files_in_memory_no_val(tmp_path, minimum_configuration): + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + channel_names=["red", "green", "blue"], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_array_3d(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can be trained on 3D arrays.""" + # training data + train_array = np.random.rand(8, 32, 32) + val_array = np.random.rand(8, 32, 32) + + # create configuration + minimum_configuration["data_config"]["axes"] = "ZYX" + minimum_configuration["data_config"]["patch_size"] = (8, 16, 16) + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array, val_source=val_array) + + # check that it trained + assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_train_tiff_files_in_memory_no_val(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained with tiff files in memory.""" # training data - train_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) # save files train_file = tmp_path / "train.tiff" @@ -167,12 +229,21 @@ def test_train_tiff_files_in_memory_no_val(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() -def test_train_tiff_files_in_memory(tmp_path, minimum_configuration): + +def test_train_tiff_files_in_memory(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained with tiff files in memory.""" # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) # save files train_file = tmp_path / "train.tiff" @@ -198,14 +269,23 @@ def test_train_tiff_files_in_memory(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + -def test_train_tiff_files(tmp_path, minimum_configuration): +def test_train_tiff_files(tmp_path: Path, minimum_configuration: dict): """Test that CAREamics can be trained with tiff files by deactivating the in memory dataset. """ # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) # save files train_file = tmp_path / "train.tiff" @@ -231,14 +311,23 @@ def test_train_tiff_files(tmp_path, minimum_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + -def test_train_array_supervised(tmp_path, supervised_configuration): +def test_train_array_supervised(tmp_path: Path, supervised_configuration: dict): """Test that CAREamics can be trained with arrays.""" # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) - train_target = np.ones((32, 32)) - val_target = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) # create configuration config = Configuration(**supervised_configuration) @@ -262,14 +351,25 @@ def test_train_array_supervised(tmp_path, supervised_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() -def test_train_tiff_files_in_memory_supervised(tmp_path, supervised_configuration): + +def test_train_tiff_files_in_memory_supervised( + tmp_path: Path, supervised_configuration: dict +): """Test that CAREamics can be trained with tiff files in memory.""" # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) - train_target = np.ones((32, 32)) - val_target = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) # save files images = tmp_path / "images" @@ -310,16 +410,25 @@ def test_train_tiff_files_in_memory_supervised(tmp_path, supervised_configuratio # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + -def test_train_tiff_files_supervised(tmp_path, supervised_configuration): +def test_train_tiff_files_supervised(tmp_path: Path, supervised_configuration: dict): """Test that CAREamics can be trained with tiff files by deactivating the in memory dataset. """ # training data - train_array = np.ones((32, 32)) - val_array = np.ones((32, 32)) - train_target = np.ones((32, 32)) - val_target = np.ones((32, 32)) + train_array = np.random.rand(32, 32) + val_array = np.random.rand(32, 32) + train_target = np.random.rand(32, 32) + val_target = np.random.rand(32, 32) # save files images = tmp_path / "images" @@ -361,10 +470,21 @@ def test_train_tiff_files_supervised(tmp_path, supervised_configuration): # check that it trained assert Path(tmp_path / "checkpoints" / "last.ckpt").exists() + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + @pytest.mark.parametrize("batch_size", [1, 2]) -def test_predict_array(tmp_path, minimum_configuration, batch_size): - """Test that CAREamics can predict with arrays.""" +def test_predict_on_array_tiled( + tmp_path: Path, minimum_configuration: dict, batch_size +): + """Test that CAREamics can predict on arrays.""" # training data train_array = np.random.rand(32, 32) @@ -384,22 +504,57 @@ def test_predict_array(tmp_path, minimum_configuration, batch_size): # predict CAREamist predicted = careamist.predict( - train_array, batch_size=batch_size, tile_overlap=(4, 4) + train_array, batch_size=batch_size, tile_size=(16, 16), tile_overlap=(4, 4) ) - # check thatmean/std were set properly - assert careamist.cfg.data_config.mean is not None - assert careamist.cfg.data_config.std is not None - assert careamist.cfg.data_config.mean == train_array.mean() - assert careamist.cfg.data_config.std == train_array.std() - # check prediction and its shape@pytest.mark.parametrize("batch_size", [1, 2]) + assert predicted.squeeze().shape == train_array.shape + + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_predict_arrays_no_tiling(tmp_path: Path, minimum_configuration: dict): + """Test that CAREamics can predict on arrays without tiling.""" + # training data + train_array = np.random.rand(4, 32, 32) + + # create configuration + config = Configuration(**minimum_configuration) + config.training_config.num_epochs = 1 + config.data_config.axes = "SYX" + config.data_config.batch_size = 2 + config.data_config.data_type = SupportedData.ARRAY.value + config.data_config.patch_size = (8, 8) + + # instantiate CAREamist + careamist = CAREamist(source=config, work_dir=tmp_path) + + # train CAREamist + careamist.train(train_source=train_array) + + # predict CAREamist + predicted = careamist.predict(train_array) - assert predicted is not None assert predicted.squeeze().shape == train_array.shape + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + @pytest.mark.parametrize("batch_size", [1, 2]) -def test_predict_path(tmp_path, minimum_configuration, batch_size): +def test_predict_path(tmp_path: Path, minimum_configuration: dict, batch_size): """Test that CAREamics can predict with tiff files.""" # training data train_array = np.random.rand(32, 32) @@ -423,27 +578,114 @@ def test_predict_path(tmp_path, minimum_configuration, batch_size): careamist.train(train_source=train_file) # predict CAREamist - predicted = careamist.predict( - train_file, batch_size=batch_size, tile_overlap=(4, 4) - ) + predicted = careamist.predict(train_file, batch_size=batch_size) # check that it predicted - assert predicted is not None assert predicted.squeeze().shape == train_array.shape + # export to BMZ + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() -def test_predict_pretrained(tmp_path, pre_trained): + +def test_predict_pretrained_checkpoint(tmp_path: Path, pre_trained: Path): """Test that CAREamics can be instantiated with a pre-trained network and predict on an array.""" - # training data - train_array = np.ones((32, 32)) + # prediction data + source_array = np.random.rand(32, 32) # instantiate CAREamist careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + assert careamist.cfg.data_config.mean is not None + assert careamist.cfg.data_config.std is not None # predict - predicted = careamist.predict(train_array, tile_overlap=(4, 4)) + predicted = careamist.predict(source_array) # check that it predicted - assert predicted is not None - assert predicted.squeeze().shape == train_array.shape + assert predicted.squeeze().shape == source_array.shape + + +def test_predict_pretrained_bmz(tmp_path: Path, pre_trained_bmz: Path): + """Test that CAREamics can be instantiated with a BMZ archive and predict.""" + # prediction data + source_array = np.random.rand(32, 32) + + # instantiate CAREamist + careamist = CAREamist(source=pre_trained_bmz, work_dir=tmp_path) + + # predict + predicted = careamist.predict(source_array) + + # check that it predicted + assert predicted.squeeze().shape == source_array.shape + + +def test_export_bmz_pretrained_prediction(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ after prediction. + + In this case, the careamist extracts the BMZ test data from the prediction + datamodule. + """ + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # prediction data + source_array = np.random.rand(32, 32) + _ = careamist.predict(source_array) + assert len(careamist.pred_datamodule.predict_dataloader()) > 0 + + # export to BMZ (random array created) + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_export_bmz_pretrained_random_array(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ. + + In this case, the careamist creates a random array for the BMZ archive test. + """ + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # export to BMZ (random array created) + careamist.export_to_bmz( + path=tmp_path / "model.zip", + name="TopModel", + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model.zip").exists() + + +def test_export_bmz_pretrained_with_array(tmp_path: Path, pre_trained: Path): + """Test that CAREamics can be instantiated with a pre-trained network and exported + to BMZ. + + In this case, we provide an array to the BMZ archive test. + """ + # instantiate CAREamist + careamist = CAREamist(source=pre_trained, work_dir=tmp_path) + + # alternatively we can pass an array + array = np.random.rand(32, 32).astype(np.float32) + careamist.export_to_bmz( + path=tmp_path / "model2.zip", + name="TopModel", + input_array=array[np.newaxis, np.newaxis, ...], + general_description="A model that just walked in.", + authors=[{"name": "Amod", "affiliation": "El"}], + ) + assert (tmp_path / "model2.zip").exists() diff --git a/tests/test_lightning_datamodule.py b/tests/test_lightning_datamodule.py index 27bc292df..33444971d 100644 --- a/tests/test_lightning_datamodule.py +++ b/tests/test_lightning_datamodule.py @@ -27,9 +27,10 @@ def test_lightning_train_datamodule_array(simple_array): data_module = CAREamicsTrainDataModule( train_data=simple_array, data_type="array", - patch_size=(2, 2), + patch_size=(8, 8), axes="YX", batch_size=2, + val_minimum_patches=2, ) data_module.prepare_data() data_module.setup() @@ -48,6 +49,7 @@ def test_lightning_train_datamodule_supervised_n2v_throws_error(simple_array): axes="YX", batch_size=2, train_target_data=simple_array, + val_minimum_patches=2, ) @@ -63,12 +65,12 @@ def test_lightning_train_datamodule_n2v2(simple_array, use_n2v2, strategy): data_module = CAREamicsTrainDataModule( train_data=simple_array, data_type="array", - patch_size=(10, 10), + patch_size=(16, 16), axes="YX", batch_size=2, use_n2v2=use_n2v2, ) - assert data_module.data_config.transforms[-1].parameters.strategy == strategy + assert data_module.data_config.transforms[-1].strategy == strategy def test_lightning_train_datamodule_structn2v(simple_array): @@ -79,20 +81,14 @@ def test_lightning_train_datamodule_structn2v(simple_array): data_module = CAREamicsTrainDataModule( train_data=simple_array, data_type="array", - patch_size=(10, 10), + patch_size=(16, 16), axes="YX", batch_size=2, struct_n2v_axis=struct_axis, struct_n2v_span=struct_span, ) - assert ( - data_module.data_config.transforms[-1].parameters.struct_mask_axis - == struct_axis - ) - assert ( - data_module.data_config.transforms[-1].parameters.struct_mask_span - == struct_span - ) + assert data_module.data_config.transforms[-1].struct_mask_axis == struct_axis + assert data_module.data_config.transforms[-1].struct_mask_span == struct_span def test_lightning_predict_datamodule_wrong_type(simple_array): @@ -101,40 +97,44 @@ def test_lightning_predict_datamodule_wrong_type(simple_array): CAREamicsPredictDataModule( pred_data=simple_array, data_type="wrong_type", - tile_size=(10, 10), - tile_overlap=(2, 2), + mean=0.5, + std=0.1, axes="YX", batch_size=2, ) -def test_lightning_pred_datamodule_error_no_mean(simple_array): +def test_lightning_pred_datamodule_tiling(simple_array): """Test that the data module is created correctly with an array.""" # create data module - CAREamicsPredictDataModule( + data_module = CAREamicsPredictDataModule( pred_data=simple_array, data_type="array", - tile_size=(10, 10), - tile_overlap=(2, 2), + mean=0.5, + std=0.1, axes="YX", batch_size=2, + tile_overlap=[2, 2], + tile_size=[8, 8], ) + data_module.prepare_data() + data_module.setup() + assert len(list(data_module.predict_dataloader())) == 2 + -def test_lightning_pred_datamodule_array(simple_array): +def test_lightning_pred_datamodule_no_tiling(simple_array): """Test that the data module is created correctly with an array.""" # create data module data_module = CAREamicsPredictDataModule( pred_data=simple_array, data_type="array", - tile_size=(10, 10), - tile_overlap=(2, 2), - axes="YX", - batch_size=2, mean=0.5, std=0.1, + axes="YX", + batch_size=2, ) + data_module.prepare_data() data_module.setup() - - assert len(list(data_module.predict_dataloader())) > 0 + assert len(list(data_module.predict_dataloader())) == 1 diff --git a/tests/test_lightning_module.py b/tests/test_lightning_module.py index 14c5d8228..afe2fab09 100644 --- a/tests/test_lightning_module.py +++ b/tests/test_lightning_module.py @@ -1,15 +1,17 @@ +import pytest +import torch + from careamics.config import AlgorithmModel from careamics.lightning_module import CAREamicsKiln, CAREamicsModule def test_careamics_module(minimum_algorithm_n2v): - """Test that the minimum algorithm allows isntantiating a the Lightning API + """Test that the minimum algorithm allows instantiating a the Lightning API intermediate layer.""" algo_config = AlgorithmModel(**minimum_algorithm_n2v) # extract model parameters model_parameters = algo_config.model.model_dump(exclude_none=True) - model_parameters.pop("architecture") # instantiate CAREamicsModule CAREamicsModule( @@ -30,3 +32,227 @@ def test_careamics_kiln(minimum_algorithm_n2v): # instantiate CAREamicsKiln CAREamicsKiln(algo_config) + + +@pytest.mark.parametrize( + "shape", + [ + (8, 8), + (16, 16), + (32, 32), + ], +) +def test_careamics_kiln_unet_2D_depth_2_shape(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": 1, + "num_classes": 1, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize( + "shape", + [ + (8, 8), + (16, 16), + (32, 32), + (64, 64), + (128, 128), + (256, 256), + ], +) +def test_careamics_kiln_unet_2D_depth_3_shape(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": 1, + "num_classes": 1, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize( + "shape", + [ + (8, 32, 16), + (16, 32, 16), + (8, 32, 32), + (32, 64, 64), + ], +) +def test_careamics_kiln_unet_depth_2_3D(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 3, + "in_channels": 1, + "num_classes": 1, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize( + "shape", + [ + (8, 64, 64), + (16, 64, 64), + (16, 128, 128), + (32, 128, 128), + ], +) +def test_careamics_kiln_unet_depth_3_3D(shape): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 3, + "in_channels": 1, + "num_classes": 1, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, 1, *shape)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_2_channels_2D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, n_channels, 32, 32)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +def test_careamics_kiln_unet_depth_3_channels_2D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, n_channels, 64, 64)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_2_channels_3D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 2, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, n_channels, 16, 32, 32)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape + + +@pytest.mark.parametrize("n_channels", [1, 3, 4]) +def test_careamics_kiln_unet_depth_3_channels_3D(n_channels): + algo_dict = { + "algorithm": "n2n", + "model": { + "architecture": "UNet", + "conv_dims": 2, + "in_channels": n_channels, + "num_classes": n_channels, + "depth": 3, + }, + "loss": "mae", + } + algo_config = AlgorithmModel(**algo_dict) + + # instantiate CAREamicsKiln + model = CAREamicsKiln(algo_config) + + # test forward pass + x = torch.rand((1, n_channels, 16, 64, 64)) + y: torch.Tensor = model.forward(x) + assert y.shape == x.shape diff --git a/tests/transforms/test_normalize.py b/tests/transforms/test_normalize.py new file mode 100644 index 000000000..da61a9119 --- /dev/null +++ b/tests/transforms/test_normalize.py @@ -0,0 +1,30 @@ +import numpy as np + +from careamics.transforms import Denormalize, Normalize + + +def test_normalize_denormalize(): + """Test the Normalize transform.""" + # Create data + array = np.arange(100).reshape((1, 1, 10, 10)) + + # Create the transform + norm = Normalize( + mean=50, + std=25, + ) + + # Apply the transform + normalized: np.array = norm(image=array)["image"] + assert np.abs(normalized.mean()) < 0.02 + assert np.abs(normalized.std() - 1) < 0.2 + + # Create the denormalize transform + denorm = Denormalize( + mean=50, + std=25, + ) + + # Apply the denormalize transform + denormalized: np.array = denorm(image=normalized)["image"] + assert np.isclose(denormalized, array).all() diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py index 4d6c85d68..b91e66909 100644 --- a/tests/utils/test_metrics.py +++ b/tests/utils/test_metrics.py @@ -2,7 +2,6 @@ import pytest from careamics.utils.metrics import ( - MetricTracker, _zero_mean, scale_invariant_psnr, ) @@ -29,24 +28,3 @@ def test_zero_mean(x): ) def test_scale_invariant_psnr(gt, pred, result): assert scale_invariant_psnr(gt, pred) == pytest.approx(result, rel=5e-3) - - -def test_metric_tracker(): - tracker = MetricTracker() - - # check initial state - assert tracker.sum == 0 - assert tracker.count == 0 - assert tracker.avg == 0 - assert tracker.val == 0 - - # run a few updates - n = 5 - for i in range(n): - tracker.update(i, n) - - # check values - assert tracker.sum == n * (n * (n - 1)) / 2 - assert tracker.count == n * n - assert tracker.avg == (n - 1) / 2 - assert tracker.val == n - 1