Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mad wrapper #4

Merged
merged 1 commit into from
Mar 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/stats.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ but the output will be a numpy array.
hmean
kurtosis
skew
median_abs_deviation
```

## Other statistical functions
Expand Down
62 changes: 62 additions & 0 deletions src/xarray_einstats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,65 @@ def skew(da, bias=True, dims=None, nan_policy=None, **kwargs):
if nan_policy is not None:
skew_kwargs["nan_policy"] = nan_policy
return _apply_reduce_func(stats.skew, da, dims, kwargs, skew_kwargs)


def median_abs_deviation(da, dims=None, center=None, scale=1, nan_policy=None, **kwargs):
"""Wrap and extend :func:`scipy.stats.median_abs_deviation`.

Usage examples available at :ref:`stats_tutorial`.

All parameters take the same values and types as the scipy counterpart
with the exception of ``scale``. Here ``scale`` can also take
:class:`~xarray.DataArray` values in which case, broadcasting
is handled by xarray, as shown in the example.


Examples
--------
Use a ``DataArray`` as ``scale``.

.. jupyter-execute::

import xarray as xr
from xarray_einstats import tutorial, stats
ds = tutorial.generate_mcmc_like_dataset(3)
s_da = xr.DataArray([1, 2, 1, 1], coords={"chain": ds.chain})
stats.median_abs_deviation(ds["mu"], dims="draw", scale=s_da)

Note that this doesn't work with the scipy counterpart because
`s_da` can't be broadcasted with the output:

.. jupyter-execute::
:raises: ValueError

from scipy import stats
stats.median_abs_deviation(ds["mu"], axis=1, scale=s_da)

"""
mad_kwargs = dict(axis=-1)
if center is not None:
mad_kwargs["center"] = center
if nan_policy is not None:
mad_kwargs["nan_policy"] = nan_policy

if dims is None:
dims = get_default_dims(da.dims)
if not isinstance(dims, str):
da = da.stack(__aux_dim__=dims)
core_dims = ["__aux_dim__"]
else:
core_dims = [dims]

scale_dims = []
if isinstance(scale, xr.DataArray):
scale_dims = [d for d in scale.dims if d in core_dims]

return xr.apply_ufunc(
lambda a, s, **kwargs: stats.median_abs_deviation(a, scale=s, **kwargs),
da,
scale,
input_core_dims=[core_dims, scale_dims],
output_core_dims=[[]],
kwargs=mad_kwargs,
**kwargs,
)
17 changes: 16 additions & 1 deletion src/xarray_einstats/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Test the stats module."""
import numpy as np
import pytest
import xarray as xr
from scipy import stats
from xarray.testing import assert_allclose

Expand All @@ -15,6 +16,7 @@
gmean,
hmean,
kurtosis,
median_abs_deviation,
rankdata,
skew,
)
Expand Down Expand Up @@ -134,7 +136,9 @@ def test_rankdata(data, dims):


@pytest.mark.parametrize("dims", ("team", ("chain", "draw"), None))
@pytest.mark.parametrize("func", (gmean, hmean, circmean, circstd, circvar, kurtosis, skew))
@pytest.mark.parametrize(
"func", (gmean, hmean, circmean, circstd, circvar, kurtosis, skew, median_abs_deviation)
)
def test_reduce_function(data, dims, func):
da = data["mu"]
out = func(da, dims=dims)
Expand All @@ -145,3 +149,14 @@ def test_reduce_function(data, dims, func):
expected_dims = [dim for dim in da.dims if dim not in dims]
assert_dims_in_da(out, expected_dims)
assert_dims_not_in_da(out, dims)


def test_mad_da_scale(data):
s_da = xr.DataArray([1, 2, 1, 1], coords={"chain": data.chain})
out = median_abs_deviation(data["mu"], dims="draw", scale=s_da)
out1 = median_abs_deviation(data["mu"].sel(chain=0), dims="draw", scale=1)
out2 = median_abs_deviation(data["mu"].sel(chain=1), dims="draw", scale=2)
assert_dims_in_da(out, ("chain", "team"))
assert_dims_not_in_da(out, ["draw"])
assert_allclose(out.sel(chain=0), out1)
assert_allclose(out.sel(chain=1), out2)