Skip to content

Commit

Permalink
Incorporating code review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
GaganCodes committed Nov 5, 2023
1 parent 11e75ff commit f50a6b8
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 101 deletions.
1 change: 1 addition & 0 deletions src/pytorch-sphinx-theme
Submodule pytorch-sphinx-theme added at cf6f6c
55 changes: 24 additions & 31 deletions tests/metrics/functional/statistical/test_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,7 @@ def _get_scipy_equivalent(
y_weights_np = y_weights.numpy()

if x.ndim == 1:
scipy_result = np.stack(
[
sp_wasserstein(sp_x, sp_y, sp_x_w, sp_y_w)
for sp_x, sp_y, sp_x_w, sp_y_w in zip(
[x_np], [y_np], [x_weights_np], [y_weights_np]
)
]
)
scipy_result = [sp_wasserstein(x_np, y_np, x_weights_np, y_weights_np)]
else:
scipy_result = np.stack(
[
Expand Down Expand Up @@ -85,7 +78,7 @@ def _test_wasserstein1d_with_input(
rtol=1e-5,
)
else:
my_compute_result = torch.Tensor(
my_compute_result = torch.tensor(
[
wasserstein_1d(x, y, x_weights, y_weights)
for x, y, x_weights, y_weights in zip(x, y, x_weights, y_weights)
Expand Down Expand Up @@ -113,29 +106,29 @@ def _test_wasserstein1d_with_input(
)

def test_wasserstein1d_distribution_values_only(self) -> None:
x = torch.Tensor([5, -5, -7, 9, -3])
y = torch.Tensor([9, -7, 5, -4, -2])
self._test_wasserstein1d_with_input(torch.Tensor([0.39999999999999997]), x, y)
x = torch.tensor([5, -5, -7, 9, -3])
y = torch.tensor([9, -7, 5, -4, -2])
self._test_wasserstein1d_with_input(torch.tensor([0.39999999999999997]), x, y)

def test_wasserstein1d_distribution_and_weight_values(self) -> None:
x = torch.Tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18])
y = torch.Tensor([9, 6, -5, -11, 9, -4, -13, -19, -14, 4])
x_weights = torch.Tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3])
y_weights = torch.Tensor([2, 2, 1, 1, 2, 2, 1, 1, 1, 1])
x = torch.tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18])
y = torch.tensor([9, 6, -5, -11, 9, -4, -13, -19, -14, 4])
x_weights = torch.tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3])
y_weights = torch.tensor([2, 2, 1, 1, 2, 2, 1, 1, 1, 1])
self._test_wasserstein1d_with_input(
torch.Tensor([8.149068322981368]), x, y, x_weights, y_weights
torch.tensor([8.149068322981368]), x, y, x_weights, y_weights
)

def test_wasserstein1d_different_distribution_shape(self) -> None:
x = torch.Tensor([5, -5, -7, 9, -3])
y = torch.Tensor([9, -7, 5, -4, -2, 4, -1])
self._test_wasserstein1d_with_input(torch.Tensor([1.4571428571428569]), x, y)
x = torch.tensor([5, -5, -7, 9, -3])
y = torch.tensor([9, -7, 5, -4, -2, 4, -1])
self._test_wasserstein1d_with_input(torch.tensor([1.4571428571428569]), x, y)

def test_wasserstein1d_identical_distributions(self) -> None:
x = torch.Tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18])
x_weights = torch.Tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3])
x = torch.tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18])
x_weights = torch.tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3])
self._test_wasserstein1d_with_input(
torch.Tensor([0.0]), x, x, x_weights, x_weights
torch.tensor([0.0]), x, x, x_weights, x_weights
)

def test_wasserstein1d_randomized_data_getter(self) -> None:
Expand Down Expand Up @@ -193,23 +186,23 @@ def test_wasserstein1d_invalid_input(self) -> None:
wasserstein_1d(torch.rand(4), torch.rand(7, 3))

with self.assertRaisesRegex(ValueError, "Distribution cannot be empty."):
wasserstein_1d(torch.rand(4), torch.Tensor([]))
wasserstein_1d(torch.rand(4), torch.tensor([]))

