Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ndfilters.variance_filter() function. #19

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ pip install ndfilters
The [mean filter](https://ndfilters.readthedocs.io/en/latest/_autosummary/ndfilters.mean_filter.html#ndfilters.mean_filter)
calculates a multidimensional rolling mean for the given kernel shape.

![mean filter](https://ndfilters.readthedocs.io/en/latest/_images/ndfilters.mean_filter_0_2.png)
![mean filter](https://ndfilters.readthedocs.io/en/latest/_images/ndfilters.mean_filter_0_0.png)

### Trimmed mean filter

The [trimmed mean filter](https://ndfilters.readthedocs.io/en/latest/_autosummary/ndfilters.trimmed_mean_filter.html#ndfilters.trimmed_mean_filter)
is like the mean filter except it ignores a given portion of the dataset before calculating the mean at each pixel.

![trimmed mean filter](https://ndfilters.readthedocs.io/en/latest/_images/ndfilters.trimmed_mean_filter_0_0.png)

### Variance filter

The [variance filter](https://ndfilters.readthedocs.io/en/latest/_autosummary/ndfilters.variance_filter.html#ndfilters.variance_filter)
calculates the rolling variance for the given kernel shape.

![variance filter](https://ndfilters.readthedocs.io/en/latest/_images/ndfilters.variance_filter_0_0.png)
2 changes: 2 additions & 0 deletions ndfilters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from ._generic import generic_filter
from ._mean import mean_filter
from ._trimmed_mean import trimmed_mean_filter
from ._variance import variance_filter

__all__ = [
"generic_filter",
"mean_filter",
"trimmed_mean_filter",
"variance_filter",
]
96 changes: 96 additions & 0 deletions ndfilters/_tests/test_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Literal
import pytest
import numpy as np
import scipy.ndimage
import scipy.stats
import astropy.units as u
import ndfilters


@pytest.mark.parametrize(
argnames="array",
argvalues=[
np.random.random(5),
np.random.random((5, 6)),
np.random.random((5, 6, 7)) * u.mm,
],
)
@pytest.mark.parametrize(
argnames="size",
argvalues=[2, (3,), (3, 4), (3, 4, 5)],
)
@pytest.mark.parametrize(
argnames="axis",
argvalues=[
None,
0,
-1,
(0,),
(-1,),
(0, 1),
(-2, -1),
(0, 1, 2),
(2, 1, 0),
],
)
@pytest.mark.parametrize(
argnames="mode",
argvalues=[
"mirror",
"nearest",
"wrap",
],
)
def test_variance_filter(
array: np.ndarray,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...],
mode: Literal["mirror", "nearest", "wrap", "truncate"],
):
kwargs = dict(
array=array,
size=size,
axis=axis,
mode=mode,
)

if axis is None:
axis_normalized = tuple(range(array.ndim))
else:
try:
axis_normalized = np.core.numeric.normalize_axis_tuple(
axis, ndim=array.ndim
)
except np.AxisError:
with pytest.raises(np.AxisError):
ndfilters.variance_filter(**kwargs)
return

if isinstance(size, int):
size_normalized = (size,) * len(axis_normalized)
else:
size_normalized = size

if len(size_normalized) != len(axis_normalized):
with pytest.raises(ValueError):
ndfilters.variance_filter(**kwargs)
return

result = ndfilters.variance_filter(**kwargs)

size_scipy = [1] * array.ndim
for i, ax in enumerate(axis_normalized):
size_scipy[ax] = size_normalized[i]

expected = scipy.ndimage.generic_filter(
input=array,
function=np.var,
size=size_scipy,
mode=mode,
)

if isinstance(result, u.Quantity):
assert np.allclose(result.value, expected)
assert result.unit == array.unit
else:
assert np.allclose(result, expected)
79 changes: 79 additions & 0 deletions ndfilters/_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Literal
import numpy as np
import numba
import astropy.units as u
import ndfilters

__all__ = [
"variance_filter",
]


def variance_filter(
array: np.ndarray | u.Quantity,
size: int | tuple[int, ...],
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
) -> np.ndarray:
"""
Calculate a multidimensional rolling variance.

Parameters
----------
array
The input array to be filtered
size
The shape of the kernel over which the variance will be calculated.
axis
The axes over which to apply the kernel.
Should either be a scalar or have the same number of items as `size`.
If :obj:`None` (the default) the kernel spans every axis of the array.
where
An optional mask that can be used to exclude parts of the array during
filtering.
mode
The method used to extend the input array beyond its boundaries.
See :func:`scipy.ndimage.generic_filter` for the definitions.
Currently, only "mirror", "nearest", "wrap", and "truncate" modes are
supported.

Returns
-------
A copy of the array with the variance filter applied.

Examples
--------

.. jupyter-execute::

import matplotlib.pyplot as plt
import scipy.datasets
import ndfilters

img = scipy.datasets.ascent()
img_filtered = ndfilters.variance_filter(img, size=21)

fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
axs[0].set_title("original image");
axs[0].imshow(img, cmap="gray");
axs[1].set_title("filtered image");
axs[1].imshow(img_filtered, cmap="gray");

"""
return ndfilters.generic_filter(
array=array,
function=_variance,
size=size,
axis=axis,
where=where,
mode=mode,
)


@numba.njit
def _variance(
array: np.ndarray,
args: tuple[float],
) -> float:
return np.var(array)
Loading