Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Welford std approximation #153

Merged
merged 9 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/careamics/dataset/dataset_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
"get_read_func",
"read_zarr",
"iterate_over_files",
"WelfordStatistics",
]


from .dataset_utils import compute_normalization_stats, reshape_array
from .dataset_utils import (
reshape_array,
)
from .file_utils import get_files_size, list_files, validate_source_target_files
from .iterate_over_files import iterate_over_files
from .read_tiff import read_tiff
from .read_utils import get_read_func
from .read_zarr import read_zarr
from .running_stats import WelfordStatistics, compute_normalization_stats
22 changes: 0 additions & 22 deletions src/careamics/dataset/dataset_utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,3 @@ def reshape_array(x: np.ndarray, axes: str) -> np.ndarray:
_x = np.expand_dims(_x, new_axes.index("S") + 1)

return _x


def compute_normalization_stats(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute mean and standard deviation of an array.

Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
computed per channel.

Parameters
----------
image : np.ndarray
Input array.

Returns
-------
Tuple[List[float], List[float]]
Lists of mean and standard deviation values per channel.
"""
# Define the list of axes excluding the channel axis
axes = tuple(np.delete(np.arange(image.ndim), 1))
return np.mean(image, axis=axes), np.std(image, axis=axes)
186 changes: 186 additions & 0 deletions src/careamics/dataset/dataset_utils/running_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""Computing data statistics."""

import numpy as np
from numpy.typing import NDArray


def compute_normalization_stats(image: NDArray) -> tuple[NDArray, NDArray]:
"""
Compute mean and standard deviation of an array.

Expected input shape is (S, C, (Z), Y, X). The mean and standard deviation are
computed per channel.

Parameters
----------
image : NDArray
Input array.

Returns
-------
tuple of (list of floats, list of floats)
Lists of mean and standard deviation values per channel.
"""
# Define the list of axes excluding the channel axis
axes = tuple(np.delete(np.arange(image.ndim), 1))
return np.mean(image, axis=axes), np.std(image, axis=axes)


def update_iterative_stats(
count: NDArray, mean: NDArray, m2: NDArray, new_values: NDArray
) -> tuple[NDArray, NDArray, NDArray]:
"""Update the mean and variance of an array iteratively.

Parameters
----------
count : NDArray
Number of elements in the array.
mean : NDArray
Mean of the array.
m2 : NDArray
Variance of the array.
new_values : NDArray
New values to add to the mean and variance.

Returns
-------
tuple[NDArray, NDArray, NDArray]
Updated count, mean, and variance.
"""
count += np.array([np.prod(channel.shape) for channel in new_values])
# newvalues - oldMean
delta = [
np.subtract(v.flatten(), [m] * len(v.flatten()))
for v, m in zip(new_values, mean)
]

mean += np.array([np.sum(d / c) for d, c in zip(delta, count)])
# newvalues - newMeant
delta2 = [
np.subtract(v.flatten(), [m] * len(v.flatten()))
for v, m in zip(new_values, mean)
]

m2 += np.array([np.sum(d * d2) for d, d2 in zip(delta, delta2)])

return (count, mean, m2)


def finalize_iterative_stats(
count: NDArray, mean: NDArray, m2: NDArray
) -> tuple[NDArray, NDArray]:
"""Finalize the mean and variance computation.

Parameters
----------
count : NDArray
Number of elements in the array.
mean : NDArray
Mean of the array.
m2 : NDArray
Variance of the array.

Returns
-------
tuple[NDArray, NDArray]
Final mean and standard deviation.
"""
std = np.array([np.sqrt(m / c) for m, c in zip(m2, count)])
if any(c < 2 for c in count):
return np.full(mean.shape, np.nan), np.full(std.shape, np.nan)
else:
return mean, std


class WelfordStatistics:
"""Compute Welford statistics iteratively.

CatEek marked this conversation as resolved.
Show resolved Hide resolved
The Welford algorithm is used to compute the mean and variance of an array
iteratively. Based on the implementation from:
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
"""

def update(self, array: NDArray, sample_idx: int) -> None:
"""Update the Welford statistics.

Parameters
----------
array : NDArray
Input array.
sample_idx : int
Current sample number.
"""
self.sample_idx = sample_idx
sample_channels = np.array(np.split(array, array.shape[1], axis=1))

# Initialize the statistics
if self.sample_idx == 0:
# Compute the mean and standard deviation
self.mean, _ = compute_normalization_stats(array)
# Initialize the count and m2 with zero-valued arrays of shape (C,)
self.count, self.mean, self.m2 = update_iterative_stats(
count=np.zeros(array.shape[1]),
mean=self.mean,
m2=np.zeros(array.shape[1]),
new_values=sample_channels,
)
else:
# Update the statistics
self.count, self.mean, self.m2 = update_iterative_stats(
CatEek marked this conversation as resolved.
Show resolved Hide resolved
count=self.count, mean=self.mean, m2=self.m2, new_values=sample_channels
)

self.sample_idx += 1

def finalize(self) -> tuple[NDArray, NDArray]:
"""Finalize the Welford statistics.

Returns
-------
tuple or numpy arrays
Final mean and standard deviation.
"""
return finalize_iterative_stats(self.count, self.mean, self.m2)


# from multiprocessing import Value
CatEek marked this conversation as resolved.
Show resolved Hide resolved
# from typing import tuple

# import numpy as np


# class RunningStats:
# """Calculates running mean and std."""

# def __init__(self) -> None:
# self.reset()

# def reset(self) -> None:
# """Reset the running stats."""
# self.avg_mean = Value("d", 0)
# self.avg_std = Value("d", 0)
# self.m2 = Value("d", 0)
# self.count = Value("i", 0)

# def init(self, mean: float, std: float) -> None:
# """Initialize running stats."""
# with self.avg_mean.get_lock():
# self.avg_mean.value += mean
# with self.avg_std.get_lock():
# self.avg_std.value = std

# def compute_std(self) -> tuple[float, float]:
# """Compute std."""
# if self.count.value >= 2:
# self.avg_std.value = np.sqrt(self.m2.value / self.count.value)

# def update(self, value: float) -> None:
# """Update running stats."""
# with self.count.get_lock():
# self.count.value += 1
# delta = value - self.avg_mean.value
# with self.avg_mean.get_lock():
# self.avg_mean.value += delta / self.count.value
# delta2 = value - self.avg_mean.value
# with self.m2.get_lock():
# self.m2.value += delta * delta2
29 changes: 14 additions & 15 deletions src/careamics/dataset/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from careamics.transforms import Compose

from ..utils.logging import get_logger
from .dataset_utils import compute_normalization_stats, iterate_over_files, read_tiff
from .dataset_utils import (
iterate_over_files,
read_tiff,
)
from .dataset_utils.running_stats import WelfordStatistics
from .patching.patching import Stats, StatsOutput
from .patching.random_patching import extract_patches_random

Expand Down Expand Up @@ -125,36 +129,31 @@ def _calculate_mean_and_std(self) -> StatsOutput:
PatchedOutput
Data class containing the image statistics.
"""
image_means = []
image_stds = []
target_means = []
target_stds = []
num_samples = 0
image_stats = WelfordStatistics()
if self.target_files is not None:
target_stats = WelfordStatistics()

for sample, target in iterate_over_files(
self.data_config, self.data_files, self.target_files, self.read_source_func
):
sample_mean, sample_std = compute_normalization_stats(sample)
image_means.append(sample_mean)
image_stds.append(sample_std)
# update the image statistics
image_stats.update(sample, num_samples)

# update the target statistics if target is available
if target is not None:
target_mean, target_std = compute_normalization_stats(target)
target_means.append(target_mean)
target_stds.append(target_std)
target_stats.update(target, num_samples)

num_samples += 1

if num_samples == 0:
raise ValueError("No samples found in the dataset.")

# Average the means and stds per sample
image_means = np.mean(image_means, axis=0)
image_stds = np.sqrt(np.mean([std**2 for std in image_stds], axis=0))
image_means, image_stds = image_stats.finalize()

if target is not None:
target_means = np.mean(target_means, axis=0)
target_stds = np.sqrt(np.mean([std**2 for std in target_stds], axis=0))
target_means, target_stds = target_stats.finalize()

logger.info(f"Calculated mean and std for {num_samples} images")
logger.info(f"Mean: {image_means}, std: {image_stds}")
Expand Down
3 changes: 2 additions & 1 deletion src/careamics/dataset/patching/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import numpy as np

from ...utils.logging import get_logger
from ..dataset_utils import compute_normalization_stats, reshape_array
from ..dataset_utils import reshape_array
from ..dataset_utils.running_stats import compute_normalization_stats
from .sequential_patching import extract_patches_sequential

logger = get_logger(__name__)
Expand Down
43 changes: 0 additions & 43 deletions src/careamics/utils/running_stats.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from careamics.dataset.dataset_utils import compute_normalization_stats
from careamics.dataset.dataset_utils.running_stats import compute_normalization_stats


@pytest.mark.parametrize("samples, channels", [[1, 2], [1, 2]])
Expand Down
Loading
Loading