diff --git a/tests/ignite/metrics/test_hsic.py b/tests/ignite/metrics/test_hsic.py index 0fcb24f35c9..f2d99d970f2 100644 --- a/tests/ignite/metrics/test_hsic.py +++ b/tests/ignite/metrics/test_hsic.py @@ -102,10 +102,10 @@ def test_compute(n_times, sigma_x: float, sigma_y: float, test_case: Tuple[Tenso hsic.update((x_batch, y_batch)) np_hsic_sum += np_hsic(x_batch, y_batch, sigma_x, sigma_y) - np_res = np_hsic_sum / n_iters + expected_hsic = np_hsic_sum / n_iters assert isinstance(hsic.compute(), float) - assert pytest.approx(np_res, abs=2e-5) == hsic.compute() + assert pytest.approx(expected_hsic, abs=2e-5) == hsic.compute() def test_accumulator_detached(): @@ -120,50 +120,54 @@ def test_accumulator_detached(): @pytest.mark.usefixtures("distributed") class TestDistributed: - @pytest.mark.parametrize("sigma_x", [-1.0, 1.0]) - @pytest.mark.parametrize("sigma_y", [-1.0, 1.0]) - def test_integration(self, sigma_x: float, sigma_y: float): - tol = 2e-5 - n_iters = 100 - batch_size = 20 - n_dims_x = 100 - n_dims_y = 50 - - rank = idist.get_rank() - torch.manual_seed(12 + rank) + @pytest.mark.parametrize("sigma_x", [-1.0, 1.0]) + @pytest.mark.parametrize("sigma_y", [-1.0, 1.0]) + def test_integration(self, sigma_x: float, sigma_y: float): + tol = 2e-5 + n_iters = 100 + batch_size = 20 + n_dims_x = 100 + n_dims_y = 50 - device = idist.device() - metric_devices = [torch.device("cpu")] - if device.type != "xla": - metric_devices.append(device) + rank = idist.get_rank() + torch.manual_seed(12 + rank) - for metric_device in metric_devices: - x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device) + device = idist.device() + metric_devices = [torch.device("cpu")] + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device) + + lin = nn.Linear(n_dims_x, n_dims_y).to(device) + y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4 - lin = nn.Linear(n_dims_x, n_dims_y).to(device) - y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4 + def data_loader(i, input_x, input_y): + return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size] - def data_loader(i): - return x[i * batch_size : (i + 1) * batch_size], y[i * batch_size : (i + 1) * batch_size] + engine = Engine(lambda e, i: data_loader(i, x, y)) - engine = Engine(lambda e, i: data_loader(i)) + m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device) + m.attach(engine, "hsic") - m = HSIC(sigma_x=sigma_x, sigma_y=sigma_y, device=metric_device) - m.attach(engine, "hsic") + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + assert "hsic" in engine.state.metrics + res = engine.state.metrics["hsic"] - assert "hsic" in engine.state.metrics - res = engine.state.metrics["hsic"] + x = idist.all_gather(x) + y = idist.all_gather(y) + total_n_iters = idist.all_reduce(n_iters) - np_res = 0.0 - for i in range(n_iters): - x_batch, y_batch = data_loader(i) - np_res += np_hsic(x_batch, y_batch) - np_res = np_res / n_iters + np_res = 0.0 + for i in range(total_n_iters): + x_batch, y_batch = data_loader(i, x, y) + np_res += np_hsic(x_batch, y_batch, sigma_x, sigma_y) - assert pytest.approx(np_res, abs=tol) == res + expected_hsic = np_res / total_n_iters + assert pytest.approx(expected_hsic, abs=tol) == res def test_accumulator_device(self): device = idist.device()