Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PearsonCorrelation metric #3212

Merged
merged 21 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ Complete list of metrics
regression.MedianAbsoluteError
regression.MedianAbsolutePercentageError
regression.MedianRelativeAbsoluteError
regression.PearsonCorrelation
regression.R2Score
regression.WaveHedgesDistance

Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
from ignite.metrics.regression.median_absolute_error import MedianAbsoluteError
from ignite.metrics.regression.median_absolute_percentage_error import MedianAbsolutePercentageError
from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
from ignite.metrics.regression.pearson_correlation import PearsonCorrelation
from ignite.metrics.regression.r2_score import R2Score
from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
123 changes: 123 additions & 0 deletions ignite/metrics/regression/pearson_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import Callable, Tuple, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce

from ignite.metrics.regression._base import _BaseRegression


class PearsonCorrelation(_BaseRegression):
r"""Calculates the
`Pearson correlation coefficient <https://en.wikipedia.org/wiki/Pearson_correlation_coefficient>`_.

.. math::
r = \frac{\sum_{j=1}^n (P_j-\bar{P})(A_j-\bar{A})}
{\max (\sqrt{\sum_{j=1}^n (P_j-\bar{P})^2 \sum_{j=1}^n (A_j-\bar{A})^2}, \epsilon)},
\quad \bar{P}=\frac{1}{n}\sum_{j=1}^n P_j, \quad \bar{A}=\frac{1}{n}\sum_{j=1}^n A_j

where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.

Parameters are inherited from ``Metric.__init__``.

Args:
eps: a small value to avoid division by zero. Default: 1e-8
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = PearsonCorrelation()
metric.attach(default_evaluator, 'corr')
y_true = torch.tensor([0., 1., 2., 3., 4., 5.])
y_pred = torch.tensor([0.5, 1.3, 1.9, 2.8, 4.1, 6.0])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['corr'])

.. testoutput::

0.9768688678741455
"""

def __init__(
self,
eps: float = 1e-8,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
):
super().__init__(output_transform, device)

self.eps = eps

_state_dict_all_req_keys = (
"_sum_of_y_preds",
"_sum_of_ys",
"_sum_of_y_pred_squares",
"_sum_of_y_squares",
"_sum_of_products",
"_num_examples",
)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_y_preds = torch.tensor(0.0, device=self._device)
self._sum_of_ys = torch.tensor(0.0, device=self._device)
self._sum_of_y_pred_squares = torch.tensor(0.0, device=self._device)
self._sum_of_y_squares = torch.tensor(0.0, device=self._device)
self._sum_of_products = torch.tensor(0.0, device=self._device)
self._num_examples = 0

def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
self._sum_of_y_preds += y_pred.sum()
self._sum_of_ys += y.sum()
self._sum_of_y_pred_squares += y_pred.square().sum()
self._sum_of_y_squares += y.square().sum()
self._sum_of_products += (y_pred * y).sum()
self._num_examples += y.shape[0]

@sync_all_reduce(
"_sum_of_y_preds",
"_sum_of_ys",
"_sum_of_y_pred_squares",
"_sum_of_y_squares",
"_sum_of_products",
"_num_examples",
)
def compute(self) -> float:
n = self._num_examples
if n == 0:
raise NotComputableError("PearsonCorrelation must have at least one example before it can be computed.")

# cov = E[xy] - E[x]*E[y]
cov = self._sum_of_products / n - self._sum_of_y_preds * self._sum_of_ys / (n * n)

# var = E[x^2] - E[x]^2
y_pred_mean = self._sum_of_y_preds / n
y_pred_var = self._sum_of_y_pred_squares / n - y_pred_mean * y_pred_mean
y_pred_var = torch.clamp(y_pred_var, min=0.0)

y_mean = self._sum_of_ys / n
y_var = self._sum_of_y_squares / n - y_mean * y_mean
y_var = torch.clamp(y_var, min=0.0)

r = cov / torch.clamp(torch.sqrt(y_pred_var * y_var), min=self.eps)
return float(r.item())
258 changes: 258 additions & 0 deletions tests/ignite/metrics/regression/test_pearson_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
from typing import Tuple

import numpy as np
import pytest
import torch
from scipy.stats import pearsonr
from torch import Tensor

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics.regression import PearsonCorrelation


def np_corr_eps(np_y_pred: np.ndarray, np_y: np.ndarray, eps: float = 1e-8):
cov = np.cov(np_y_pred, np_y, ddof=0)[0, 1]
std_y_pred = np.std(np_y_pred, ddof=0)
std_y = np.std(np_y, ddof=0)
corr = cov / np.clip(std_y_pred * std_y, eps, None)
return corr


def scipy_corr(np_y_pred: np.ndarray, np_y: np.ndarray):
corr = pearsonr(np_y_pred, np_y)
return corr.statistic


def test_zero_sample():
m = PearsonCorrelation()
with pytest.raises(
NotComputableError, match=r"PearsonCorrelation must have at least one example before it can be computed"
):
m.compute()


def test_wrong_input_shapes():
m = PearsonCorrelation()

with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"):
m.update((torch.rand(4), torch.rand(4, 1)))

with pytest.raises(ValueError, match=r"Input data shapes should be the same, but given"):
m.update((torch.rand(4, 1), torch.rand(4)))


