Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making functions in utils/metrics.py compatible with torch.Tensor inputs #216

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions src/careamics/check_metrics.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"from skimage.metrics import peak_signal_noise_ratio as psnr\n",
"\n",
"from careamics.utils.metrics import _zero_mean, scale_invariant_psnr, torch_cast_to_float, _fix, _fix_range"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x, y = np.array([1, 2, 3, 4, 5, 6]), np.array([1, 2, 3, 4, 5, 6])\n",
"# x, y = torch.tensor([1, 2, 3, 4, 5, 6]), torch.tensor([1, 2, 3, 4, 5, 6])\n",
"# x, y = torch_cast_to_float(x), torch_cast_to_float(y)\n",
"\n",
"# scale_invariant_psnr(x, y)\n",
"psnr(np.asarray(x), np.asarray(y), data_range=5.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"range_ = (x.max() - x.min()) / x.std()\n",
"print(range_)\n",
"x_ = _zero_mean(x) / x.std()\n",
"print(x_)\n",
"# fix\n",
"y_ = _zero_mean(y)\n",
"print(y_)\n",
"x__ = _zero_mean(x_)\n",
"print(x__)\n",
"# fix range\n",
"a = (x__ * y_).sum() / (y_**2).sum()\n",
"print(a)\n",
"y__ = y_ * a\n",
"print(y__)\n",
"print(psnr(np.asarray(x__), np.asarray(y__), data_range=range_))\n",
"# psnr(\n",
"# np.asarray([-1.3363, -0.8018, -0.2673, 0.2673, 0.8018, 1.3363]),\n",
"# np.asarray([-1.3363, -0.8018, -0.2673, 0.2673, 0.8018, 1.3363]),\n",
"# data_range=2.6726\n",
"# )\n",
"print(psnr(\n",
" np.asarray([-1.46385011, -0.87831007, -0.29277002, 0.29277002, 0.87831007, 1.46385011]),\n",
" np.asarray([-1.46385011, -0.87831007, -0.29277002, 0.29277002, 0.87831007, 1.46385011]),\n",
" data_range=2.9277\n",
"))\n",
"np.allclose(np.asarray(x__), np.asarray([-1.46385011, -0.87831007, -0.29277002, 0.29277002, 0.87831007, 1.46385011]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float32)\n",
"torch_cast_to_float(x)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "train_lvae",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
99 changes: 75 additions & 24 deletions src/careamics/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,20 @@
This module contains various metrics and a metrics tracking class.
"""

# NOTE: this doesn't work with torch tensors, since `torch` refuses to
# compute the `mean()` or `std()` of a tensor whose dtype is not float.

from typing import Union
from warnings import warn

import numpy as np
import torch
from skimage.metrics import peak_signal_noise_ratio

Array = Union[np.ndarray, torch.Tensor]


def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float:
def psnr(gt: Array, pred: Array, range: float = 255.0) -> float:
"""
Peak Signal to Noise Ratio.

Expand All @@ -20,9 +26,9 @@ def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float:

Parameters
----------
gt : NumPy array
gt : Array
Ground truth image.
pred : NumPy array
pred : Array
Predicted image.
range : float, optional
The images pixel range, by default 255.0.
Expand All @@ -32,84 +38,129 @@ def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float:
float
PSNR value.
"""
return peak_signal_noise_ratio(gt, pred, data_range=range)
# TODO: replace with explicit formula (?) it'd be a couple lines of code
# and won't impact performance. On the contrary it would make the code
# more explicit and easier to test.
return peak_signal_noise_ratio(
np.asarray(gt),
np.asarray(pred),
data_range=range,
)


def _zero_mean(x: np.ndarray) -> np.ndarray:
def _zero_mean(x: Array) -> Array:
"""
Zero the mean of an array.

NOTE: `torch` does not support the `mean()` method for tensors whose
`dtype` is not `float`. Hence, this function will raise a warning and
automatically cast the input tensor to `float` if it is a `torch.Tensor`.

Parameters
----------
x : NumPy array
x : Array
Input array.

Returns
-------
NumPy array
Array
Zero-mean array.
"""
return x - np.mean(x)
x = _torch_cast_to_double(x)
return x - x.mean()


def _fix_range(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
def _fix_range(gt: Array, x: Array) -> Array:
"""
Adjust the range of an array based on a reference ground-truth array.