with self.assertRaisesRegex(ValueError, "Distribution cannot be empty."):
wasserstein_1d(torch.Tensor([]), torch.rand(5))
wasserstein_1d(torch.tensor([]), torch.rand(5))

with self.assertRaisesRegex(
ValueError, "Weight tensor sum must be positive-finite."
):
wasserstein_1d(
torch.rand(4), torch.rand(4), torch.Tensor([]), torch.rand(4)
torch.rand(4), torch.rand(4), torch.tensor([torch.inf]), torch.rand(4)
)

with self.assertRaisesRegex(
ValueError, "Weight tensor sum must be positive-finite."
):
wasserstein_1d(
torch.rand(4), torch.rand(4), torch.rand(4), torch.Tensor([])
torch.rand(4), torch.rand(4), torch.rand(4), torch.tensor([torch.inf])
)

with self.assertRaisesRegex(
Expand All @@ -230,19 +223,19 @@ def test_wasserstein1d_invalid_input(self) -> None:

with self.assertRaisesRegex(ValueError, "All weights must be non-negative."):
wasserstein_1d(
torch.rand(4), torch.rand(4), torch.Tensor([1, -1, 2, 3]), torch.rand(4)
torch.rand(4), torch.rand(4), torch.tensor([1, -1, 2, 3]), torch.rand(4)
)

with self.assertRaisesRegex(ValueError, "All weights must be non-negative."):
wasserstein_1d(
torch.rand(4), torch.rand(4), torch.rand(4), torch.Tensor([1, -1, 2, 3])
torch.rand(4), torch.rand(4), torch.rand(4), torch.tensor([1, -1, 2, 3])
)

with self.assertRaisesRegex(ValueError, "All weights must be non-negative."):
wasserstein_1d(
torch.rand(4),
torch.rand(4),
torch.Tensor([-1.0, -2.0, 0.0, 1.0]),
torch.tensor([-1.0, -2.0, 0.0, 1.0]),
torch.rand(4),
)

Expand All @@ -251,5 +244,5 @@ def test_wasserstein1d_invalid_input(self) -> None:
torch.rand(4),
torch.rand(4),
torch.rand(4),
torch.Tensor([-1.5, -1.0, 0.5, 0.75]),
torch.tensor([-1.5, -1.0, 0.5, 0.75]),
)
59 changes: 27 additions & 32 deletions tests/metrics/statistical/test_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import Optional

import numpy as np

import torch