def test_degenerated_sample():
# one sample
m = PearsonCorrelation()
y_pred = torch.tensor([1.0])
y = torch.tensor([1.0])
m.update((y_pred, y))

np_y_pred = y_pred.numpy()
np_y = y_pred.numpy()
np_res = np_corr_eps(np_y_pred, np_y)
assert pytest.approx(np_res) == m.compute()

# constant samples
m.reset()
y_pred = torch.ones(10).float()
y = torch.zeros(10).float()
m.update((y_pred, y))

np_y_pred = y_pred.numpy()
np_y = y_pred.numpy()
np_res = np_corr_eps(np_y_pred, np_y)
assert pytest.approx(np_res) == m.compute()


def test_pearson_correlation():
a = np.random.randn(4).astype(np.float32)
b = np.random.randn(4).astype(np.float32)
c = np.random.randn(4).astype(np.float32)
d = np.random.randn(4).astype(np.float32)
ground_truth = np.random.randn(4).astype(np.float32)

m = PearsonCorrelation()

m.update((torch.from_numpy(a), torch.from_numpy(ground_truth)))
np_ans = scipy_corr(a, ground_truth)
assert m.compute() == pytest.approx(np_ans, rel=1e-4)

m.update((torch.from_numpy(b), torch.from_numpy(ground_truth)))
np_ans = scipy_corr(np.concatenate([a, b]), np.concatenate([ground_truth] * 2))
assert m.compute() == pytest.approx(np_ans, rel=1e-4)

m.update((torch.from_numpy(c), torch.from_numpy(ground_truth)))
np_ans = scipy_corr(np.concatenate([a, b, c]), np.concatenate([ground_truth] * 3))
assert m.compute() == pytest.approx(np_ans, rel=1e-4)

m.update((torch.from_numpy(d), torch.from_numpy(ground_truth)))
np_ans = scipy_corr(np.concatenate([a, b, c, d]), np.concatenate([ground_truth] * 4))
assert m.compute() == pytest.approx(np_ans, rel=1e-4)


@pytest.fixture(params=list(range(2)))
def test_case(request):
# correlated sample
x = torch.randn(size=[50]).float()
y = x + torch.randn_like(x) * 0.1

return [
(x, y, 1),
(torch.rand(size=(50, 1)).float(), torch.rand(size=(50, 1)).float(), 10),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_integration(n_times, test_case: Tuple[Tensor, Tensor, int]):
y_pred, y, batch_size = test_case

def update_fn(engine: Engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

engine = Engine(update_fn)

m = PearsonCorrelation()
m.attach(engine, "corr")

np_y = y.ravel().numpy()
np_y_pred = y_pred.ravel().numpy()

data = list(range(y_pred.shape[0] // batch_size))
corr = engine.run(data, max_epochs=1).metrics["corr"]

np_ans = scipy_corr(np_y_pred, np_y)

assert pytest.approx(np_ans, rel=2e-4) == corr


def test_accumulator_detached():
corr = PearsonCorrelation()

y_pred = torch.tensor([2.0, 3.0], requires_grad=True)
y = torch.tensor([-2.0, -1.0])
corr.update((y_pred, y))

assert all(
(not accumulator.requires_grad)
for accumulator in (
corr._sum_of_products,
corr._sum_of_y_pred_squares,
corr._sum_of_y_preds,
corr._sum_of_y_squares,
corr._sum_of_ys,
)
)


@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_compute(self):
rank = idist.get_rank()
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)

torch.manual_seed(10 + rank)
for metric_device in metric_devices:
m = PearsonCorrelation(device=metric_device)

y_pred = torch.rand(size=[100], device=device)
y = torch.rand(size=[100], device=device)

m.update((y_pred, y))

y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

np_y = y.cpu().numpy()
np_y_pred = y_pred.cpu().numpy()

np_ans = scipy_corr(np_y_pred, np_y)

assert pytest.approx(np_ans) == m.compute()
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("n_epochs", [1, 2])
def test_integration(self, n_epochs: int):
tol = 2e-4
rank = idist.get_rank()
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)

n_iters = 80
batch_size = 16

for metric_device in metric_devices:
torch.manual_seed(12 + rank)

y_true = torch.rand(size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(size=(n_iters * batch_size,)).to(device)

engine = Engine(
lambda e, i: (
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)
)

corr = PearsonCorrelation(device=metric_device)
corr.attach(engine, "corr")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "corr" in engine.state.metrics

res = engine.state.metrics["corr"]

np_y = y_true.cpu().numpy()
np_y_pred = y_preds.cpu().numpy()

np_ans = scipy_corr(np_y_pred, np_y)

assert pytest.approx(np_ans, rel=tol) == res

def test_accumulator_device(self):
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
for metric_device in metric_devices:
corr = PearsonCorrelation(device=metric_device)

devices = (
corr._device,
corr._sum_of_products.device,
corr._sum_of_y_pred_squares.device,
corr._sum_of_y_preds.device,
corr._sum_of_y_squares.device,
corr._sum_of_ys.device,
)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

y_pred = torch.tensor([2.0, 3.0])
y = torch.tensor([-1.0, 1.0])
corr.update((y_pred, y))

devices = (
corr._device,
corr._sum_of_products.device,
corr._sum_of_y_pred_squares.device,
corr._sum_of_y_preds.device,
corr._sum_of_y_squares.device,
corr._sum_of_ys.device,
)
for dev in devices:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
Loading