Skip to content
This repository has been archived by the owner on Feb 7, 2025. It is now read-only.

Commit

Permalink
Run tests
Browse files Browse the repository at this point in the history
Signed-off-by: Walter Hugo Lopez Pinaya <[email protected]>
  • Loading branch information
Warvito committed May 15, 2023
1 parent f969c24 commit 9b209eb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion generative/losses/perceptual.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights,
we make sure that the input and target have 3 channels, and then do Z-Score normalization.
The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).
The outputs are normalised across the channels, and we obtain the mean from the spatial dimensions (similar
approach to the lpips package).
"""
# If input has just 1 channel, repeat channel to have 3 channels
if input.shape[1] == 1 and target.shape[1] == 1:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import unittest

import torch
from generative.losses import PerceptualLoss
from parameterized import parameterized

from generative.losses import PerceptualLoss

TEST_CASES = [
[{"spatial_dims": 2, "network_type": "squeeze"}, (2, 1, 64, 64), (2, 1, 64, 64)],
[
Expand Down

0 comments on commit 9b209eb

Please sign in to comment.