Skip to content

Commit

Permalink
Channel-dependent normalization (#134)
Browse files Browse the repository at this point in the history
### Description

Input patches and targets should be normalized separately, because they
may have very different pixel value ranges. Also normalization was done
for all channels at once, which might not be correct for cases where 2
channels should not be mixed.

- **What**: Added separate normalization for images and targets and per
channel normalization
- **Why**: Normalizing everything with the same stats is not correct 


### Changes Made

- Normalization func + some changes in the datasets, training and
prediction loops.
- Reelevant tests

Please ensure your PR meets the following requirements:

 

- [x] Code builds and passes tests locally, including doctests
- [x]  New tests have been added (for bug fixes/features)
- [x]  Pre-commit passes
- [ ]  No change to the documentation needed

---------

Co-authored-by: Joran Deschamps <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 14, 2024
1 parent aa66daf commit 07fb84e
Show file tree
Hide file tree
Showing 37 changed files with 987 additions and 394 deletions.
11 changes: 10 additions & 1 deletion src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def predict(
self,
source: Union[CAREamicsPredictData, Path, str, NDArray],
*,
batch_size: int = 1,
batch_size: Optional[int] = None,
tile_size: Optional[Tuple[int, ...]] = None,
tile_overlap: Tuple[int, ...] = (48, 48),
axes: Optional[str] = None,
Expand Down Expand Up @@ -597,6 +597,15 @@ def predict(
"No configuration found. Train a model or load from a "
"checkpoint before predicting."
)

# Reuse batch size if not provided explicitly
if batch_size is None:
batch_size = (
self.train_datamodule.batch_size
if self.train_datamodule
else self.cfg.data_config.batch_size
)

# create predict config, reuse training config if parameters missing
prediction_config = create_inference_configuration(
configuration=self.cfg,
Expand Down
9 changes: 6 additions & 3 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,10 @@ def create_inference_configuration(
InferenceConfiguration
Configuration used to configure CAREamicsPredictData.
"""
if configuration.data_config.mean is None or configuration.data_config.std is None:
if (
configuration.data_config.image_mean is None
or configuration.data_config.image_std is None
):
raise ValueError("Mean and std must be provided in the configuration.")

# tile size for UNets
Expand Down Expand Up @@ -634,8 +637,8 @@ def create_inference_configuration(
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes or configuration.data_config.axes,
mean=configuration.data_config.mean,
std=configuration.data_config.std,
image_mean=configuration.data_config.image_mean,
image_std=configuration.data_config.image_std,
tta_transforms=tta_transforms,
batch_size=batch_size,
)
91 changes: 74 additions & 17 deletions src/careamics/config/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pprint import pformat
from typing import Any, Literal, Optional, Union

import numpy as np
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -37,7 +38,9 @@ class DataConfig(BaseModel):
If std is specified, mean must be specified as well. Note that setting the std first
and then the mean (if they were both `None` before) will raise a validation error.
Prefer instead `set_mean_and_std` to set both at once.
Prefer instead `set_mean_and_std` to set both at once. Means and stds are expected
to be lists of floats, one for each channel. For supervised tasks, the mean and std
of the target could be different from the input data.
All supported transforms are defined in the SupportedTransform enum.
Expand All @@ -53,7 +56,7 @@ class DataConfig(BaseModel):
... )
To change the mean and std of the data:
>>> data.set_mean_and_std(mean=214.3, std=84.5)
>>> data.set_mean_and_std(image_mean=[214.3], image_std=[84.5])
One can pass also a list of transformations, by keyword, using the
SupportedTransform value:
Expand All @@ -78,13 +81,17 @@ class DataConfig(BaseModel):

# Dataset configuration
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
patch_size: list[int] = Field(..., min_length=2, max_length=3)
patch_size: Union[list[int]] = Field(..., min_length=2, max_length=3)
batch_size: int = Field(default=1, ge=1, validate_default=True)
axes: str

# Optional fields
mean: Optional[float] = None
std: Optional[float] = None
image_mean: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
image_std: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
target_mean: Optional[list[float]] = Field(
default=None, min_length=0, max_length=32
)
target_std: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)

transforms: list[TRANSFORMS_UNION] = Field(
default=[
Expand All @@ -105,20 +112,22 @@ class DataConfig(BaseModel):

@field_validator("patch_size")
@classmethod
def all_elements_power_of_2_minimum_8(cls, patch_list: list[int]) -> list[int]:
def all_elements_power_of_2_minimum_8(
cls, patch_list: Union[list[int]]
) -> Union[list[int]]:
"""
Validate patch size.
Patch size must be powers of 2 and minimum 8.
Parameters
----------
patch_list : list[int]
patch_list : list of int
Patch size.
Returns
-------
list[int]
list of int
Validated patch size.
Raises
Expand Down Expand Up @@ -180,7 +189,7 @@ def validate_prediction_transforms(
Returns
-------
list[TRANSFORMS_UNION]
list of transforms
Validated transforms.
Raises
Expand Down Expand Up @@ -223,11 +232,34 @@ def std_only_with_mean(self: Self) -> Self:
If std is not None and mean is None.
"""
# check that mean and std are either both None, or both specified
if (self.mean is None) != (self.std is None):
if (self.image_mean and not self.image_std) or (
self.image_std and not self.image_mean
):
raise ValueError(
"Mean and std must be either both None, or both specified."
)

elif (self.image_mean is not None and self.image_std is not None) and (
len(self.image_mean) != len(self.image_std)
):
raise ValueError(
"Mean and std must be specified for each " "input channel."
)

if (self.target_mean and not self.target_std) or (
self.target_std and not self.target_mean
):
raise ValueError(
"Mean and std must be either both None, or both specified "
)

elif self.target_mean is not None and self.target_std is not None:
if len(self.target_mean) != len(self.target_std):
raise ValueError(
"Mean and std must be either both None, or both specified for each "
"target channel."
)

return self

@model_validator(mode="after")
Expand Down Expand Up @@ -310,7 +342,13 @@ def remove_n2v_manipulate(self) -> None:
if self.has_n2v_manipulate():
self.transforms.pop(-1)

def set_mean_and_std(self, mean: float, std: float) -> None:
def set_mean_and_std(
self,
image_mean: Union[np.ndarray, tuple, list, None],
image_std: Union[np.ndarray, tuple, list, None],
target_mean: Optional[Union[np.ndarray, tuple, list, None]] = None,
target_std: Optional[Union[np.ndarray, tuple, list, None]] = None,
) -> None:
"""
Set mean and standard deviation of the data.
Expand All @@ -319,12 +357,31 @@ def set_mean_and_std(self, mean: float, std: float) -> None:
Parameters
----------
mean : float
Mean of the data.
std : float
Standard deviation of the data.
image_mean : np.ndarray or tuple or list
Mean value for normalization.
image_std : np.ndarray or tuple or list
Standard deviation value for normalization.
target_mean : np.ndarray or tuple or list, optional
Target mean value for normalization, by default ().
target_std : np.ndarray or tuple or list, optional
Target standard deviation value for normalization, by default ().
"""
self._update(mean=mean, std=std)
# make sure we pass a list
if image_mean is not None:
image_mean = list(image_mean)
if image_std is not None:
image_std = list(image_std)
if target_mean is not None:
target_mean = list(target_mean)
if target_std is not None:
target_std = list(target_std)

self._update(
image_mean=image_mean,
image_std=image_std,
target_mean=target_mean,
target_std=target_std,
)

def set_3D(self, axes: str, patch_size: list[int]) -> None:
"""
Expand All @@ -334,7 +391,7 @@ def set_3D(self, axes: str, patch_size: list[int]) -> None:
----------
axes : str
Axes.
patch_size : list[int]
patch_size : list of int
Patch size.
"""
self._update(axes=axes, patch_size=patch_size)
Expand Down
36 changes: 26 additions & 10 deletions src/careamics/config/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing_extensions import Self
Expand All @@ -17,13 +17,17 @@ class InferenceConfig(BaseModel):

# Mandatory fields
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
tile_size: Optional[list[int]] = Field(default=None, min_length=2, max_length=3)
tile_overlap: Optional[list[int]] = Field(default=None, min_length=2, max_length=3)
tile_size: Optional[Union[list[int]]] = Field(
default=None, min_length=2, max_length=3
)
tile_overlap: Optional[Union[list[int]]] = Field(
default=None, min_length=2, max_length=3
)

axes: str

mean: float
std: float = Field(..., ge=0.0)
image_mean: list = Field(..., min_length=0, max_length=32)
image_std: list = Field(..., min_length=0, max_length=32)

# only default TTAs are supported for now
tta_transforms: bool = Field(default=True)
Expand Down Expand Up @@ -80,12 +84,12 @@ def tile_min_8_power_of_2(
Parameters
----------
tile_list : list[int] or None
tile_list : list of int
Patch size.
Returns
-------
list[int] or None
list of int
Validated patch size.
Raises
Expand Down Expand Up @@ -178,11 +182,23 @@ def std_only_with_mean(self: Self) -> Self:
If std is not None and mean is None.
"""
# check that mean and std are either both None, or both specified
if (self.mean is None) != (self.std is None):
if not self.image_mean and not self.image_std:
raise ValueError("Mean and std must be specified during inference.")

if (self.image_mean and not self.image_std) or (
self.image_std and not self.image_mean
):
raise ValueError(
"Mean and std must be either both None, or both specified."
)

elif (self.image_mean is not None and self.image_std is not None) and (
len(self.image_mean) != len(self.image_std)
):
raise ValueError(
"Mean and std must be specified for each " "input channel."
)

return self

def _update(self, **kwargs: Any) -> None:
Expand All @@ -205,9 +221,9 @@ def set_3D(self, axes: str, tile_size: list[int], tile_overlap: list[int]) -> No
----------
axes : str
Axes.
tile_size : list[int]
tile_size : list of int
Tile size.
tile_overlap : list[int]
tile_overlap : list of int
Tile overlap.
"""
self._update(axes=axes, tile_size=tile_size, tile_overlap=tile_overlap)
36 changes: 32 additions & 4 deletions src/careamics/config/transformations/normalize_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Pydantic model for the Normalize transform."""

from typing import Literal
from typing import Literal, Optional

from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self

from .transform_model import TransformModel

Expand All @@ -28,5 +29,32 @@ class NormalizeModel(TransformModel):
)

name: Literal["Normalize"] = "Normalize"
mean: float = Field(default=0.485) # albumentations defaults
std: float = Field(default=0.229)
image_means: list = Field(..., min_length=0, max_length=32)
image_stds: list = Field(..., min_length=0, max_length=32)
target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)

@model_validator(mode="after")
def validate_means_stds(self: Self) -> Self:
"""Validate that the means and stds have the same length.
Returns
-------
Self
The instance of the model.
"""
if len(self.image_means) != len(self.image_stds):
raise ValueError("The number of image means and stds must be the same.")

if (self.target_means is None) != (self.target_stds is None):
raise ValueError(
"Both target means and stds must be provided together, or bot None."
)

if self.target_means is not None and self.target_stds is not None:
if len(self.target_means) != len(self.target_stds):
raise ValueError(
"The number of target means and stds must be the same."
)

return self
3 changes: 2 additions & 1 deletion src/careamics/dataset/dataset_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"reshape_array",
"compute_normalization_stats",
"get_files_size",
"list_files",
"validate_source_target_files",
Expand All @@ -12,7 +13,7 @@
]


from .dataset_utils import reshape_array
from .dataset_utils import compute_normalization_stats, reshape_array
from .file_utils import get_files_size, list_files, validate_source_target_files
from .iterate_over_files import iterate_over_files
from .read_tiff import read_tiff
Expand Down
Loading

0 comments on commit 07fb84e

Please sign in to comment.