From 728cd42b79f4cd19e4642e6243c8e33c417da69b Mon Sep 17 00:00:00 2001 From: jon Date: Thu, 31 Aug 2023 23:28:32 +0300 Subject: [PATCH] add none in acc black'd --- elk/metrics/eval.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py index 653beae55..d0e2bf7a5 100644 --- a/elk/metrics/eval.py +++ b/elk/metrics/eval.py @@ -23,6 +23,8 @@ class EvalResult: roc_auc: RocAucResult """Area under the ROC curve. For multi-class classification, each class is treated as a one-vs-rest binary classification problem.""" + cal_thresh: float | None + """The threshold used to compute the calibrated accuracy.""" def to_dict(self, prefix: str = "") -> dict[str, float]: """Convert the result to a dictionary.""" @@ -38,7 +40,13 @@ def to_dict(self, prefix: str = "") -> dict[str, float]: else {} ) auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()} - return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict} + return { + **auroc_dict, + **cal_acc_dict, + **acc_dict, + **cal_dict, + f"{prefix}cal_thresh": self.cal_thresh, + } def evaluate_preds( @@ -64,7 +72,14 @@ def evaluate_preds( else: y_true = repeat(y_true, "n -> n v", v=v) - y_pred = y_logits.argmax(dim=-1) + THRESHOLD = 0.5 + if ensembling == "none": + y_pred = y_logits[..., 1].gt(THRESHOLD).to(torch.int) + else: + y_pred = y_logits.argmax(dim=-1) + + acc = accuracy_ci(y_true, y_pred) + if ensembling == "none": auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1)) elif ensembling in ("partial", "full"): @@ -76,22 +91,27 @@ def evaluate_preds( else: raise ValueError(f"Unknown mode: {ensembling}") - acc = accuracy_ci(y_true, y_pred) cal_acc = None cal_err = None + cal_thresh = None if c == 2: - pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0]) + pooled_logits = ( + y_logits[..., 1] + if ensembling == "none" + else y_logits[..., 1] - y_logits[..., 0] + ) + pos_probs = torch.sigmoid(pooled_logits) # Calibrated accuracy - cal_thresh = pos_probs.float().quantile(y_true.float().mean()) + cal_thresh = pos_probs.float().quantile(y_true.float().mean()).item() cal_preds = pos_probs.gt(cal_thresh).to(torch.int) cal_acc = accuracy_ci(y_true, cal_preds) cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) cal_err = cal.compute() - return EvalResult(acc, cal_acc, cal_err, auroc) + return EvalResult(acc, cal_acc, cal_err, auroc, cal_thresh) def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: