-
-
Notifications
You must be signed in to change notification settings - Fork 633
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* move contrib * make changes to ignite/metrics * add tests and import fixes * move docs for deprecated contrib.metrics * move tests for deprecated contrib.metrics * adjust references * rename test modules * fix version of deprecation * fix doctest * add deprecation warnings * adjust precision of comparison in test this test fails intermittently. the main difference between this branch and master is the a chance difference in the order the tests are run which reliably triggers the failure.
- Loading branch information
Showing
82 changed files
with
2,504 additions
and
1,894 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,15 @@ | ||
ignite.contrib.metrics | ||
====================== | ||
======================= | ||
|
||
Contrib module metrics | ||
---------------------- | ||
Contrib module metrics [deprecated] | ||
----------------------------------- | ||
|
||
.. currentmodule:: ignite.contrib.metrics | ||
.. deprecated:: 0.5.1 | ||
All metrics moved to :ref:`Complete list of metrics`. | ||
|
||
.. autosummary:: | ||
:nosignatures: | ||
:toctree: ../generated | ||
|
||
AveragePrecision | ||
CohenKappa | ||
GpuInfo | ||
PrecisionRecallCurve | ||
ROC_AUC | ||
RocCurve | ||
Regression metrics [deprecated] | ||
-------------------------------- | ||
|
||
Regression metrics | ||
------------------ | ||
|
||
.. currentmodule:: ignite.contrib.metrics.regression | ||
|
||
.. automodule:: ignite.contrib.metrics.regression | ||
|
||
|
||
Module :mod:`ignite.contrib.metrics.regression` provides implementations of | ||
metrics useful for regression tasks. Definitions of metrics are based on `Botchkarev 2018`_, page 30 "Appendix 2. Metrics mathematical definitions". | ||
|
||
.. _`Botchkarev 2018`: | ||
https://arxiv.org/abs/1809.03006 | ||
|
||
Complete list of metrics: | ||
|
||
.. currentmodule:: ignite.contrib.metrics.regression | ||
|
||
.. autosummary:: | ||
:nosignatures: | ||
:toctree: ../generated | ||
|
||
CanberraMetric | ||
FractionalAbsoluteError | ||
FractionalBias | ||
GeometricMeanAbsoluteError | ||
GeometricMeanRelativeAbsoluteError | ||
ManhattanDistance | ||
MaximumAbsoluteError | ||
MeanAbsoluteRelativeError | ||
MeanError | ||
MeanNormalizedBias | ||
MedianAbsoluteError | ||
MedianAbsolutePercentageError | ||
MedianRelativeAbsoluteError | ||
R2Score | ||
WaveHedgesDistance | ||
.. deprecated:: 0.5.1 | ||
All metrics moved to :ref:`Complete list of metrics`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import ignite.contrib.metrics.regression | ||
from ignite.contrib.metrics.average_precision import AveragePrecision | ||
from ignite.contrib.metrics.cohen_kappa import CohenKappa | ||
from ignite.contrib.metrics.gpu_info import GpuInfo | ||
from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve | ||
from ignite.contrib.metrics.roc_auc import ROC_AUC, RocCurve | ||
import ignite.metrics.regression | ||
from ignite.metrics import average_precision, cohen_kappa, gpu_info, precision_recall_curve, roc_auc | ||
from ignite.metrics.average_precision import AveragePrecision | ||
from ignite.metrics.cohen_kappa import CohenKappa | ||
from ignite.metrics.gpu_info import GpuInfo | ||
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve | ||
from ignite.metrics.roc_auc import ROC_AUC, RocCurve |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,22 @@ | ||
from typing import Callable, Union | ||
|
||
import torch | ||
|
||
from ignite.metrics import EpochMetric | ||
|
||
|
||
def average_precision_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> float: | ||
from sklearn.metrics import average_precision_score | ||
|
||
y_true = y_targets.cpu().numpy() | ||
y_pred = y_preds.cpu().numpy() | ||
return average_precision_score(y_true, y_pred) | ||
|
||
|
||
class AveragePrecision(EpochMetric): | ||
"""Computes Average Precision accumulating predictions and the ground-truth during an epoch | ||
and applying `sklearn.metrics.average_precision_score <https://scikit-learn.org/stable/modules/generated/ | ||
sklearn.metrics.average_precision_score.html#sklearn.metrics.average_precision_score>`_ . | ||
Args: | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
check_compute_fn: Default False. If True, `average_precision_score | ||
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html | ||
#sklearn.metrics.average_precision_score>`_ is run on the first batch of data to ensure there are | ||
no issues. User will be warned in case there are any issues computing the function. | ||
device: optional device specification for internal storage. | ||
Note: | ||
AveragePrecision expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or | ||
confidence values. To apply an activation to y_pred, use output_transform as shown below: | ||
.. code-block:: python | ||
def activated_output_transform(output): | ||
y_pred, y = output | ||
y_pred = torch.softmax(y_pred, dim=1) | ||
return y_pred, y | ||
avg_precision = AveragePrecision(activated_output_transform) | ||
Examples: | ||
.. include:: defaults.rst | ||
:start-after: :orphan: | ||
.. testcode:: | ||
y_pred = torch.tensor([[0.79, 0.21], [0.30, 0.70], [0.46, 0.54], [0.16, 0.84]]) | ||
y_true = torch.tensor([[1, 1], [1, 1], [0, 1], [0, 1]]) | ||
avg_precision = AveragePrecision() | ||
avg_precision.attach(default_evaluator, 'average_precision') | ||
state = default_evaluator.run([[y_pred, y_true]]) | ||
print(state.metrics['average_precision']) | ||
.. testoutput:: | ||
0.9166... | ||
""" | ||
|
||
def __init__( | ||
self, | ||
output_transform: Callable = lambda x: x, | ||
check_compute_fn: bool = False, | ||
device: Union[str, torch.device] = torch.device("cpu"), | ||
): | ||
try: | ||
from sklearn.metrics import average_precision_score # noqa: F401 | ||
except ImportError: | ||
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.") | ||
|
||
super(AveragePrecision, self).__init__( | ||
average_precision_compute_fn, | ||
output_transform=output_transform, | ||
check_compute_fn=check_compute_fn, | ||
device=device, | ||
) | ||
""" ``ignite.contrib.metrics.average_precision`` was moved to ``ignite.metrics.average_precision``. | ||
Note: | ||
``ignite.contrib.metrics.average_precision`` was moved to ``ignite.metrics.average_precision``. | ||
Please refer to :mod:`~ignite.metrics.average_precision`. | ||
""" | ||
|
||
import warnings | ||
|
||
removed_in = "0.6.0" | ||
deprecation_warning = ( | ||
f"{__file__} has been moved to /ignite/metrics/average_precision.py" | ||
+ (f" and will be removed in version {removed_in}" if removed_in else "") | ||
+ ".\n Please refer to the documentation for more details." | ||
) | ||
warnings.warn(deprecation_warning, DeprecationWarning, stacklevel=2) | ||
from ignite.metrics.average_precision import AveragePrecision | ||
|
||
__all__ = [ | ||
"AveragePrecision", | ||
] | ||
|
||
AveragePrecision = AveragePrecision |
Oops, something went wrong.