Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add HSIC metric #3282

Merged
merged 21 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ Complete list of metrics
KLDivergence
JSDivergence
MaximumMeanDiscrepancy
HSIC
AveragePrecision
CohenKappa
GpuInfo
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ignite.metrics.gan.fid import FID
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.gpu_info import GpuInfo
from ignite.metrics.hsic import HSIC
from ignite.metrics.js_divergence import JSDivergence
from ignite.metrics.kl_divergence import KLDivergence
from ignite.metrics.loss import Loss
Expand Down Expand Up @@ -64,6 +65,7 @@
"JaccardIndex",
"JSDivergence",
"KLDivergence",
"HSIC",
"MaximumMeanDiscrepancy",
"MultiLabelConfusionMatrix",
"MutualInformation",
Expand Down
170 changes: 170 additions & 0 deletions ignite/metrics/hsic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Callable, Sequence, Union

import torch
from torch import Tensor

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["HSIC"]


class HSIC(Metric):
r"""Calculates the `Hilbert-Schmidt Independence Criterion (HSIC)
<https://papers.nips.cc/paper_files/paper/2007/hash/d5cfead94f5350c12c322b5b664544c1-Abstract.html>`_.

.. math::
\text{HSIC}(X,Y) = \frac{1}{B(B-3)}\left[ \text{tr}(\tilde{\mathbf{K}}\tilde{\mathbf{L}})
+ \frac{\mathbf{1}^\top \tilde{\mathbf{K}} \mathbf{11}^\top \tilde{\mathbf{L}} \mathbf{1}}{(B-1)(B-2)}
-\frac{2}{B-2}\mathbf{1}^\top \tilde{\mathbf{K}}\tilde{\mathbf{L}} \mathbf{1} \right]

where :math:`B` is the batch size, and :math:`\tilde{\mathbf{K}}`
and :math:`\tilde{\mathbf{L}}` are the Gram matrices of
the Gaussian RBF kernel with their diagonal entries being set to zero.

HSIC measures non-linear statistical independence between features :math:`X` and :math:`Y`.
HSIC becomes zero if and only if :math:`X` and :math:`Y` are independent.

This metric computes the unbiased estimator of HSIC proposed in
`Song et al. (2012) <https://jmlr.csail.mit.edu/papers/v13/song12a.html>`_.
The HSIC is estimated using Eq. (5) of the paper for each batch and the average is accumulated.

Each batch must contain at least four samples.

- ``update`` must receive output of the form ``(y_pred, y)``.

Args:
sigma_x: bandwidth of the kernel for :math:`X`.
If negative, a heuristic value determined by the median of the distances between
the samples is used. Default: -1
sigma_y: bandwidth of the kernel for :math:`Y`.
If negative, a heuristic value determined by the median of the distances
between the samples is used. Default: -1
ignore_invalid_batch: If ``True``, computation for a batch with less than four samples is skipped.
If ``False``, ``ValueError`` is raised when received such a batch.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.

``y_pred`` and ``y`` should have the same shape.

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = HSIC()
metric.attach(default_evaluator, "hsic")
X = torch.tensor([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.],
[10., 11., 12., 13., 14.],
[15., 16., 17., 18., 19.],
[20., 21., 22., 23., 24.],
[25., 26., 27., 28., 29.],
[30., 31., 32., 33., 34.],
[35., 36., 37., 38., 39.],
[40., 41., 42., 43., 44.],
[45., 46., 47., 48., 49.]])
Y = torch.sin(X * torch.pi * 2 / 50)
state = default_evaluator.run([[X, Y]])
print(state.metrics["hsic"])

.. testoutput::

0.09226646274328232

.. versionadded:: 0.5.2
"""

