Skip to content

Commit

Permalink
Merge branch 'main' into ntm/fix/ONNX_support
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Nov 25, 2024
2 parents 7a414ee + 0e0bc28 commit b26f006
Show file tree
Hide file tree
Showing 15 changed files with 198 additions and 201 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies = [
'numpy<2.0.0',
'torch>=2.0.0',
'torchvision',
'bioimageio.core>=0.6.9',
'bioimageio.core>=0.7.0',
'tifffile',
'psutil',
'pydantic>=2.5,<2.9',
Expand Down
5 changes: 1 addition & 4 deletions src/careamics/config/architectures/lvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ class LVAEModel(ArchitectureModel):
)

predict_logvar: Literal[None, "pixelwise"] = None

analytical_kl: bool = Field(
default=False,
)
analytical_kl: bool = Field(default=False)

@model_validator(mode="after")
def validate_conv_strides(self: Self) -> Self:
Expand Down
3 changes: 1 addition & 2 deletions src/careamics/config/loss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class KLLossConfig(BaseModel):

model_config = ConfigDict(validate_assignment=True, validate_default=True)

loss_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
loss_type: Literal["kl", "kl_restricted"] = "kl"
"""Type of KL divergence used as KL loss."""
rescaling: Literal["latent_dim", "image_dim"] = "latent_dim"
"""Rescaling of the KL loss."""
Expand Down Expand Up @@ -48,7 +48,6 @@ class LVAELossConfig(BaseModel):
"""Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
denoisplit_weight: float = 0.9
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""

kl_params: KLLossConfig = KLLossConfig()
"""KL loss configuration."""

Expand Down
31 changes: 27 additions & 4 deletions src/careamics/lightning/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SupportedOptimizer,
SupportedScheduler,
)
from careamics.config.tile_information import TileInformation
from careamics.losses import loss_factory
from careamics.models.lvae.likelihoods import (
GaussianLikelihood,
Expand Down Expand Up @@ -163,15 +164,28 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
Any
Model output.
"""
if self._trainer.datamodule.tiled:
# TODO refactor when redoing datasets
# hacky way to determine if it is PredictDataModule, otherwise there is a
# circular import to solve with isinstance
from_prediction = hasattr(self._trainer.datamodule, "tiled")
is_tiled = (
len(batch) > 1
and isinstance(batch[1], list)
and isinstance(batch[1][0], TileInformation)
)

if is_tiled:
x, *aux = batch
else:
x = batch
aux = []

# apply test-time augmentation if available
# TODO: probably wont work with batch size > 1
if self._trainer.datamodule.prediction_config.tta_transforms:
if (
from_prediction
and self._trainer.datamodule.prediction_config.tta_transforms
):
tta = ImageRestorationTTA()
augmented_batch = tta.forward(x) # list of augmented tensors
augmented_output = []
Expand All @@ -183,9 +197,18 @@ def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
output = self.model(x)

# Denormalize the output
# TODO incompatible API between predict and train datasets
denorm = Denormalize(
image_means=self._trainer.datamodule.predict_dataset.image_means,
image_stds=self._trainer.datamodule.predict_dataset.image_stds,
image_means=(
self._trainer.datamodule.predict_dataset.image_means
if from_prediction
else self._trainer.datamodule.train_dataset.image_stats.means
),
image_stds=(
self._trainer.datamodule.predict_dataset.image_stds
if from_prediction
else self._trainer.datamodule.train_dataset.image_stats.stds
),
)
denormalized_output = denorm(patch=output.cpu().numpy())

Expand Down
95 changes: 1 addition & 94 deletions src/careamics/losses/loss_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,15 @@

from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union
from typing import Callable, Union

from torch import Tensor as tensor

from ..config.support import SupportedLoss
from .fcn.losses import mae_loss, mse_loss, n2v_loss
from .lvae.losses import denoisplit_loss, denoisplit_musplit_loss, musplit_loss

if TYPE_CHECKING:
from careamics.models.lvae.likelihoods import (
GaussianLikelihood,
NoiseModelLikelihood,
)


@dataclass
class FCNLossParameters:
Expand All @@ -35,92 +28,6 @@ class FCNLossParameters:
loss_weight: float


@dataclass # TODO why not pydantic?
class LVAELossParameters:
"""Dataclass for LVAE loss."""

