Skip to content

Commit

Permalink
(fix): list of floats in model, fix normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 13, 2024
1 parent 4a2cb91 commit 178e1b9
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 67 deletions.
15 changes: 6 additions & 9 deletions src/careamics/config/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ class DataConfig(BaseModel):
... axes="YX",
... transforms=[
... {
... "name": SupportedTransform.NORMALIZE.value,
... "image_means": [167.6],
... "image_stds": [47.2],
... },
... {
... "name": "XYFlip",
... }
... ]
Expand All @@ -91,10 +86,12 @@ class DataConfig(BaseModel):
axes: str

# Optional fields
image_mean: Optional[list] = Field(default=None, min_length=0, max_length=32)
image_std: Optional[list] = Field(default=None, min_length=0, max_length=32)
target_mean: Optional[list] = Field(default=None, min_length=0, max_length=32)
target_std: Optional[list] = Field(default=None, min_length=0, max_length=32)
image_mean: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
image_std: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
target_mean: Optional[list[float]] = Field(
default=None, min_length=0, max_length=32
)
target_std: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)

transforms: list[TRANSFORMS_UNION] = Field(
default=[
Expand Down
38 changes: 32 additions & 6 deletions src/careamics/config/transformations/normalize_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Pydantic model for the Normalize transform."""

from typing import Literal, Union
from typing import Literal, Optional

from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self

from .transform_model import TransformModel

Expand All @@ -28,7 +29,32 @@ class NormalizeModel(TransformModel):
)

name: Literal["Normalize"] = "Normalize"
image_means: Union[tuple, None] = Field(default=(), min_length=0, max_length=32)
image_stds: Union[tuple, None] = Field(default=(), min_length=0, max_length=32)
target_means: Union[tuple, None] = Field(default=(), min_length=0, max_length=32)
target_stds: Union[tuple, None] = Field(default=(), min_length=0, max_length=32)
image_means: list = Field(..., min_length=0, max_length=32)
image_stds: list = Field(..., min_length=0, max_length=32)
target_means: Optional[list] = Field(default=None, min_length=0, max_length=32)
target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32)

@model_validator(mode="after")
def validate_means_stds(self: Self) -> Self:
"""Validate that the means and stds have the same length.
Returns
-------
Self
The instance of the model.
"""
if len(self.image_means) != len(self.image_stds):
raise ValueError("The number of image means and stds must be the same.")

if (self.target_means is None) != (self.target_stds is None):
raise ValueError(
"Both target means and stds must be provided together, or bot None."
)

if self.target_means is not None and self.target_stds is not None:
if len(self.target_means) != len(self.target_stds):
raise ValueError(
"The number of target means and stds must be the same."
)

return self
19 changes: 10 additions & 9 deletions src/careamics/dataset/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,7 @@ def __init__(
self.data = patches_data.patches
self.data_targets = patches_data.targets

if (
self.data_config.image_mean is not None
and not any(self.data_config.image_mean)
) or (
self.data_config.image_std is not None
and not any(self.data_config.image_std)
):
if self.data_config.image_mean is None:
self.image_means = patches_data.image_stats.means
self.image_stds = patches_data.image_stats.stds
logger.info(
Expand All @@ -101,7 +95,7 @@ def __init__(
self.image_means = self.data_config.image_mean
self.image_stds = self.data_config.image_std

if not self.data_config.target_mean or not self.data_config.target_std:
if self.data_config.target_mean is None:
self.target_means = patches_data.target_stats.means
self.target_stds = patches_data.target_stats.stds
else:
Expand All @@ -118,7 +112,14 @@ def __init__(
)
# get transforms
self.patch_transform = Compose(
transform_list=[NormalizeModel(mean=self.mean, std=self.std)]
transform_list=[
NormalizeModel(
image_means=self.image_means,
image_stds=self.image_stds,
target_means=self.target_means,
target_stds=self.target_stds,
)
]
+ self.data_config.transforms,
)

Expand Down
70 changes: 33 additions & 37 deletions src/careamics/dataset/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,19 @@ class PathIterableDataset(IterableDataset):
----------
data_config : DataConfig
Data configuration.
src_files : list[Path]
src_files : list of pathlib.Path
List of data files.
target_files : list[Path] or None, optional
target_files : list of pathlib.Path, optional
Optional list of target files, by default None.
read_source_func : Callable, optional
Read source function for custom types, by default read_tiff.
Attributes
----------
data_path : list[Path]
data_path : list of pathlib.Path
Path to the data, must be a directory.
axes : str
Description of axes in format STCZYX.
patch_size : list[int] or tuple[int] or None, optional
Size of the patches in each dimension, by default None.
patch_overlap : list[int] or tuple[int] or None, optional
Overlap of the patches in each dimension, by default None.
mean : float or None, optional
Expected mean of the dataset, by default None.
std : float or None, optional
Expected standard deviation of the dataset, by default None.
patch_transform : Callable or None, optional
Patch transform callable, by default None.
"""

def __init__(
Expand Down Expand Up @@ -81,41 +71,47 @@ def __init__(
self.read_source_func = read_source_func

# compute mean and std over the dataset
# Only checking the image_mean because the DataConfig class ensures that
# only checking the image_mean because the DataConfig class ensures that
# if image_mean is provided, image_std is also provided
if not self.data_config.image_mean:
self.patches_data = self._calculate_mean_and_std()
self.data_stats = self._calculate_mean_and_std()
logger.info(
f"Computed dataset mean: {self.patches_data.image_stats.means},"
f"std: {self.patches_data.image_stats.stds}"
f"Computed dataset mean: {self.data_stats.image_stats.means},"
f"std: {self.data_stats.image_stats.stds}"
)

else:
self.patches_data = StatsOutput(
Stats(self.data_config.image_mean, self.data_config.image_std),
Stats(self.data_config.target_mean, self.data_config.target_std),
)

if hasattr(self.data_config, "set_mean_and_std"):
# update the mean in the config
self.data_config.set_mean_and_std(
image_mean=self.patches_data.image_stats.means,
image_std=self.patches_data.image_stats.stds,
image_mean=self.data_stats.image_stats.means,
image_std=self.data_stats.image_stats.stds,
target_mean=(
tuple(self.patches_data.target_stats.means)
if self.patches_data.target_stats.means is not None
else []
list(self.data_stats.target_stats.means)
if self.data_stats.target_stats.means is not None
else None
),
target_std=(
tuple(self.patches_data.target_stats.stds)
if self.patches_data.target_stats.stds is not None
else []
list(self.data_stats.target_stats.stds)
if self.data_stats.target_stats.stds is not None
else None
),
)

# get transforms
else:
# if mean and std are provided in the config, use them
self.data_stats = StatsOutput(
Stats(self.data_config.image_mean, self.data_config.image_std),
Stats(self.data_config.target_mean, self.data_config.target_std),
)

# create transform composed of normalization and other transforms
self.patch_transform = Compose(
transform_list=[
NormalizeModel(mean=self.mean, std=self.std),
NormalizeModel(
image_means=self.data_stats.image_stats.means,
image_stds=self.data_stats.image_stats.stds,
target_means=self.data_stats.target_stats.means,
target_stds=self.data_stats.target_stats.stds,
)
]
+ data_config.transforms
)
Expand Down Expand Up @@ -155,7 +151,7 @@ def _calculate_mean_and_std(self) -> StatsOutput:

if target is not None:
target_means = np.mean(target_means, axis=0)
target_stds = np.mean([std**2 for std in image_stds], axis=0)
target_stds = np.mean([std**2 for std in target_stds], axis=0)

logger.info(f"Calculated mean and std for {num_samples} images")
logger.info(f"Mean: {image_means}, std: {image_stds}")
Expand All @@ -179,8 +175,8 @@ def __iter__(
Single patch.
"""
assert (
self.patches_data.image_stats.means is not None
and self.patches_data.image_stats.stds is not None
self.data_stats.image_stats.means is not None
and self.data_stats.image_stats.stds is not None
), "Mean and std must be provided"

# iterate over files
Expand Down
8 changes: 5 additions & 3 deletions tests/config/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def test_set_mean_and_std(minimum_data: dict):
def test_normalize_not_accepted(minimum_data: dict):
"""Test that normalize is not accepted, because it is mandatory and applied else
where."""
minimum_data["mean"] = 10.4
minimum_data["std"] = 3.2
minimum_data["transforms"] = [NormalizeModel(mean=0.485, std=0.229)]
minimum_data["image_means"] = [10.4]
minimum_data["image_stds"] = [3.2]
minimum_data["transforms"] = [
NormalizeModel(image_means=[0.485], image_stds=[0.229])
]

with pytest.raises(ValueError):
DataConfig(**minimum_data)
Expand Down
47 changes: 44 additions & 3 deletions tests/config/transformations/test_normalize_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest

from careamics.config.transformations import NormalizeModel
from careamics.transforms import Normalize


def test_setting_image_means_std():
"""Test that we can set the image_means and std values."""
model = NormalizeModel(name="Normalize", image_means=[0.5], image_stds=[0.5])
model = NormalizeModel(image_means=[0.5], image_stds=[0.5])
assert model.image_means == [0.5]
assert model.image_stds == [0.5]

Expand All @@ -14,7 +16,6 @@ def test_setting_image_means_std():
assert model.image_stds == [0.6]

model = NormalizeModel(
name="Normalize",
image_means=[0.5],
image_stds=[0.5],
target_means=[0.5],
Expand All @@ -36,9 +37,49 @@ def test_setting_image_means_std():
assert model.target_stds == [0.6]


def test_error_different_length_means_stds():
"""Test that an error is raised if the image_means and stds have different
lengths."""
with pytest.raises(ValueError):
NormalizeModel(image_means=[0.5], image_stds=[0.5, 0.6])

with pytest.raises(ValueError):
NormalizeModel(image_means=[0.5, 0.6], image_stds=[0.5])

with pytest.raises(ValueError):
NormalizeModel(
image_means=[0.5],
image_stds=[0.5],
target_means=[0.5],
)

with pytest.raises(ValueError):
NormalizeModel(
image_means=[0.5],
image_stds=[0.5],
target_stds=[0.5],
)

with pytest.raises(ValueError):
NormalizeModel(
image_means=[0.5],
image_stds=[0.5],
target_means=[0.5, 0.6],
target_stds=[0.5],
)

with pytest.raises(ValueError):
NormalizeModel(
image_means=[0.5],
image_stds=[0.5],
target_means=[0.5],
target_stds=[0.5, 0.6],
)


def test_comptatibility_with_transform():
"""Test that the model allows instantiating a transform."""
model = NormalizeModel(name="Normalize", image_means=[0.5], image_stds=[0.5])
model = NormalizeModel(image_means=[0.5], image_stds=[0.5])

# instantiate transform
Normalize(**model.model_dump())

0 comments on commit 178e1b9

Please sign in to comment.