Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add clustering metrics #3290

Merged
merged 18 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/defaults.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.metrics.clustering import *
from ignite.metrics.regression import *
from ignite.utils import *

Expand Down
3 changes: 3 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ Complete list of metrics
regression.KendallRankCorrelation
regression.R2Score
regression.WaveHedgesDistance
clustering.SilhouetteScore
clustering.DaviesBouldinScore
clustering.CalinskiHarabaszScore


.. note::
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ignite.metrics.clustering
import ignite.metrics.regression

from ignite.metrics.accumulation import Average, GeometricAverage, VariableAccumulation
Expand Down Expand Up @@ -82,6 +83,7 @@
"RougeN",
"RougeL",
"regression",
"clustering",
"AveragePrecision",
"CohenKappa",
"GpuInfo",
Expand Down
3 changes: 3 additions & 0 deletions ignite/metrics/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ignite.metrics.clustering.calinski_harabasz_score import CalinskiHarabaszScore
from ignite.metrics.clustering.davies_bouldin_score import DaviesBouldinScore
from ignite.metrics.clustering.silhouette_score import SilhouetteScore
42 changes: 42 additions & 0 deletions ignite/metrics/clustering/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Tuple

from torch import Tensor

from ignite.exceptions import NotComputableError
from ignite.metrics.epoch_metric import EpochMetric


class _ClusteringMetricBase(EpochMetric):
required_output_keys = ("features", "labels")

def _check_shape(self, output: Tuple[Tensor, Tensor]) -> None:
features, labels = output
if features.ndimension() != 2:
raise ValueError("Features should be of shape (batch_size, n_targets).")

if labels.ndimension() != 1:
raise ValueError("Labels should be of shape (batch_size, ).")

def _check_type(self, output: Tuple[Tensor, Tensor]) -> None:
features, labels = output
if len(self._predictions) < 1:
return
dtype_preds = self._predictions[-1].dtype
if dtype_preds != features.dtype:
raise ValueError(
f"Incoherent types between input features and stored features: {dtype_preds} vs {features.dtype}"
)

dtype_targets = self._targets[-1].dtype
if dtype_targets != labels.dtype:
raise ValueError(
f"Incoherent types between input labels and stored labels: {dtype_targets} vs {labels.dtype}"
)

def compute(self) -> float:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError(
f"{self.__class__.__name__} must have at least one example before it can be computed."
)

return super().compute()
106 changes: 106 additions & 0 deletions ignite/metrics/clustering/calinski_harabasz_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Any, Callable, Union

import torch
from torch import Tensor

from ignite.metrics.clustering._base import _ClusteringMetricBase

__all__ = ["CalinskiHarabaszScore"]


def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import calinski_harabasz_score

np_features = features.numpy()
np_labels = labels.numpy()
score = calinski_harabasz_score(np_features, np_labels)
return score


class CalinskiHarabaszScore(_ClusteringMetricBase):
r"""Calculates the
`Calinski-Harabasz score <https://en.wikipedia.org/wiki/Calinski%E2%80%93Harabasz_index>`_.

The Calinski-Harabasz score evaluates the quality of clustering results.

More details can be found
`here <https://scikit-learn.org/stable/modules/clustering.html#calinski-harabasz-index>`_.

A higher Calinski-Harabasz score indicates that
the clustering result is good (i.e., clusters are well-separated).

The computation of this metric is implemented with
`sklearn.metrics.calinski_harabasz_score
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.calinski_harabasz_score.html>`_.

- ``update`` must receive output of the form ``(features, labels)``
or ``{'features': features, 'labels': labels}``.
- `features` and `labels` must be of same shape `(B, D)` and `(B,)`.

Parameters are inherited from ``EpochMetric.__init__``.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(features, labels)``
or ``{'features': features, 'labels': labels}``.
check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no
issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``.
Default, True.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in format of
``(features, labels)`` or ``{'features': features, 'labels': labels, ...}``.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = CalinskiHarabaszScore()
metric.attach(default_evaluator, "calinski_harabasz_score")
X = torch.tensor([
[-1.04, -0.71, -1.42, -0.28, -0.43],
[0.47, 0.96, -0.43, 1.57, -2.24],
[-0.62, -0.29, 0.10, -0.72, -1.69],
[0.96, -0.77, 0.60, -0.89, 0.49],
[-1.33, -1.53, 0.25, -1.60, -2.0],
[-0.63, -0.55, -1.03, -0.89, -0.77],
[-0.26, -1.67, -0.24, -1.33, -0.40],
[-0.20, -1.34, -0.52, -1.55, -1.50],
[2.68, 1.13, 2.51, 0.80, 0.92],
[0.33, 2.88, 1.35, -0.56, 1.71]
])
Y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2])
state = default_evaluator.run([{"features": X, "labels": Y}])
print(state.metrics["calinski_harabasz_score"])

.. testoutput::

5.733935121807529

.. versionadded:: 0.5.2
"""

