Skip to content

Commit

Permalink
Bug Related to Calculation of Binary Metrics (#349)
Browse files Browse the repository at this point in the history
* Updated Precision and Recall to fix binary calculation, ensure that calculation type remains the same during training, calculate binary precision using a threshold function vs categorical.

* Fixed flake8 errors

* Updated handling of threshold_function.

* _classification_support.py -  Refactored checks into ClassificationSupport, precision/recall calculation into PrecisionRecallSupport.
precision.py and recall.py - use PrecisionRecallSupport

* Fixed flake8 errors, passing tests local

* Refactored Precision and Recall functions into _BasePrecisionRecallSupport (child of _BaseClassification). Refactored Accuracy into _BaseClassification.

* Removed .idea folder.

* Renamed precision.py to precision_recall.py which now contains _BasePrecisionRecallSupport, Precision, Recall.

* Added tests similar to Precision/Recall and vfdev5 modified tests for accuracy from #333

* Updated Precision and Recall with separate update functions

* Updated accuracy.py by removing threshold_function and adding categorical mapping.

* Updated precision_recall to remove threshold_function and add binary to categorical mapping.

* removed constructor from Accuracy

* Updated docstring, fixed bug in Accuracy, changed condition for binary to categorical mapping.

* Added warning for num_classes=2

* Added warning for num_classes=2, included tests in Precision/Recall with warning check, type check and binary multiclass case.

* Updated warning to print appropriate class name.

* Corrected warning format.

* Removed copied line from the docs

* Removed self._updated, track changing update type using self._type per comment

* Fixed docstring for Precision.

* Added DepreciationWarning to binary accuracy, categorical accuracy. Updated tests to catch warning, updated accuracy tests to assert _type.

* Fixed correct calculation for accuracy, separated precision_recall.py to precision.py and recall.py

* Added docstring regarding y_pred being probabilities.

* Change warning message 0.1.2 -> 0.2.0, remove tests from binary/categorical accuracy

* Update tests and accuracy code, minor changes on precision

* [WIP] Updated code and tests on precision

* Updated precision, recall and tests.

* Updated test_running_average with Accuracy.

* Added warning exception for sklearn.exceptions.UndefinedMetricWarning.

* Updated docs

* Removed identical method _check_type from _BasePrecisionRecall
  • Loading branch information
anmolsjoshi authored and vfdev-5 committed Dec 13, 2018
1 parent e343f77 commit 8558a8f
Show file tree
Hide file tree
Showing 11 changed files with 985 additions and 569 deletions.
91 changes: 69 additions & 22 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,95 @@
from ignite.exceptions import NotComputableError


class Accuracy(Metric):
"""
Calculates the accuracy.
class _BaseClassification(Metric):

- `update` must receive output of the form `(y_pred, y)`.
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...)
- `y` must be in the following shape (batch_size, ...)
"""
def reset(self):
self._num_correct = 0
self._num_examples = 0
def __init__(self, output_transform=lambda x: x):
self._type = None
super(_BaseClassification, self).__init__(output_transform=output_transform)

def update(self, output):
def _check_shape(self, output):
y_pred, y = output

if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError("y must have shape of (batch_size, ...) and y_pred "
"must have shape of (batch_size, num_classes, ...) or (batch_size, ...).")

if y.ndimension() > 1 and y.shape[1] == 1:
# (N, 1, ...) -> (N, ...)
y = y.squeeze(dim=1)

if y_pred.ndimension() > 1 and y_pred.shape[1] == 1:
# (N, 1, ...) -> (N, ...)
y_pred = y_pred.squeeze(dim=1)

if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError("y must have shape of (batch_size, ...) and y_pred must have "
"shape of (batch_size, num_categories, ...) or (batch_size, ...), "
"but given {} vs {}".format(y.shape, y_pred.shape))

y_shape = y.shape
y_pred_shape = y_pred.shape

if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0], ) + y_pred_shape[2:]
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]

if not (y_shape == y_pred_shape):
raise ValueError("y and y_pred must have compatible shapes.")

if y_pred.ndimension() == y.ndimension():
# Maps Binary Case to Categorical Case with 2 classes
y_pred = y_pred.unsqueeze(dim=1)
y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1)
return y_pred, y

indices = torch.max(y_pred, dim=1)[1]
correct = torch.eq(indices, y).view(-1)
def _check_type(self, output):
y_pred, y = output

if y.ndimension() + 1 == y_pred.ndimension():
update_type = "multiclass"
elif y.ndimension() == y_pred.ndimension():
update_type = "binary"
if not torch.equal(y, y ** 2):
raise ValueError("For binary cases, y must be comprised of 0's and 1's.")
# TODO: Uncomment the following after 0.1.2 release
# if not torch.equal(y_pred, y_pred ** 2):
# raise ValueError("For binary cases, y_pred must be comprised of 0's and 1's.")
else:
raise RuntimeError("Invalid shapes of y (shape={}) and y_pred (shape={}), check documentation"
" for expected shapes of y and y_pred.".format(y.shape, y_pred.shape))
if self._type is None:
self._type = update_type
else:
if self._type != update_type:
raise RuntimeError("update_type has changed from {} to {}.".format(self._type, update_type))


