Skip to content

Commit

Permalink
update running_average, ssim, top_k_categorical_accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
simeetnayan81 committed Jul 16, 2024
1 parent 757e198 commit 743fc4f
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 5 deletions.
11 changes: 10 additions & 1 deletion ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class RunningAverage(Metric):
None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will
use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value
from the metric is a tensor.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
Examples:
Expand Down Expand Up @@ -84,6 +87,9 @@ def log_running_avg_metrics():
0.039208...
0.038423...
0.057655...
.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""

required_output_keys = None
Expand All @@ -96,6 +102,7 @@ def __init__(
output_transform: Optional[Callable] = None,
epoch_bound: Optional[bool] = None,
device: Optional[Union[str, torch.device]] = None,
skip_unrolling: bool = False,
):
if not (isinstance(src, Metric) or src is None):
raise TypeError("Argument src should be a Metric or None.")
Expand Down Expand Up @@ -131,7 +138,9 @@ def output_transform(x: Any) -> Any:
)
self.epoch_bound = epoch_bound
self.alpha = alpha
super(RunningAverage, self).__init__(output_transform=output_transform, device=device)
super(RunningAverage, self).__init__(
output_transform=output_transform, device=device, skip_unrolling=skip_unrolling
)

@reinit__is_reduced
def reset(self) -> None:
Expand Down
9 changes: 8 additions & 1 deletion ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class SSIM(Metric):
device: specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
Expand Down Expand Up @@ -62,6 +65,9 @@ class SSIM(Metric):
0.9218971...
.. versionadded:: 0.4.2
.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""

_state_dict_all_req_keys = ("_sum_of_ssim", "_num_examples", "_kernel")
Expand All @@ -76,6 +82,7 @@ def __init__(
gaussian: bool = True,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
if isinstance(kernel_size, int):
self.kernel_size: Sequence[int] = [kernel_size, kernel_size]
Expand All @@ -97,7 +104,7 @@ def __init__(
if any(y <= 0 for y in self.sigma):
raise ValueError(f"Expected sigma to have positive number. Got {sigma}.")

super(SSIM, self).__init__(output_transform=output_transform, device=device)
super(SSIM, self).__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling)
self.gaussian = gaussian
self.data_range = data_range
self.c1 = (k1 * data_range) ** 2
Expand Down
9 changes: 8 additions & 1 deletion ignite/metrics/top_k_categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class TopKCategoricalAccuracy(Metric):
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
Expand Down Expand Up @@ -71,6 +74,9 @@ def one_hot_to_binary_output_transform(output):
.. testoutput::
0.75
.. versionchanged:: 0.5.1
``skip_unrolling`` argument is added.
"""

_state_dict_all_req_keys = ("_num_correct", "_num_examples")
Expand All @@ -80,8 +86,9 @@ def __init__(
k: int = 5,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device)
super(TopKCategoricalAccuracy, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
self._k = k

@reinit__is_reduced
Expand Down
1 change: 0 additions & 1 deletion tests/ignite/metrics/test_mean_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch

import ignite.distributed as idist

from ignite.exceptions import NotComputableError
from ignite.metrics import MeanPairwiseDistance

Expand Down
1 change: 0 additions & 1 deletion tests/ignite/metrics/test_multilabel_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from sklearn.metrics import multilabel_confusion_matrix

import ignite.distributed as idist

from ignite.exceptions import NotComputableError
from ignite.metrics.multilabel_confusion_matrix import MultiLabelConfusionMatrix

Expand Down

0 comments on commit 743fc4f

Please sign in to comment.