From 9b209eb053ebbf0f0f2bbaad39cc9316125a5407 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 15 May 2023 18:17:31 +0100 Subject: [PATCH] Run tests Signed-off-by: Walter Hugo Lopez Pinaya --- generative/losses/perceptual.py | 3 ++- tests/test_perceptual_loss.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/generative/losses/perceptual.py b/generative/losses/perceptual.py index c5ad80f0..8fffb1c8 100644 --- a/generative/losses/perceptual.py +++ b/generative/losses/perceptual.py @@ -319,7 +319,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights, we make sure that the input and target have 3 channels, and then do Z-Score normalization. - The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package). + The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar + approach to the lpips package). """ # If input has just 1 channel, repeat channel to have 3 channels if input.shape[1] == 1 and target.shape[1] == 1: diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py index c96bec71..fee375dd 100644 --- a/tests/test_perceptual_loss.py +++ b/tests/test_perceptual_loss.py @@ -14,9 +14,10 @@ import unittest import torch -from generative.losses import PerceptualLoss from parameterized import parameterized +from generative.losses import PerceptualLoss + TEST_CASES = [ [{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)], [