def __init__(
self,
output_transform: Callable[..., Any] = lambda x: x,
check_compute_fn: bool = True,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
try:
from sklearn.metrics import calinski_harabasz_score # noqa: F401
except ImportError:
raise ModuleNotFoundError("This module requires scikit-learn to be installed.")
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

super().__init__(_calinski_harabasz_score, output_transform, check_compute_fn, device, skip_unrolling)
106 changes: 106 additions & 0 deletions ignite/metrics/clustering/davies_bouldin_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Any, Callable, Union

import torch
from torch import Tensor

from ignite.metrics.clustering._base import _ClusteringMetricBase

__all__ = ["DaviesBouldinScore"]


def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import davies_bouldin_score

np_features = features.numpy()
np_labels = labels.numpy()
score = davies_bouldin_score(np_features, np_labels)
return score


class DaviesBouldinScore(_ClusteringMetricBase):
r"""Calculates the
`Davies-Bouldin score <https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index>`_.

The Davies-Bouldin score evaluates the quality of clustering results.

More details can be found
`here <https://scikit-learn.org/1.5/modules/clustering.html#davies-bouldin-index>`_.

The Davies-Bouldin score is non-negative,
where values closer to zero indicate that the clustering result is good (i.e., clusters are well-separated).

The computation of this metric is implemented with
`sklearn.metrics.davies_bouldin_score
<https://scikit-learn.org/1.5/modules/generated/sklearn.metrics.davies_bouldin_score.html>`_.

- ``update`` must receive output of the form ``(features, labels)``
or ``{'features': features, 'labels': labels}``.
- `features` and `labels` must be of same shape `(B, D)` and `(B,)`.

Parameters are inherited from ``EpochMetric.__init__``.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(features, labels)``
or ``{'features': features, 'labels': labels}``.
check_compute_fn: if True, ``compute_fn`` is run on the first batch of data to ensure there are no
issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``.
Default, True.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in format of
``(features, labels)`` or ``{'features': features, 'labels': labels, ...}``.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = DaviesBouldinScore()
metric.attach(default_evaluator, "davies_bouldin_score")
X = torch.tensor([
[-1.04, -0.71, -1.42, -0.28, -0.43],
[0.47, 0.96, -0.43, 1.57, -2.24],
[-0.62, -0.29, 0.10, -0.72, -1.69],
[0.96, -0.77, 0.60, -0.89, 0.49],
[-1.33, -1.53, 0.25, -1.60, -2.0],
[-0.63, -0.55, -1.03, -0.89, -0.77],
[-0.26, -1.67, -0.24, -1.33, -0.40],
[-0.20, -1.34, -0.52, -1.55, -1.50],
[2.68, 1.13, 2.51, 0.80, 0.92],
[0.33, 2.88, 1.35, -0.56, 1.71]
])
Y = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2])
state = default_evaluator.run([{"features": X, "labels": Y}])
print(state.metrics["davies_bouldin_score"])

.. testoutput::

1.3838673743829881

.. versionadded:: 0.5.2
"""

def __init__(
self,
output_transform: Callable[..., Any] = lambda x: x,
check_compute_fn: bool = True,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
try:
from sklearn.metrics import davies_bouldin_score # noqa: F401
except ImportError:
raise ModuleNotFoundError("This module requires scikit-learn to be installed.")

super().__init__(_davies_bouldin_score, output_transform, check_compute_fn, device, skip_unrolling)
Loading
Loading