Skip to content

Commit

Permalink
Merge pull request #290 from NotodAI-Research/not-372-none-ensembling…
Browse files Browse the repository at this point in the history
…-for-accuracy

"None" ensembling for classfication accuracy
  • Loading branch information
derpyplops authored Sep 7, 2023
2 parents a470603 + 728cd42 commit 14669b1
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class EvalResult:
roc_auc: RocAucResult
"""Area under the ROC curve. For multi-class classification, each class is treated
as a one-vs-rest binary classification problem."""
cal_thresh: float | None
"""The threshold used to compute the calibrated accuracy."""

def to_dict(self, prefix: str = "") -> dict[str, float]:
"""Convert the result to a dictionary."""
Expand All @@ -38,7 +40,13 @@ def to_dict(self, prefix: str = "") -> dict[str, float]:
else {}
)
auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()}
return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict}
return {
**auroc_dict,
**cal_acc_dict,
**acc_dict,
**cal_dict,
f"{prefix}cal_thresh": self.cal_thresh,
}


def evaluate_preds(
Expand All @@ -64,7 +72,14 @@ def evaluate_preds(
else:
y_true = repeat(y_true, "n -> n v", v=v)

y_pred = y_logits.argmax(dim=-1)
THRESHOLD = 0.5
if ensembling == "none":
y_pred = y_logits[..., 1].gt(THRESHOLD).to(torch.int)
else:
y_pred = y_logits.argmax(dim=-1)

acc = accuracy_ci(y_true, y_pred)

if ensembling == "none":
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1))
elif ensembling in ("partial", "full"):
Expand All @@ -76,22 +91,27 @@ def evaluate_preds(
else:
raise ValueError(f"Unknown mode: {ensembling}")

acc = accuracy_ci(y_true, y_pred)
cal_acc = None
cal_err = None
cal_thresh = None

if c == 2:
pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0])
pooled_logits = (
y_logits[..., 1]
if ensembling == "none"
else y_logits[..., 1] - y_logits[..., 0]
)
pos_probs = torch.sigmoid(pooled_logits)

# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(y_true.float().mean())
cal_thresh = pos_probs.float().quantile(y_true.float().mean()).item()
cal_preds = pos_probs.gt(cal_thresh).to(torch.int)
cal_acc = accuracy_ci(y_true, cal_preds)

cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten())
cal_err = cal.compute()

return EvalResult(acc, cal_acc, cal_err, auroc)
return EvalResult(acc, cal_acc, cal_err, auroc, cal_thresh)


def to_one_hot(labels: Tensor, n_classes: int) -> Tensor:
Expand Down

0 comments on commit 14669b1

Please sign in to comment.