Skip to content

Commit

Permalink
Add average ROC function + test (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rubenkl authored Sep 16, 2020
1 parent 9d5c946 commit ddf63b5
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 0 deletions.
60 changes: 60 additions & 0 deletions evalutils/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,66 @@ def get_bootstrapped_roc_ci_curves(
)


def average_roc_curves(
roc_curves: List[BootstrappedROCCICurves], bins: int = 200
) -> BootstrappedROCCICurves:
"""
Averages ROC curves using vertical averaging (fixed FP rates),
which gives a 1D measure of variability.
Parameters
----------
curves
List of BootstrappedROCCICurves to be averaged
bins (optional)
Number of false-positives to iterate over. (Default: 200)
Returns
-------
BootstrappedROCCICurves
ROC class containing the average over all ROCs.
"""
tprs = []
low_tprs = []
high_tprs = []
low_azs = []
high_azs = []

mean_fpr = np.linspace(0, 1, bins)

for roc in roc_curves:
# get values at fixed fpr locations
interp_tpr = np.interp(mean_fpr, roc.fpr_vals, roc.mean_tpr_vals)
interp_tpr[0] = 0.0

interp_low_tpr = np.interp(mean_fpr, roc.fpr_vals, roc.low_tpr_vals)
interp_high_tpr = np.interp(mean_fpr, roc.fpr_vals, roc.high_tpr_vals)

tprs.append(interp_tpr)
low_tprs.append(interp_low_tpr)
high_tprs.append(interp_high_tpr)
low_azs.append(roc.low_az_val)
high_azs.append(roc.high_az_val)

# get the mean tpr of all ROCs
mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0

mean_low_tpr = np.mean(low_tprs, axis=0)
mean_high_tpr = np.mean(high_tprs, axis=0)
mean_low_az = np.mean(low_azs, axis=0)
mean_high_az = np.mean(high_azs, axis=0)

return BootstrappedROCCICurves(
fpr_vals=mean_fpr,
mean_tpr_vals=mean_tpr,
low_tpr_vals=mean_low_tpr,
high_tpr_vals=mean_high_tpr,
low_az_val=mean_low_az,
high_az_val=mean_high_az,
)


class BootstrappedCIPointError(NamedTuple):
mean_fprs: ndarray
mean_tprs: ndarray
Expand Down
53 changes: 53 additions & 0 deletions tests/test_roc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import numpy as np
import pytest

from evalutils import roc


@pytest.fixture(autouse=True)
def reset_seeds():
np.random.seed(42)
yield


def test_get_bootstrapped_roc_ci_curves():
y_true = np.random.randint(0, 2, 500).astype(np.int)
y_pred = np.random.random_sample(500)
Expand Down Expand Up @@ -49,6 +56,52 @@ def test_get_bootstrapped_roc_ci_curves():
assert roc_95.high_az_val >= roc_65.high_az_val


def test_average_roc_curves():
y_true = np.random.randint(0, 2, 500).astype(np.int)
y_pred = np.random.random_sample(500)
roc_95 = roc.get_bootstrapped_roc_ci_curves(
y_pred, y_true, num_bootstraps=3, ci_to_use=0.95
)

# average of 3 identical ROCs should be close to the
# individual ROC.
assert np.isclose(
roc.average_roc_curves([roc_95, roc_95, roc_95], bins=101).fpr_vals,
roc_95.fpr_vals,
).all()

assert np.isclose(
roc.average_roc_curves(
[roc_95, roc_95, roc_95], bins=101
).mean_tpr_vals,
roc_95.mean_tpr_vals,
).all()

assert np.isclose(
roc.average_roc_curves(
[roc_95, roc_95, roc_95], bins=101
).low_tpr_vals,
roc_95.low_tpr_vals,
).all()

assert np.isclose(
roc.average_roc_curves(
[roc_95, roc_95, roc_95], bins=101
).high_tpr_vals,
roc_95.high_tpr_vals,
).all()

assert np.isclose(
roc.average_roc_curves([roc_95, roc_95, roc_95], bins=101).low_az_val,
roc_95.low_az_val,
)

assert np.isclose(
roc.average_roc_curves([roc_95, roc_95, roc_95], bins=101).high_az_val,
roc_95.high_az_val,
)


def test_get_bootstrapped_ci_point_error():
y_true = np.random.randint(0, 2, 500).astype(np.int)
y_pred = np.random.randint(1, 10, 500).astype(np.int)
Expand Down

0 comments on commit ddf63b5

Please sign in to comment.