diff --git a/src/careamics/config/data_model.py b/src/careamics/config/data_model.py index 3f7d2b46b..11b9ba01a 100644 --- a/src/careamics/config/data_model.py +++ b/src/careamics/config/data_model.py @@ -68,11 +68,6 @@ class DataConfig(BaseModel): ... axes="YX", ... transforms=[ ... { - ... "name": SupportedTransform.NORMALIZE.value, - ... "image_means": [167.6], - ... "image_stds": [47.2], - ... }, - ... { ... "name": "XYFlip", ... } ... ] @@ -91,10 +86,12 @@ class DataConfig(BaseModel): axes: str # Optional fields - image_mean: Optional[list] = Field(default=None, min_length=0, max_length=32) - image_std: Optional[list] = Field(default=None, min_length=0, max_length=32) - target_mean: Optional[list] = Field(default=None, min_length=0, max_length=32) - target_std: Optional[list] = Field(default=None, min_length=0, max_length=32) + image_mean: Optional[list[float]] = Field(default=None, min_length=0, max_length=32) + image_std: Optional[list[float]] = Field(default=None, min_length=0, max_length=32) + target_mean: Optional[list[float]] = Field( + default=None, min_length=0, max_length=32 + ) + target_std: Optional[list[float]] = Field(default=None, min_length=0, max_length=32) transforms: list[TRANSFORMS_UNION] = Field( default=[ diff --git a/src/careamics/config/transformations/normalize_model.py b/src/careamics/config/transformations/normalize_model.py index 19ce7522f..09bfe39ff 100644 --- a/src/careamics/config/transformations/normalize_model.py +++ b/src/careamics/config/transformations/normalize_model.py @@ -1,8 +1,9 @@ """Pydantic model for the Normalize transform.""" -from typing import Literal, Union +from typing import Literal, Optional -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self from .transform_model import TransformModel @@ -28,7 +29,32 @@ class NormalizeModel(TransformModel): ) name: Literal["Normalize"] = "Normalize" - image_means: Union[tuple, None] = Field(default=(), min_length=0, max_length=32) - image_stds: Union[tuple, None] = Field(default=(), min_length=0, max_length=32) - target_means: Union[tuple, None] = Field(default=(), min_length=0, max_length=32) - target_stds: Union[tuple, None] = Field(default=(), min_length=0, max_length=32) + image_means: list = Field(..., min_length=0, max_length=32) + image_stds: list = Field(..., min_length=0, max_length=32) + target_means: Optional[list] = Field(default=None, min_length=0, max_length=32) + target_stds: Optional[list] = Field(default=None, min_length=0, max_length=32) + + @model_validator(mode="after") + def validate_means_stds(self: Self) -> Self: + """Validate that the means and stds have the same length. + + Returns + ------- + Self + The instance of the model. + """ + if len(self.image_means) != len(self.image_stds): + raise ValueError("The number of image means and stds must be the same.") + + if (self.target_means is None) != (self.target_stds is None): + raise ValueError( + "Both target means and stds must be provided together, or bot None." + ) + + if self.target_means is not None and self.target_stds is not None: + if len(self.target_means) != len(self.target_stds): + raise ValueError( + "The number of target means and stds must be the same." + ) + + return self diff --git a/src/careamics/dataset/in_memory_dataset.py b/src/careamics/dataset/in_memory_dataset.py index a8bc3efe1..4ff623548 100644 --- a/src/careamics/dataset/in_memory_dataset.py +++ b/src/careamics/dataset/in_memory_dataset.py @@ -85,13 +85,7 @@ def __init__( self.data = patches_data.patches self.data_targets = patches_data.targets - if ( - self.data_config.image_mean is not None - and not any(self.data_config.image_mean) - ) or ( - self.data_config.image_std is not None - and not any(self.data_config.image_std) - ): + if self.data_config.image_mean is None: self.image_means = patches_data.image_stats.means self.image_stds = patches_data.image_stats.stds logger.info( @@ -101,7 +95,7 @@ def __init__( self.image_means = self.data_config.image_mean self.image_stds = self.data_config.image_std - if not self.data_config.target_mean or not self.data_config.target_std: + if self.data_config.target_mean is None: self.target_means = patches_data.target_stats.means self.target_stds = patches_data.target_stats.stds else: @@ -118,7 +112,14 @@ def __init__( ) # get transforms self.patch_transform = Compose( - transform_list=[NormalizeModel(mean=self.mean, std=self.std)] + transform_list=[ + NormalizeModel( + image_means=self.image_means, + image_stds=self.image_stds, + target_means=self.target_means, + target_stds=self.target_stds, + ) + ] + self.data_config.transforms, ) diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index e0b0c78da..193553b6a 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -29,29 +29,19 @@ class PathIterableDataset(IterableDataset): ---------- data_config : DataConfig Data configuration. - src_files : list[Path] + src_files : list of pathlib.Path List of data files. - target_files : list[Path] or None, optional + target_files : list of pathlib.Path, optional Optional list of target files, by default None. read_source_func : Callable, optional Read source function for custom types, by default read_tiff. Attributes ---------- - data_path : list[Path] + data_path : list of pathlib.Path Path to the data, must be a directory. axes : str Description of axes in format STCZYX. - patch_size : list[int] or tuple[int] or None, optional - Size of the patches in each dimension, by default None. - patch_overlap : list[int] or tuple[int] or None, optional - Overlap of the patches in each dimension, by default None. - mean : float or None, optional - Expected mean of the dataset, by default None. - std : float or None, optional - Expected standard deviation of the dataset, by default None. - patch_transform : Callable or None, optional - Patch transform callable, by default None. """ def __init__( @@ -81,41 +71,47 @@ def __init__( self.read_source_func = read_source_func # compute mean and std over the dataset - # Only checking the image_mean because the DataConfig class ensures that + # only checking the image_mean because the DataConfig class ensures that # if image_mean is provided, image_std is also provided if not self.data_config.image_mean: - self.patches_data = self._calculate_mean_and_std() + self.data_stats = self._calculate_mean_and_std() logger.info( - f"Computed dataset mean: {self.patches_data.image_stats.means}," - f"std: {self.patches_data.image_stats.stds}" + f"Computed dataset mean: {self.data_stats.image_stats.means}," + f"std: {self.data_stats.image_stats.stds}" ) - else: - self.patches_data = StatsOutput( - Stats(self.data_config.image_mean, self.data_config.image_std), - Stats(self.data_config.target_mean, self.data_config.target_std), - ) - - if hasattr(self.data_config, "set_mean_and_std"): + # update the mean in the config self.data_config.set_mean_and_std( - image_mean=self.patches_data.image_stats.means, - image_std=self.patches_data.image_stats.stds, + image_mean=self.data_stats.image_stats.means, + image_std=self.data_stats.image_stats.stds, target_mean=( - tuple(self.patches_data.target_stats.means) - if self.patches_data.target_stats.means is not None - else [] + list(self.data_stats.target_stats.means) + if self.data_stats.target_stats.means is not None + else None ), target_std=( - tuple(self.patches_data.target_stats.stds) - if self.patches_data.target_stats.stds is not None - else [] + list(self.data_stats.target_stats.stds) + if self.data_stats.target_stats.stds is not None + else None ), ) - # get transforms + else: + # if mean and std are provided in the config, use them + self.data_stats = StatsOutput( + Stats(self.data_config.image_mean, self.data_config.image_std), + Stats(self.data_config.target_mean, self.data_config.target_std), + ) + + # create transform composed of normalization and other transforms self.patch_transform = Compose( transform_list=[ - NormalizeModel(mean=self.mean, std=self.std), + NormalizeModel( + image_means=self.data_stats.image_stats.means, + image_stds=self.data_stats.image_stats.stds, + target_means=self.data_stats.target_stats.means, + target_stds=self.data_stats.target_stats.stds, + ) ] + data_config.transforms ) @@ -155,7 +151,7 @@ def _calculate_mean_and_std(self) -> StatsOutput: if target is not None: target_means = np.mean(target_means, axis=0) - target_stds = np.mean([std**2 for std in image_stds], axis=0) + target_stds = np.mean([std**2 for std in target_stds], axis=0) logger.info(f"Calculated mean and std for {num_samples} images") logger.info(f"Mean: {image_means}, std: {image_stds}") @@ -179,8 +175,8 @@ def __iter__( Single patch. """ assert ( - self.patches_data.image_stats.means is not None - and self.patches_data.image_stats.stds is not None + self.data_stats.image_stats.means is not None + and self.data_stats.image_stats.stds is not None ), "Mean and std must be provided" # iterate over files diff --git a/tests/config/test_data_model.py b/tests/config/test_data_model.py index 7d07d1cd1..ced101d01 100644 --- a/tests/config/test_data_model.py +++ b/tests/config/test_data_model.py @@ -86,9 +86,11 @@ def test_set_mean_and_std(minimum_data: dict): def test_normalize_not_accepted(minimum_data: dict): """Test that normalize is not accepted, because it is mandatory and applied else where.""" - minimum_data["mean"] = 10.4 - minimum_data["std"] = 3.2 - minimum_data["transforms"] = [NormalizeModel(mean=0.485, std=0.229)] + minimum_data["image_means"] = [10.4] + minimum_data["image_stds"] = [3.2] + minimum_data["transforms"] = [ + NormalizeModel(image_means=[0.485], image_stds=[0.229]) + ] with pytest.raises(ValueError): DataConfig(**minimum_data) diff --git a/tests/config/transformations/test_normalize_model.py b/tests/config/transformations/test_normalize_model.py index c273dee4c..b5822cb8a 100644 --- a/tests/config/transformations/test_normalize_model.py +++ b/tests/config/transformations/test_normalize_model.py @@ -1,10 +1,12 @@ +import pytest + from careamics.config.transformations import NormalizeModel from careamics.transforms import Normalize def test_setting_image_means_std(): """Test that we can set the image_means and std values.""" - model = NormalizeModel(name="Normalize", image_means=[0.5], image_stds=[0.5]) + model = NormalizeModel(image_means=[0.5], image_stds=[0.5]) assert model.image_means == [0.5] assert model.image_stds == [0.5] @@ -14,7 +16,6 @@ def test_setting_image_means_std(): assert model.image_stds == [0.6] model = NormalizeModel( - name="Normalize", image_means=[0.5], image_stds=[0.5], target_means=[0.5], @@ -36,9 +37,49 @@ def test_setting_image_means_std(): assert model.target_stds == [0.6] +def test_error_different_length_means_stds(): + """Test that an error is raised if the image_means and stds have different + lengths.""" + with pytest.raises(ValueError): + NormalizeModel(image_means=[0.5], image_stds=[0.5, 0.6]) + + with pytest.raises(ValueError): + NormalizeModel(image_means=[0.5, 0.6], image_stds=[0.5]) + + with pytest.raises(ValueError): + NormalizeModel( + image_means=[0.5], + image_stds=[0.5], + target_means=[0.5], + ) + + with pytest.raises(ValueError): + NormalizeModel( + image_means=[0.5], + image_stds=[0.5], + target_stds=[0.5], + ) + + with pytest.raises(ValueError): + NormalizeModel( + image_means=[0.5], + image_stds=[0.5], + target_means=[0.5, 0.6], + target_stds=[0.5], + ) + + with pytest.raises(ValueError): + NormalizeModel( + image_means=[0.5], + image_stds=[0.5], + target_means=[0.5], + target_stds=[0.5, 0.6], + ) + + def test_comptatibility_with_transform(): """Test that the model allows instantiating a transform.""" - model = NormalizeModel(name="Normalize", image_means=[0.5], image_stds=[0.5]) + model = NormalizeModel(image_means=[0.5], image_stds=[0.5]) # instantiate transform Normalize(**model.model_dump())