From dfa6ca423db5d4906835715970a11353233a4df8 Mon Sep 17 00:00:00 2001 From: Gagandeep Date: Tue, 21 Nov 2023 09:52:09 -0800 Subject: [PATCH] Wasserstein (#184) Summary: Implemented the wasserstein distance function and metric class for 1 dimensional case. Pull Request resolved: https://github.com/pytorch/torcheval/pull/184 Test Plan: Every feature of the code is tested via unitTest framework in "torcheval\tests\metrics\statistical\test_wasserstein.py" and "torcheval\tests\metrics\functional\statistical\test_wasserstein.py". Fixes #{issue number} https://github.com/pytorch/torcheval/issues/137 Reviewed By: JKSenthil Differential Revision: D50825881 Pulled By: bobakfb fbshipit-source-id: 99936cd1773f6436c2f70d21a6de866dbca9d1e7 --- src/pytorch-sphinx-theme | 1 + .../functional/statistical/__init__.py | 5 + .../statistical/test_wasserstein.py | 248 ++++++++++++++++++ tests/metrics/statistical/__init__.py | 5 + tests/metrics/statistical/test_wasserstein.py | 218 +++++++++++++++ .../functional/statistical/__init__.py | 12 + .../functional/statistical/wasserstein.py | 176 +++++++++++++ torcheval/metrics/statistical/__init__.py | 12 + torcheval/metrics/statistical/wasserstein.py | 145 ++++++++++ torcheval/utils/random_data.py | 37 +++ 10 files changed, 859 insertions(+) create mode 160000 src/pytorch-sphinx-theme create mode 100644 tests/metrics/functional/statistical/__init__.py create mode 100644 tests/metrics/functional/statistical/test_wasserstein.py create mode 100644 tests/metrics/statistical/__init__.py create mode 100644 tests/metrics/statistical/test_wasserstein.py create mode 100644 torcheval/metrics/functional/statistical/__init__.py create mode 100644 torcheval/metrics/functional/statistical/wasserstein.py create mode 100644 torcheval/metrics/statistical/__init__.py create mode 100644 torcheval/metrics/statistical/wasserstein.py diff --git a/src/pytorch-sphinx-theme b/src/pytorch-sphinx-theme new file mode 160000 index 00000000..cf6f6cd7 --- /dev/null +++ b/src/pytorch-sphinx-theme @@ -0,0 +1 @@ +Subproject commit cf6f6cd79dc0fc064b3d1d810821ed9375c2efeb diff --git a/tests/metrics/functional/statistical/__init__.py b/tests/metrics/functional/statistical/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/tests/metrics/functional/statistical/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/metrics/functional/statistical/test_wasserstein.py b/tests/metrics/functional/statistical/test_wasserstein.py new file mode 100644 index 00000000..b8de40ef --- /dev/null +++ b/tests/metrics/functional/statistical/test_wasserstein.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Optional + +import numpy as np + +import torch + +from scipy.stats import wasserstein_distance as sp_wasserstein +from torcheval.metrics.functional.statistical.wasserstein import wasserstein_1d +from torcheval.utils import random_data as rd + + +class TestWasserstein1D(unittest.TestCase): + def _get_scipy_equivalent( + self, + x: torch.Tensor, + y: torch.Tensor, + x_weights: Optional[torch.Tensor] = None, + y_weights: Optional[torch.Tensor] = None, + device: str = "cpu", + ) -> torch.Tensor: + # Convert inputs to scipy style inputs + x_np = x.numpy() + y_np = y.numpy() + if x_weights is not None: + x_weights_np = x_weights.numpy() + if y_weights is not None: + y_weights_np = y_weights.numpy() + + if x.ndim == 1: + scipy_result = [sp_wasserstein(x_np, y_np, x_weights_np, y_weights_np)] + else: + scipy_result = np.stack( + [ + sp_wasserstein(sp_x, sp_y, sp_x_w, sp_y_w) + for sp_x, sp_y, sp_x_w, sp_y_w in zip( + x_np, y_np, x_weights_np, y_weights_np + ) + ] + ) + + return torch.tensor(scipy_result, device=device).to(torch.float) + + def _test_wasserstein1d_with_input( + self, + compute_result: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + x_weights: Optional[torch.Tensor] = None, + y_weights: Optional[torch.Tensor] = None, + ) -> None: + if x.ndim == 1: + my_compute_result = wasserstein_1d(x, y, x_weights, y_weights) + torch.testing.assert_close( + my_compute_result, + compute_result, + equal_nan=True, + atol=1e-8, + rtol=1e-5, + ) + + # Also test for cuda + if torch.cuda.is_available(): + compute_result_cuda = tuple(c.to("cuda") for c in compute_result) + my_compute_result_cuda = tuple(c.to("cuda") for c in my_compute_result) + + torch.testing.assert_close( + my_compute_result_cuda, + compute_result_cuda, + equal_nan=True, + atol=1e-8, + rtol=1e-5, + ) + else: + my_compute_result = torch.tensor( + [ + wasserstein_1d(x, y, x_weights, y_weights) + for x, y, x_weights, y_weights in zip(x, y, x_weights, y_weights) + ] + ).to(x.device) + torch.testing.assert_close( + my_compute_result, + compute_result, + equal_nan=True, + atol=1e-8, + rtol=1e-5, + ) + + # Also test for cuda + if torch.cuda.is_available(): + compute_result_cuda = tuple(c.to("cuda") for c in compute_result) + my_compute_result_cuda = tuple(c.to("cuda") for c in my_compute_result) + + torch.testing.assert_close( + my_compute_result_cuda, + compute_result_cuda, + equal_nan=True, + atol=1e-8, + rtol=1e-5, + ) + + def test_wasserstein1d_distribution_values_only(self) -> None: + x = torch.tensor([5, -5, -7, 9, -3]) + y = torch.tensor([9, -7, 5, -4, -2]) + self._test_wasserstein1d_with_input(torch.tensor([0.39999999999999997]), x, y) + + def test_wasserstein1d_distribution_and_weight_values(self) -> None: + x = torch.tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18]) + y = torch.tensor([9, 6, -5, -11, 9, -4, -13, -19, -14, 4]) + x_weights = torch.tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3]) + y_weights = torch.tensor([2, 2, 1, 1, 2, 2, 1, 1, 1, 1]) + self._test_wasserstein1d_with_input( + torch.tensor([8.149068322981368]), x, y, x_weights, y_weights + ) + + def test_wasserstein1d_different_distribution_shape(self) -> None: + x = torch.tensor([5, -5, -7, 9, -3]) + y = torch.tensor([9, -7, 5, -4, -2, 4, -1]) + self._test_wasserstein1d_with_input(torch.tensor([1.4571428571428569]), x, y) + + def test_wasserstein1d_identical_distributions(self) -> None: + x = torch.tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18]) + x_weights = torch.tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3]) + self._test_wasserstein1d_with_input( + torch.tensor([0.0]), x, x, x_weights, x_weights + ) + + def test_wasserstein1d_randomized_data_getter(self) -> None: + num_updates = 1 + batch_size = 32 + device = "cuda" if torch.cuda.is_available() else "cpu" + + for _ in range(10): + x, y, x_weights, y_weights = rd.get_rand_data_wasserstein1d( + num_updates, batch_size, device + ) + + compute_result = self._get_scipy_equivalent( + x.to("cpu"), + y.to("cpu"), + x_weights.to("cpu"), + y_weights.to("cpu"), + device, + ) + + self._test_wasserstein1d_with_input( + compute_result, x, y, x_weights, y_weights + ) + + num_updates = 8 + batch_size = 32 + device = "cuda" if torch.cuda.is_available() else "cpu" + + for _ in range(10): + x, y, x_weights, y_weights = rd.get_rand_data_wasserstein1d( + num_updates, batch_size, device + ) + + compute_result = self._get_scipy_equivalent( + x.to("cpu"), + y.to("cpu"), + x_weights.to("cpu"), + y_weights.to("cpu"), + device, + ) + + self._test_wasserstein1d_with_input( + compute_result, x, y, x_weights, y_weights + ) + + def test_wasserstein1d_invalid_input(self) -> None: + with self.assertRaisesRegex( + ValueError, "Distribution has to be one dimensional." + ): + wasserstein_1d(torch.rand(4, 2), torch.rand(7)) + + with self.assertRaisesRegex( + ValueError, "Distribution has to be one dimensional." + ): + wasserstein_1d(torch.rand(4), torch.rand(7, 3)) + + with self.assertRaisesRegex(ValueError, "Distribution cannot be empty."): + wasserstein_1d(torch.rand(4), torch.tensor([])) + + with self.assertRaisesRegex(ValueError, "Distribution cannot be empty."): + wasserstein_1d(torch.tensor([]), torch.rand(5)) + + with self.assertRaisesRegex( + ValueError, "Weight tensor sum must be positive-finite." + ): + wasserstein_1d( + torch.rand(4), torch.rand(4), torch.tensor([torch.inf]), torch.rand(4) + ) + + with self.assertRaisesRegex( + ValueError, "Weight tensor sum must be positive-finite." + ): + wasserstein_1d( + torch.rand(4), torch.rand(4), torch.rand(4), torch.tensor([torch.inf]) + ) + + with self.assertRaisesRegex( + ValueError, + "Distribution values and weight tensors must be of the same shape, " + "got shapes " + r"torch.Size\(\[4\]\) and torch.Size\(\[7\]\).", + ): + wasserstein_1d(torch.rand(4), torch.rand(4), torch.rand(7), torch.rand(4)) + + with self.assertRaisesRegex( + ValueError, + "Distribution values and weight tensors must be of the same shape, " + "got shapes " + r"torch.Size\(\[6\]\) and torch.Size\(\[10\]\).", + ): + wasserstein_1d(torch.rand(6), torch.rand(6), torch.rand(6), torch.rand(10)) + + with self.assertRaisesRegex(ValueError, "All weights must be non-negative."): + wasserstein_1d( + torch.rand(4), torch.rand(4), torch.tensor([1, -1, 2, 3]), torch.rand(4) + ) + + with self.assertRaisesRegex(ValueError, "All weights must be non-negative."): + wasserstein_1d( + torch.rand(4), torch.rand(4), torch.rand(4), torch.tensor([1, -1, 2, 3]) + ) + + with self.assertRaisesRegex(ValueError, "All weights must be non-negative."): + wasserstein_1d( + torch.rand(4), + torch.rand(4), + torch.tensor([-1.0, -2.0, 0.0, 1.0]), + torch.rand(4), + ) + + with self.assertRaisesRegex(ValueError, "All weights must be non-negative."): + wasserstein_1d( + torch.rand(4), + torch.rand(4), + torch.rand(4), + torch.tensor([-1.5, -1.0, 0.5, 0.75]), + ) diff --git a/tests/metrics/statistical/__init__.py b/tests/metrics/statistical/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/tests/metrics/statistical/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/metrics/statistical/test_wasserstein.py b/tests/metrics/statistical/test_wasserstein.py new file mode 100644 index 00000000..9a02d8f6 --- /dev/null +++ b/tests/metrics/statistical/test_wasserstein.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +from scipy.stats import wasserstein_distance as sp_wasserstein +from torcheval.metrics.statistical.wasserstein import Wasserstein1D +from torcheval.utils.random_data import get_rand_data_wasserstein1d +from torcheval.utils.test_utils.metric_class_tester import ( + BATCH_SIZE, + MetricClassTester, + NUM_PROCESSES, +) + +NUM_TOTAL_UPDATES = 8 + + +class TestWasserstein1D(MetricClassTester): + def _get_scipy_equivalent( + self, + x: torch.Tensor, + y: torch.Tensor, + x_weights: Optional[torch.Tensor] = None, + y_weights: Optional[torch.Tensor] = None, + device: str = "cpu", + ) -> torch.Tensor: + # Convert inputs to scipy style inputs + x_np = x.numpy().flatten() + y_np = y.numpy().flatten() + if x_weights is not None: + x_weights_np = x_weights.numpy().flatten() + if y_weights is not None: + y_weights_np = y_weights.numpy().flatten() + + scipy_result = [sp_wasserstein(x_np, y_np, x_weights_np, y_weights_np)] + + return torch.tensor(scipy_result, device=device).to(torch.float) + + def _check_against_scipy( + self, + x: torch.Tensor, + y: torch.Tensor, + x_weights: Optional[torch.Tensor] = None, + y_weights: Optional[torch.Tensor] = None, + device: str = "cpu", + ) -> None: + compute_result = self._get_scipy_equivalent( + x.to("cpu"), y.to("cpu"), x_weights.to("cpu"), y_weights.to("cpu"), device + ) + + self.run_class_implementation_tests( + metric=Wasserstein1D(device=device), + state_names={ + "dist_1_samples", + "dist_2_samples", + "dist_1_weights", + "dist_2_weights", + }, + update_kwargs={ + "new_samples_dist_1": x, + "new_samples_dist_2": y, + "new_weights_dist_1": x_weights, + "new_weights_dist_2": y_weights, + }, + compute_result=compute_result, + num_total_updates=NUM_TOTAL_UPDATES, + num_processes=NUM_PROCESSES, + ) + + def test_wasserstein1d_valid_input(self) -> None: + # Checking with distribution values only + metric = Wasserstein1D() + x = torch.tensor([5, -5, -7, 9, -3]) + y = torch.tensor([9, -7, 5, -4, -2]) + metric.update(x, y) + result = metric.compute() + expected = torch.tensor([0.39999999999999997]) + torch.testing.assert_close( + result, + expected, + equal_nan=True, + atol=1e-4, + rtol=1e-3, + ) + + # Checking with distribution and weight values + metric = Wasserstein1D() + x = torch.tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18]) + y = torch.tensor([9, 6, -5, -11, 9, -4, -13, -19, -14, 4]) + x_weights = torch.tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3]) + y_weights = torch.tensor([2, 2, 1, 1, 2, 2, 1, 1, 1, 1]) + metric.update(x, y, x_weights, y_weights) + result = metric.compute() + expected = torch.tensor([8.149068322981368]) + torch.testing.assert_close( + result, + expected, + equal_nan=True, + atol=1e-4, + rtol=1e-3, + ) + + # Checking with different distribution shapes + metric = Wasserstein1D() + x = torch.tensor([5, -5, -7, 9, -3]) + y = torch.tensor([9, -7, 5, -4, -2, 4, -1]) + metric.update(x, y) + result = metric.compute() + expected = torch.tensor([1.4571428571428569]) + torch.testing.assert_close( + result, + expected, + equal_nan=True, + atol=1e-4, + rtol=1e-3, + ) + + # Checking with identical distributions + metric = Wasserstein1D() + x = torch.tensor([-13, -9, -19, 11, -18, -20, 8, 2, -8, -18]) + x_weights = torch.tensor([3, 3, 1, 2, 2, 3, 2, 2, 2, 3]) + metric.update(x, x, x_weights, x_weights) + result = metric.compute() + expected = torch.tensor([0.0]) + torch.testing.assert_close( + result, + expected, + equal_nan=True, + atol=1e-4, + rtol=1e-3, + ) + + def test_wasserstein1d_random_data_getter(self) -> None: + for _ in range(10): + x, y, x_weights, y_weights = get_rand_data_wasserstein1d( + num_updates=NUM_TOTAL_UPDATES, batch_size=BATCH_SIZE + ) + + self._check_against_scipy(x, y, x_weights, y_weights) + + def test_wasserstein1d_invalid_input(self) -> None: + metric = Wasserstein1D() + with self.assertRaisesRegex( + ValueError, "Distribution has to be one dimensional." + ): + metric.update(torch.rand(4, 2), torch.rand(7)) + + with self.assertRaisesRegex( + ValueError, "Distribution has to be one dimensional." + ): + metric.update(torch.rand(4), torch.rand(7, 3)) + + with self.assertRaisesRegex(ValueError, "Distribution cannot be empty."): + metric.update(torch.rand(4), torch.tensor([])) + + with self.assertRaisesRegex(ValueError, "Distribution cannot be empty."): + metric.update(torch.tensor([]), torch.rand(5)) + + with self.assertRaisesRegex( + ValueError, "Weight tensor sum must be positive-finite." + ): + metric.update( + torch.rand(4), torch.rand(4), torch.tensor([torch.inf]), torch.rand(4) + ) + + with self.assertRaisesRegex( + ValueError, "Weight tensor sum must be positive-finite." + ): + metric.update( + torch.rand(4), torch.rand(4), torch.rand(4), torch.tensor([torch.inf]) + ) + + with self.assertRaisesRegex( + ValueError, + "Distribution values and weight tensors must be of the same shape, " + "got shapes " + r"torch.Size\(\[4\]\) and torch.Size\(\[7\]\).", + ): + metric.update(torch.rand(4), torch.rand(4), torch.rand(7), torch.rand(4)) + + with self.assertRaisesRegex( + ValueError, + "Distribution values and weight tensors must be of the same shape, " + "got shapes " + r"torch.Size\(\[6\]\) and torch.Size\(\[10\]\).", + ): + metric.update(torch.rand(6), torch.rand(6), torch.rand(6), torch.rand(10)) + + with self.assertRaisesRegex(ValueError, "All weights must be non-negative."): + metric.update( + torch.rand(4), torch.rand(4), torch.tensor([1, -1, 2, 3]), torch.rand(4) + ) + + with self.assertRaisesRegex(ValueError, "All weights must be non-negative."): + metric.update( + torch.rand(4), torch.rand(4), torch.rand(4), torch.tensor([1, -1, 2, 3]) + ) + + with self.assertRaisesRegex(ValueError, "All weights must be non-negative."): + metric.update( + torch.rand(4), + torch.rand(4), + torch.tensor([-1.0, -2.0, 0.0, 1.0]), + torch.rand(4), + ) + + with self.assertRaisesRegex(ValueError, "All weights must be non-negative."): + metric.update( + torch.rand(4), + torch.rand(4), + torch.rand(4), + torch.tensor([-1.5, -1.0, 0.5, 0.75]), + ) diff --git a/torcheval/metrics/functional/statistical/__init__.py b/torcheval/metrics/functional/statistical/__init__.py new file mode 100644 index 00000000..4af25344 --- /dev/null +++ b/torcheval/metrics/functional/statistical/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors[16]: Undefined attribute of metric states. + +from torcheval.metrics.functional.statistical.wasserstein import wasserstein_1d + +__all__ = ["wasserstein_1d"] +__doc_name__ = "Statistical Metrics" diff --git a/torcheval/metrics/functional/statistical/wasserstein.py b/torcheval/metrics/functional/statistical/wasserstein.py new file mode 100644 index 00000000..b5c237b2 --- /dev/null +++ b/torcheval/metrics/functional/statistical/wasserstein.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + + +@torch.inference_mode() +def wasserstein_1d( + x: torch.Tensor, + y: torch.Tensor, + x_weights: Optional[torch.Tensor] = None, + y_weights: Optional[torch.Tensor] = None, +) -> torch.Tensor: + r""" + The Wasserstein distance, also called the Earth Mover's Distance, is a + measure of the similarity between two distributions. + + The Wasserstein distance between two distributions is intuitively the + minimum weight of soil (times distance moved) that would need to be moved + if the two distributions were represented by two piles of soil. + + Args + ---------- + x, y (Tensor) : 1D Tensor values observed in the distribution. + x_weights, y_weights (Tensor): Optional tensor weights for each value. + If unspecified, each value is assigned the same value. + `x_weights` (resp. `y_weights`) must have the same length as + `x` (resp. `y`). If the weight sum differs from 1, it + must still be positive and finite so that the weights can be normalized + to sum to 1. + + Returns + ------- + distance : Tensor value + The computed distance between the distributions. + + Notes + ----- + The first Wasserstein distance between the distributions :math:`x` and + :math:`x` is: + + .. math:: + + W_1 (x, y) = \inf_{\pi \in \Gamma (x, y)} \int_{\mathbb{R} \times + \mathbb{R}} |p-q| \mathrm{d} \pi (p, q) + + where :math:`\Gamma (x, y)` is the set of (probability) distributions on + :math:`\mathbb{R} \times \mathbb{R}` whose marginals are :math:`x` and + :math:`y` on the first and second factors respectively. + + If :math:`X` and :math:`Y` are the respective CDFs of :math:`x` and + :math:`y`, this distance also equals to: + + .. math:: + + W_1(x, y) = \int_{-\infty}^{+\infty} |X-Y| + + See [2]_ for a proof of the equivalence of both definitions. + + The input distributions can be empirical, therefore coming from samples + whose values are effectively inputs of the function, or they can be seen as + generalized functions, in which case they are weighted sums of Dirac delta + functions located at the specified values. + + References + ---------- + .. [1] "Wasserstein metric", https://en.wikipedia.org/wiki/Wasserstein_metric + .. [2] Ramdas, Garcia, Cuturi "On Wasserstein Two Sample Testing and Related + Families of Nonparametric Tests" (2015). :arXiv:`1509.02237`. + + Examples + -------- + >>> from torcheval.metrics.functional import wasserstein_1d + >>> wasserstein_1d(torch.tensor([0,1,2]), torch.tensor([0,1,1])) + torch.tensor([0.33333333333333337]) + >>> wasserstein_1d(torch.tensor([0,1,2]), torch.tensor([0,1,1]), torch.tensor([1,2,0]), torch.tensor([1,1,1])) + torch.tensor([0.0]) + >>> wasserstein_1d(torch.tensor([0,1,2,2]), torch.tensor([0,1])) + torch.tensor([0.75]) + + """ + _wasserstein_update_input_check(x, y, x_weights, y_weights) + return _wasserstein_compute(x, y, x_weights, y_weights) + + +def _wasserstein_update_input_check( + x: torch.Tensor, + y: torch.Tensor, + x_weights: Optional[torch.Tensor] = None, + y_weights: Optional[torch.Tensor] = None, +) -> None: + if x.nelement() == 0 or y.nelement() == 0: + raise ValueError("Distribution cannot be empty.") + if x.dim() > 1 or y.dim() > 1: + raise ValueError("Distribution has to be one dimensional.") + if not x.device == y.device: + raise ValueError("Expected all the tensors to be on the same device.") + if x_weights is not None: + if x_weights.nelement() == 0: + raise ValueError("Weights cannot be empty.") + if not torch.all(x_weights > 0): + raise ValueError("All weights must be non-negative.") + if not (0 < torch.sum(x_weights) < torch.inf): + raise ValueError("Weight tensor sum must be positive-finite.") + if not x_weights.device == x.device: + raise ValueError("Expected values and weights to be on the same device.") + if x_weights.shape != x.shape: + raise ValueError( + "Distribution values and weight tensors must be of the same shape, " + f"got shapes {x.shape} and {x_weights.shape}." + ) + if y_weights is not None: + if y_weights.nelement() == 0: + raise ValueError("Weights cannot be empty.") + if not torch.all(y_weights > 0): + raise ValueError("All weights must be non-negative.") + if not (0 < torch.sum(y_weights) < torch.inf): + raise ValueError("Weight tensor sum must be positive-finite.") + if not y_weights.device == y.device: + raise ValueError("Expected values and weights to be on the same device.") + if y_weights.shape != y.shape: + raise ValueError( + "Distribution values and weight tensors must be of the same shape, " + f"got shapes {y.shape} and {y_weights.shape}." + ) + + +def _wasserstein_compute( + x: torch.Tensor, + y: torch.Tensor, + x_weights: Optional[torch.Tensor], + y_weights: Optional[torch.Tensor], +) -> torch.Tensor: + # Assigning device per input + device = x.device + + # Finding the sorted values + x_sorter = torch.argsort(x) + y_sorter = torch.argsort(y) + + # Bringing all the values on a central number line + all_values = torch.concatenate((x, y)) + all_values, _ = torch.sort(all_values) + + # Compute the differences between successive values of x and y + deltas = torch.diff(all_values) + + # Obtain respective positions of the x and y values among all_values + x_cdf_indices = torch.searchsorted(x[x_sorter], all_values[:-1], right=True) + y_cdf_indices = torch.searchsorted(y[y_sorter], all_values[:-1], right=True) + + # Calculate the CDF of x and y using their weights, if specified + if x_weights is None: + x_cdf = x_cdf_indices.to(device) / x.size(0) + else: + x_sorted_cum_weights = torch.cat( + (torch.Tensor([0]).to(device), torch.cumsum(x_weights[x_sorter], dim=0)) + ) + x_cdf = x_sorted_cum_weights[x_cdf_indices] / x_sorted_cum_weights[-1] + + if y_weights is None: + y_cdf = y_cdf_indices.to(device) / y.size(0) + else: + y_sorted_cum_weights = torch.cat( + (torch.Tensor([0]).to(device), torch.cumsum(y_weights[y_sorter], dim=0)) + ) + y_cdf = y_sorted_cum_weights[y_cdf_indices] / y_sorted_cum_weights[-1] + + return torch.sum( + torch.multiply(torch.abs(x_cdf - y_cdf), deltas), dim=0, keepdim=True + ).to(device) diff --git a/torcheval/metrics/statistical/__init__.py b/torcheval/metrics/statistical/__init__.py new file mode 100644 index 00000000..8ec697f4 --- /dev/null +++ b/torcheval/metrics/statistical/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors[16]: Undefined attribute of metric states. + +from torcheval.metrics.statistical.wasserstein import Wasserstein1D + +__all__ = ["Wasserstein1D"] +__doc_name__ = "Statistical Metrics" diff --git a/torcheval/metrics/statistical/wasserstein.py b/torcheval/metrics/statistical/wasserstein.py new file mode 100644 index 00000000..694b8569 --- /dev/null +++ b/torcheval/metrics/statistical/wasserstein.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterable, Optional, TypeVar + +import torch + +from torcheval.metrics.functional.statistical.wasserstein import ( + _wasserstein_compute, + _wasserstein_update_input_check, +) +from torcheval.metrics.metric import Metric + +TWasserstein = TypeVar("TWasserstein") + + +class Wasserstein1D(Metric[torch.Tensor]): + r""" + The Wasserstein distance, also called the Earth Mover's Distance, is a + measure of the similarity between two distributions. + + The Wasserstein distance between two distributions is intuitively the + minimum weight of soil (times distance moved) that would need to be moved + if the two distributions were represented by two piles of soil. + + Its functional version is :func:'torcheval.metrics.functional.statistical.wasserstein'. + + Examples + -------- + >>> from torcheval.metrics import Wasserstein1D + >>> metric = Wasserstein1D() + >>> metric.update(torch.tensor([0,1,2,2]), torch.tensor([0,1])) + >>> metric.compute() + 0.75 + >>> metric = Wasserstein1D() + >>> metric.update(torch.tensor([0,1,2]), torch.tensor([0,1,1]), torch.tensor([1,2,0]), torch.tensor([1,1,1])) + >>> metric.compute() + 0 + >>> metric = Wasserstein1D() + >>> metric.update(torch.tensor([0,1,2]), torch.tensor([0,1,1])) + >>> metric.compute() + 0.33333333333333337 + >>> metric.update(torch.tensor([1,1,1]), torch.tensor([1,1,1])) + >>> metric.compute() + 0.16666666666666663 + """ + + def __init__(self: TWasserstein, *, device: Optional[torch.device] = None) -> None: + super().__init__(device=device) + # Keeping record of samples + self._add_state("dist_1_samples", []) + self._add_state("dist_2_samples", []) + self._add_state("dist_1_weights", []) + self._add_state("dist_2_weights", []) + + @torch.inference_mode() + def update( + self, + new_samples_dist_1: torch.Tensor, + new_samples_dist_2: torch.Tensor, + new_weights_dist_1: Optional[torch.Tensor] = None, + new_weights_dist_2: Optional[torch.Tensor] = None, + ) -> None: + r""" + Update states with distribution values and corresponding weights. + + Args: + new_samples_dist_1, new_samples_dist_2 (Tensor) : 1D Tensor values observed in the distribution. + new_weights_dist_1, new_weights_dist_2 (Tensor): Optional tensor weights for each value. + If unspecified, each value is assigned the same value (1.0). + """ + _wasserstein_update_input_check( + new_samples_dist_1, + new_samples_dist_2, + new_weights_dist_1, + new_weights_dist_2, + ) + + new_samples_dist_1 = new_samples_dist_1.to(self.device) + new_samples_dist_2 = new_samples_dist_2.to(self.device) + + if new_weights_dist_1 is None: + new_weights_dist_1 = torch.ones_like(new_samples_dist_1, dtype=torch.float) + else: + new_weights_dist_1 = new_weights_dist_1.to(self.device) + + if new_weights_dist_2 is None: + new_weights_dist_2 = torch.ones_like(new_samples_dist_2, dtype=torch.float) + else: + new_weights_dist_2 = new_weights_dist_2.to(self.device) + + # When new data comes in, just add them to the list of samples + self.dist_1_samples.append(new_samples_dist_1) + self.dist_2_samples.append(new_samples_dist_2) + self.dist_1_weights.append(new_weights_dist_1) + self.dist_2_weights.append(new_weights_dist_2) + + return self + + @torch.inference_mode() + def compute(self): + r""" + Return Wasserstein distance. If no ``update()`` calls are made before + ``compute()`` is called, return an empty tensor. + + Returns: + Tensor: The return value of Wasserstein value. + """ + return _wasserstein_compute( + torch.cat(self.dist_1_samples, -1), + torch.cat(self.dist_2_samples, -1), + torch.cat(self.dist_1_weights, -1), + torch.cat(self.dist_2_weights, -1), + ) + + @torch.inference_mode() + def merge_state( + self: TWasserstein, metrics: Iterable[TWasserstein] + ) -> TWasserstein: + for metric in metrics: + if metric.dist_1_samples != []: + metric_dist_1_samples = torch.cat(metric.dist_1_samples, -1).to( + self.device + ) + self.dist_1_samples.append(metric_dist_1_samples) + + metric_dist_2_samples = torch.cat(metric.dist_2_samples, -1).to( + self.device + ) + self.dist_2_samples.append(metric_dist_2_samples) + + metric_dist_1_weights = torch.cat(metric.dist_1_weights, -1).to( + self.device + ) + self.dist_1_weights.append(metric_dist_1_weights) + + metric_dist_2_weights = torch.cat(metric.dist_2_weights, -1).to( + self.device + ) + self.dist_2_weights.append(metric_dist_2_weights) + + return self diff --git a/torcheval/utils/random_data.py b/torcheval/utils/random_data.py index 7e1835d4..084aba2c 100644 --- a/torcheval/utils/random_data.py +++ b/torcheval/utils/random_data.py @@ -159,3 +159,40 @@ def get_rand_data_binned_binary( threshold, _ = torch.sort(threshold) threshold = torch.unique(threshold) return input, target, threshold.to(device) + + +def get_rand_data_wasserstein1d( + num_updates: int, + batch_size: int, + device: Optional[torch.device] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Generates a random distribution dataset. + + Notes: + - If num_updates is 1, the update dimension will be omitted; tensors will have shape (batch_size,). + + Args: + num_updates: the number of calls to update on each rank. + batch_size: batch size of the dataset. + device: device for the returned Tensors + + Returns: + torch.Tensor: distribution values first distribution + torch.Tensor: distribution values second distribution + torch.Tensor: weight values first distribution + torch.Tensor: weight values second distribution + """ + if device is None: + device = torch.device("cpu") + + shape = [num_updates, batch_size] + if num_updates == 1: + shape = [batch_size] + + x = torch.rand(size=shape) + y = torch.rand(size=shape) + x_weights = torch.randint(low=1, high=10, size=shape) + y_weights = torch.randint(low=1, high=10, size=shape) + + return x.to(device), y.to(device), x_weights.to(device), y_weights.to(device)