Skip to content

Commit

Permalink
(chore): rename image means and stds with plural value
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 14, 2024
1 parent 07fb84e commit 128418d
Show file tree
Hide file tree
Showing 21 changed files with 146 additions and 140 deletions.
8 changes: 4 additions & 4 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ def create_inference_configuration(
Configuration used to configure CAREamicsPredictData.
"""
if (
configuration.data_config.image_mean is None
or configuration.data_config.image_std is None
configuration.data_config.image_means is None
or configuration.data_config.image_stds is None
):
raise ValueError("Mean and std must be provided in the configuration.")

Expand Down Expand Up @@ -637,8 +637,8 @@ def create_inference_configuration(
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes or configuration.data_config.axes,
image_mean=configuration.data_config.image_mean,
image_std=configuration.data_config.image_std,
image_means=configuration.data_config.image_means,
image_stds=configuration.data_config.image_stds,
tta_transforms=tta_transforms,
batch_size=batch_size,
)
80 changes: 42 additions & 38 deletions src/careamics/config/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pprint import pformat
from typing import Any, Literal, Optional, Union

import numpy as np
from numpy.typing import NDArray
from pydantic import (
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -56,7 +56,7 @@ class DataConfig(BaseModel):
... )
To change the mean and std of the data:
>>> data.set_mean_and_std(image_mean=[214.3], image_std=[84.5])
>>> data.set_mean_and_std(image_means=[214.3], image_stds=[84.5])
One can pass also a list of transformations, by keyword, using the
SupportedTransform value:
Expand Down Expand Up @@ -86,12 +86,16 @@ class DataConfig(BaseModel):
axes: str

# Optional fields
image_mean: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
image_std: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
target_mean: Optional[list[float]] = Field(
image_means: Optional[list[float]] = Field(
default=None, min_length=0, max_length=32
)
image_stds: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)
target_means: Optional[list[float]] = Field(
default=None, min_length=0, max_length=32
)
target_stds: Optional[list[float]] = Field(
default=None, min_length=0, max_length=32
)
target_std: Optional[list[float]] = Field(default=None, min_length=0, max_length=32)

transforms: list[TRANSFORMS_UNION] = Field(
default=[
Expand Down Expand Up @@ -232,29 +236,29 @@ def std_only_with_mean(self: Self) -> Self:
If std is not None and mean is None.
"""
# check that mean and std are either both None, or both specified
if (self.image_mean and not self.image_std) or (
self.image_std and not self.image_mean
if (self.image_means and not self.image_stds) or (
self.image_stds and not self.image_means
):
raise ValueError(
"Mean and std must be either both None, or both specified."
)

elif (self.image_mean is not None and self.image_std is not None) and (
len(self.image_mean) != len(self.image_std)
elif (self.image_means is not None and self.image_stds is not None) and (
len(self.image_means) != len(self.image_stds)
):
raise ValueError(
"Mean and std must be specified for each " "input channel."
)

if (self.target_mean and not self.target_std) or (
self.target_std and not self.target_mean
if (self.target_means and not self.target_stds) or (
self.target_stds and not self.target_means
):
raise ValueError(
"Mean and std must be either both None, or both specified "
)

elif self.target_mean is not None and self.target_std is not None:
if len(self.target_mean) != len(self.target_std):
elif self.target_means is not None and self.target_stds is not None:
if len(self.target_means) != len(self.target_stds):
raise ValueError(
"Mean and std must be either both None, or both specified for each "
"target channel."
Expand Down Expand Up @@ -344,10 +348,10 @@ def remove_n2v_manipulate(self) -> None:

def set_mean_and_std(
self,
image_mean: Union[np.ndarray, tuple, list, None],
image_std: Union[np.ndarray, tuple, list, None],
target_mean: Optional[Union[np.ndarray, tuple, list, None]] = None,
target_std: Optional[Union[np.ndarray, tuple, list, None]] = None,
image_means: Union[NDArray, tuple, list, None],
image_stds: Union[NDArray, tuple, list, None],
target_means: Optional[Union[NDArray, tuple, list, None]] = None,
target_stds: Optional[Union[NDArray, tuple, list, None]] = None,
) -> None:
"""
Set mean and standard deviation of the data.
Expand All @@ -357,30 +361,30 @@ def set_mean_and_std(
Parameters
----------
image_mean : np.ndarray or tuple or list
Mean value for normalization.
image_std : np.ndarray or tuple or list
Standard deviation value for normalization.
target_mean : np.ndarray or tuple or list, optional
Target mean value for normalization, by default ().
target_std : np.ndarray or tuple or list, optional
Target standard deviation value for normalization, by default ().
image_means : NDArray or tuple or list
Mean values for normalization.
image_stds : NDArray or tuple or list
Standard deviation values for normalization.
target_means : NDArray or tuple or list, optional
Target mean values for normalization, by default ().
target_stds : NDArray or tuple or list, optional
Target standard deviation values for normalization, by default ().
"""
# make sure we pass a list
if image_mean is not None:
image_mean = list(image_mean)
if image_std is not None:
image_std = list(image_std)
if target_mean is not None:
target_mean = list(target_mean)
if target_std is not None:
target_std = list(target_std)
if image_means is not None:
image_means = list(image_means)
if image_stds is not None:
image_stds = list(image_stds)
if target_means is not None:
target_means = list(target_means)
if target_stds is not None:
target_stds = list(target_stds)

self._update(
image_mean=image_mean,
image_std=image_std,
target_mean=target_mean,
target_std=target_std,
image_means=image_means,
image_stds=image_stds,
target_means=target_means,
target_stds=target_stds,
)

def set_3D(self, axes: str, patch_size: list[int]) -> None:
Expand Down
14 changes: 7 additions & 7 deletions src/careamics/config/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class InferenceConfig(BaseModel):

axes: str

image_mean: list = Field(..., min_length=0, max_length=32)
image_std: list = Field(..., min_length=0, max_length=32)
image_means: list = Field(..., min_length=0, max_length=32)
image_stds: list = Field(..., min_length=0, max_length=32)

# only default TTAs are supported for now
tta_transforms: bool = Field(default=True)
Expand Down Expand Up @@ -182,18 +182,18 @@ def std_only_with_mean(self: Self) -> Self:
If std is not None and mean is None.
"""
# check that mean and std are either both None, or both specified
if not self.image_mean and not self.image_std:
if not self.image_means and not self.image_stds:
raise ValueError("Mean and std must be specified during inference.")

if (self.image_mean and not self.image_std) or (
self.image_std and not self.image_mean
if (self.image_means and not self.image_stds) or (
self.image_stds and not self.image_means
):
raise ValueError(
"Mean and std must be either both None, or both specified."
)

elif (self.image_mean is not None and self.image_std is not None) and (
len(self.image_mean) != len(self.image_std)
elif (self.image_means is not None and self.image_stds is not None) and (
len(self.image_means) != len(self.image_stds)
):
raise ValueError(
"Mean and std must be specified for each " "input channel."
Expand Down
20 changes: 10 additions & 10 deletions src/careamics/dataset/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,30 @@ def __init__(
self.data = patches_data.patches
self.data_targets = patches_data.targets

if self.data_config.image_mean is None:
if self.data_config.image_means is None:
self.image_means = patches_data.image_stats.means
self.image_stds = patches_data.image_stats.stds
logger.info(
f"Computed dataset mean: {self.image_means}, std: {self.image_stds}"
)
else:
self.image_means = self.data_config.image_mean
self.image_stds = self.data_config.image_std
self.image_means = self.data_config.image_means
self.image_stds = self.data_config.image_stds

if self.data_config.target_mean is None:
if self.data_config.target_means is None:
self.target_means = patches_data.target_stats.means
self.target_stds = patches_data.target_stats.stds
else:
self.target_means = self.data_config.target_mean
self.target_stds = self.data_config.target_std
self.target_means = self.data_config.target_means
self.target_stds = self.data_config.target_stds

# update mean and std in configuration
# the object is mutable and should then be recorded in the CAREamist obj
self.data_config.set_mean_and_std(
image_mean=self.image_means,
image_std=self.image_stds,
target_mean=self.target_means,
target_std=self.target_stds,
image_means=self.image_means,
image_stds=self.image_stds,
target_means=self.target_means,
target_stds=self.target_stds,
)
# get transforms
self.patch_transform = Compose(
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/dataset/in_memory_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(
self.pred_config = prediction_config
self.input_array = inputs
self.axes = self.pred_config.axes
self.image_means = self.pred_config.image_mean
self.image_stds = self.pred_config.image_std
self.image_means = self.pred_config.image_means
self.image_stds = self.pred_config.image_stds

# Reshape data
self.data = reshape_array(self.input_array, self.axes)
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/dataset/in_memory_tiled_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(
self.axes = self.pred_config.axes
self.tile_size = prediction_config.tile_size
self.tile_overlap = prediction_config.tile_overlap
self.image_means = self.pred_config.image_mean
self.image_stds = self.pred_config.image_std
self.image_means = self.pred_config.image_means
self.image_stds = self.pred_config.image_stds

# Generate patches
self.data = self._prepare_tiles()
Expand Down
14 changes: 7 additions & 7 deletions src/careamics/dataset/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
# compute mean and std over the dataset
# only checking the image_mean because the DataConfig class ensures that
# if image_mean is provided, image_std is also provided
if not self.data_config.image_mean:
if not self.data_config.image_means:
self.data_stats = self._calculate_mean_and_std()
logger.info(
f"Computed dataset mean: {self.data_stats.image_stats.means},"
Expand All @@ -82,14 +82,14 @@ def __init__(

# update the mean in the config
self.data_config.set_mean_and_std(
image_mean=self.data_stats.image_stats.means,
image_std=self.data_stats.image_stats.stds,
target_mean=(
image_means=self.data_stats.image_stats.means,
image_stds=self.data_stats.image_stats.stds,
target_means=(
list(self.data_stats.target_stats.means)
if self.data_stats.target_stats.means is not None
else None
),
target_std=(
target_stds=(
list(self.data_stats.target_stats.stds)
if self.data_stats.target_stats.stds is not None
else None
Expand All @@ -99,8 +99,8 @@ def __init__(
else:
# if mean and std are provided in the config, use them
self.data_stats = StatsOutput(
Stats(self.data_config.image_mean, self.data_config.image_std),
Stats(self.data_config.target_mean, self.data_config.target_std),
Stats(self.data_config.image_means, self.data_config.image_stds),
Stats(self.data_config.target_means, self.data_config.target_stds),
)

# create transform composed of normalization and other transforms
Expand Down
8 changes: 4 additions & 4 deletions src/careamics/dataset/iterable_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ def __init__(

# check mean and std and create normalize transform
if (
self.prediction_config.image_mean is None
or self.prediction_config.image_std is None
self.prediction_config.image_means is None
or self.prediction_config.image_stds is None
):
raise ValueError("Mean and std must be provided for prediction.")
else:
self.image_means = self.prediction_config.image_mean
self.image_stds = self.prediction_config.image_std
self.image_means = self.prediction_config.image_means
self.image_stds = self.prediction_config.image_stds

# instantiate normalize transform
self.patch_transform = Compose(
Expand Down
8 changes: 4 additions & 4 deletions src/careamics/dataset/iterable_tiled_pred_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ def __init__(

# check mean and std and create normalize transform
if (
self.prediction_config.image_mean is None
or self.prediction_config.image_std is None
self.prediction_config.image_means is None
or self.prediction_config.image_stds is None
):
raise ValueError("Mean and std must be provided for prediction.")
else:
self.image_means = self.prediction_config.image_mean
self.image_stds = self.prediction_config.image_std
self.image_means = self.prediction_config.image_means
self.image_stds = self.prediction_config.image_stds

# instantiate normalize transform
self.patch_transform = Compose(
Expand Down
26 changes: 13 additions & 13 deletions src/careamics/lightning_prediction_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Prediction Lightning data modules."""

from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union

import numpy as np
import pytorch_lightning as L
Expand Down Expand Up @@ -271,10 +271,10 @@ class PredictDataWrapper(CAREamicsPredictData):
Prediction data.
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
Data type, see `SupportedData` for available options.
image_mean : float
Mean value for normalization, only used if Normalization is defined.
image_std : float
Std value for normalization, only used if Normalization is defined.
image_means : list of float
Mean values for normalization, only used if Normalization is defined.
image_stds : list of float
Std values for normalization, only used if Normalization is defined.
tile_size : Tuple[int, ...]
Tile size, 2D or 3D tile size.
tile_overlap : Tuple[int, ...]
Expand All @@ -298,8 +298,8 @@ def __init__(
self,
pred_data: Union[str, Path, np.ndarray],
data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
image_mean=List,
image_std=List,
image_means=list[float],
image_stds=list[float],
tile_size: Optional[Tuple[int, ...]] = None,
tile_overlap: Optional[Tuple[int, ...]] = None,
axes: str = "YX",
Expand All @@ -318,10 +318,10 @@ def __init__(
Prediction data.
data_type : Union[Literal["array", "tiff", "custom"], SupportedData]
Data type, see `SupportedData` for available options.
image_mean : float
Mean value for normalization, only used if Normalization is defined.
image_std : float
Std value for normalization, only used if Normalization is defined.
image_means : list of float
Mean values for normalization, only used if Normalization is defined.
image_stds : list of float
Std values for normalization, only used if Normalization is defined.
tile_size : List[int]
Tile size, 2D or 3D tile size.
tile_overlap : List[int]
Expand All @@ -347,8 +347,8 @@ def __init__(
"tile_size": tile_size,
"tile_overlap": tile_overlap,
"axes": axes,
"image_mean": image_mean,
"image_std": image_std,
"image_means": image_means,
"image_stds": image_stds,
"tta": tta_transforms,
"batch_size": batch_size,
"transforms": [],
Expand Down
Loading

0 comments on commit 128418d

Please sign in to comment.