-
-
Notifications
You must be signed in to change notification settings - Fork 623
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bug Related to Calculation of Binary Metrics (#349)
* Updated Precision and Recall to fix binary calculation, ensure that calculation type remains the same during training, calculate binary precision using a threshold function vs categorical. * Fixed flake8 errors * Updated handling of threshold_function. * _classification_support.py - Refactored checks into ClassificationSupport, precision/recall calculation into PrecisionRecallSupport. precision.py and recall.py - use PrecisionRecallSupport * Fixed flake8 errors, passing tests local * Refactored Precision and Recall functions into _BasePrecisionRecallSupport (child of _BaseClassification). Refactored Accuracy into _BaseClassification. * Removed .idea folder. * Renamed precision.py to precision_recall.py which now contains _BasePrecisionRecallSupport, Precision, Recall. * Added tests similar to Precision/Recall and vfdev5 modified tests for accuracy from #333 * Updated Precision and Recall with separate update functions * Updated accuracy.py by removing threshold_function and adding categorical mapping. * Updated precision_recall to remove threshold_function and add binary to categorical mapping. * removed constructor from Accuracy * Updated docstring, fixed bug in Accuracy, changed condition for binary to categorical mapping. * Added warning for num_classes=2 * Added warning for num_classes=2, included tests in Precision/Recall with warning check, type check and binary multiclass case. * Updated warning to print appropriate class name. * Corrected warning format. * Removed copied line from the docs * Removed self._updated, track changing update type using self._type per comment * Fixed docstring for Precision. * Added DepreciationWarning to binary accuracy, categorical accuracy. Updated tests to catch warning, updated accuracy tests to assert _type. * Fixed correct calculation for accuracy, separated precision_recall.py to precision.py and recall.py * Added docstring regarding y_pred being probabilities. * Change warning message 0.1.2 -> 0.2.0, remove tests from binary/categorical accuracy * Update tests and accuracy code, minor changes on precision * [WIP] Updated code and tests on precision * Updated precision, recall and tests. * Updated test_running_average with Accuracy. * Added warning exception for sklearn.exceptions.UndefinedMetricWarning. * Updated docs * Removed identical method _check_type from _BasePrecisionRecall
- Loading branch information
1 parent
e343f77
commit 8558a8f
Showing
11 changed files
with
985 additions
and
569 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1,48 @@ | ||
from __future__ import division | ||
|
||
import torch | ||
|
||
from ignite.metrics.metric import Metric | ||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics.precision import _BasePrecisionRecall | ||
from ignite._utils import to_onehot | ||
|
||
|
||
class Recall(Metric): | ||
class Recall(_BasePrecisionRecall): | ||
""" | ||
Calculates recall. | ||
Calculates recall for binary and multiclass data | ||
- `update` must receive output of the form `(y_pred, y)`. | ||
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...) | ||
- `y` must be in the following shape (batch_size, ...) | ||
If `average` is True, returns the unweighted average across all classes. | ||
Otherwise, returns a tensor with the recall for each class. | ||
""" | ||
In binary case, when `y` has 0 or 1 values, the elements of `y_pred` must be between 0 and 1. Recall is | ||
computed over positive class, assumed to be 1. | ||
def __init__(self, average=False, output_transform=lambda x: x): | ||
super(Recall, self).__init__(output_transform) | ||
self._average = average | ||
Args: | ||
average (bool, optional): if True, recall is computed as the unweighted average (across all classes | ||
in multiclass case), otherwise, returns a tensor with the recall (for each class in multiclass case). | ||
def reset(self): | ||
self._actual = None | ||
self._true_positives = None | ||
""" | ||
|
||
def update(self, output): | ||
y_pred, y = output | ||
dtype = y_pred.type() | ||
|
||
if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()): | ||
raise ValueError("y must have shape of (batch_size, ...) and y_pred " | ||
"must have shape of (batch_size, num_classes, ...) or (batch_size, ...).") | ||
|
||
if y.ndimension() > 1 and y.shape[1] == 1: | ||
y = y.squeeze(dim=1) | ||
y_pred, y = self._check_shape(output) | ||
self._check_type((y_pred, y)) | ||
|
||
if y_pred.ndimension() > 1 and y_pred.shape[1] == 1: | ||
y_pred = y_pred.squeeze(dim=1) | ||
|
||
y_shape = y.shape | ||
y_pred_shape = y_pred.shape | ||
|
||
if y.ndimension() + 1 == y_pred.ndimension(): | ||
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:] | ||
|
||
if not (y_shape == y_pred_shape): | ||
raise ValueError("y and y_pred must have compatible shapes.") | ||
|
||
if y_pred.ndimension() == y.ndimension(): | ||
# Maps Binary Case to Categorical Case with 2 classes | ||
y_pred = y_pred.unsqueeze(dim=1) | ||
y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1) | ||
dtype = y_pred.type() | ||
|
||
y = to_onehot(y.view(-1), num_classes=y_pred.size(1)) | ||
indices = torch.max(y_pred, dim=1)[1].view(-1) | ||
y_pred = to_onehot(indices, num_classes=y_pred.size(1)) | ||
if self._type == "binary": | ||
y_pred = torch.round(y_pred).view(-1) | ||
y = y.view(-1) | ||
elif self._type == "multiclass": | ||
num_classes = y_pred.size(1) | ||
y = to_onehot(y.view(-1), num_classes=num_classes) | ||
indices = torch.max(y_pred, dim=1)[1].view(-1) | ||
y_pred = to_onehot(indices, num_classes=num_classes) | ||
|
||
y_pred = y_pred.type(dtype) | ||
y = y.type(dtype) | ||
|
||
correct = y * y_pred | ||
actual = y.sum(dim=0) | ||
actual_positives = y.sum(dim=0) | ||
|
||
if correct.sum() == 0: | ||
true_positives = torch.zeros_like(actual) | ||
true_positives = torch.zeros_like(actual_positives) | ||
else: | ||
true_positives = correct.sum(dim=0) | ||
if self._actual is None: | ||
self._actual = actual | ||
self._true_positives = true_positives | ||
else: | ||
self._actual += actual | ||
self._true_positives += true_positives | ||
|
||
def compute(self): | ||
if self._actual is None: | ||
raise NotComputableError('Recall must have at least one example before it can be computed') | ||
result = self._true_positives / self._actual | ||
result[result != result] = 0.0 | ||
if self._average: | ||
return result.mean().item() | ||
else: | ||
return result | ||
self._true_positives += true_positives | ||
self._positives += actual_positives |
Oops, something went wrong.