def __init__(
self,
sigma_x: float = -1,
sigma_y: float = -1,
ignore_invalid_batch: bool = True,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
super().__init__(output_transform, device, skip_unrolling=skip_unrolling)

self.sigma_x = sigma_x
self.sigma_y = sigma_y
self.ignore_invalid_batch = ignore_invalid_batch

_state_dict_all_req_keys = ("_sum_of_hsic", "_num_batches")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_hsic = torch.tensor(0.0, device=self._device)
self._num_batches = 0

@reinit__is_reduced
def update(self, output: Sequence[Tensor]) -> None:
X = output[0].detach().flatten(start_dim=1)
Y = output[1].detach().flatten(start_dim=1)
b = X.shape[0]

if b <= 3:
if self.ignore_invalid_batch:
return
else:
raise ValueError(f"A batch must contain more than four samples, got only {b} samples.")

mask = 1.0 - torch.eye(b, device=X.device)

xx = X @ X.T
rx = xx.diag().unsqueeze(0).expand_as(xx)
dxx = rx.T + rx - xx * 2

vx: Union[Tensor, float]
if self.sigma_x < 0:
vx = torch.quantile(dxx, 0.5)
else:
vx = self.sigma_x**2
K = torch.exp(-0.5 * dxx / vx) * mask

yy = Y @ Y.T
ry = yy.diag().unsqueeze(0).expand_as(yy)
dyy = ry.T + ry - yy * 2

vy: Union[Tensor, float]
if self.sigma_y < 0:
vy = torch.quantile(dyy, 0.5)
else:
vy = self.sigma_y**2
L = torch.exp(-0.5 * dyy / vy) * mask

KL = K @ L
trace = KL.trace()
second_term = K.sum() * L.sum() / ((b - 1) * (b - 2))
third_term = KL.sum() / (b - 2)

hsic = trace + second_term - third_term * 2.0
hsic /= b * (b - 3)
hsic = torch.clamp(hsic, min=0.0) # HSIC must not be negative
self._sum_of_hsic += hsic.to(self._device)

self._num_batches += 1

@sync_all_reduce("_sum_of_hsic", "_num_batches")
def compute(self) -> float:
if self._num_batches == 0:
raise NotComputableError("HSIC must have at least one batch before it can be computed.")

return self._sum_of_hsic.item() / self._num_batches
185 changes: 185 additions & 0 deletions tests/ignite/metrics/test_hsic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from typing import Tuple

import numpy as np
import pytest

import torch
from torch import nn, Tensor

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics import HSIC


def np_hsic(x: Tensor, y: Tensor, sigma_x: float = -1, sigma_y: float = -1) -> float:
x_np = x.detach().cpu().numpy()
y_np = y.detach().cpu().numpy()
b = x_np.shape[0]

ii, jj = np.meshgrid(np.arange(b), np.arange(b), indexing="ij")
mask = 1.0 - np.eye(b)

dxx = np.square(x_np[ii] - x_np[jj]).sum(axis=2)
if sigma_x < 0:
vx = np.median(dxx)
else:
vx = sigma_x * sigma_x
K = np.exp(-0.5 * dxx / vx) * mask

dyy = np.square(y_np[ii] - y_np[jj]).sum(axis=2)
if sigma_y < 0:
vy = np.median(dyy)
else:
vy = sigma_y * sigma_y
L = np.exp(-0.5 * dyy / vy) * mask

KL = K @ L
ones = np.ones(b)
hsic = np.trace(KL) + (ones @ K @ ones) * (ones @ L @ ones) / ((b - 1) * (b - 2)) - ones @ KL @ ones * 2 / (b - 2)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
hsic /= b * (b - 3)
hsic = np.clip(hsic, 0.0, None)
return hsic


def test_zero_batch():
hsic = HSIC()
with pytest.raises(NotComputableError, match=r"HSIC must have at least one batch before it can be computed"):
hsic.compute()


def test_invalid_batch():
hsic = HSIC(ignore_invalid_batch=False)
X = torch.tensor([[1, 2, 3]]).float()
Y = torch.tensor([[4, 5, 6]]).float()
with pytest.raises(ValueError, match=r"A batch must contain more than four samples, got only"):
hsic.update((X, Y))


@pytest.fixture(params=[0, 1, 2])
def test_case(request) -> Tuple[Tensor, Tensor, int]:
if request.param == 0:
# independent
N = 100
b = 10
x, y = torch.randn((N, 50)), torch.randn((N, 30))
elif request.param == 1:
# linearly dependent
N = 100
b = 10
x = torch.normal(1.0, 2.0, size=(N, 10))
y = x @ torch.rand(10, 15) * 3 + torch.randn(N, 15) * 1e-4
else:
# non-linearly dependent
N = 200
b = 20
x = torch.randn(N, 5)
y = x @ torch.normal(0.0, torch.pi, size=(5, 3))
y = (
torch.stack([torch.sin(y[:, 0]), torch.cos(y[:, 1]), torch.exp(y[:, 2])], dim=1)
+ torch.randn_like(y) * 1e-4
)

return x, y, b


@pytest.mark.parametrize("n_times", range(3))
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tensor, Tensor, int]):
x, y, batch_size = test_case

hsic = HSIC(sigma_x=sigma_x, sigma_y=sigma_y)

hsic.reset()

np_hsic_sum = 0.0
n_iters = y.shape[0] // batch_size
for i in range(n_iters):
idx = i * batch_size
x_batch = x[idx : idx + batch_size]
y_batch = y[idx : idx + batch_size]

hsic.update((x_batch, y_batch))
np_hsic_sum += np_hsic(x_batch, y_batch, sigma_x, sigma_y)
np_res = np_hsic_sum / n_iters
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(hsic.compute(), float)
assert pytest.approx(np_res, abs=2e-5) == hsic.compute()


def test_accumulator_detached():
hsic = HSIC()

x = torch.rand(10, 10, dtype=torch.float)
y = torch.rand(10, 10, dtype=torch.float)
hsic.update((x, y))

assert not hsic._sum_of_hsic.requires_grad


@pytest.mark.usefixtures("distributed")
class TestDistributed:
@pytest.mark.parametrize("sigma_x", [-1.0, 1.0])
@pytest.mark.parametrize("sigma_y", [-1.0, 1.0])
def test_integration(self, sigma_x: float, sigma_y: float):
tol = 2e-5
n_iters = 100
batch_size = 20
n_dims_x = 100
n_dims_y = 50

rank = idist.get_rank()
torch.manual_seed(12 + rank)

device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)

lin = nn.Linear(n_dims_x, n_dims_y)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
for metric_device in metric_devices:
x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device)

lin.to(device)
y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4

def data_loader(i):
return x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size]

engine = Engine(lambda e, i: data_loader(i))

m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device)
m.attach(engine, "hsic")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

assert "hsic" in engine.state.metrics
res = engine.state.metrics["hsic"]

np_res = 0.0
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
for i in range(n_iters):
x_batch, y_batch = data_loader(i)
np_res += np_hsic(x_batch, y_batch)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
np_res = np_res / n_iters
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

assert pytest.approx(np_res, abs=tol) == res

def test_accumulator_device(self):
device = idist.device()
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(device)
for metric_device in metric_devices:
hsic = HSIC(device=metric_device)

for dev in (hsic._device, hsic._sum_of_hsic.device):
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

x = torch.zeros(10, 10).float()
y = torch.ones(10, 10).float()
hsic.update((x, y))

for dev in (hsic._device, hsic._sum_of_hsic.device):
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
Loading