From ab9767103fb7eec8800d5f8354e5127df8057de2 Mon Sep 17 00:00:00 2001 From: CatEek Date: Tue, 18 Jun 2024 17:18:57 +0200 Subject: [PATCH 1/7] draft + test --- .../dataset/dataset_utils/__init__.py | 9 +++++- .../dataset/dataset_utils/dataset_utils.py | 28 ++++++++++++++++ src/careamics/dataset/iterable_dataset.py | 32 +++++++++++++++---- tests/dataset/test_iterable_dataset.py | 31 ++++++++---------- 4 files changed, 76 insertions(+), 24 deletions(-) diff --git a/src/careamics/dataset/dataset_utils/__init__.py b/src/careamics/dataset/dataset_utils/__init__.py index 69db09e87..667ddf443 100644 --- a/src/careamics/dataset/dataset_utils/__init__.py +++ b/src/careamics/dataset/dataset_utils/__init__.py @@ -3,6 +3,8 @@ __all__ = [ "reshape_array", "compute_normalization_stats", + "update_iterative_stats", + "finalize_iterative_stats", "get_files_size", "list_files", "validate_source_target_files", @@ -13,7 +15,12 @@ ] -from .dataset_utils import compute_normalization_stats, reshape_array +from .dataset_utils import ( + compute_normalization_stats, + finalize_iterative_stats, + reshape_array, + update_iterative_stats, +) from .file_utils import get_files_size, list_files, validate_source_target_files from .iterate_over_files import iterate_over_files from .read_tiff import read_tiff diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py index 6da4f122a..5d524b943 100644 --- a/src/careamics/dataset/dataset_utils/dataset_utils.py +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -121,3 +121,31 @@ def compute_normalization_stats(image: np.ndarray) -> Tuple[np.ndarray, np.ndarr # Define the list of axes excluding the channel axis axes = tuple(np.delete(np.arange(image.ndim), 1)) return np.mean(image, axis=axes), np.std(image, axis=axes) + + +def update_iterative_stats(count, mean, m2, new_values): + count += np.array([len(arr.flatten()) for arr in new_values]) + # newvalues - oldMean + delta = [ + np.subtract(v.flatten(), [m] * len(v.flatten())) + for v, m in zip(new_values, mean) + ] + + mean += np.array([np.sum(d / c) for d, c in zip(delta, count)]) + # newvalues - newMeant + delta2 = [ + np.subtract(v.flatten(), [m] * len(v.flatten())) + for v, m in zip(new_values, mean) + ] + + m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)]) + + return (count, mean, m2) + + +def finalize_iterative_stats(count, mean, m2): + std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)]) + if any(c < 2 for c in count): + return float("nan"), float("nan") + else: + return mean, std diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 62cc6b2c3..0dd277339 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -15,7 +15,13 @@ from careamics.transforms import Compose from ..utils.logging import get_logger -from .dataset_utils import compute_normalization_stats, iterate_over_files, read_tiff +from .dataset_utils import ( + compute_normalization_stats, + iterate_over_files, + read_tiff, + update_iterative_stats, + finalize_iterative_stats, +) from .patching.patching import Stats, StatsOutput from .patching.random_patching import extract_patches_random @@ -134,9 +140,24 @@ def _calculate_mean_and_std(self) -> StatsOutput: for sample, target in iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): - sample_mean, sample_std = compute_normalization_stats(sample) - image_means.append(sample_mean) - image_stds.append(sample_std) + init_mean, _ = compute_normalization_stats(sample) + # separate channels + sample_channels = np.array(np.split(sample, sample.shape[1], axis=1)) + + counts, m2s = [], [] + if num_samples == 0: + mean = init_mean + + count = np.array([np.prod(channel.shape) for channel in sample_channels]) + m2 = np.array([ + np.sum(np.subtract(channel.flatten(), [mean[i]] * count[i]) ** 2) + for i, channel in enumerate(sample_channels) + ]) + + else: + count, mean, m2 = update_iterative_stats( + count, mean, m2, sample_channels + ) if target is not None: target_mean, target_std = compute_normalization_stats(target) @@ -149,8 +170,7 @@ def _calculate_mean_and_std(self) -> StatsOutput: raise ValueError("No samples found in the dataset.") # Average the means and stds per sample - image_means = np.mean(image_means, axis=0) - image_stds = np.sqrt(np.mean([std**2 for std in image_stds], axis=0)) + image_means, image_stds = finalize_iterative_stats(count, mean, m2) if target is not None: target_means = np.mean(target_means, axis=0) diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py index bb233786d..dd49910dd 100644 --- a/tests/dataset/test_iterable_dataset.py +++ b/tests/dataset/test_iterable_dataset.py @@ -145,29 +145,21 @@ def test_extracting_val_files(tmp_path, ordered_array, percentage): ((32, 32, 32), "ZYX", (8, 8, 8)), ], ) -def test_compute_mean_std_transform_iterable( - tmp_path, ordered_array, shape, axes, patch_size -): +def test_compute_mean_std_transform_iterable(tmp_path, shape, axes, patch_size): """Test that mean and std are computed and correctly added to the configuration and transform.""" - # create array - n_files = 3 - array = ordered_array(shape) - - # save three files + n_files = 100 files = [] + array = np.random.randint(0, np.iinfo(np.uint16).max, (n_files, *shape)) - # create test array with channel axis - if "C" not in axes: - stacked_array = np.stack([array] * n_files)[:, np.newaxis, ...] - else: - stacked_array = np.stack([array] * n_files) for i in range(n_files): file = tmp_path / f"array{i}.tif" - tifffile.imwrite(file, array) + tifffile.imwrite(file, array[i]) files.append(file) + array = array[:, np.newaxis, ...] if "C" not in axes else array + # create config config_dict = { "data_type": SupportedData.TIFF.value, @@ -181,7 +173,12 @@ def test_compute_mean_std_transform_iterable( data_config=config, src_files=files, read_source_func=read_tiff ) - axes = tuple(np.delete(np.arange(stacked_array.ndim), 1)) + # define axes for mean and std computation + stats_axes = tuple(np.delete(np.arange(array.ndim), 1)) - assert np.array_equal(stacked_array.mean(axis=axes), dataset.data_config.image_mean) - assert np.array_equal(stacked_array.std(axis=axes), dataset.data_config.image_std) + assert np.allclose( + array.mean(axis=stats_axes), dataset.data_config.image_mean + ) + assert np.allclose( + array.std(axis=stats_axes), dataset.data_config.image_std + ) From c9639c882398ed28f03644b363e1f7b2d8195d44 Mon Sep 17 00:00:00 2001 From: CatEek Date: Tue, 18 Jun 2024 18:41:38 +0200 Subject: [PATCH 2/7] types, docstrings --- .../dataset/dataset_utils/dataset_utils.py | 52 ++++++++++++++++--- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py index 5d524b943..b8305b829 100644 --- a/src/careamics/dataset/dataset_utils/dataset_utils.py +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -123,27 +123,63 @@ def compute_normalization_stats(image: np.ndarray) -> Tuple[np.ndarray, np.ndarr return np.mean(image, axis=axes), np.std(image, axis=axes) -def update_iterative_stats(count, mean, m2, new_values): +def update_iterative_stats( + count: np.ndarray, mean: np.ndarray, m2: np.ndarray, new_values: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Update the mean and variance of an array iteratively. + + Parameters + ---------- + count : np.ndarray + Number of elements in the array. + mean : np.ndarray + Mean of the array. + m2 : np.ndarray + Variance of the array. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + Updated count, mean, and variance. + """ count += np.array([len(arr.flatten()) for arr in new_values]) # newvalues - oldMean delta = [ - np.subtract(v.flatten(), [m] * len(v.flatten())) - for v, m in zip(new_values, mean) - ] + np.subtract(v.flatten(), [m] * len(v.flatten())) + for v, m in zip(new_values, mean) + ] mean += np.array([np.sum(d / c) for d, c in zip(delta, count)]) # newvalues - newMeant delta2 = [ - np.subtract(v.flatten(), [m] * len(v.flatten())) - for v, m in zip(new_values, mean) - ] + np.subtract(v.flatten(), [m] * len(v.flatten())) + for v, m in zip(new_values, mean) + ] m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)]) return (count, mean, m2) -def finalize_iterative_stats(count, mean, m2): +def finalize_iterative_stats( + count: np.ndarray, mean: np.ndarray, m2: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """Finalize the mean and variance computation. + + Parameters + ---------- + count : np.ndarray + Number of elements in the array. + mean : np.ndarray + Mean of the array. + m2 : np.ndarray + Variance of the array. + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Final mean and standard deviation. + """ std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)]) if any(c < 2 for c in count): return float("nan"), float("nan") From 8bd270e9bee404d9c0296e99d7880d01ca1f3dc2 Mon Sep 17 00:00:00 2001 From: CatEek Date: Tue, 18 Jun 2024 18:54:12 +0200 Subject: [PATCH 3/7] target stats, test --- src/careamics/dataset/iterable_dataset.py | 58 ++++++++++++++------- tests/dataset/test_iterable_dataset.py | 62 +++++++++++++++++++++-- 2 files changed, 99 insertions(+), 21 deletions(-) diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 0dd277339..3426f86bb 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -17,10 +17,10 @@ from ..utils.logging import get_logger from .dataset_utils import ( compute_normalization_stats, + finalize_iterative_stats, iterate_over_files, read_tiff, update_iterative_stats, - finalize_iterative_stats, ) from .patching.patching import Stats, StatsOutput from .patching.random_patching import extract_patches_random @@ -131,38 +131,59 @@ def _calculate_mean_and_std(self) -> StatsOutput: PatchedOutput Data class containing the image statistics. """ - image_means = [] - image_stds = [] - target_means = [] - target_stds = [] num_samples = 0 for sample, target in iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): + # compute mean and std for each sample init_mean, _ = compute_normalization_stats(sample) # separate channels sample_channels = np.array(np.split(sample, sample.shape[1], axis=1)) - counts, m2s = [], [] if num_samples == 0: mean = init_mean - - count = np.array([np.prod(channel.shape) for channel in sample_channels]) - m2 = np.array([ - np.sum(np.subtract(channel.flatten(), [mean[i]] * count[i]) ** 2) - for i, channel in enumerate(sample_channels) - ]) - + count = np.array( + [np.prod(channel.shape) for channel in sample_channels] + ) + m2 = np.array( + [ + np.sum( + np.subtract(channel.flatten(), [mean[i]] * count[i]) ** 2 + ) + for i, channel in enumerate(sample_channels) + ] + ) else: count, mean, m2 = update_iterative_stats( count, mean, m2, sample_channels ) if target is not None: - target_mean, target_std = compute_normalization_stats(target) - target_means.append(target_mean) - target_stds.append(target_std) + target_init_mean, _ = compute_normalization_stats(target) + target_channels = np.array(np.split(target, target.shape[1], axis=1)) + if num_samples == 0: + target_mean = target_init_mean + target_count = np.array( + [np.prod(channel.shape) for channel in target_channels] + ) + target_m2 = np.array( + [ + np.sum( + np.subtract( + channel.flatten(), + [target_mean[i]] * target_count[i], + ) + ** 2 + ) + for i, channel in enumerate(target_channels) + ] + ) + + else: + target_count, target_mean, target_m2 = update_iterative_stats( + target_count, target_mean, target_m2, target_channels + ) num_samples += 1 @@ -173,8 +194,9 @@ def _calculate_mean_and_std(self) -> StatsOutput: image_means, image_stds = finalize_iterative_stats(count, mean, m2) if target is not None: - target_means = np.mean(target_means, axis=0) - target_stds = np.sqrt(np.mean([std**2 for std in target_stds], axis=0)) + target_means, target_stds = finalize_iterative_stats( + target_count, target_mean, target_m2 + ) logger.info(f"Calculated mean and std for {num_samples} images") logger.info(f"Mean: {image_means}, std: {image_stds}") diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py index dd49910dd..dac4da735 100644 --- a/tests/dataset/test_iterable_dataset.py +++ b/tests/dataset/test_iterable_dataset.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest import tifffile @@ -152,7 +154,6 @@ def test_compute_mean_std_transform_iterable(tmp_path, shape, axes, patch_size): files = [] array = np.random.randint(0, np.iinfo(np.uint16).max, (n_files, *shape)) - for i in range(n_files): file = tmp_path / f"array{i}.tif" tifffile.imwrite(file, array[i]) @@ -176,9 +177,64 @@ def test_compute_mean_std_transform_iterable(tmp_path, shape, axes, patch_size): # define axes for mean and std computation stats_axes = tuple(np.delete(np.arange(array.ndim), 1)) + assert np.allclose(array.mean(axis=stats_axes), dataset.data_config.image_mean) + assert np.allclose(array.std(axis=stats_axes), dataset.data_config.image_std) + + +@pytest.mark.parametrize( + "shape, axes, patch_size", + [ + ((32, 32), "YX", (8, 8)), + ((2, 32, 32), "CYX", (8, 8)), + ((32, 32, 32), "ZYX", (8, 8, 8)), + ], +) +def test_compute_mean_std_transform_iterable_with_targets( + tmp_path, shape, axes, patch_size +): + """Test that mean and std are computed and correctly added to the configuration + and transform.""" + n_files = 100 + files = [] + target_files = [] + array = np.random.randint(0, np.iinfo(np.uint16).max, (n_files, *shape)) + target_array = np.random.randint(0, np.iinfo(np.uint16).max, (n_files, *shape)) + + for i in range(n_files): + file = tmp_path / "images" / f"array{i}.tif" + target_file = tmp_path / "targets" / f"array{i}.tif" + os.makedirs(file.parent, exist_ok=True) + os.makedirs(target_file.parent, exist_ok=True) + tifffile.imwrite(file, array[i]) + tifffile.imwrite(target_file, target_array[i]) + files.append(file) + target_files.append(target_file) + + array = array[:, np.newaxis, ...] if "C" not in axes else array + target_array = target_array[:, np.newaxis, ...] if "C" not in axes else target_array + + # create config + config_dict = { + "data_type": SupportedData.TIFF.value, + "patch_size": patch_size, + "axes": axes, + } + config = DataConfig(**config_dict) + + # create dataset + dataset = PathIterableDataset( + data_config=config, + src_files=files, + target_files=target_files, + read_source_func=read_tiff, + ) + + # define axes for mean and std computation + stats_axes = tuple(np.delete(np.arange(array.ndim), 1)) + assert np.allclose( - array.mean(axis=stats_axes), dataset.data_config.image_mean + target_array.mean(axis=stats_axes), dataset.data_config.target_mean ) assert np.allclose( - array.std(axis=stats_axes), dataset.data_config.image_std + target_array.std(axis=stats_axes), dataset.data_config.target_std ) From a127edf5912ecf9f381574dd855d0b852853a73a Mon Sep 17 00:00:00 2001 From: CatEek Date: Wed, 19 Jun 2024 11:57:11 +0200 Subject: [PATCH 4/7] mypy fix --- src/careamics/dataset/dataset_utils/dataset_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py index b8305b829..29d244b29 100644 --- a/src/careamics/dataset/dataset_utils/dataset_utils.py +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -136,6 +136,8 @@ def update_iterative_stats( Mean of the array. m2 : np.ndarray Variance of the array. + new_values : np.ndarray + New values to add to the mean and variance. Returns ------- @@ -182,6 +184,6 @@ def finalize_iterative_stats( """ std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)]) if any(c < 2 for c in count): - return float("nan"), float("nan") + return np.full(mean.shape, np.nan), np.full(std.shape, np.nan) else: return mean, std From 4ccdc25e9ab6e618830309c41d8daba432dfcc26 Mon Sep 17 00:00:00 2001 From: CatEek Date: Mon, 24 Jun 2024 12:12:36 +0200 Subject: [PATCH 5/7] stats calculation as a class refac --- .../dataset/dataset_utils/__init__.py | 5 +- .../dataset/dataset_utils/dataset_utils.py | 88 --------- .../dataset/dataset_utils/running_stats.py | 183 ++++++++++++++++++ src/careamics/dataset/iterable_dataset.py | 63 +----- src/careamics/dataset/patching/patching.py | 3 +- src/careamics/utils/running_stats.py | 43 ---- .../test_compute_normalization_stats.py | 2 +- tests/transforms/test_compose.py | 2 +- tests/transforms/test_normalize.py | 2 +- 9 files changed, 200 insertions(+), 191 deletions(-) create mode 100644 src/careamics/dataset/dataset_utils/running_stats.py delete mode 100644 src/careamics/utils/running_stats.py diff --git a/src/careamics/dataset/dataset_utils/__init__.py b/src/careamics/dataset/dataset_utils/__init__.py index 667ddf443..35fe1a75d 100644 --- a/src/careamics/dataset/dataset_utils/__init__.py +++ b/src/careamics/dataset/dataset_utils/__init__.py @@ -12,17 +12,16 @@ "get_read_func", "read_zarr", "iterate_over_files", + "WelfordStatistics", ] from .dataset_utils import ( - compute_normalization_stats, - finalize_iterative_stats, reshape_array, - update_iterative_stats, ) from .file_utils import get_files_size, list_files, validate_source_target_files from .iterate_over_files import iterate_over_files from .read_tiff import read_tiff from .read_utils import get_read_func from .read_zarr import read_zarr +from .running_stats import WelfordStatistics, compute_normalization_stats diff --git a/src/careamics/dataset/dataset_utils/dataset_utils.py b/src/careamics/dataset/dataset_utils/dataset_utils.py index 29d244b29..ebaed0d46 100644 --- a/src/careamics/dataset/dataset_utils/dataset_utils.py +++ b/src/careamics/dataset/dataset_utils/dataset_utils.py @@ -99,91 +99,3 @@ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray: _x = np.expand_dims(_x, new_axes.index("S") + 1) return _x - - -def compute_normalization_stats(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - """ - Compute mean and standard deviation of an array. - - Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are - computed per channel. - - Parameters - ---------- - image : np.ndarray - Input array. - - Returns - ------- - Tuple[List[float], List[float]] - Lists of mean and standard deviation values per channel. - """ - # Define the list of axes excluding the channel axis - axes = tuple(np.delete(np.arange(image.ndim), 1)) - return np.mean(image, axis=axes), np.std(image, axis=axes) - - -def update_iterative_stats( - count: np.ndarray, mean: np.ndarray, m2: np.ndarray, new_values: np.ndarray -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Update the mean and variance of an array iteratively. - - Parameters - ---------- - count : np.ndarray - Number of elements in the array. - mean : np.ndarray - Mean of the array. - m2 : np.ndarray - Variance of the array. - new_values : np.ndarray - New values to add to the mean and variance. - - Returns - ------- - Tuple[np.ndarray, np.ndarray, np.ndarray] - Updated count, mean, and variance. - """ - count += np.array([len(arr.flatten()) for arr in new_values]) - # newvalues - oldMean - delta = [ - np.subtract(v.flatten(), [m] * len(v.flatten())) - for v, m in zip(new_values, mean) - ] - - mean += np.array([np.sum(d / c) for d, c in zip(delta, count)]) - # newvalues - newMeant - delta2 = [ - np.subtract(v.flatten(), [m] * len(v.flatten())) - for v, m in zip(new_values, mean) - ] - - m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)]) - - return (count, mean, m2) - - -def finalize_iterative_stats( - count: np.ndarray, mean: np.ndarray, m2: np.ndarray -) -> Tuple[np.ndarray, np.ndarray]: - """Finalize the mean and variance computation. - - Parameters - ---------- - count : np.ndarray - Number of elements in the array. - mean : np.ndarray - Mean of the array. - m2 : np.ndarray - Variance of the array. - - Returns - ------- - Tuple[np.ndarray, np.ndarray] - Final mean and standard deviation. - """ - std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)]) - if any(c < 2 for c in count): - return np.full(mean.shape, np.nan), np.full(std.shape, np.nan) - else: - return mean, std diff --git a/src/careamics/dataset/dataset_utils/running_stats.py b/src/careamics/dataset/dataset_utils/running_stats.py new file mode 100644 index 000000000..6977c1478 --- /dev/null +++ b/src/careamics/dataset/dataset_utils/running_stats.py @@ -0,0 +1,183 @@ +"""Computing data statistics.""" + +import numpy as np +from numpy.typing import NDArray + + +def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]: + """ + Compute mean and standard deviation of an array. + + Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are + computed per channel. + + Parameters + ---------- + image : NDArray + Input array. + + Returns + ------- + tuple[List[float], List[float]] + Lists of mean and standard deviation values per channel. + """ + # Define the list of axes excluding the channel axis + axes = tuple(np.delete(np.arange(image.ndim), 1)) + return np.mean(image, axis=axes), np.std(image, axis=axes) + + +def update_iterative_stats( + count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray +) -> tuple[NDArray, NDArray, NDArray]: + """Update the mean and variance of an array iteratively. + + Parameters + ---------- + count : NDArray + Number of elements in the array. + mean : NDArray + Mean of the array. + m2 : NDArray + Variance of the array. + new_values : NDArray + New values to add to the mean and variance. + + Returns + ------- + tuple[NDArray, NDArray, NDArray] + Updated count, mean, and variance. + """ + count += np.array([len(arr.flatten()) for arr in new_values]) + # newvalues - oldMean + delta = [ + np.subtract(v.flatten(), [m] * len(v.flatten())) + for v, m in zip(new_values, mean) + ] + + mean += np.array([np.sum(d / c) for d, c in zip(delta, count)]) + # newvalues - newMeant + delta2 = [ + np.subtract(v.flatten(), [m] * len(v.flatten())) + for v, m in zip(new_values, mean) + ] + + m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)]) + + return (count, mean, m2) + + +def finalize_iterative_stats( + count: NDArray, mean: NDArray, m2: NDArray +) -> tuple[NDArray, NDArray]: + """Finalize the mean and variance computation. + + Parameters + ---------- + count : NDArray + Number of elements in the array. + mean : NDArray + Mean of the array. + m2 : NDArray + Variance of the array. + + Returns + ------- + tuple[NDArray, NDArray] + Final mean and standard deviation. + """ + std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)]) + if any(c < 2 for c in count): + return np.full(mean.shape, np.nan), np.full(std.shape, np.nan) + else: + return mean, std + + +class WelfordStatistics: + """Compute Welford statistics iteratively.""" + + def update(self, array: NDArray, num_samples: int) -> None: + """Update the Welford statistics. + + Parameters + ---------- + array : NDArray + Input array. + num_samples : int + Current sample number. + """ + self.num_samples = num_samples + sample_channels = np.array(np.split(array, array.shape[1], axis=1)) + + if self.num_samples == 0: + self.mean, _ = compute_normalization_stats(array) + self.count = np.array( + [np.prod(channel.shape) for channel in sample_channels] + ) + self.m2 = np.array( + [ + np.sum( + np.subtract(channel.flatten(), [self.mean[i]] * self.count[i]) + ** 2 + ) + for i, channel in enumerate(sample_channels) + ] + ) + else: + self.count, self.mean, self.m2 = update_iterative_stats( + self.count, self.mean, self.m2, sample_channels + ) + + self.num_samples += 1 + + def finalize(self) -> tuple[NDArray, NDArray]: + """Finalize the Welford statistics. + + Returns + ------- + tuple[NDArray, NDArray] + Final mean and standard deviation. + """ + return finalize_iterative_stats(self.count, self.mean, self.m2) + + +# from multiprocessing import Value +# from typing import tuple + +# import numpy as np + + +# class RunningStats: +# """Calculates running mean and std.""" + +# def __init__(self) -> None: +# self.reset() + +# def reset(self) -> None: +# """Reset the running stats.""" +# self.avg_mean = Value("d", 0) +# self.avg_std = Value("d", 0) +# self.m2 = Value("d", 0) +# self.count = Value("i", 0) + +# def init(self, mean: float, std: float) -> None: +# """Initialize running stats.""" +# with self.avg_mean.get_lock(): +# self.avg_mean.value += mean +# with self.avg_std.get_lock(): +# self.avg_std.value = std + +# def compute_std(self) -> tuple[float, float]: +# """Compute std.""" +# if self.count.value >= 2: +# self.avg_std.value = np.sqrt(self.m2.value / self.count.value) + +# def update(self, value: float) -> None: +# """Update running stats.""" +# with self.count.get_lock(): +# self.count.value += 1 +# delta = value - self.avg_mean.value +# with self.avg_mean.get_lock(): +# self.avg_mean.value += delta / self.count.value +# delta2 = value - self.avg_mean.value +# with self.m2.get_lock(): +# self.m2.value += delta * delta2 diff --git a/src/careamics/dataset/iterable_dataset.py b/src/careamics/dataset/iterable_dataset.py index 3426f86bb..539980d11 100644 --- a/src/careamics/dataset/iterable_dataset.py +++ b/src/careamics/dataset/iterable_dataset.py @@ -16,12 +16,10 @@ from ..utils.logging import get_logger from .dataset_utils import ( - compute_normalization_stats, - finalize_iterative_stats, iterate_over_files, read_tiff, - update_iterative_stats, ) +from .dataset_utils.running_stats import WelfordStatistics from .patching.patching import Stats, StatsOutput from .patching.random_patching import extract_patches_random @@ -132,58 +130,19 @@ def _calculate_mean_and_std(self) -> StatsOutput: Data class containing the image statistics. """ num_samples = 0 + image_stats = WelfordStatistics() + if self.target_files is not None: + target_stats = WelfordStatistics() for sample, target in iterate_over_files( self.data_config, self.data_files, self.target_files, self.read_source_func ): - # compute mean and std for each sample - init_mean, _ = compute_normalization_stats(sample) - # separate channels - sample_channels = np.array(np.split(sample, sample.shape[1], axis=1)) - - if num_samples == 0: - mean = init_mean - count = np.array( - [np.prod(channel.shape) for channel in sample_channels] - ) - m2 = np.array( - [ - np.sum( - np.subtract(channel.flatten(), [mean[i]] * count[i]) ** 2 - ) - for i, channel in enumerate(sample_channels) - ] - ) - else: - count, mean, m2 = update_iterative_stats( - count, mean, m2, sample_channels - ) + # update the image statistics + image_stats.update(sample, num_samples) + # update the target statistics if target is available if target is not None: - target_init_mean, _ = compute_normalization_stats(target) - target_channels = np.array(np.split(target, target.shape[1], axis=1)) - if num_samples == 0: - target_mean = target_init_mean - target_count = np.array( - [np.prod(channel.shape) for channel in target_channels] - ) - target_m2 = np.array( - [ - np.sum( - np.subtract( - channel.flatten(), - [target_mean[i]] * target_count[i], - ) - ** 2 - ) - for i, channel in enumerate(target_channels) - ] - ) - - else: - target_count, target_mean, target_m2 = update_iterative_stats( - target_count, target_mean, target_m2, target_channels - ) + target_stats.update(target, num_samples) num_samples += 1 @@ -191,12 +150,10 @@ def _calculate_mean_and_std(self) -> StatsOutput: raise ValueError("No samples found in the dataset.") # Average the means and stds per sample - image_means, image_stds = finalize_iterative_stats(count, mean, m2) + image_means, image_stds = image_stats.finalize() if target is not None: - target_means, target_stds = finalize_iterative_stats( - target_count, target_mean, target_m2 - ) + target_means, target_stds = target_stats.finalize() logger.info(f"Calculated mean and std for {num_samples} images") logger.info(f"Mean: {image_means}, std: {image_stds}") diff --git a/src/careamics/dataset/patching/patching.py b/src/careamics/dataset/patching/patching.py index 50d6d0c1e..dcd816444 100644 --- a/src/careamics/dataset/patching/patching.py +++ b/src/careamics/dataset/patching/patching.py @@ -7,7 +7,8 @@ import numpy as np from ...utils.logging import get_logger -from ..dataset_utils import compute_normalization_stats, reshape_array +from ..dataset_utils import reshape_array +from ..dataset_utils.running_stats import compute_normalization_stats from .sequential_patching import extract_patches_sequential logger = get_logger(__name__) diff --git a/src/careamics/utils/running_stats.py b/src/careamics/utils/running_stats.py deleted file mode 100644 index 1268d3e43..000000000 --- a/src/careamics/utils/running_stats.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Running stats submodule, used in the Zarr dataset.""" - -# from multiprocessing import Value -# from typing import Tuple - -# import numpy as np - - -# class RunningStats: -# """Calculates running mean and std.""" - -# def __init__(self) -> None: -# self.reset() - -# def reset(self) -> None: -# """Reset the running stats.""" -# self.avg_mean = Value("d", 0) -# self.avg_std = Value("d", 0) -# self.m2 = Value("d", 0) -# self.count = Value("i", 0) - -# def init(self, mean: float, std: float) -> None: -# """Initialize running stats.""" -# with self.avg_mean.get_lock(): -# self.avg_mean.value += mean -# with self.avg_std.get_lock(): -# self.avg_std.value = std - -# def compute_std(self) -> Tuple[float, float]: -# """Compute std.""" -# if self.count.value >= 2: -# self.avg_std.value = np.sqrt(self.m2.value / self.count.value) - -# def update(self, value: float) -> None: -# """Update running stats.""" -# with self.count.get_lock(): -# self.count.value += 1 -# delta = value - self.avg_mean.value -# with self.avg_mean.get_lock(): -# self.avg_mean.value += delta / self.count.value -# delta2 = value - self.avg_mean.value -# with self.m2.get_lock(): -# self.m2.value += delta * delta2 diff --git a/tests/dataset/dataset_utils/test_compute_normalization_stats.py b/tests/dataset/dataset_utils/test_compute_normalization_stats.py index cf6a157ee..d03e527db 100644 --- a/tests/dataset/dataset_utils/test_compute_normalization_stats.py +++ b/tests/dataset/dataset_utils/test_compute_normalization_stats.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from careamics.dataset.dataset_utils import compute_normalization_stats +from careamics.dataset.dataset_utils.running_stats import compute_normalization_stats @pytest.mark.parametrize("samples, channels", [[1, 2], [1, 2]]) diff --git a/tests/transforms/test_compose.py b/tests/transforms/test_compose.py index de7dcab9f..12f04e2e2 100644 --- a/tests/transforms/test_compose.py +++ b/tests/transforms/test_compose.py @@ -7,7 +7,7 @@ XYFlipModel, XYRandomRotate90Model, ) -from careamics.dataset.dataset_utils import compute_normalization_stats +from careamics.dataset.dataset_utils.running_stats import compute_normalization_stats from careamics.transforms import Compose, Normalize, XYFlip, XYRandomRotate90 diff --git a/tests/transforms/test_normalize.py b/tests/transforms/test_normalize.py index ee8303eba..98ec36420 100644 --- a/tests/transforms/test_normalize.py +++ b/tests/transforms/test_normalize.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from careamics.dataset.dataset_utils import compute_normalization_stats +from careamics.dataset.dataset_utils.running_stats import compute_normalization_stats from careamics.transforms import Denormalize, Normalize from careamics.transforms.normalize import _reshape_stats From 7edcba2a89ea941c23776ef458b26a6b9ea3c01b Mon Sep 17 00:00:00 2001 From: jdeschamps <6367888+jdeschamps@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:24:00 +0200 Subject: [PATCH 6/7] fix: fix calls to parameters in tests --- tests/dataset/test_iterable_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py index dac4da735..e9f5db0a8 100644 --- a/tests/dataset/test_iterable_dataset.py +++ b/tests/dataset/test_iterable_dataset.py @@ -177,8 +177,8 @@ def test_compute_mean_std_transform_iterable(tmp_path, shape, axes, patch_size): # define axes for mean and std computation stats_axes = tuple(np.delete(np.arange(array.ndim), 1)) - assert np.allclose(array.mean(axis=stats_axes), dataset.data_config.image_mean) - assert np.allclose(array.std(axis=stats_axes), dataset.data_config.image_std) + assert np.allclose(array.mean(axis=stats_axes), dataset.data_config.image_means) + assert np.allclose(array.std(axis=stats_axes), dataset.data_config.image_stds) @pytest.mark.parametrize( @@ -233,8 +233,8 @@ def test_compute_mean_std_transform_iterable_with_targets( stats_axes = tuple(np.delete(np.arange(array.ndim), 1)) assert np.allclose( - target_array.mean(axis=stats_axes), dataset.data_config.target_mean + target_array.mean(axis=stats_axes), dataset.data_config.target_means ) assert np.allclose( - target_array.std(axis=stats_axes), dataset.data_config.target_std + target_array.std(axis=stats_axes), dataset.data_config.target_stds ) From d812971212ebae8a7af0f33ddad3c134b2702f2c Mon Sep 17 00:00:00 2001 From: CatEek Date: Mon, 24 Jun 2024 14:58:18 +0200 Subject: [PATCH 7/7] rename test, minor refac --- .../dataset/dataset_utils/__init__.py | 2 - .../dataset/dataset_utils/running_stats.py | 45 ++++++++++--------- tests/dataset/test_iterable_dataset.py | 4 +- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/careamics/dataset/dataset_utils/__init__.py b/src/careamics/dataset/dataset_utils/__init__.py index 35fe1a75d..b6a626aaf 100644 --- a/src/careamics/dataset/dataset_utils/__init__.py +++ b/src/careamics/dataset/dataset_utils/__init__.py @@ -3,8 +3,6 @@ __all__ = [ "reshape_array", "compute_normalization_stats", - "update_iterative_stats", - "finalize_iterative_stats", "get_files_size", "list_files", "validate_source_target_files", diff --git a/src/careamics/dataset/dataset_utils/running_stats.py b/src/careamics/dataset/dataset_utils/running_stats.py index 6977c1478..5ee40abd5 100644 --- a/src/careamics/dataset/dataset_utils/running_stats.py +++ b/src/careamics/dataset/dataset_utils/running_stats.py @@ -18,7 +18,7 @@ def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]: Returns ------- - tuple[List[float], List[float]] + tuple of (list of floats, list of floats) Lists of mean and standard deviation values per channel. """ # Define the list of axes excluding the channel axis @@ -47,7 +47,7 @@ def update_iterative_stats( tuple[NDArray, NDArray, NDArray] Updated count, mean, and variance. """ - count += np.array([len(arr.flatten()) for arr in new_values]) + count += np.array([np.prod(channel.shape) for channel in new_values]) # newvalues - oldMean delta = [ np.subtract(v.flatten(), [m] * len(v.flatten())) @@ -93,48 +93,51 @@ def finalize_iterative_stats( class WelfordStatistics: - """Compute Welford statistics iteratively.""" + """Compute Welford statistics iteratively. - def update(self, array: NDArray, num_samples: int) -> None: + The Welford algorithm is used to compute the mean and variance of an array + iteratively. Based on the implementation from: + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + """ + + def update(self, array: NDArray, sample_idx: int) -> None: """Update the Welford statistics. Parameters ---------- array : NDArray Input array. - num_samples : int + sample_idx : int Current sample number. """ - self.num_samples = num_samples + self.sample_idx = sample_idx sample_channels = np.array(np.split(array, array.shape[1], axis=1)) - if self.num_samples == 0: + # Initialize the statistics + if self.sample_idx == 0: + # Compute the mean and standard deviation self.mean, _ = compute_normalization_stats(array) - self.count = np.array( - [np.prod(channel.shape) for channel in sample_channels] - ) - self.m2 = np.array( - [ - np.sum( - np.subtract(channel.flatten(), [self.mean[i]] * self.count[i]) - ** 2 - ) - for i, channel in enumerate(sample_channels) - ] + # Initialize the count and m2 with zero-valued arrays of shape (C,) + self.count, self.mean, self.m2 = update_iterative_stats( + count=np.zeros(array.shape[1]), + mean=self.mean, + m2=np.zeros(array.shape[1]), + new_values=sample_channels, ) else: + # Update the statistics self.count, self.mean, self.m2 = update_iterative_stats( - self.count, self.mean, self.m2, sample_channels + count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels ) - self.num_samples += 1 + self.sample_idx += 1 def finalize(self) -> tuple[NDArray, NDArray]: """Finalize the Welford statistics. Returns ------- - tuple[NDArray, NDArray] + tuple or numpy arrays Final mean and standard deviation. """ return finalize_iterative_stats(self.count, self.mean, self.m2) diff --git a/tests/dataset/test_iterable_dataset.py b/tests/dataset/test_iterable_dataset.py index e9f5db0a8..d3e90febb 100644 --- a/tests/dataset/test_iterable_dataset.py +++ b/tests/dataset/test_iterable_dataset.py @@ -147,7 +147,7 @@ def test_extracting_val_files(tmp_path, ordered_array, percentage): ((32, 32, 32), "ZYX", (8, 8, 8)), ], ) -def test_compute_mean_std_transform_iterable(tmp_path, shape, axes, patch_size): +def test_compute_mean_std_transform_welford(tmp_path, shape, axes, patch_size): """Test that mean and std are computed and correctly added to the configuration and transform.""" n_files = 100 @@ -189,7 +189,7 @@ def test_compute_mean_std_transform_iterable(tmp_path, shape, axes, patch_size): ((32, 32, 32), "ZYX", (8, 8, 8)), ], ) -def test_compute_mean_std_transform_iterable_with_targets( +def test_compute_mean_std_transform_welford_with_targets( tmp_path, shape, axes, patch_size ): """Test that mean and std are computed and correctly added to the configuration