class Accuracy(_BaseClassification):
"""
Calculates the accuracy for binary and multiclass data
- `update` must receive output of the form `(y_pred, y)`.
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...)
- `y` must be in the following shape (batch_size, ...)
In binary case, when `y` has 0 or 1 values, the elements of `y_pred` must be between 0 and 1.
"""
# TODO: Include the following into the docstring after 0.1.2 release
# .. code-block:: python
#
# def thresholded_output_transform(output):
# y_pred, y = output
# y_pred = torch.round(y_pred)
# return y_pred, y
#
# binary_accuracy = Accuracy(thresholded_output_transform)

def reset(self):
self._num_correct = 0
self._num_examples = 0

def update(self, output):

y_pred, y = self._check_shape(output)
self._check_type((y_pred, y))

if self._type == "binary":
indices = torch.round(y_pred).type(y.type())
elif self._type == "multiclass":
indices = torch.max(y_pred, dim=1)[1]

correct = torch.eq(indices, y).view(-1)
self._num_correct += torch.sum(correct).item()
self._num_examples += correct.shape[0]

Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/binary_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ class BinaryAccuracy(Accuracy):
"""
def __init__(self, *args, **kwargs):
warnings.warn("The use of ignite.metrics.BinaryAccuracy is deprecated, it will be "
"removed in 0.1.2. Please use ignite.metrics.Accuracy instead.")
"removed in 0.2.0. Please use ignite.metrics.Accuracy instead.", DeprecationWarning)
super(Accuracy, self).__init__(*args, **kwargs)
2 changes: 1 addition & 1 deletion ignite/metrics/categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ class CategoricalAccuracy(Accuracy):
"""
def __init__(self, *args, **kwargs):
warnings.warn("The use of ignite.metrics.CategoricalAccuracy is deprecated, it will be "
"removed in 0.1.2. Please use ignite.metrics.Accuracy instead.")
"removed in 0.2.0. Please use ignite.metrics.Accuracy instead.", DeprecationWarning)
super(Accuracy, self).__init__(*args, **kwargs)
94 changes: 42 additions & 52 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,83 +2,73 @@

import torch

from ignite.metrics.metric import Metric
from ignite.metrics.accuracy import _BaseClassification
from ignite.exceptions import NotComputableError
from ignite._utils import to_onehot


class Precision(Metric):
"""
Calculates precision.
- `update` must receive output of the form `(y_pred, y)`.
class _BasePrecisionRecall(_BaseClassification):

If `average` is True, returns the unweighted average across all classes.
Otherwise, returns a tensor with the precision for each class.
"""
def __init__(self, average=False, output_transform=lambda x: x):
super(Precision, self).__init__(output_transform)
def __init__(self, output_transform=lambda x: x, average=False):
self._average = average
super(_BasePrecisionRecall, self).__init__(output_transform=output_transform)

def reset(self):
self._all_positives = None
self._true_positives = None
self._true_positives = 0
self._positives = 0

def update(self, output):
y_pred, y = output
dtype = y_pred.type()
def compute(self):
if not isinstance(self._positives, torch.Tensor):
raise NotComputableError("{} must have at least one example before"
" it can be computed".format(self.__class__.__name__))

if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError("y must have shape of (batch_size, ...) and y_pred "
"must have shape of (batch_size, num_classes, ...) or (batch_size, ...).")
result = self._true_positives / self._positives
result[result != result] = 0.0
if self._average:
return result.mean().item()
else:
return result

if y.ndimension() > 1 and y.shape[1] == 1:
y = y.squeeze(dim=1)

if y_pred.ndimension() > 1 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=1)
class Precision(_BasePrecisionRecall):
"""
Calculates precision for binary and multiclass data
- `update` must receive output of the form `(y_pred, y)`.
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...)
- `y` must be in the following shape (batch_size, ...)
y_shape = y.shape
y_pred_shape = y_pred.shape
In binary case, when `y` has 0 or 1 values, the elements of `y_pred` must be between 0 and 1. Precision is
computed over positive class, assumed to be 1.
if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]
Args:
average (bool, optional): if True, precision is computed as the unweighted average (across all classes
in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).
if not (y_shape == y_pred_shape):
raise ValueError("y and y_pred must have compatible shapes.")
"""
def update(self, output):
y_pred, y = self._check_shape(output)
self._check_type((y_pred, y))

if y_pred.ndimension() == y.ndimension():
# Maps Binary Case to Categorical Case with 2 classes
y_pred = y_pred.unsqueeze(dim=1)
y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1)
dtype = y_pred.type()

