Skip to content

Commit

Permalink
add HSIC metric (#3282)
Browse files Browse the repository at this point in the history
* add HSIC metric

* minor update on docstring

* add reference to the HSIC formula in docstring

* update version directive

* fix formatting issue

* add type hints

* accumulate HSIC value for each batch

* update test to clip value for each batch

* fix accumulator device error

* fix error in making y

* fix test to use the same linear layer across metric_devices

* Revert "fix test to use the same linear layer across metric_devices"

This reverts commit cb71355.

* Fixed distributed tests

* Fixed code formatting errors

---------

Co-authored-by: vfdev <[email protected]>
  • Loading branch information
kzkadc and vfdev-5 authored Sep 27, 2024
1 parent 9481227 commit 8e53b76
Show file tree
Hide file tree
Showing 4 changed files with 361 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ Complete list of metrics
KLDivergence
JSDivergence
MaximumMeanDiscrepancy
HSIC
AveragePrecision
CohenKappa
GpuInfo
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ignite.metrics.gan.fid import FID
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.gpu_info import GpuInfo
from ignite.metrics.hsic import HSIC
from ignite.metrics.js_divergence import JSDivergence
from ignite.metrics.kl_divergence import KLDivergence
from ignite.metrics.loss import Loss
Expand Down Expand Up @@ -64,6 +65,7 @@
"JaccardIndex",
"JSDivergence",
"KLDivergence",
"HSIC",
"MaximumMeanDiscrepancy",
"MultiLabelConfusionMatrix",
"MutualInformation",
Expand Down
170 changes: 170 additions & 0 deletions ignite/metrics/hsic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Callable, Sequence, Union

import torch
from torch import Tensor

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["HSIC"]


class HSIC(Metric):
r"""Calculates the `Hilbert-Schmidt Independence Criterion (HSIC)
<https://papers.nips.cc/paper_files/paper/2007/hash/d5cfead94f5350c12c322b5b664544c1-Abstract.html>`_.
.. math::
\text{HSIC}(X,Y) = \frac{1}{B(B-3)}\left[ \text{tr}(\tilde{\mathbf{K}}\tilde{\mathbf{L}})
+ \frac{\mathbf{1}^\top \tilde{\mathbf{K}} \mathbf{11}^\top \tilde{\mathbf{L}} \mathbf{1}}{(B-1)(B-2)}
-\frac{2}{B-2}\mathbf{1}^\top \tilde{\mathbf{K}}\tilde{\mathbf{L}} \mathbf{1} \right]
where :math:`B` is the batch size, and :math:`\tilde{\mathbf{K}}`
and :math:`\tilde{\mathbf{L}}` are the Gram matrices of
the Gaussian RBF kernel with their diagonal entries being set to zero.
HSIC measures non-linear statistical independence between features :math:`X` and :math:`Y`.
HSIC becomes zero if and only if :math:`X` and :math:`Y` are independent.
This metric computes the unbiased estimator of HSIC proposed in
`Song et al. (2012) <https://jmlr.csail.mit.edu/papers/v13/song12a.html>`_.
The HSIC is estimated using Eq. (5) of the paper for each batch and the average is accumulated.
Each batch must contain at least four samples.
- ``update`` must receive output of the form ``(y_pred, y)``.
Args:
sigma_x: bandwidth of the kernel for :math:`X`.
If negative, a heuristic value determined by the median of the distances between
the samples is used. Default: -1
sigma_y: bandwidth of the kernel for :math:`Y`.
If negative, a heuristic value determined by the median of the distances
between the samples is used. Default: -1
ignore_invalid_batch: If ``True``, computation for a batch with less than four samples is skipped.
If ``False``, ``ValueError`` is raised when received such a batch.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.
``y_pred`` and ``y`` should have the same shape.
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode::
metric = HSIC()
metric.attach(default_evaluator, "hsic")
X = torch.tensor([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.]])
Y = torch.sin(X * torch.pi * 2 / 50)
state = default_evaluator.run([[X, Y]])
print(state.metrics["hsic"])
.. testoutput::
0.09226646274328232
.. versionadded:: 0.5.2
"""

def __init__(
self,
sigma_x: float = -1,
sigma_y: float = -1,
ignore_invalid_batch: bool = True,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
super().__init__(output_transform, device, skip_unrolling=skip_unrolling)

self.sigma_x = sigma_x
self.sigma_y = sigma_y
self.ignore_invalid_batch = ignore_invalid_batch

_state_dict_all_req_keys = ("_sum_of_hsic", "_num_batches")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_hsic = torch.tensor(0.0, device=self._device)
self._num_batches = 0

@reinit__is_reduced
def update(self, output: Sequence[Tensor]) -> None:
X = output[0].detach().flatten(start_dim=1)
Y = output[1].detach().flatten(start_dim=1)
b = X.shape[0]

if b <= 3:
if self.ignore_invalid_batch:
return
else:
raise ValueError(f"A batch must contain more than four samples, got only {b} samples.")

mask = 1.0 - torch.eye(b, device=X.device)

xx = X @ X.T
rx = xx.diag().unsqueeze(0).expand_as(xx)
dxx = rx.T + rx - xx * 2

vx: Union[Tensor, float]
if self.sigma_x < 0:
vx = torch.quantile(dxx, 0.5)
else:
vx = self.sigma_x**2
K = torch.exp(-0.5 * dxx / vx) * mask

yy = Y @ Y.T
ry = yy.diag().unsqueeze(0).expand_as(yy)
dyy = ry.T + ry - yy * 2

vy: Union[Tensor, float]
if self.sigma_y < 0:
vy = torch.quantile(dyy, 0.5)
else:
vy = self.sigma_y**2
L = torch.exp(-0.5 * dyy / vy) * mask

KL = K @ L
trace = KL.trace()
second_term = K.sum() * L.sum() / ((b - 1) * (b - 2))
third_term = KL.sum() / (b - 2)

hsic = trace + second_term - third_term * 2.0
hsic /= b * (b - 3)
hsic = torch.clamp(hsic, min=0.0) # HSIC must not be negative
self._sum_of_hsic += hsic.to(self._device)

self._num_batches += 1

@sync_all_reduce("_sum_of_hsic", "_num_batches")
def compute(self) -> float:
if self._num_batches == 0:
raise NotComputableError("HSIC must have at least one batch before it can be computed.")

return self._sum_of_hsic.item() / self._num_batches
188 changes: 188 additions & 0 deletions tests/ignite/metrics/test_hsic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from typing import Tuple

import numpy as np
import pytest

import torch
from torch import nn, Tensor

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics import HSIC


def np_hsic(x: Tensor, y: Tensor, sigma_x: float = -1, sigma_y: float = -1) -> float:
x_np = x.detach().cpu().numpy()
y_np = y.detach().cpu().numpy()
b = x_np.shape[0]

ii, jj = np.meshgrid(np.arange(b), np.arange(b), indexing="ij")
mask = 1.0 - np.eye(b)

dxx = np.square(x_np[ii] - x_np[jj]).sum(axis=2)
if sigma_x < 0:
vx = np.median(dxx)
else:
vx = sigma_x * sigma_x
K = np.exp(-0.5 * dxx / vx) * mask

dyy = np.square(y_np[ii] - y_np[jj]).sum(axis=2)
if sigma_y < 0:
vy = np.median(dyy)
else:
vy = sigma_y * sigma_y
L = np.exp(-0.5 * dyy / vy) * mask

KL = K @ L
ones = np.ones(b)
hsic = np.trace(KL) + (ones @ K @ ones) * (ones @ L @ ones) / ((b - 1) * (b - 2)) - ones @ KL @ ones * 2 / (b - 2)
hsic /= b * (b - 3)
hsic = np.clip(hsic, 0.0, None)
return hsic


def test_zero_batch():
hsic = HSIC()
with pytest.raises(NotComputableError, match=r"HSIC must have at least one batch before it can be computed"):
hsic.compute()


def test_invalid_batch():
hsic = HSIC(ignore_invalid_batch=False)
X = torch.tensor([[1, 2, 3]]).float()
Y = torch.tensor([[4, 5, 6]]).float()
with pytest.raises(ValueError, match=r"A batch must contain more than four samples, got only"):
hsic.update((X, Y))


@pytest.fixture(params=[0, 1, 2])
def test_case(request) -> Tuple[Tensor, Tensor, int]:
if request.param == 0:
# independent
N = 100
b = 10
x, y = torch.randn((N, 50)), torch.randn((N, 30))
elif request.param == 1:
# linearly dependent
N = 100
b = 10
x = torch.normal(1.0, 2.0, size=(N, 10))
y = x @ torch.rand(10, 15) * 3 + torch.randn(N, 15) * 1e-4
else:
# non-linearly dependent
N = 200
b = 20
x = torch.randn(N, 5)
y = x @ torch.normal(0.0, torch.pi, size=(5, 3))
y = (
torch.stack([torch.sin(y[:, 0]), torch.cos(y[:, 1]), torch.exp(y[:, 2])], dim=1)
+ torch.randn_like(y) * 1e-4
)

return x, y, b


@pytest.mark.parametrize("n_times", range(3))
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tensor, Tensor, int]):
x, y, batch_size = test_case

