Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rank correlation metrics #3276

Merged
merged 7 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ Complete list of metrics
regression.MedianAbsolutePercentageError
regression.MedianRelativeAbsoluteError
regression.PearsonCorrelation
regression.SpearmanRankCorrelation
regression.KendallRankCorrelation
regression.R2Score
regression.WaveHedgesDistance

Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ignite.metrics.regression.fractional_bias import FractionalBias
from ignite.metrics.regression.geometric_mean_absolute_error import GeometricMeanAbsoluteError
from ignite.metrics.regression.geometric_mean_relative_absolute_error import GeometricMeanRelativeAbsoluteError
from ignite.metrics.regression.kendall_correlation import KendallRankCorrelation
from ignite.metrics.regression.manhattan_distance import ManhattanDistance
from ignite.metrics.regression.maximum_absolute_error import MaximumAbsoluteError
from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
Expand All @@ -13,4 +14,5 @@
from ignite.metrics.regression.median_relative_absolute_error import MedianRelativeAbsoluteError
from ignite.metrics.regression.pearson_correlation import PearsonCorrelation
from ignite.metrics.regression.r2_score import R2Score
from ignite.metrics.regression.spearman_correlation import SpearmanRankCorrelation
from ignite.metrics.regression.wave_hedges_distance import WaveHedgesDistance
109 changes: 109 additions & 0 deletions ignite/metrics/regression/kendall_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Any, Callable, Tuple

import torch

from scipy.stats import kendalltau
kzkadc marked this conversation as resolved.
Show resolved Hide resolved
from torch import Tensor

from ignite.exceptions import NotComputableError
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types


def _compute_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]:
if variant not in ("b", "c"):
raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.")

def _tau(predictions: Tensor, targets: Tensor) -> float:
np_preds = predictions.flatten().numpy()
np_targets = targets.flatten().numpy()
r = kendalltau(np_preds, np_targets, variant=variant).statistic
return r

return _tau


class KendallRankCorrelation(EpochMetric):
r"""Calculates the
`Kendall rank correlation coefficient <https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient>`_.

.. math::
\tau = 1-\frac{2(\text{number of discordant pairs})}{\left( \begin{array}{c}n\\2\end{array} \right)}

Two prediction-target pairs :math:`(P_i, A_i)` and :math:`(P_j, A_j)`, where :math:`i<j`,
are said to be concordant when both :math:`P_i<P_j` and :math:`A_i<A_j` holds
or both :math:`P_i>P_j` and :math:`A_i>A_j`.

The ``number of discordant pairs`` counts the number of pairs that are not concordant.

The computation of this metric is implemented with
`scipy.stats.kendalltau <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kendalltau.html>`_.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.

Parameters are inherited from ``Metric.__init__``.

Args:
variant: variant of kendall rank correlation. ``b`` or ``c`` is accepted.
Details can be found
`here <https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient#Accounting_for_ties>`_.
Default: ``b``
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = KendallRankCorrelation()
metric.attach(default_evaluator, 'kendall_tau')
y_true = torch.tensor([0., 1., 2., 3., 4., 5.])
y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['kendall_tau'])

.. testoutput::

0.4666666666666666
"""

def __init__(
self,
variant: str = "b",
output_transform: Callable[..., Any] = lambda x: x,
check_compute_fn: bool = True,
device: str | torch.device = torch.device("cpu"),
kzkadc marked this conversation as resolved.
Show resolved Hide resolved
skip_unrolling: bool = False,
) -> None:
super().__init__(_compute_kendall_tau(variant), output_transform, check_compute_fn, device, skip_unrolling)

def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(1)
if y.ndim == 1:
y = y.unsqueeze(1)

_check_output_shapes(output)
_check_output_types(output)

super().update(output)

def compute(self) -> float:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("KendallRankCorrelation must have at least one example before it can be computed.")

return super().compute()
96 changes: 96 additions & 0 deletions ignite/metrics/regression/spearman_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Any, Callable, Tuple

import torch

from scipy.stats import spearmanr
kzkadc marked this conversation as resolved.
Show resolved Hide resolved
from torch import Tensor

from ignite.exceptions import NotComputableError
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.regression._base import _check_output_shapes, _check_output_types


def _compute_spearman_r(predictions: Tensor, targets: Tensor) -> float:
np_preds = predictions.flatten().numpy()
np_targets = targets.flatten().numpy()
r = spearmanr(np_preds, np_targets).statistic
return r


class SpearmanRankCorrelation(EpochMetric):
r"""Calculates the
`Spearman's rank correlation coefficient <https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient>`_.

.. math::
r_\text{s} = \text{Corr}[R[P], R[A]] = \frac{\text{Cov}[R[P], R[A]]}{\sigma_{R[P]} \sigma_{R[A]}}

where :math:`A` and :math:`P` are the ground truth and predicted value, and R[X] is the ranking value of X.

The computation of this metric is implemented with
`scipy.stats.spearmanr <https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html>`_.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)`.

Parameters are inherited from ``Metric.__init__``.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = SpearmanRankCorrelation()
metric.attach(default_evaluator, 'spearman_corr')
y_true = torch.tensor([0., 1., 2., 3., 4., 5.])
y_pred = torch.tensor([0.5, 2.8, 1.9, 1.3, 6.0, 4.1])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['spearman_corr'])

.. testoutput::

0.7142857142857143
"""

def __init__(
self,
output_transform: Callable[..., Any] = lambda x: x,
check_compute_fn: bool = True,
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
super().__init__(_compute_spearman_r, output_transform, check_compute_fn, device, skip_unrolling)

def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(1)
if y.ndim == 1:
y = y.unsqueeze(1)

_check_output_shapes(output)
_check_output_types(output)

super().update(output)

def compute(self) -> float:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError(
"SpearmanRankCorrelation must have at least one example before it can be computed."
)

return super().compute()
Loading