# TODO: refactor in more modular blocks (otherwise it gets messy very easily)
# e.g., - weights, - kl_params, ...

# General params
noise_model_likelihood: Optional[NoiseModelLikelihood] = None
"""Noise model likelihood instance."""
gaussian_likelihood: Optional[GaussianLikelihood] = None
"""Gaussian likelihood instance."""
current_epoch: int = 0
"""Current epoch in the training loop."""
reconstruction_weight: float = 1.0
"""Weight for the reconstruction loss in the total net loss
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
kl_weight: float = 1.0
"""Weight for the KL loss in the total net loss.
(i.e., `net_loss = reconstruction_weight * rec_loss + kl_weight * kl_loss`)."""
musplit_weight: float = 0.1
"""Weight for the muSplit loss (used in the muSplit-denoiSplit loss)."""
denoisplit_weight: float = 0.9
"""Weight for the denoiSplit loss (used in the muSplit-deonoiSplit loss)."""

# KL params
kl_type: Literal["kl", "kl_restricted", "kl_spatial", "kl_channelwise"] = "kl"
"""Type of KL divergence used as KL loss."""
kl_rescaling: Literal["latent_dim", "image_dim"] = "latent_dim"
"""Rescaling of the KL loss."""
kl_aggregation: Literal["sum", "mean"] = "mean"
"""Aggregation of the KL loss across different layers."""
kl_free_bits_coeff: float = 0.0
"""Free bits coefficient for the KL loss."""
kl_annealing: bool = False
"""Whether to apply KL loss annealing."""
kl_start: int = -1
"""Epoch at which KL loss annealing starts."""
kl_annealtime: int = 10
"""Number of epochs for which KL loss annealing is applied."""
non_stochastic: bool = False
"""Whether to sample latents and compute KL."""

def __post_init__(self):
"""Raise a deprecation warning."""
warnings.warn(
f"{self.__class__.__name__} is deprecated. Please use`LVAELossConfig`.",
DeprecationWarning,
stacklevel=2,
)


def loss_parameters_factory(
type: SupportedLoss,
) -> Union[FCNLossParameters, LVAELossParameters]:
"""Return loss parameters.
Parameters
----------
type : SupportedLoss
Requested loss.
Returns
-------
Union[FCNLossParameters, LVAELossParameters]
Loss parameters.
Raises
------
NotImplementedError
If the loss is unknown.
"""
if type in [SupportedLoss.N2V, SupportedLoss.MSE, SupportedLoss.MAE]:
return FCNLossParameters

elif type in [
SupportedLoss.MUSPLIT,
SupportedLoss.DENOISPLIT,
SupportedLoss.DENOISPLIT_MUSPLIT,
]:
return LVAELossParameters # it returns the class, not an instance

else:
raise NotImplementedError(f"Loss {type} is not yet supported.")


def loss_factory(loss: Union[SupportedLoss, str]) -> Callable:
"""Return loss function.
Expand Down
39 changes: 28 additions & 11 deletions src/careamics/losses/lvae/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _reconstruction_loss_musplit_denoisplit(


def get_kl_divergence_loss(
kl_type: Literal["kl", "kl_restricted"],
topdown_data: dict[str, torch.Tensor],
rescaling: Literal["latent_dim", "image_dim"],
aggregation: Literal["mean", "sum"],
Expand Down Expand Up @@ -135,6 +136,8 @@ def get_kl_divergence_loss(
Parameters
----------
kl_type : Literal["kl", "kl_restricted"]
The type of KL divergence loss to compute.
topdown_data : dict[str, torch.Tensor]
A dictionary containing information computed for each layer during the top-down
pass. The dictionary must include the following keys:
Expand All @@ -161,39 +164,40 @@ def get_kl_divergence_loss(
The KL divergence loss. Shape is (1, ).
"""
kl = torch.cat(
[kl_layer.unsqueeze(1) for kl_layer in topdown_data["kl"]],
[kl_layer.unsqueeze(1) for kl_layer in topdown_data[kl_type]],
dim=1,
) # shape: (B, n_layers)

# Apply free bits (& batch average)
kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)

# In 3D case, rescale by Z dim
# TODO If we have downsampling in Z dimension, then this needs to change.
if len(img_shape) == 3:
kl = kl / img_shape[0]

# Rescaling
if rescaling == "latent_dim":
for i in range(kl.shape[1]):
for i in range(len(kl)):
latent_dim = topdown_data["z"][i].shape[1:]
norm_factor = np.prod(latent_dim)
kl[:, i] = kl[:, i] / norm_factor
kl[i] = kl[i] / norm_factor
elif rescaling == "image_dim":
kl = kl / np.prod(img_shape[-2:])

# Apply free bits
kl_loss = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)

# Aggregation
if aggregation == "mean":
kl_loss = kl_loss.mean() # shape: (1,)
kl = kl.mean() # shape: (1,)
elif aggregation == "sum":
kl_loss = kl_loss.sum() # shape: (1,)
kl = kl.sum() # shape: (1,)

return kl_loss
return kl


def _get_kl_divergence_loss_musplit(
topdown_data: dict[str, torch.Tensor],
img_shape: tuple[int],
kl_type: Literal["kl", "kl_restricted"],
) -> torch.Tensor:
"""Compute the KL divergence loss for muSplit.
Expand All @@ -207,13 +211,16 @@ def _get_kl_divergence_loss_musplit(
(B, layers, `z_dims[i]`, H, W).
img_shape : tuple[int]
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
kl_type : Literal["kl", "kl_restricted"]
The type of KL divergence loss to compute.
Returns
-------
kl_loss : torch.Tensor
The KL divergence loss for the muSplit case. Shape is (1, ).
"""
return get_kl_divergence_loss(
kl_type="kl", # TODO: hardcoded, deal in future PR
topdown_data=topdown_data,
rescaling="latent_dim",
aggregation="mean",
Expand All @@ -225,6 +232,7 @@ def _get_kl_divergence_loss_musplit(
def _get_kl_divergence_loss_denoisplit(
topdown_data: dict[str, torch.Tensor],
img_shape: tuple[int],
kl_type: Literal["kl", "kl_restricted"],
) -> torch.Tensor:
"""Compute the KL divergence loss for denoiSplit.
Expand All @@ -238,13 +246,16 @@ def _get_kl_divergence_loss_denoisplit(
(B, layers, `z_dims[i]`, H, W).
img_shape : tuple[int]
The shape of the input image to the LVAE model. Shape is ([Z], Y, X).
kl_type : Literal["kl", "kl_restricted"]
The type of KL divergence loss to compute.
Returns
-------
kl_loss : torch.Tensor
The KL divergence loss for the denoiSplit case. Shape is (1, ).
"""
return get_kl_divergence_loss(
kl_type=kl_type,
topdown_data=topdown_data,
rescaling="image_dim",
aggregation="sum",
Expand Down Expand Up @@ -312,7 +323,9 @@ def musplit_loss(
)
kl_loss = (
_get_kl_divergence_loss_musplit(
topdown_data=td_data, img_shape=targets.shape[2:]
topdown_data=td_data,
img_shape=targets.shape[2:],
kl_type=config.kl_params.loss_type,
)
* kl_weight
)
Expand Down Expand Up @@ -387,7 +400,9 @@ def denoisplit_loss(
)
kl_loss = (
_get_kl_divergence_loss_denoisplit(
topdown_data=td_data, img_shape=targets.shape[2:]
topdown_data=td_data,
img_shape=targets.shape[2:],
kl_type=config.kl_params.loss_type,
)
* kl_weight
)
Expand Down Expand Up @@ -459,10 +474,12 @@ def denoisplit_musplit_loss(
denoisplit_kl = _get_kl_divergence_loss_denoisplit(
topdown_data=td_data,
img_shape=targets.shape[2:],
kl_type=config.kl_params.loss_type,
)
musplit_kl = _get_kl_divergence_loss_musplit(
topdown_data=td_data,
img_shape=targets.shape[2:],
kl_type=config.kl_params.loss_type,
)
kl_loss = (
config.denoisplit_weight * denoisplit_kl + config.musplit_weight * musplit_kl
Expand Down
1 change: 1 addition & 0 deletions src/careamics/lvae_training/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class DataType(Enum):
OptiMEM100_014 = 10
SeparateTiffData = 11
BioSR_MRC = 12
PunctaRemoval = 13 # for the case when we have a set of differently sized crops for each channel.


class DataSplitType(Enum):
Expand Down
Loading

0 comments on commit b26f006

Please sign in to comment.