-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: Welford std approximation #153
Conversation
def update_iterative_stats( | ||
count: np.ndarray, mean: np.ndarray, m2: np.ndarray, new_values: np.ndarray | ||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
"""Update the mean and variance of an array iteratively. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Method should be referenced in docstring or at least comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, there should be mention of Welford algorithm and maybe a link to the wiki page (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance)
@@ -13,7 +15,12 @@ | |||
] | |||
|
|||
|
|||
from .dataset_utils import compute_normalization_stats, reshape_array | |||
from .dataset_utils import ( | |||
compute_normalization_stats, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd prefer from .data_utils.welford import ...
For methods that are clearly part of a particular algorithm, I would like this to be clearly identified, even in the import!
def update_iterative_stats( | ||
count: np.ndarray, mean: np.ndarray, m2: np.ndarray, new_values: np.ndarray | ||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | ||
"""Update the mean and variance of an array iteratively. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, there should be mention of Welford algorithm and maybe a link to the wiki page (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance)
target_stds.append(target_std) | ||
target_init_mean, _ = compute_normalization_stats(target) | ||
target_channels = np.array(np.split(target, target.shape[1], axis=1)) | ||
if num_samples == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be refactored somewhere else, especially because the whole procedure is called twice!
In the spirit of the previous comment, update_iterative_stats
and the other method could become _update_iterative_stats
and then a new method in dataset_utils.welford
called welford_mean_and_std(array, ...)
could do the calculation
@@ -121,3 +121,69 @@ def compute_normalization_stats(image: np.ndarray) -> Tuple[np.ndarray, np.ndarr | |||
# Define the list of axes excluding the channel axis | |||
axes = tuple(np.delete(np.arange(image.ndim), 1)) | |||
return np.mean(image, axis=axes), np.std(image, axis=axes) | |||
|
|||
|
|||
def update_iterative_stats( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure whether that's "better", but another approach could be:
class WelfordStatistics:
count: int
means: NDArray
m2: NDArray
def update(new_values) -> None:
# update here
def get_stats(self) -> Tuple[NDArray, NDArray]:
# finalize here
Description
Changed std computation for iterative dataset to use Welford algorithm.
Changes Made