hsic = HSIC(sigma_x=sigma_x, sigma_y=sigma_y)

hsic.reset()

np_hsic_sum = 0.0
n_iters = y.shape[0] // batch_size
for i in range(n_iters):
idx = i * batch_size
x_batch = x[idx : idx + batch_size]
y_batch = y[idx : idx + batch_size]

hsic.update((x_batch, y_batch))
np_hsic_sum += np_hsic(x_batch, y_batch, sigma_x, sigma_y)
expected_hsic = np_hsic_sum / n_iters

assert isinstance(hsic.compute(), float)
assert pytest.approx(expected_hsic, abs=2e-5) == hsic.compute()


def test_accumulator_detached():
hsic = HSIC()

x = torch.rand(10, 10, dtype=torch.float)
y = torch.rand(10, 10, dtype=torch.float)
hsic.update((x, y))

assert not hsic._sum_of_hsic.requires_grad


@pytest.mark.usefixtures("distributed")
class TestDistributed:
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
def test_integration(self, sigma_x: float, sigma_y: float):
tol = 2e-5
n_iters = 100
batch_size = 20
n_dims_x = 100
n_dims_y = 50

rank = idist.get_rank()
torch.manual_seed(12 + rank)

device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device)

lin = nn.Linear(n_dims_x, n_dims_y).to(device)
y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4

def data_loader(i, input_x, input_y):
return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size]

engine = Engine(lambda e, i: data_loader(i, x, y))

m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device)
m.attach(engine, "hsic")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

assert "hsic" in engine.state.metrics
res = engine.state.metrics["hsic"]

x = idist.all_gather(x)
y = idist.all_gather(y)
total_n_iters = idist.all_reduce(n_iters)

np_res = 0.0
for i in range(total_n_iters):
x_batch, y_batch = data_loader(i, x, y)
np_res += np_hsic(x_batch, y_batch, sigma_x, sigma_y)

expected_hsic = np_res / total_n_iters
assert pytest.approx(expected_hsic, abs=tol) == res

def test_accumulator_device(self):
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
for metric_device in metric_devices:
hsic = HSIC(device=metric_device)

for dev in (hsic._device, hsic._sum_of_hsic.device):
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

x = torch.zeros(10, 10).float()
y = torch.ones(10, 10).float()
hsic.update((x, y))

for dev in (hsic._device, hsic._sum_of_hsic.device):
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

0 comments on commit 8e53b76

Please sign in to comment.