from scipy.stats import wasserstein_distance as sp_wasserstein
Expand Down Expand Up @@ -39,14 +37,7 @@ def _get_scipy_equivalent(
if y_weights is not None:
y_weights_np = y_weights.numpy().flatten()

scipy_result = np.stack(
[
sp_wasserstein(sp_x, sp_y, sp_x_w, sp_y_w)
for sp_x, sp_y, sp_x_w, sp_y_w in zip(
[x_np], [y_np], [x_weights_np], [y_weights_np]
)
]
)
scipy_result = [sp_wasserstein(x_np, y_np, x_weights_np, y_weights_np)]

return torch.tensor(scipy_result, device=device).to(torch.float)

Expand Down Expand Up @@ -84,11 +75,11 @@ def _check_against_scipy(
def test_wasserstein1d_valid_input(self) -> None:
# Checking with distribution values only
metric = Wasserstein1D()
x = torch.Tensor([5, -5, -7, 9, -3])
y = torch.Tensor([9, -7, 5, -4, -2])
x = torch.tensor([5, -5, -7, 9, -3])
y = torch.tensor([9, -7, 5, -4, -2])
metric.update(x, y)
result = metric.compute()
expected = torch.Tensor([0.39999999999999997])
expected = torch.tensor([0.39999999999999997])
torch.testing.assert_close(
result,
expected,
Expand All @@ -99,13 +90,13 @@ def test_wasserstein1d_valid_input(self) -> None:

# Checking with distribution and weight values
metric = Wasserstein1D()
x = torch.Tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18])
y = torch.Tensor([9, 6, -5, -11, 9, -4, -13, -19, -14, 4])
x_weights = torch.Tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3])
y_weights = torch.Tensor([2, 2, 1, 1, 2, 2, 1, 1, 1, 1])
x = torch.tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18])
y = torch.tensor([9, 6, -5, -11, 9, -4, -13, -19, -14, 4])
x_weights = torch.tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3])
y_weights = torch.tensor([2, 2, 1, 1, 2, 2, 1, 1, 1, 1])
metric.update(x, y, x_weights, y_weights)
result = metric.compute()
expected = torch.Tensor([8.149068322981368])
expected = torch.tensor([8.149068322981368])
torch.testing.assert_close(
result,
expected,
Expand All @@ -116,11 +107,11 @@ def test_wasserstein1d_valid_input(self) -> None:

# Checking with different distribution shapes
metric = Wasserstein1D()
x = torch.Tensor([5, -5, -7, 9, -3])
y = torch.Tensor([9, -7, 5, -4, -2, 4, -1])
x = torch.tensor([5, -5, -7, 9, -3])
y = torch.tensor([9, -7, 5, -4, -2, 4, -1])
metric.update(x, y)
result = metric.compute()
expected = torch.Tensor([1.4571428571428569])
expected = torch.tensor([1.4571428571428569])
torch.testing.assert_close(
result,
expected,
Expand All @@ -131,11 +122,11 @@ def test_wasserstein1d_valid_input(self) -> None:

# Checking with identical distributions
metric = Wasserstein1D()
x = torch.Tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18])
x_weights = torch.Tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3])
x = torch.tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18])
x_weights = torch.tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3])
metric.update(x, x, x_weights, x_weights)
result = metric.compute()
expected = torch.Tensor([0.0])
expected = torch.tensor([0.0])
torch.testing.assert_close(
result,
expected,
Expand Down Expand Up @@ -165,20 +156,24 @@ def test_wasserstein1d_invalid_input(self) -> None:
metric.update(torch.rand(4), torch.rand(7, 3))

with self.assertRaisesRegex(ValueError, "Distribution cannot be empty."):
metric.update(torch.rand(4), torch.Tensor([]))
metric.update(torch.rand(4), torch.tensor([]))

with self.assertRaisesRegex(ValueError, "Distribution cannot be empty."):
metric.update(torch.Tensor([]), torch.rand(5))
metric.update(torch.tensor([]), torch.rand(5))

with self.assertRaisesRegex(
ValueError, "Weight tensor sum must be positive-finite."
):
metric.update(torch.rand(4), torch.rand(4), torch.Tensor([]), torch.rand(4))
metric.update(
torch.rand(4), torch.rand(4), torch.tensor([torch.inf]), torch.rand(4)
)

with self.assertRaisesRegex(
ValueError, "Weight tensor sum must be positive-finite."
):
metric.update(torch.rand(4), torch.rand(4), torch.rand(4), torch.Tensor([]))
metric.update(
torch.rand(4), torch.rand(4), torch.rand(4), torch.tensor([torch.inf])
)

with self.assertRaisesRegex(
ValueError,
Expand All @@ -198,19 +193,19 @@ def test_wasserstein1d_invalid_input(self) -> None:

with self.assertRaisesRegex(ValueError, "All weights must be non-negative."):
metric.update(
torch.rand(4), torch.rand(4), torch.Tensor([1, -1, 2, 3]), torch.rand(4)
torch.rand(4), torch.rand(4), torch.tensor([1, -1, 2, 3]), torch.rand(4)
)

with self.assertRaisesRegex(ValueError, "All weights must be non-negative."):
metric.update(
torch.rand(4), torch.rand(4), torch.rand(4), torch.Tensor([1, -1, 2, 3])
torch.rand(4), torch.rand(4), torch.rand(4), torch.tensor([1, -1, 2, 3])
)

with self.assertRaisesRegex(ValueError, "All weights must be non-negative."):
metric.update(
torch.rand(4),
torch.rand(4),
torch.Tensor([-1.0, -2.0, 0.0, 1.0]),
torch.tensor([-1.0, -2.0, 0.0, 1.0]),
torch.rand(4),
)

Expand All @@ -219,5 +214,5 @@ def test_wasserstein1d_invalid_input(self) -> None:
torch.rand(4),
torch.rand(4),
torch.rand(4),
torch.Tensor([-1.5, -1.0, 0.5, 0.75]),
torch.tensor([-1.5, -1.0, 0.5, 0.75]),
)
51 changes: 21 additions & 30 deletions torcheval/metrics/functional/statistical/wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,59 +84,50 @@ def wasserstein_1d(
torch.tensor([0.75])
"""
_wasserstein_param_check(x, y, x_weights, y_weights)
_wasserstein_update_input_check(x, y, x_weights, y_weights)
return _wasserstein_compute(x, y, x_weights, y_weights)


def _wasserstein_param_check(
def _wasserstein_update_input_check(
x: torch.Tensor,
y: torch.Tensor,
x_weights: Optional[torch.Tensor] = None,
y_weights: Optional[torch.Tensor] = None,
) -> None:
if x.nelement() == 0 or y.nelement() == 0:
raise ValueError("Distribution cannot be empty.")
if x.dim() > 1 or y.dim() > 1:
raise ValueError("Distribution has to be one dimensional.")
if not x.device == y.device:
raise ValueError("Expected all the tensors to be on the same device.")
if x_weights is not None:
if x_weights.nelement() == 0:
raise ValueError("Weights cannot be empty.")
if not torch.all(x_weights > 0):
raise ValueError("All weights must be non-negative.")
if not (0 < torch.sum(x_weights) < torch.inf):
raise ValueError("Weight tensor sum must be positive-finite.")
if not x_weights.device == x.device:
raise ValueError("Expected values and weights to be on the same device.")
if x_weights.shape != x.shape:
raise ValueError(
"Distribution values and weight tensors must be of the same shape, "
f"got shapes {x.shape} and {x_weights.shape}."
)
if y_weights is not None:
if y_weights.nelement() == 0:
raise ValueError("Weights cannot be empty.")
if not torch.all(y_weights > 0):
raise ValueError("All weights must be non-negative.")
if not (0 < torch.sum(y_weights) < torch.inf):
raise ValueError("Weight tensor sum must be positive-finite.")
if not y_weights.device == y.device:
raise ValueError("Expected values and weights to be on the same device.")
if not x.device == y.device:
raise ValueError("Expected all the tensors to be on the same device.")


def _wasserstein_update_input_check(
x: torch.Tensor,
y: torch.Tensor,
x_weights: Optional[torch.Tensor] = None,
y_weights: Optional[torch.Tensor] = None,
) -> None:
if x.nelement() == 0 or y.nelement() == 0:
raise ValueError("Distribution cannot be empty.")
if x.dim() > 1 or y.dim() > 1:
raise ValueError("Distribution has to be one dimensional.")
if x_weights is not None and x_weights.nelement() == 0:
raise ValueError("Weights cannot be empty.")
if x_weights is not None and x_weights.shape != x.shape:
raise ValueError(
"Distribution values and weight tensors must be of the same shape, "
f"got shapes {x.shape} and {x_weights.shape}."
)
if y_weights is not None and y_weights.nelement() == 0:
raise ValueError("Weights cannot be empty.")
if y_weights is not None and y_weights.shape != y.shape:
raise ValueError(
"Distribution values and weight tensors must be of the same shape, "
f"got shapes {y.shape} and {y_weights.shape}."
)
if y_weights.shape != y.shape:
raise ValueError(
"Distribution values and weight tensors must be of the same shape, "
f"got shapes {y.shape} and {y_weights.shape}."
)


def _wasserstein_compute(
Expand Down
8 changes: 0 additions & 8 deletions torcheval/metrics/statistical/wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from torcheval.metrics.functional.statistical.wasserstein import (
_wasserstein_compute,
_wasserstein_param_check,
_wasserstein_update_input_check,
)
from torcheval.metrics.metric import Metric
Expand Down Expand Up @@ -73,13 +72,6 @@ def update(
new_weights_dist_1, new_weights_dist_2 (Tensor): Optional tensor weights for each value.
If unspecified, each value is assigned the same value (1.0).
"""
_wasserstein_param_check(
new_samples_dist_1,
new_samples_dist_2,
new_weights_dist_1,
new_weights_dist_2,
)

_wasserstein_update_input_check(
new_samples_dist_1,
new_samples_dist_2,
Expand Down

0 comments on commit f50a6b8

Please sign in to comment.