y = to_onehot(y.view(-1), num_classes=y_pred.size(1))
indices = torch.max(y_pred, dim=1)[1].view(-1)
y_pred = to_onehot(indices, num_classes=y_pred.size(1))
if self._type == "binary":
y_pred = torch.round(y_pred).view(-1)
y = y.view(-1)
elif self._type == "multiclass":
num_classes = y_pred.size(1)
y = to_onehot(y.view(-1), num_classes=num_classes)
indices = torch.max(y_pred, dim=1)[1].view(-1)
y_pred = to_onehot(indices, num_classes=num_classes)

y_pred = y_pred.type(dtype)
y = y.type(dtype)

correct = y * y_pred
all_positives = y_pred.sum(dim=0)

if correct.sum() == 0:
true_positives = torch.zeros_like(all_positives)
else:
true_positives = correct.sum(dim=0)
if self._all_positives is None:
self._all_positives = all_positives
self._true_positives = true_positives
else:
self._all_positives += all_positives
self._true_positives += true_positives

def compute(self):
if self._all_positives is None:
raise NotComputableError('Precision must have at least one example before it can be computed')
result = self._true_positives / self._all_positives
result[result != result] = 0.0
if self._average:
return result.mean().item()
else:
return result
self._true_positives += true_positives
self._positives += all_positives
89 changes: 26 additions & 63 deletions ignite/metrics/recall.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,48 @@
from __future__ import division

import torch

from ignite.metrics.metric import Metric
from ignite.exceptions import NotComputableError
from ignite.metrics.precision import _BasePrecisionRecall
from ignite._utils import to_onehot


class Recall(Metric):
class Recall(_BasePrecisionRecall):
"""
Calculates recall.
Calculates recall for binary and multiclass data
- `update` must receive output of the form `(y_pred, y)`.
- `y_pred` must be in the following shape (batch_size, num_categories, ...) or (batch_size, ...)
- `y` must be in the following shape (batch_size, ...)
If `average` is True, returns the unweighted average across all classes.
Otherwise, returns a tensor with the recall for each class.
"""
In binary case, when `y` has 0 or 1 values, the elements of `y_pred` must be between 0 and 1. Recall is
computed over positive class, assumed to be 1.
def __init__(self, average=False, output_transform=lambda x: x):
super(Recall, self).__init__(output_transform)
self._average = average
Args:
average (bool, optional): if True, recall is computed as the unweighted average (across all classes
in multiclass case), otherwise, returns a tensor with the recall (for each class in multiclass case).
def reset(self):
self._actual = None
self._true_positives = None
"""

def update(self, output):
y_pred, y = output
dtype = y_pred.type()

if not (y.ndimension() == y_pred.ndimension() or y.ndimension() + 1 == y_pred.ndimension()):
raise ValueError("y must have shape of (batch_size, ...) and y_pred "
"must have shape of (batch_size, num_classes, ...) or (batch_size, ...).")

if y.ndimension() > 1 and y.shape[1] == 1:
y = y.squeeze(dim=1)
y_pred, y = self._check_shape(output)
self._check_type((y_pred, y))

if y_pred.ndimension() > 1 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=1)

y_shape = y.shape
y_pred_shape = y_pred.shape

if y.ndimension() + 1 == y_pred.ndimension():
y_pred_shape = (y_pred_shape[0],) + y_pred_shape[2:]

if not (y_shape == y_pred_shape):
raise ValueError("y and y_pred must have compatible shapes.")

if y_pred.ndimension() == y.ndimension():
# Maps Binary Case to Categorical Case with 2 classes
y_pred = y_pred.unsqueeze(dim=1)
y_pred = torch.cat([1.0 - y_pred, y_pred], dim=1)
dtype = y_pred.type()

y = to_onehot(y.view(-1), num_classes=y_pred.size(1))
indices = torch.max(y_pred, dim=1)[1].view(-1)
y_pred = to_onehot(indices, num_classes=y_pred.size(1))
if self._type == "binary":
y_pred = torch.round(y_pred).view(-1)
y = y.view(-1)
elif self._type == "multiclass":
num_classes = y_pred.size(1)
y = to_onehot(y.view(-1), num_classes=num_classes)
indices = torch.max(y_pred, dim=1)[1].view(-1)
y_pred = to_onehot(indices, num_classes=num_classes)

y_pred = y_pred.type(dtype)
y = y.type(dtype)

correct = y * y_pred
actual = y.sum(dim=0)
actual_positives = y.sum(dim=0)

if correct.sum() == 0:
true_positives = torch.zeros_like(actual)
true_positives = torch.zeros_like(actual_positives)
else:
true_positives = correct.sum(dim=0)
if self._actual is None:
self._actual = actual
self._true_positives = true_positives
else:
self._actual += actual
self._true_positives += true_positives

def compute(self):
if self._actual is None:
raise NotComputableError('Recall must have at least one example before it can be computed')
result = self._true_positives / self._actual
result[result != result] = 0.0
if self._average:
return result.mean().item()
else:
return result
self._true_positives += true_positives
self._positives += actual_positives
Loading

0 comments on commit 8558a8f

Please sign in to comment.