Skip to content

Commit

Permalink
Add Entropy metric (#3210)
Browse files Browse the repository at this point in the history
* add Entropy metric

* fix error in torch.randint

* update Entropy metric to support other shapes

* Update ignite/metrics/entropy.py

Co-authored-by: vfdev <[email protected]>

* update test of Entropy metric

* Update ignite/metrics/entropy.py

* format code

* fix error in converting Tensor to ndarray

---------

Co-authored-by: vfdev <[email protected]>
  • Loading branch information
kzkadc and vfdev-5 authored Mar 22, 2024
1 parent 1e7d336 commit c3845ba
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ Complete list of metrics
InceptionScore
FID
CosineSimilarity
Entropy

Helpers for customizing metrics
-------------------------------
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ignite.metrics.classification_report import ClassificationReport
from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, JaccardIndex, mIoU
from ignite.metrics.cosine_similarity import CosineSimilarity
from ignite.metrics.entropy import Entropy
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.fbeta import Fbeta
from ignite.metrics.frequency import Frequency
Expand Down Expand Up @@ -39,6 +40,7 @@
"TopKCategoricalAccuracy",
"Average",
"DiceCoefficient",
"Entropy",
"EpochMetric",
"Fbeta",
"FID",
Expand Down
91 changes: 91 additions & 0 deletions ignite/metrics/entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Sequence

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["Entropy"]


class Entropy(Metric):
r"""Calculates the mean of `entropy <https://en.wikipedia.org/wiki/Entropy_(information_theory)>`_.
.. math:: H = \frac{1}{N} \sum_{i=1}^N \sum_{c=1}^C -p_{i,c} \log p_{i,c},
\quad p_{i,c} = \frac{\exp(z_{i,c})}{\sum_{c'=1}^C \exp(z_{i,c'})}
where :math:`p_{i,c}` is the prediction probability of :math:`i`-th data belonging to the class :math:`c`.
- ``update`` must receive output of the form ``(y_pred, y)`` while ``y`` is not used in this metric.
- ``y_pred`` is expected to be the unnormalized logits for each class. :math:`(B, C)` (classification)
or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed.
Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode::
metric = Entropy()
metric.attach(default_evaluator, 'entropy')
y_true = torch.tensor([0, 1, 2]) # not considered in the Entropy metric.
y_pred = torch.tensor([
[ 0.0000, 0.6931, 1.0986],
[ 1.3863, 1.6094, 1.6094],
[ 0.0000, -2.3026, -2.3026]
])
state = default_evaluator.run([[y_pred, y_true]])
print(state.metrics['entropy'])
.. testoutput::
0.8902875582377116
"""

_state_dict_all_req_keys = ("_sum_of_entropies", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_entropies = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred = output[0].detach()
if y_pred.ndim >= 3:
num_classes = y_pred.shape[1]
# (B, C, ...) -> (B, ..., C) -> (B*..., C)
# regarding as B*... predictions
y_pred = y_pred.movedim(1, -1).reshape(-1, num_classes)
elif y_pred.ndim == 1:
raise ValueError(f"y_pred must be in the shape of (B, C) or (B, C, ...), got {y_pred.shape}.")

prob = F.softmax(y_pred, dim=1)
log_prob = F.log_softmax(y_pred, dim=1)
entropy_sum = -torch.sum(prob * log_prob)
self._sum_of_entropies += entropy_sum.to(self._device)
self._num_examples += y_pred.shape[0]

@sync_all_reduce("_sum_of_entropies", "_num_examples")
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("Entropy must have at least one example before it can be computed.")
return self._sum_of_entropies.item() / self._num_examples
206 changes: 206 additions & 0 deletions tests/ignite/metrics/test_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import os

import numpy as np
import pytest
import torch
from scipy.special import softmax
from scipy.stats import entropy as scipy_entropy

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import Entropy


def np_entropy(np_y_pred: np.ndarray):
prob = softmax(np_y_pred, axis=1)
ent = np.mean(scipy_entropy(prob, axis=1))
return ent


def test_zero_sample():
ent = Entropy()
with pytest.raises(NotComputableError, match=r"Entropy must have at least one example before it can be computed"):
ent.compute()


def test_invalid_shape():
ent = Entropy()
y_pred = torch.randn(10).float()
with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"):
ent.update((y_pred, None))


@pytest.fixture(params=[item for item in range(4)])
def test_case(request):
return [
(torch.randn((100, 10)), torch.randint(0, 10, size=[100]), 1),
(torch.rand((100, 500)), torch.randint(0, 500, size=[100]), 1),
# updated batches
(torch.normal(0.0, 5.0, size=(100, 10)), torch.randint(0, 10, size=[100]), 16),
(torch.normal(5.0, 3.0, size=(100, 200)), torch.randint(0, 200, size=[100]), 16),
# image segmentation
(torch.randn((100, 5, 32, 32)), torch.randint(0, 5, size=(100, 32, 32)), 16),
(torch.randn((100, 5, 224, 224)), torch.randint(0, 5, size=(100, 224, 224)), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_compute(n_times, test_case):
ent = Entropy()

y_pred, y, batch_size = test_case

ent.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
ent.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
ent.update((y_pred, y))

np_res = np_entropy(y_pred.numpy())

assert isinstance(ent.compute(), float)
assert pytest.approx(ent.compute()) == np_res


def _test_distrib_integration(device, tol=1e-6):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12 + rank)

def _test(metric_device):
n_iters = 100
batch_size = 10
n_cls = 50

y_true = torch.randint(0, n_cls, size=[n_iters * batch_size], dtype=torch.long).to(device)
y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_cls), dtype=torch.float).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

m = Entropy(device=metric_device)
m.attach(engine, "entropy")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "entropy" in engine.state.metrics
res = engine.state.metrics["entropy"]

true_res = np_entropy(y_preds.cpu().numpy())

assert pytest.approx(res, rel=tol) == true_res

_test("cpu")
if device.type != "xla":
_test(idist.device())


def _test_distrib_accumulator_device(device):
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
device = torch.device(device)
ent = Entropy(device=metric_device)

for dev in [ent._device, ent._sum_of_entropies.device]:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

y_pred = torch.tensor([[2.0], [-2.0]])
y = torch.zeros(2)
ent.update((y_pred, y))

for dev in [ent._device, ent._sum_of_entropies.device]:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"


def test_accumulator_detached():
ent = Entropy()

y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True)
y = torch.zeros(2)
ent.update((y_pred, y))

assert not ent._sum_of_entropies.requires_grad


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_integration(device, tol=1e-4)
_test_distrib_accumulator_device(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_integration(device, tol=1e-4)
_test_distrib_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)

0 comments on commit c3845ba

Please sign in to comment.