From aa79625325e43971c8fe3128a64b1fef55abf36b Mon Sep 17 00:00:00 2001 From: Melisande Croft <63270704+melisande-c@users.noreply.github.com> Date: Fri, 15 Nov 2024 07:49:34 +0000 Subject: [PATCH 1/7] Feature: Load model from BMZ using URL (#273) ### Description - **What**: Now it is possible to pass a URL to the `load_from_bmz` function to download and load BMZ files. - **Why**: Not many users will have access to the model resource URLs, but this functionality is useful for developing the CAREamics BMZ compatibility script. - **How**: - Type hint `path` as also `pydantic.HttpUrl` in `load_from_bmz` (as in `bioimage.core`); - Remove `path` validation checks from `load_from_bmz` and allow it to be handled in `load_model_description` - Call `download` on the file resources to download and get the correct path. ### Changes Made - **Modified**: - `load_from_bmz` - `extract_model_path` ### Additional Notes and Examples This will have merge conflicts with #271. There are currently no official tests (it does work), we can discuss using the URL of one of the existing CAREamics models uploaded to the BMZ or create a Mock. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --- .../model_io/bioimage/model_description.py | 12 +++++++--- src/careamics/model_io/bmz_io.py | 23 +++++++++---------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index b4ba8d571..21ed50b8f 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -302,11 +302,17 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]: tuple of (path, path) Weights and configuration paths. """ - weights_path = model_desc.weights.pytorch_state_dict.source.path + if model_desc.weights.pytorch_state_dict is None: + raise ValueError("No model weights found in model description.") + weights_path = model_desc.weights.pytorch_state_dict.download().path for file in model_desc.attachments: - if file.source.path.name == "careamics.yaml": - config_path = file.source.path + file_path = file.source if isinstance(file.source, Path) else file.source.path + if file_path is None: + continue + file_path = Path(file_path) + if file_path.name == "careamics.yaml": + config_path = file.download().path break else: raise ValueError("Configuration file not found.") diff --git a/src/careamics/model_io/bmz_io.py b/src/careamics/model_io/bmz_io.py index d19a02a78..dc4564ecc 100644 --- a/src/careamics/model_io/bmz_io.py +++ b/src/careamics/model_io/bmz_io.py @@ -6,8 +6,9 @@ import numpy as np import pkg_resources -from bioimageio.core import load_description, test_model +from bioimageio.core import load_model_description, test_model from bioimageio.spec import ValidationSummary, save_bioimageio_package +from pydantic import HttpUrl from torch import __version__ as PYTORCH_VERSION from torch import load, save from torchvision import __version__ as TORCHVISION_VERSION @@ -193,32 +194,30 @@ def export_to_bmz( def load_from_bmz( - path: Union[Path, str] + path: Union[Path, str, HttpUrl] ) -> Tuple[Union[FCNModule, VAEModule], Configuration]: """Load a model from a BioImage Model Zoo archive. Parameters ---------- - path : Union[Path, str] - Path to the BioImage Model Zoo archive. + path : Path, str or HttpUrl + Path to the BioImage Model Zoo archive. A Http URL must point to a downloadable + location. Returns ------- - Tuple[CAREamicsKiln, Configuration] - CAREamics model and configuration. + FCNModel or VAEModel + The loaded CAREamics model. + Configuration + The loaded CAREamics configuration. Raises ------ ValueError If the path is not a zip file. """ - path = Path(path) - - if path.suffix != ".zip": - raise ValueError(f"Path must be a bioimage.io zip file, got {path}.") - # load description, this creates an unzipped folder next to the archive - model_desc = load_description(path) + model_desc = load_model_description(path) # extract relative paths weights_path, config_path = extract_model_path(model_desc) From c7e29126129174287bbba5c51f1e948924165bca Mon Sep 17 00:00:00 2001 From: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> Date: Fri, 15 Nov 2024 09:01:09 +0100 Subject: [PATCH 2/7] feat: Enable prediction step during training (#266) ### Description Following https://github.com/CAREamics/careamics/issues/148, I have been exploring how to predict during training. This PR would allow adding `Callback` that use `predict_step` during Training. - **What**: Allow callbacks to call `predict_step` during training. - **Why**: Some applications might require predicting consistently on full images to assess training performances throughout training. - **How**: Modified `FCNModule.predict_step` to make it compatible with a `TrainDataModule` (all calls to `trainer.datamodule` were written with the expectation that it returns a `PredictDataModule`. ### Changes Made - **Modified**: `lightning_module.py`, `test_lightning_module.py` ### Related Issues - Resolves https://github.com/CAREamics/careamics/issues/148 ### Additional Notes and Examples ```python import numpy as np from pytorch_lightning import Callback, Trainer from careamics import CAREamist, Configuration from careamics.lightning import PredictDataModule, create_predict_datamodule from careamics.prediction_utils import convert_outputs config = Configuration(**minimum_configuration) class CustomPredictAfterValidationCallback(Callback): def __init__(self, pred_datamodule: PredictDataModule): self.pred_datamodule = pred_datamodule # prepare data and setup self.pred_datamodule.prepare_data() self.pred_datamodule.setup() self.pred_dataloader = pred_datamodule.predict_dataloader() def on_validation_epoch_end(self, trainer: Trainer, pl_module): if trainer.sanity_checking: # optional skip return # update statistics in the prediction dataset for coherence # (they can computed on-line by the training dataset) self.pred_datamodule.predict_dataset.image_means = ( trainer.datamodule.train_dataset.image_stats.means ) self.pred_datamodule.predict_dataset.image_stds = ( trainer.datamodule.train_dataset.image_stats.stds ) # predict on the dataset outputs = [] for idx, batch in enumerate(self.pred_dataloader): batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0) outputs.append(pl_module.predict_step(batch, batch_idx=idx)) data = convert_outputs(outputs, self.pred_datamodule.tiled) # can save data here array = np.arange(32 * 32).reshape((32, 32)) pred_datamodule = create_predict_datamodule( pred_data=array, data_type=config.data_config.data_type, axes=config.data_config.axes, image_means=[11.8], # random placeholder image_stds=[3.14], # can do tiling here ) predict_after_val_callback = CustomPredictAfterValidationCallback( pred_datamodule=pred_datamodule ) engine = CAREamist(config, callbacks=[predict_after_val_callback]) engine.train(train_source=array) ``` Currently, this current implementation is not fully satisfactory and here are a few important points: - For this PR to work we need to discriminate between `TrainDataModule` and `PredictDataModule` in `predict_step`, which is a bit of a hack as it currently check `hasattr(..., "tiled")`. The reason is to avoid a circular import of `PredictDataModule`. We should revisit that. - `TrainDataModule` and `PredictDataModule` have incompatible members: `PredictDataModule` has `.tiled`, and the two have different naming conventions for the statistics (`PredictDataModule` has `image_means` and `image_stds`, while `TrainDataModule` has them wrapped in a `stats` dataclass). These statistics are retrieved either through `_trainer.datamodule.predict_dataset` or `_trainer.datamodule.train_dataset`. - We do not provide the `Callable` that would allow to use such feature. We might want to some heavy lifting here as well (see example). - Probably the most serious issue, normalization is done in the datasets but denormalization is performed in the `predict_step`. In our case, that means that normalization could be applied by a `PredictDataModule` (in the `Callback` and the denormalization by the `TrainDataModule` (in `predict_step`). That is incoherent and due to the way we wrote CAREamics. All in all, this draft exemplifies two problems with CAREamics: - `TrainDataModule` and `PredictDataModule` have different members - Normalization is done by the `DataModule` but denormalization by `LightningModule` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [x] PR to the documentation exists (for bug fixes / features) --- src/careamics/lightning/lightning_module.py | 31 ++++++++-- tests/lightning/test_lightning_module.py | 64 +++++++++++++++++++++ 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/src/careamics/lightning/lightning_module.py b/src/careamics/lightning/lightning_module.py index 438d80809..d9b65ff8d 100644 --- a/src/careamics/lightning/lightning_module.py +++ b/src/careamics/lightning/lightning_module.py @@ -14,6 +14,7 @@ SupportedOptimizer, SupportedScheduler, ) +from careamics.config.tile_information import TileInformation from careamics.losses import loss_factory from careamics.models.lvae.likelihoods import ( GaussianLikelihood, @@ -163,7 +164,17 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: Any Model output. """ - if self._trainer.datamodule.tiled: + # TODO refactor when redoing datasets + # hacky way to determine if it is PredictDataModule, otherwise there is a + # circular import to solve with isinstance + from_prediction = hasattr(self._trainer.datamodule, "tiled") + is_tiled = ( + len(batch) > 1 + and isinstance(batch[1], list) + and isinstance(batch[1][0], TileInformation) + ) + + if is_tiled: x, *aux = batch else: x = batch @@ -171,7 +182,10 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: # apply test-time augmentation if available # TODO: probably wont work with batch size > 1 - if self._trainer.datamodule.prediction_config.tta_transforms: + if ( + from_prediction + and self._trainer.datamodule.prediction_config.tta_transforms + ): tta = ImageRestorationTTA() augmented_batch = tta.forward(x) # list of augmented tensors augmented_output = [] @@ -183,9 +197,18 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any: output = self.model(x) # Denormalize the output + # TODO incompatible API between predict and train datasets denorm = Denormalize( - image_means=self._trainer.datamodule.predict_dataset.image_means, - image_stds=self._trainer.datamodule.predict_dataset.image_stds, + image_means=( + self._trainer.datamodule.predict_dataset.image_means + if from_prediction + else self._trainer.datamodule.train_dataset.image_stats.means + ), + image_stds=( + self._trainer.datamodule.predict_dataset.image_stds + if from_prediction + else self._trainer.datamodule.train_dataset.image_stats.stds + ), ) denormalized_output = denorm(patch=output.cpu().numpy()) diff --git a/tests/lightning/test_lightning_module.py b/tests/lightning/test_lightning_module.py index 0b0eaafcf..fc8094875 100644 --- a/tests/lightning/test_lightning_module.py +++ b/tests/lightning/test_lightning_module.py @@ -308,3 +308,67 @@ def test_fcn_module_unet_depth_3_channels_3D(n_channels): x = torch.rand((1, n_channels, 16, 64, 64)) y: torch.Tensor = model.forward(x) assert y.shape == x.shape + + +@pytest.mark.parametrize("tiled", [False, True]) +def test_prediction_callback_during_training(minimum_configuration, tiled): + import numpy as np + from pytorch_lightning import Callback, Trainer + + from careamics import CAREamist, Configuration + from careamics.lightning import PredictDataModule, create_predict_datamodule + from careamics.prediction_utils import convert_outputs + + config = Configuration(**minimum_configuration) + + class CustomPredictAfterValidationCallback(Callback): + def __init__(self, pred_datamodule: PredictDataModule): + self.pred_datamodule = pred_datamodule + + # prepare data and setup + self.pred_datamodule.prepare_data() + self.pred_datamodule.setup() + self.pred_dataloader = pred_datamodule.predict_dataloader() + + self.data = None + + def on_validation_epoch_end(self, trainer: Trainer, pl_module): + if trainer.sanity_checking: # optional skip + return + + # update statistics in the prediction dataset for coherence + # (they can computed on-line by the training dataset) + self.pred_datamodule.predict_dataset.image_means = ( + trainer.datamodule.train_dataset.image_stats.means + ) + self.pred_datamodule.predict_dataset.image_stds = ( + trainer.datamodule.train_dataset.image_stats.stds + ) + + # predict on the dataset + outputs = [] + for idx, batch in enumerate(self.pred_dataloader): + batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0) + outputs.append(pl_module.predict_step(batch, batch_idx=idx)) + + self.data = convert_outputs(outputs, self.pred_datamodule.tiled) + + array = np.arange(64 * 64).reshape((64, 64)) + pred_datamodule = create_predict_datamodule( + pred_data=array, + data_type=config.data_config.data_type, + axes=config.data_config.axes, + image_means=[11.8], # random placeholder + image_stds=[3.14], + tile_size=(16, 16) if tiled else None, + tile_overlap=(8, 8) if tiled else None, + batch_size=2, + ) + + predict_after_val_callback = CustomPredictAfterValidationCallback( + pred_datamodule=pred_datamodule + ) + engine = CAREamist(config, callbacks=[predict_after_val_callback]) + engine.train(train_source=array) + + assert not np.allclose(array, predict_after_val_callback.data) From b29fc6c95c65026c6169fbb35f1f9b4f31383568 Mon Sep 17 00:00:00 2001 From: Federico Carrara <74301866+federico-carrara@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:08:43 +0100 Subject: [PATCH 3/7] refac: added possibility to pick `kl_restricted` loss & other loss-related refactoring (#272) ### Description In some LVAE training examples (see `microSplit_reproducibility` repo) there is the need to consider the *restricted KL loss*, instead of the simple sample-wise one. This PR allows the user to pick that one. - **What**: The KL loss type is no longer hardcoded in the loss functions. Now it is possible to pick also the `restricted_kl` KL loss type. - **Why**: It is needed for some experiments/examples. - **How**: added an input parameter to the KL loss functions. ### Breaking changes The `kl_type` parameter is added in the loss functions, so we need to be careful of correctly specifying it in the examples in the `microSplit_reproducibility` repo. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --- .../config/architectures/lvae_model.py | 5 +- src/careamics/config/loss_model.py | 3 +- src/careamics/losses/loss_factory.py | 95 +------------------ src/careamics/losses/lvae/losses.py | 23 ++++- src/careamics/models/lvae/layers.py | 6 -- src/careamics/models/lvae/lvae.py | 3 - src/careamics/models/lvae/stochastic.py | 19 ++-- tests/losses/test_lvae_losses.py | 31 +++++- 8 files changed, 57 insertions(+), 128 deletions(-) diff --git a/src/careamics/config/architectures/lvae_model.py b/src/careamics/config/architectures/lvae_model.py index 9e9b20a94..8881845f0 100644 --- a/src/careamics/config/architectures/lvae_model.py +++ b/src/careamics/config/architectures/lvae_model.py @@ -38,10 +38,7 @@ class LVAEModel(ArchitectureModel): ) predict_logvar: Literal[None, "pixelwise"] = None - - analytical_kl: bool = Field( - default=False, - ) + analytical_kl: bool = Field(default=False) @model_validator(mode="after") def validate_conv_strides(self: Self) -> Self: diff --git a/src/careamics/config/loss_model.py b/src/careamics/config/loss_model.py index 2b8793fd2..53ded0c23 100644 --- a/src/careamics/config/loss_model.py +++ b/src/careamics/config/loss_model.py @@ -10,7 +10,7 @@ class KLLossConfig(BaseModel): model_config = ConfigDict(validate_assignment=True, validate_default=True) - loss_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl" + loss_type: Literal["kl", "kl_restricted"] = "kl" """Type of KL divergence used as KL loss.""" rescaling: Literal["latent_dim", "image_dim"] = "latent_dim" """Rescaling of the KL loss.""" @@ -48,7 +48,6 @@ class LVAELossConfig(BaseModel): """Weight for the muSplit loss (used in the muSplit-denoiSplit loss).""" denoisplit_weight: float = 0.9 """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss).""" - kl_params: KLLossConfig = KLLossConfig() """KL loss configuration.""" diff --git a/src/careamics/losses/loss_factory.py b/src/careamics/losses/loss_factory.py index eb5236935..2aaac250f 100644 --- a/src/careamics/losses/loss_factory.py +++ b/src/careamics/losses/loss_factory.py @@ -6,9 +6,8 @@ from __future__ import annotations -import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Literal, Optional, Union +from typing import Callable, Union from torch import Tensor as tensor @@ -16,12 +15,6 @@ from .fcn.losses import mae_loss, mse_loss, n2v_loss from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss -if TYPE_CHECKING: - from careamics.models.lvae.likelihoods import ( - GaussianLikelihood, - NoiseModelLikelihood, - ) - @dataclass class FCNLossParameters: @@ -35,92 +28,6 @@ class FCNLossParameters: loss_weight: float -@dataclass # TODO why not pydantic? -class LVAELossParameters: - """Dataclass for LVAE loss.""" - - # TODO: refactor in more modular blocks (otherwise it gets messy very easily) - # e.g., - weights, - kl_params, ... - - # General params - noise_model_likelihood: Optional[NoiseModelLikelihood] = None - """Noise model likelihood instance.""" - gaussian_likelihood: Optional[GaussianLikelihood] = None - """Gaussian likelihood instance.""" - current_epoch: int = 0 - """Current epoch in the training loop.""" - reconstruction_weight: float = 1.0 - """Weight for the reconstruction loss in the total net loss - (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`).""" - kl_weight: float = 1.0 - """Weight for the KL loss in the total net loss. - (i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`).""" - musplit_weight: float = 0.1 - """Weight for the muSplit loss (used in the muSplit-denoiSplit loss).""" - denoisplit_weight: float = 0.9 - """Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss).""" - - # KL params - kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl" - """Type of KL divergence used as KL loss.""" - kl_rescaling: Literal["latent_dim", "image_dim"] = "latent_dim" - """Rescaling of the KL loss.""" - kl_aggregation: Literal["sum", "mean"] = "mean" - """Aggregation of the KL loss across different layers.""" - kl_free_bits_coeff: float = 0.0 - """Free bits coefficient for the KL loss.""" - kl_annealing: bool = False - """Whether to apply KL loss annealing.""" - kl_start: int = -1 - """Epoch at which KL loss annealing starts.""" - kl_annealtime: int = 10 - """Number of epochs for which KL loss annealing is applied.""" - non_stochastic: bool = False - """Whether to sample latents and compute KL.""" - - def __post_init__(self): - """Raise a deprecation warning.""" - warnings.warn( - f"{self.__class__.__name__} is deprecated. Please use`LVAELossConfig`.", - DeprecationWarning, - stacklevel=2, - ) - - -def loss_parameters_factory( - type: SupportedLoss, -) -> Union[FCNLossParameters, LVAELossParameters]: - """Return loss parameters. - - Parameters - ---------- - type : SupportedLoss - Requested loss. - - Returns - ------- - Union[FCNLossParameters, LVAELossParameters] - Loss parameters. - - Raises - ------ - NotImplementedError - If the loss is unknown. - """ - if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]: - return FCNLossParameters - - elif type in [ - SupportedLoss.MUSPLIT, - SupportedLoss.DENOISPLIT, - SupportedLoss.DENOISPLIT_MUSPLIT, - ]: - return LVAELossParameters # it returns the class, not an instance - - else: - raise NotImplementedError(f"Loss {type} is not yet supported.") - - def loss_factory(loss: Union[SupportedLoss, str]) -> Callable: """Return loss function. diff --git a/src/careamics/losses/lvae/losses.py b/src/careamics/losses/lvae/losses.py index 474ee9d8b..310d0bb80 100644 --- a/src/careamics/losses/lvae/losses.py +++ b/src/careamics/losses/lvae/losses.py @@ -107,6 +107,7 @@ def _reconstruction_loss_musplit_denoisplit( def get_kl_divergence_loss( + kl_type: Literal["kl", "kl_restricted"], topdown_data: dict[str, torch.Tensor], rescaling: Literal["latent_dim", "image_dim"], aggregation: Literal["mean", "sum"], @@ -135,6 +136,8 @@ def get_kl_divergence_loss( Parameters ---------- + kl_type : Literal["kl", "kl_restricted"] + The type of KL divergence loss to compute. topdown_data : dict[str, torch.Tensor] A dictionary containing information computed for each layer during the top-down pass. The dictionary must include the following keys: @@ -161,7 +164,7 @@ def get_kl_divergence_loss( The KL divergence loss. Shape is (1, ). """ kl = torch.cat( - [kl_layer.unsqueeze(1) for kl_layer in topdown_data["kl"]], + [kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_type]], dim=1, ) # shape: (B, n_layers) @@ -194,6 +197,7 @@ def get_kl_divergence_loss( def _get_kl_divergence_loss_musplit( topdown_data: dict[str, torch.Tensor], img_shape: tuple[int], + kl_type: Literal["kl", "kl_restricted"], ) -> torch.Tensor: """Compute the KL divergence loss for muSplit. @@ -207,6 +211,8 @@ def _get_kl_divergence_loss_musplit( (B, layers, `z_dims[i]`, H, W). img_shape : tuple[int] The shape of the input image to the LVAE model. Shape is ([Z], Y, X). + kl_type : Literal["kl", "kl_restricted"] + The type of KL divergence loss to compute. Returns ------- @@ -214,6 +220,7 @@ def _get_kl_divergence_loss_musplit( The KL divergence loss for the muSplit case. Shape is (1, ). """ return get_kl_divergence_loss( + kl_type=kl_type, topdown_data=topdown_data, rescaling="latent_dim", aggregation="mean", @@ -225,6 +232,7 @@ def _get_kl_divergence_loss_musplit( def _get_kl_divergence_loss_denoisplit( topdown_data: dict[str, torch.Tensor], img_shape: tuple[int], + kl_type: Literal["kl", "kl_restricted"], ) -> torch.Tensor: """Compute the KL divergence loss for denoiSplit. @@ -238,6 +246,8 @@ def _get_kl_divergence_loss_denoisplit( (B, layers, `z_dims[i]`, H, W). img_shape : tuple[int] The shape of the input image to the LVAE model. Shape is ([Z], Y, X). + kl_type : Literal["kl", "kl_restricted"] + The type of KL divergence loss to compute. Returns ------- @@ -245,6 +255,7 @@ def _get_kl_divergence_loss_denoisplit( The KL divergence loss for the denoiSplit case. Shape is (1, ). """ return get_kl_divergence_loss( + kl_type=kl_type, topdown_data=topdown_data, rescaling="image_dim", aggregation="sum", @@ -312,7 +323,9 @@ def musplit_loss( ) kl_loss = ( _get_kl_divergence_loss_musplit( - topdown_data=td_data, img_shape=targets.shape[2:] + topdown_data=td_data, + img_shape=targets.shape[2:], + kl_type=config.kl_params.loss_type, ) * kl_weight ) @@ -387,7 +400,9 @@ def denoisplit_loss( ) kl_loss = ( _get_kl_divergence_loss_denoisplit( - topdown_data=td_data, img_shape=targets.shape[2:] + topdown_data=td_data, + img_shape=targets.shape[2:], + kl_type=config.kl_params.loss_type, ) * kl_weight ) @@ -459,10 +474,12 @@ def denoisplit_musplit_loss( denoisplit_kl = _get_kl_divergence_loss_denoisplit( topdown_data=td_data, img_shape=targets.shape[2:], + kl_type=config.kl_params.loss_type, ) musplit_kl = _get_kl_divergence_loss_musplit( topdown_data=td_data, img_shape=targets.shape[2:], + kl_type=config.kl_params.loss_type, ) kl_loss = ( config.denoisplit_weight * denoisplit_kl + config.musplit_weight * musplit_kl diff --git a/src/careamics/models/lvae/layers.py b/src/careamics/models/lvae/layers.py index f7f570b62..14273b02d 100644 --- a/src/careamics/models/lvae/layers.py +++ b/src/careamics/models/lvae/layers.py @@ -963,7 +963,6 @@ def __init__( top_prior_param_shape: Union[Iterable[int], None] = None, analytical_kl: bool = False, retain_spatial_dims: bool = False, - restricted_kl: bool = False, vanilla_latent_hw: Union[Iterable[int], None] = None, input_image_shape: Union[tuple[int, int], None] = None, normalize_latent_factor: float = 1.0, @@ -1032,10 +1031,6 @@ def __init__( This implies that the oput spatial size equals the input spatial size. To achieve this, we centercrop the intermediate representation. Default is `False`. - restricted_kl: bool, optional - Whether to compute the restricted version of KL Divergence. - See `NormalStochasticBlock2d` module for more information about its computation. - Default is `False`. vanilla_latent_hw: Iterable[int], optional The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL). Default is `None`. @@ -1115,7 +1110,6 @@ def __init__( conv_dims=len(conv_strides), transform_p_params=(not is_top_layer), vanilla_latent_hw=vanilla_latent_hw, - restricted_kl=restricted_kl, use_naive_exponential=stochastic_use_naive_exponential, ) diff --git a/src/careamics/models/lvae/lvae.py b/src/careamics/models/lvae/lvae.py index b915034b0..5ac20f94b 100644 --- a/src/careamics/models/lvae/lvae.py +++ b/src/careamics/models/lvae/lvae.py @@ -100,7 +100,6 @@ def __init__( self.decoder_dropout = decoder_dropout self.nonlin = nonlinearity self.predict_logvar = predict_logvar - self.analytical_kl = analytical_kl # ------------------------------------------------------- @@ -162,7 +161,6 @@ def __init__( # ------------------------------------------------------- # Loss attributes - self._restricted_kl = False # HC # enabling reconstruction loss on mixed input self.mixed_rec_w = 0 self.nbr_consistency_w = 0 @@ -442,7 +440,6 @@ def create_top_down_layers(self) -> nn.ModuleList: res_block_kernel=self.decoder_res_block_kernel, gated=self.gated, analytical_kl=self.analytical_kl, - restricted_kl=self._restricted_kl, vanilla_latent_hw=self.get_latent_spatial_size(i), retain_spatial_dims=self.multiscale_decoder_retain_spatial_dims, input_image_shape=self.image_size, diff --git a/src/careamics/models/lvae/stochastic.py b/src/careamics/models/lvae/stochastic.py index cd794ae15..e7f899c9c 100644 --- a/src/careamics/models/lvae/stochastic.py +++ b/src/careamics/models/lvae/stochastic.py @@ -49,7 +49,6 @@ def __init__( kernel: int = 3, transform_p_params: bool = True, vanilla_latent_hw: int = None, - restricted_kl: bool = False, use_naive_exponential: bool = False, ): """ @@ -76,10 +75,6 @@ def __init__( vanilla_latent_hw: int, optional The shape of the latent tensor used for prediction (i.e., it influences the computation of restricted KL). Default is `None`. - restricted_kl: bool, optional - Whether to compute the restricted version of KL Divergence. - See NOTE 2 for more information about its computation. - Default is `False`. use_naive_exponential: bool, optional If `False`, exponentials are computed according to the alternative definition provided by `StableExponential` class. This should improve numerical stability @@ -95,7 +90,6 @@ def __init__( self.conv_dims = conv_dims self._use_naive_exponential = use_naive_exponential self._vanilla_latent_hw = vanilla_latent_hw - self._restricted_kl = restricted_kl conv_layer: ConvType = getattr(nn, f"Conv{conv_dims}d") @@ -199,7 +193,6 @@ def compute_kl_metrics( z: torch.Tensor The sampled latent tensor. """ - kl_samplewise_restricted = None if mode_pred is False: # if not predicting if analytical_kl: kl_elementwise = kl_divergence(q, p) @@ -207,15 +200,17 @@ def compute_kl_metrics( kl_elementwise = kl_normal_mc(z, p_params, q_params) all_dims = tuple(range(len(kl_elementwise.shape))) + kl_samplewise = kl_elementwise.sum(all_dims[1:]) + kl_channelwise = kl_elementwise.sum(all_dims[2:]) + # compute KL only on the portion of the latent space that is used for prediction. - if self._restricted_kl: - pad = (kl_elementwise.shape[-1] - self._vanilla_latent_hw) // 2 - assert pad > 0, "Disable restricted kl since there is no restriction." + pad = (kl_elementwise.shape[-1] - self._vanilla_latent_hw) // 2 + if pad > 0: tmp = kl_elementwise[..., pad:-pad, pad:-pad] kl_samplewise_restricted = tmp.sum(all_dims[1:]) + else: + kl_samplewise_restricted = kl_samplewise - kl_samplewise = kl_elementwise.sum(all_dims[1:]) - kl_channelwise = kl_elementwise.sum(all_dims[2:]) # Compute spatial KL analytically (but conditioned on samples from # previous layers) kl_spatial = kl_elementwise.sum(1) diff --git a/tests/losses/test_lvae_losses.py b/tests/losses/test_lvae_losses.py index 648b9c7fb..d3c3e6a1e 100644 --- a/tests/losses/test_lvae_losses.py +++ b/tests/losses/test_lvae_losses.py @@ -17,6 +17,7 @@ GaussianLikelihoodConfig, NMLikelihoodConfig, ) +from careamics.config.loss_model import KLLossConfig from careamics.losses.loss_factory import ( SupportedLoss, loss_factory, @@ -188,6 +189,7 @@ def test_reconstruction_loss_musplit_denoisplit( @pytest.mark.parametrize("batch_size", [1, 8]) @pytest.mark.parametrize("n_layers", [1, 4]) @pytest.mark.parametrize("enable_LC", [False, True]) +@pytest.mark.parametrize("kl_type", ["kl", "kl_restricted"]) @pytest.mark.parametrize("rescaling", ["latent_dim", "image_dim"]) @pytest.mark.parametrize("aggregation", ["mean", "sum"]) @pytest.mark.parametrize("free_bits_coeff", [0.0, 1.0]) @@ -195,6 +197,7 @@ def test_KL_divergence_loss( batch_size: int, n_layers: int, enable_LC: bool, + kl_type: Literal["kl", "kl_restricted"], rescaling: Literal["latent_dim", "image_dim"], aggregation: Literal["mean", "sum"], free_bits_coeff: float, @@ -210,12 +213,18 @@ def test_KL_divergence_loss( td_data = { "z": z, "kl": [torch.ones(batch_size) for _ in range(n_layers)], + "kl_restricted": [torch.ones(batch_size) for _ in range(n_layers)], } # compute the loss for different settings img_shape = (img_size, img_size) kl_loss = get_kl_divergence_loss( - td_data, rescaling, aggregation, free_bits_coeff, img_shape + kl_type=kl_type, + topdown_data=td_data, + rescaling=rescaling, + aggregation=aggregation, + free_bits_coeff=free_bits_coeff, + img_shape=img_shape, ) assert isinstance(kl_loss, torch.Tensor) assert isinstance(kl_loss.item(), float) @@ -226,12 +235,14 @@ def test_KL_divergence_loss( @pytest.mark.parametrize("predict_logvar", [None, "pixelwise"]) @pytest.mark.parametrize("n_layers", [1, 4]) @pytest.mark.parametrize("enable_LC", [False, True]) +@pytest.mark.parametrize("kl_type", ["kl", "kl_restricted"]) def test_musplit_loss( batch_size: int, target_ch: int, predict_logvar: str, n_layers: int, enable_LC: bool, + kl_type: Literal["kl", "kl_restricted"], ): # create test data img_size = 64 @@ -247,6 +258,7 @@ def test_musplit_loss( td_data = { "z": z, "kl": [torch.rand(batch_size) for _ in range(n_layers)], + "kl_restricted": [torch.rand(batch_size) for _ in range(n_layers)], } # create likelihood @@ -254,7 +266,8 @@ def test_musplit_loss( likelihood = likelihood_factory(config) # compute the loss - loss_parameters = LVAELossConfig(loss_type="musplit") + kl_params = KLLossConfig(loss_type=kl_type) + loss_parameters = LVAELossConfig(loss_type="musplit", kl_params=kl_params) output = musplit_loss( model_outputs=(reconstruction, td_data), targets=target, @@ -274,11 +287,13 @@ def test_musplit_loss( @pytest.mark.parametrize("batch_size", [1, 8]) @pytest.mark.parametrize("target_ch", [1, 3]) @pytest.mark.parametrize("n_layers", [1, 4]) +@pytest.mark.parametrize("kl_type", ["kl", "kl_restricted"]) def test_denoisplit_loss( tmp_path: Path, batch_size: int, target_ch: int, n_layers: int, + kl_type: Literal["kl", "kl_restricted"], ): # create test data img_size = 64 @@ -287,6 +302,7 @@ def test_denoisplit_loss( td_data = { "z": [torch.rand(batch_size, 128, img_size, img_size) for _ in range(n_layers)], "kl": [torch.rand(batch_size) for _ in range(n_layers)], + "kl_restricted": [torch.rand(batch_size) for _ in range(n_layers)], } # create likelihood @@ -297,7 +313,8 @@ def test_denoisplit_loss( likelihood = likelihood_factory(nm_config, noise_model=nm) # compute the loss - loss_parameters = LVAELossConfig(loss_type="denoisplit") + kl_params = KLLossConfig(loss_type=kl_type) + loss_parameters = LVAELossConfig(loss_type="denoisplit", kl_params=kl_params) output = denoisplit_loss( model_outputs=(reconstruction, td_data), targets=target, @@ -319,6 +336,7 @@ def test_denoisplit_loss( @pytest.mark.parametrize("predict_logvar", [None, "pixelwise"]) @pytest.mark.parametrize("n_layers", [1, 4]) @pytest.mark.parametrize("enable_LC", [False, True]) +@pytest.mark.parametrize("kl_type", ["kl", "kl_restricted"]) def test_denoisplit_musplit_loss( tmp_path: Path, batch_size: int, @@ -326,6 +344,7 @@ def test_denoisplit_musplit_loss( predict_logvar: str, n_layers: int, enable_LC: bool, + kl_type: Literal["kl", "kl_restricted"], ): # create test data img_size = 64 @@ -341,6 +360,7 @@ def test_denoisplit_musplit_loss( td_data = { "z": z, "kl": [torch.rand(batch_size) for _ in range(n_layers)], + "kl_restricted": [torch.rand(batch_size) for _ in range(n_layers)], } # create likelihood objects @@ -353,7 +373,10 @@ def test_denoisplit_musplit_loss( gaussian_likelihood = likelihood_factory(gaussian_config) # compute the loss - loss_parameters = LVAELossConfig(loss_type="denoisplit_musplit") + kl_params = KLLossConfig(loss_type=kl_type) + loss_parameters = LVAELossConfig( + loss_type="denoisplit_musplit", kl_params=kl_params + ) output = denoisplit_musplit_loss( model_outputs=(reconstruction, td_data), targets=target, From 6703079d5aa0aac176ca0246b67f2d6ceb46863c Mon Sep 17 00:00:00 2001 From: ashesh Date: Tue, 19 Nov 2024 10:38:20 +0100 Subject: [PATCH 4/7] A new enum for a new splitting task. (#270) ### Description Adding a new enum type for a splitting task which I had missed communicating earlier. I am putting the relevant things in microsplit-reproducibility repo after a brief chat with @veegalinova. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> --- src/careamics/lvae_training/dataset/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/careamics/lvae_training/dataset/types.py b/src/careamics/lvae_training/dataset/types.py index 4174afe4b..5d0d9e0d9 100644 --- a/src/careamics/lvae_training/dataset/types.py +++ b/src/careamics/lvae_training/dataset/types.py @@ -15,6 +15,7 @@ class DataType(Enum): OptiMEM100_014 = 10 SeparateTiffData = 11 BioSR_MRC = 12 + PunctaRemoval = 13 # for the case when we have a set of differently sized crops for each channel. class DataSplitType(Enum): From 5f85f3be49f08794a57fcb41b791fe624ccfa030 Mon Sep 17 00:00:00 2001 From: Melisande Croft <63270704+melisande-c@users.noreply.github.com> Date: Tue, 19 Nov 2024 18:03:13 +0100 Subject: [PATCH 5/7] Fix(BMZ): Relax model output validation kwargs; extract weights and config file following new `spec` and `core` release (#279) ### Description - **What**: Relaxing the model output validation kwargs, both absolute and relative tolerance, from the default, `1e-4`, to `1e-2`. - **Why**: The defaults are pretty strict and some of our uploaded models are stuck in pending because of slightly mismatching input and outputs. - e.g. (Actually maybe absolute tolerance should be put to 0, otherwise it still might not pass after this PR) ```console Output and expected output disagree: Not equal to tolerance rtol=0.0001, atol=0.00015 Mismatched elements: 40202 / 1048576 (3.83%) Max absolute difference: 0.1965332 Max relative difference: 0.0003221 ``` - **How**: In the model description config param, added the new test kwargs. Additionally, updated `bmz_export` so that the test kwargs in the model description are used during model testing at export time. ### Changes Made - **Modified**: Describe existing features or files modified. - `create_model_description`: added test_kwargs to config param - `export_to_bmz`: use test_kwargs in model description for model testing at export time. ### Related Issues - Resolves - last checkbox in #278 EDIT: This PR also fixes loading from BMZ following an incompatible release of `bioimageio/core` (`0.7.0`) and `bioimageio/spec` (`0.5.3.5`). The problem was `load_model_description` no longer unzips the archive file but only streams the `rdf.yaml` file data. This means we have to now extract the weights and careamics config from the zip to load them, which can be done using `bioimageio.spec._internal.io.resolve_and_extract` --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> --- .../model_io/bioimage/model_description.py | 17 +++++++++++++++-- src/careamics/model_io/bmz_io.py | 15 +++++++-------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index 21ed50b8f..dbd1dfe08 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union import numpy as np +from bioimageio.spec._internal.io import resolve_and_extract from bioimageio.spec.model.v0_5 import ( ArchitectureFromLibraryDescr, Author, @@ -280,6 +281,16 @@ def create_model_description( "https://careamics.github.io/latest/", ], license="BSD-3-Clause", + config={ + "bioimageio": { + "test_kwargs": { + "pytorch_state_dict": { + "absolute_tolerance": 1e-2, + "relative_tolerance": 1e-2, + } + } + } + }, version="0.1.0", weights=weights_descr, attachments=[FileDescr(source=config_path)], @@ -304,7 +315,9 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]: """ if model_desc.weights.pytorch_state_dict is None: raise ValueError("No model weights found in model description.") - weights_path = model_desc.weights.pytorch_state_dict.download().path + weights_path = resolve_and_extract( + model_desc.weights.pytorch_state_dict.source + ).path for file in model_desc.attachments: file_path = file.source if isinstance(file.source, Path) else file.source.path @@ -312,7 +325,7 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]: continue file_path = Path(file_path) if file_path.name == "careamics.yaml": - config_path = file.download().path + config_path = resolve_and_extract(file.source).path break else: raise ValueError("Configuration file not found.") diff --git a/src/careamics/model_io/bmz_io.py b/src/careamics/model_io/bmz_io.py index dc4564ecc..65a3ea99c 100644 --- a/src/careamics/model_io/bmz_io.py +++ b/src/careamics/model_io/bmz_io.py @@ -21,7 +21,6 @@ create_env_text, create_model_description, extract_model_path, - get_unzip_path, ) @@ -185,7 +184,12 @@ def export_to_bmz( ) # test model description - summary: ValidationSummary = test_model(model_description) + test_kwargs = ( + model_description.config.get("bioimageio", {}) + .get("test_kwargs", {}) + .get("pytorch_state_dict", {}) + ) + summary: ValidationSummary = test_model(model_description, **test_kwargs) if summary.status == "failed": raise ValueError(f"Model description test failed: {summary}") @@ -219,14 +223,9 @@ def load_from_bmz( # load description, this creates an unzipped folder next to the archive model_desc = load_model_description(path) - # extract relative paths + # extract paths weights_path, config_path = extract_model_path(model_desc) - # create folder path and absolute paths - unzip_path = get_unzip_path(path) - weights_path = unzip_path / weights_path - config_path = unzip_path / config_path - # load configuration config = load_configuration(config_path) From 240e4d34ed5aa503cb8f5bcc3a281fa64c7142b2 Mon Sep 17 00:00:00 2001 From: Melisande Croft <63270704+melisande-c@users.noreply.github.com> Date: Wed, 20 Nov 2024 23:07:37 +0100 Subject: [PATCH 6/7] Fix(dependencies): Set bioimageio-core version greater than 0.7.0 (#280) ### Description - **What**: Set bioimageio-core version greater than 0.7.0 - **Why**: Following the new `bioimage-core` release (0.7.0), we needed to make some fixes (part of PR #279). The most convenient function to solve this problem, `resolve_and_extract` only exists since 0.7.0. - **How**: In pyproject.toml ### Changes Made - **Modified**: pyproject.toml --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9c13382c0..b38b1c853 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ 'numpy<2.0.0', 'torch>=2.0.0', 'torchvision', - 'bioimageio.core>=0.6.9', + 'bioimageio.core>=0.7.0', 'tifffile', 'psutil', 'pydantic>=2.5,<2.9', From 0e0bc285d9f78b20a26c43b63214c0e09be91279 Mon Sep 17 00:00:00 2001 From: Federico Carrara <74301866+federico-carrara@users.noreply.github.com> Date: Sat, 23 Nov 2024 15:36:13 +0100 Subject: [PATCH 7/7] fix: fixed a bug in KL loss aggregation (LVAE) (#277) ### Description Found a bug in the KL loss aggregation happening in the `LadderVAE` model `training_step()`. Specifically, the application of free-bits (`free_bits_kl()`, basically clamping the values of KL entries to a certain lower threshold) was happening after KL entries were rescaled. In this way, when free-bits threshold was set to 1, all the KL entries were clamped to 1, as normally way smaller than this. - **What**: See above. - **Why**: Clear bug in the code. - **How**: Inverted the order of calls in the `get_kl_divergence_loss()` function & adjusted some parts of the code to reflect the changes. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Joran Deschamps <6367888+jdeschamps@users.noreply.github.com> --- src/careamics/losses/lvae/losses.py | 18 +++++++------- src/careamics/models/lvae/utils.py | 37 ----------------------------- 2 files changed, 9 insertions(+), 46 deletions(-) diff --git a/src/careamics/losses/lvae/losses.py b/src/careamics/losses/lvae/losses.py index 310d0bb80..9514846c3 100644 --- a/src/careamics/losses/lvae/losses.py +++ b/src/careamics/losses/lvae/losses.py @@ -168,6 +168,9 @@ def get_kl_divergence_loss( dim=1, ) # shape: (B, n_layers) + # Apply free bits (& batch average) + kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,) + # In 3D case, rescale by Z dim # TODO If we have downsampling in Z dimension, then this needs to change. if len(img_shape) == 3: @@ -175,23 +178,20 @@ def get_kl_divergence_loss( # Rescaling if rescaling == "latent_dim": - for i in range(kl.shape[1]): + for i in range(len(kl)): latent_dim = topdown_data["z"][i].shape[1:] norm_factor = np.prod(latent_dim) - kl[:, i] = kl[:, i] / norm_factor + kl[i] = kl[i] / norm_factor elif rescaling == "image_dim": kl = kl / np.prod(img_shape[-2:]) - # Apply free bits - kl_loss = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,) - # Aggregation if aggregation == "mean": - kl_loss = kl_loss.mean() # shape: (1,) + kl = kl.mean() # shape: (1,) elif aggregation == "sum": - kl_loss = kl_loss.sum() # shape: (1,) + kl = kl.sum() # shape: (1,) - return kl_loss + return kl def _get_kl_divergence_loss_musplit( @@ -220,7 +220,7 @@ def _get_kl_divergence_loss_musplit( The KL divergence loss for the muSplit case. Shape is (1, ). """ return get_kl_divergence_loss( - kl_type=kl_type, + kl_type="kl", # TODO: hardcoded, deal in future PR topdown_data=topdown_data, rescaling="latent_dim", aggregation="mean", diff --git a/src/careamics/models/lvae/utils.py b/src/careamics/models/lvae/utils.py index 1089932a9..2698dbf5a 100644 --- a/src/careamics/models/lvae/utils.py +++ b/src/careamics/models/lvae/utils.py @@ -402,40 +402,3 @@ def kl_normal_mc(z, p_mulv, q_mulv): p_distrib = Normal(p_mu.get(), p_std) q_distrib = Normal(q_mu.get(), q_std) return q_distrib.log_prob(z) - p_distrib.log_prob(z) - - -def free_bits_kl( - kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6 -) -> torch.Tensor: - """ - Computes free-bits version of KL divergence. - Ensures that the KL doesn't go to zero for any latent dimension. - Hence, it contributes to use latent variables more efficiently, - leading to better representation learning. - - NOTE: - Takes in the KL with shape (batch size, layers), returns the KL with - free bits (for optimization) with shape (layers,), which is the average - free-bits KL per layer in the current batch. - If batch_average is False (default), the free bits are per layer and - per batch element. Otherwise, the free bits are still per layer, but - are assigned on average to the whole batch. In both cases, the batch - average is returned, so it's simply a matter of doing mean(clamp(KL)) - or clamp(mean(KL)). - - Args: - kl (torch.Tensor) - free_bits (float) - batch_average (bool, optional)) - eps (float, optional) - - Returns - ------- - The KL with free bits - """ - assert kl.dim() == 2 - if free_bits < eps: - return kl.mean(0) - if batch_average: - return kl.mean(0).clamp(min=free_bits) - return kl.clamp(min=free_bits).mean(0)