From 63a31a0821dd83c7f5769ba20bf98b31b32ea486 Mon Sep 17 00:00:00 2001 From: tiphaine Date: Fri, 22 Nov 2024 15:24:01 +0100 Subject: [PATCH 1/2] Correct FID result --- tests/metrics/image/test_fid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/image/test_fid.py b/tests/metrics/image/test_fid.py index a20d44f..cca643c 100644 --- a/tests/metrics/image/test_fid.py +++ b/tests/metrics/image/test_fid.py @@ -78,7 +78,7 @@ def test_fid_random_data_default_model(self) -> None: 299, ) self._test_fid( - imgs=imgs, feature_dim=2048, expected_result=torch.tensor(4.48304) + imgs=imgs, feature_dim=2048, expected_result=torch.tensor(4.58449) ) def test_fid_random_data_custom_model(self) -> None: From d2de82150e72072d55abcfa96f9214173bd87d77 Mon Sep 17 00:00:00 2001 From: tiphaine Date: Fri, 22 Nov 2024 15:25:40 +0100 Subject: [PATCH 2/2] Fix precision in FID metric (#192) --- torcheval/metrics/image/fid.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torcheval/metrics/image/fid.py b/torcheval/metrics/image/fid.py index 4bcf7b3..17859dc 100644 --- a/torcheval/metrics/image/fid.py +++ b/torcheval/metrics/image/fid.py @@ -95,13 +95,19 @@ def __init__( self.model.requires_grad_(False) # Initialize state variables used to compute FID - self._add_state("real_sum", torch.zeros(feature_dim, device=device)) self._add_state( - "real_cov_sum", torch.zeros((feature_dim, feature_dim), device=device) + "real_sum", torch.zeros(feature_dim, device=device, dtype=torch.float64) ) - self._add_state("fake_sum", torch.zeros(feature_dim, device=device)) self._add_state( - "fake_cov_sum", torch.zeros((feature_dim, feature_dim), device=device) + "real_cov_sum", + torch.zeros((feature_dim, feature_dim), device=device, dtype=torch.float64), + ) + self._add_state( + "fake_sum", torch.zeros(feature_dim, device=device, dtype=torch.float64) + ) + self._add_state( + "fake_cov_sum", + torch.zeros((feature_dim, feature_dim), device=device, dtype=torch.float64), ) self._add_state("num_real_images", torch.tensor(0, device=device).int()) self._add_state("num_fake_images", torch.tensor(0, device=device).int()) @@ -200,6 +206,7 @@ def compute(self: TFrechetInceptionDistance) -> Tensor: fid = gaussian_frechet_distance( real_mean.squeeze(), real_cov, fake_mean.squeeze(), fake_cov ) + fid = fid.to(torch.float32) return fid def _FID_parameter_check(