Parameters
----------
gt : np.ndarray
gt : Array
Ground truth image.
x : np.ndarray
x : Array
Input array.

Returns
-------
np.ndarray
Array
Range-adjusted array.
"""
a = np.sum(gt * x) / (np.sum(x * x))
a = (gt * x).sum() / (x * x).sum()
return x * a


def _fix(gt: np.ndarray, x: np.ndarray) -> np.ndarray:
def _fix(gt: Array, x: Array) -> Array:
"""
Zero mean a groud truth array and adjust the range of the array.

Parameters
----------
gt : np.ndarray
gt : Array
Ground truth image.
x : np.ndarray
x : Array
Input array.

Returns
-------
np.ndarray
Array
Zero-mean and range-adjusted array.
"""
gt_ = _zero_mean(gt)
return _fix_range(gt_, _zero_mean(x))


def scale_invariant_psnr(
gt: np.ndarray, pred: np.ndarray
) -> Union[float, torch.tensor]:
def scale_invariant_psnr(gt: Array, pred: Array) -> Union[float, torch.tensor]:
"""
Scale invariant PSNR.

NOTE: `torch` does not support the `mean()` method for tensors whose
`dtype` is not `float`. Hence, this function will raise a warning and
automatically cast the input tensor to `float` if it is a `torch.Tensor`.

NOTE: results may vary slightly between `numpy` and `torch` due to the way
`var()` is computed. In `torch`, the unbiased estimator is used (i.e., SSE/n-1),
while in `numpy` the biased estimator is used (i.e., SSE/n).

Parameters
----------
gt : np.ndarray
gt : Array
Ground truth image.
pred : np.ndarray
pred : Array
Predicted image.

Returns
-------
Union[float, torch.tensor]
Scale invariant PSNR value.
"""
range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt)
gt_ = _zero_mean(gt) / np.std(gt)
# cast tensors to double dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic looks good.

gt = _torch_cast_to_double(gt)
pred = _torch_cast_to_double(pred)
# compute scale-invariant PSNR
range_parameter = (gt.max() - gt.min()) / gt.std()
gt_ = _zero_mean(gt) / gt.std()
return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter)


def _torch_cast_to_double(x: Array) -> Array:
"""
Cast a tensor to float.

Parameters
----------
x : Array
Input tensor.

Returns
-------
Array
Float tensor.
"""
if isinstance(x, torch.Tensor) and x.dtype != torch.float64:
warn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the warning necessary? I ask because it will never lead to any information loss. Also, we do not return this tensor; we just use it to compute the metrics.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd definitely remove it "in production", but I have temporarily put it there for debugging :)

f"Casting tensor of type `{x.dtype}` to double (`torch.float64`).",
UserWarning,
)
return x.double()
return x
17 changes: 16 additions & 1 deletion tests/utils/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
import torch

from careamics.utils.metrics import (
_zero_mean,
Expand All @@ -10,21 +11,35 @@
@pytest.mark.parametrize(
"x",
[
5.6,
np.array([1, 2, 3, 4, 5]),
np.array([[1, 2, 3], [4, 5, 6]]),
torch.tensor([1, 2, 3, 4, 5]),
torch.tensor([[1, 2, 3], [4, 5, 6]]),
],
)
def test_zero_mean(x):
x = np.asarray(x)
assert np.allclose(_zero_mean(x), x - np.mean(x))


# NOTE: the behavior of the PSNR function for np.arrays is weird. Indeed, PSNR computed over
# identical vectors should be infinite, but the function returns a finite value.
# Using torch it gives instead `inf`.
@pytest.mark.parametrize(
"gt, pred, result",
[
(np.array([1, 2, 3, 4, 5, 6]), np.array([1, 2, 3, 4, 5, 6]), 332.22),
(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[1, 2, 3], [4, 5, 6]]), 332.22),
(torch.tensor([1, 2, 3, 4, 5, 6]), torch.tensor([1, 2, 3, 4, 5, 6]), 332.22),
(
torch.tensor([[1, 2, 3], [4, 5, 6]]),
torch.tensor([[1, 2, 3], [4, 5, 6]]),
332.22,
),
],
)
def test_scale_invariant_psnr(gt, pred, result):
assert scale_invariant_psnr(gt, pred) == pytest.approx(result, rel=5e-3)


# TODO: add tests for RunningPSNR
Loading