-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making functions in utils/metrics.py
compatible with torch.Tensor
inputs
#216
base: main
Are you sure you want to change the base?
Changes from 10 commits
c465fe6
09b2f5f
0d0d3da
8f700a8
5ffc6bb
c1661aa
ef0cf74
ccc0a09
f5d17d7
d950b1d
e5cd386
7f82938
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import numpy as np\n", | ||
"from skimage.metrics import peak_signal_noise_ratio as psnr\n", | ||
"\n", | ||
"from careamics.utils.metrics import _zero_mean, scale_invariant_psnr, torch_cast_to_float, _fix, _fix_range" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"x, y = np.array([1, 2, 3, 4, 5, 6]), np.array([1, 2, 3, 4, 5, 6])\n", | ||
"# x, y = torch.tensor([1, 2, 3, 4, 5, 6]), torch.tensor([1, 2, 3, 4, 5, 6])\n", | ||
"# x, y = torch_cast_to_float(x), torch_cast_to_float(y)\n", | ||
"\n", | ||
"# scale_invariant_psnr(x, y)\n", | ||
"psnr(np.asarray(x), np.asarray(y), data_range=5.)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"range_ = (x.max() - x.min()) / x.std()\n", | ||
"print(range_)\n", | ||
"x_ = _zero_mean(x) / x.std()\n", | ||
"print(x_)\n", | ||
"# fix\n", | ||
"y_ = _zero_mean(y)\n", | ||
"print(y_)\n", | ||
"x__ = _zero_mean(x_)\n", | ||
"print(x__)\n", | ||
"# fix range\n", | ||
"a = (x__ * y_).sum() / (y_**2).sum()\n", | ||
"print(a)\n", | ||
"y__ = y_ * a\n", | ||
"print(y__)\n", | ||
"print(psnr(np.asarray(x__), np.asarray(y__), data_range=range_))\n", | ||
"# psnr(\n", | ||
"# np.asarray([-1.3363, -0.8018, -0.2673, 0.2673, 0.8018, 1.3363]),\n", | ||
"# np.asarray([-1.3363, -0.8018, -0.2673, 0.2673, 0.8018, 1.3363]),\n", | ||
"# data_range=2.6726\n", | ||
"# )\n", | ||
"print(psnr(\n", | ||
" np.asarray([-1.46385011, -0.87831007, -0.29277002, 0.29277002, 0.87831007, 1.46385011]),\n", | ||
" np.asarray([-1.46385011, -0.87831007, -0.29277002, 0.29277002, 0.87831007, 1.46385011]),\n", | ||
" data_range=2.9277\n", | ||
"))\n", | ||
"np.allclose(np.asarray(x__), np.asarray([-1.46385011, -0.87831007, -0.29277002, 0.29277002, 0.87831007, 1.46385011]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"x = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float32)\n", | ||
"torch_cast_to_float(x)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "train_lvae", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.19" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,20 @@ | |
This module contains various metrics and a metrics tracking class. | ||
""" | ||
|
||
# NOTE: this doesn't work with torch tensors, since `torch` refuses to | ||
# compute the `mean()` or `std()` of a tensor whose dtype is not float. | ||
|
||
from typing import Union | ||
from warnings import warn | ||
|
||
import numpy as np | ||
import torch | ||
from skimage.metrics import peak_signal_noise_ratio | ||
|
||
Array = Union[np.ndarray, torch.Tensor] | ||
|
||
|
||
def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float: | ||
def psnr(gt: Array, pred: Array, range: float = 255.0) -> float: | ||
""" | ||
Peak Signal to Noise Ratio. | ||
|
||
|
@@ -20,9 +26,9 @@ def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float: | |
|
||
Parameters | ||
---------- | ||
gt : NumPy array | ||
gt : Array | ||
Ground truth image. | ||
pred : NumPy array | ||
pred : Array | ||
Predicted image. | ||
range : float, optional | ||
The images pixel range, by default 255.0. | ||
|
@@ -32,84 +38,129 @@ def psnr(gt: np.ndarray, pred: np.ndarray, range: float = 255.0) -> float: | |
float | ||
PSNR value. | ||
""" | ||
return peak_signal_noise_ratio(gt, pred, data_range=range) | ||
# TODO: replace with explicit formula (?) it'd be a couple lines of code | ||
# and won't impact performance. On the contrary it would make the code | ||
# more explicit and easier to test. | ||
return peak_signal_noise_ratio( | ||
np.asarray(gt), | ||
np.asarray(pred), | ||
data_range=range, | ||
) | ||
|
||
|
||
def _zero_mean(x: np.ndarray) -> np.ndarray: | ||
def _zero_mean(x: Array) -> Array: | ||
""" | ||
Zero the mean of an array. | ||
|
||
NOTE: `torch` does not support the `mean()` method for tensors whose | ||
`dtype` is not `float`. Hence, this function will raise a warning and | ||
automatically cast the input tensor to `float` if it is a `torch.Tensor`. | ||
|
||
Parameters | ||
---------- | ||
x : NumPy array | ||
x : Array | ||
Input array. | ||
|
||
Returns | ||
------- | ||
NumPy array | ||
Array | ||
Zero-mean array. | ||
""" | ||
return x - np.mean(x) | ||
x = _torch_cast_to_double(x) | ||
return x - x.mean() | ||
|
||
|
||
def _fix_range(gt: np.ndarray, x: np.ndarray) -> np.ndarray: | ||
def _fix_range(gt: Array, x: Array) -> Array: | ||
""" | ||
Adjust the range of an array based on a reference ground-truth array. | ||
|
||
Parameters | ||
---------- | ||
gt : np.ndarray | ||
gt : Array | ||
Ground truth image. | ||
x : np.ndarray | ||
x : Array | ||
Input array. | ||
|
||
Returns | ||
------- | ||
np.ndarray | ||
Array | ||
Range-adjusted array. | ||
""" | ||
a = np.sum(gt * x) / (np.sum(x * x)) | ||
a = (gt * x).sum() / (x * x).sum() | ||
return x * a | ||
|
||
|
||
def _fix(gt: np.ndarray, x: np.ndarray) -> np.ndarray: | ||
def _fix(gt: Array, x: Array) -> Array: | ||
""" | ||
Zero mean a groud truth array and adjust the range of the array. | ||
|
||
Parameters | ||
---------- | ||
gt : np.ndarray | ||
gt : Array | ||
Ground truth image. | ||
x : np.ndarray | ||
x : Array | ||
Input array. | ||
|
||
Returns | ||
------- | ||
np.ndarray | ||
Array | ||
Zero-mean and range-adjusted array. | ||
""" | ||
gt_ = _zero_mean(gt) | ||
return _fix_range(gt_, _zero_mean(x)) | ||
|
||
|
||
def scale_invariant_psnr( | ||
gt: np.ndarray, pred: np.ndarray | ||
) -> Union[float, torch.tensor]: | ||
def scale_invariant_psnr(gt: Array, pred: Array) -> Union[float, torch.tensor]: | ||
""" | ||
Scale invariant PSNR. | ||
|
||
NOTE: `torch` does not support the `mean()` method for tensors whose | ||
`dtype` is not `float`. Hence, this function will raise a warning and | ||
automatically cast the input tensor to `float` if it is a `torch.Tensor`. | ||
|
||
NOTE: results may vary slightly between `numpy` and `torch` due to the way | ||
`var()` is computed. In `torch`, the unbiased estimator is used (i.e., SSE/n-1), | ||
while in `numpy` the biased estimator is used (i.e., SSE/n). | ||
|
||
Parameters | ||
---------- | ||
gt : np.ndarray | ||
gt : Array | ||
Ground truth image. | ||
pred : np.ndarray | ||
pred : Array | ||
Predicted image. | ||
|
||
Returns | ||
------- | ||
Union[float, torch.tensor] | ||
Scale invariant PSNR value. | ||
""" | ||
range_parameter = (np.max(gt) - np.min(gt)) / np.std(gt) | ||
gt_ = _zero_mean(gt) / np.std(gt) | ||
# cast tensors to double dtype | ||
gt = _torch_cast_to_double(gt) | ||
pred = _torch_cast_to_double(pred) | ||
# compute scale-invariant PSNR | ||
range_parameter = (gt.max() - gt.min()) / gt.std() | ||
gt_ = _zero_mean(gt) / gt.std() | ||
return psnr(_zero_mean(gt_), _fix(gt_, pred), range_parameter) | ||
|
||
|
||
def _torch_cast_to_double(x: Array) -> Array: | ||
""" | ||
Cast a tensor to float. | ||
|
||
Parameters | ||
---------- | ||
x : Array | ||
Input tensor. | ||
|
||
Returns | ||
------- | ||
Array | ||
Float tensor. | ||
""" | ||
if isinstance(x, torch.Tensor) and x.dtype != torch.float64: | ||
warn( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the warning necessary? I ask because it will never lead to any information loss. Also, we do not return this tensor; we just use it to compute the metrics. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd definitely remove it "in production", but I have temporarily put it there for debugging :) |
||
f"Casting tensor of type `{x.dtype}` to double (`torch.float64`).", | ||
UserWarning, | ||
) | ||
return x.double() | ||
return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic looks good.