Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Entropy metric #3210

Merged
merged 11 commits into from
Mar 22, 2024
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ Complete list of metrics
InceptionScore
FID
CosineSimilarity
Entropy

Helpers for customizing metrics
-------------------------------
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ignite.metrics.classification_report import ClassificationReport
from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, JaccardIndex, mIoU
from ignite.metrics.cosine_similarity import CosineSimilarity
from ignite.metrics.entropy import Entropy
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.fbeta import Fbeta
from ignite.metrics.frequency import Frequency
Expand Down Expand Up @@ -39,6 +40,7 @@
"TopKCategoricalAccuracy",
"Average",
"DiceCoefficient",
"Entropy",
"EpochMetric",
"Fbeta",
"FID",
Expand Down
82 changes: 82 additions & 0 deletions ignite/metrics/entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Sequence

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["Entropy"]


class Entropy(Metric):
r"""Calculates the mean of `entropy <https://en.wikipedia.org/wiki/Entropy_(information_theory)>`_.

.. math:: H = \frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C -p_{i,c} \log p_{i,c},
\quad p_{i,c} = \frac{\exp(z_{i,c})}{\sum_{c'=1}^C \exp(z_{i,c'})}

where :math:`p_{i,c}` is the prediction probability of :math:`i`-th data belonging to the class :math:`c`.

- ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric.
- ``y_pred`` is expected to be the unnormalized logits for each class.
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = Entropy()
metric.attach(default_evaluator, 'entropy')
y_true = torch.tensor([0, 1, 2]) # not considered in the Entropy metric.
y_pred = torch.tensor([
[ 0.0000, 0.6931, 1.0986],
[ 1.3863, 1.6094, 1.6094],
[ 0.0000, -2.3026, -2.3026]
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['entropy'])

.. testoutput::

0.8902875582377116
"""

_state_dict_all_req_keys = ("_sum_of_entropies", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_entropies = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred = output[0].detach()
prob = F.softmax(y_pred, dim=1)
log_prob = F.log_softmax(y_pred, dim=1)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
entropy_sum = -torch.sum(prob * log_prob)
self._sum_of_entropies += entropy_sum.to(self._device)
self._num_examples += y_pred.shape[0]

@sync_all_reduce("_sum_of_entropies", "_num_examples")
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("Entropy must have at least one example before it can be computed.")
return self._sum_of_entropies.item() / self._num_examples
196 changes: 196 additions & 0 deletions tests/ignite/metrics/test_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import os

import numpy as np
import pytest
import torch

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import Entropy


def np_entropy(np_y_pred: np.ndarray):
np_y_pred = np_y_pred - np_y_pred.max(axis=1, keepdims=True)
prob = np.exp(np_y_pred) / np.sum(np.exp(np_y_pred), axis=1, keepdims=True)
log_prob = np_y_pred - np.log(np.sum(np.exp(np_y_pred), axis=1, keepdims=True))
np_ent = -np.sum(prob * log_prob) / np_y_pred.shape[0]
return np_ent
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved


def test_zero_sample():
ent = Entropy()
with pytest.raises(NotComputableError, match=r"Entropy must have at least one example before it can be computed"):
ent.compute()


@pytest.fixture(params=[item for item in range(4)])
def test_case(request):
return [
(torch.randn((100, 10)), torch.randint(0, 10, size=[100]), 1),
(torch.rand((100, 500)), torch.randint(0, 500, size=[100]), 1),
# updated batches
(torch.normal(0.0, 5.0, size=(100, 10)), torch.randint(0, 10, size=[100]), 16),
(torch.normal(5.0, 3.0, size=(100, 200)), torch.randint(0, 200, size=[100]), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_compute(n_times, test_case):
ent = Entropy()

y_pred, y, batch_size = test_case

ent.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
ent.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
ent.update((y_pred, y))

np_res = np_entropy(y_pred.numpy())

assert isinstance(ent.compute(), float)
assert pytest.approx(ent.compute()) == np_res


def _test_distrib_integration(device, tol=1e-6):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12 + rank)

def _test(metric_device):
n_iters = 100
batch_size = 10
n_cls = 50

y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device)
y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

m = Entropy(device=metric_device)
m.attach(engine, "entropy")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "entropy" in engine.state.metrics
res = engine.state.metrics["entropy"]

true_res = np_entropy(y_preds.numpy())

assert pytest.approx(res, rel=tol) == true_res

_test("cpu")
if device.type != "xla":
_test(idist.device())


def _test_distrib_accumulator_device(device):
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
device = torch.device(device)
ent = Entropy(device=metric_device)

for dev in [ent._device, ent._sum_of_entropies.device]:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

y_pred = torch.tensor([[2.0], [-2.0]])
y = torch.zeros(2)
ent.update((y_pred, y))

for dev in [ent._device, ent._sum_of_entropies.device]:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"


def test_accumulator_detached():
ent = Entropy()

y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True)
y = torch.zeros(2)
ent.update((y_pred, y))

assert not ent._sum_of_entropies.requires_grad


@pytest.mark.distributed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For distributed config tests, could you please rewrite them using new testing formalism that we are trying to adopt. Here is an example of the code to inspire of:

@pytest.mark.usefixtures("distributed")
class TestDistributed:
@pytest.mark.parametrize("average", [False, "macro", "weighted", "micro"])
@pytest.mark.parametrize("n_epochs", [1, 2])
def test_integration_multiclass(self, average, n_epochs):

Here is a PR showing how to pass from old code to the new one:
https://github.com/pytorch/ignite/pull/3208/files#diff-c56c264ef288f88e5738e9ad22de66dffd4c58d2e656eb62e8dbaa678672317d

Thanks!

@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_integration(device, tol=1e-4)
_test_distrib_accumulator_device(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_integration(device, tol=1e-4)
_test_distrib_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)