-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlosses.py
100 lines (87 loc) · 2.74 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from torch import Tensor, nn
class BinaryMacroSoftFBetaLoss(nn.Module):
def __init__(self, beta=1, eps=torch.finfo(torch.float32).eps):
super().__init__()
self.beta = beta
self.beta2 = beta**2
self.eps = eps
def forward(self, yhat: Tensor, y: Tensor):
tp = (yhat * y).sum(axis=0)
precision = tp / (yhat.sum(axis=0) + self.eps)
recall = tp / (y.sum(axis=0) + self.eps)
fbeta = (
(1 + self.beta2)
* precision
* recall
/ (self.beta2 * precision + recall + self.eps)
)
return 1 - fbeta.mean()
class BinarySurrogateFBetaLoss(nn.Module):
def __init__(self, beta=1, eps=torch.finfo(torch.float32).eps):
super().__init__()
self.beta = beta
self.beta2 = beta**2
self.clip_log_x = torch.exp(torch.tensor(-100.0))
self.eps = eps
def forward(self, yhat: Tensor, y: Tensor):
p = y.mean(axis=0)
return (
-y * self.log(yhat)
+ (1 - y)
* self.log(self.beta2 * p / (1 - p + self.eps) + yhat)
).mean()
def log(self, x: Tensor):
return torch.log(torch.max(x, self.clip_log_x))
class BinaryExpectedCostLoss(nn.Module):
def __init__(
self,
ctp: float = 0.0,
cfp: float = 1.0,
cfn: float = 50.0,
ctn: float = 0.0,
):
"""
Args:
ctp: Cost of true positive
cfp: Cost of false positive
cfn: Cost of false negative
ctn: Cost of true negative
"""
super().__init__()
self.ctp = ctp
self.cfp = cfp
self.cfn = cfn
self.ctn = ctn
def forward(self, yhat: Tensor, y: Tensor):
tp = (yhat * y).sum(axis=0)
fp = (yhat * (1 - y)).sum(axis=0)
fn = ((1 - yhat) * y).sum(axis=0)
tn = ((1 - yhat) * (1 - y)).sum(axis=0)
n = tp + tn + fp + fn
cost = (
self.ctp * tp + self.cfp * fp + self.cfn * fn + self.ctn * tn
) / n
return cost.mean()
class HybridLoss(nn.Module):
def __init__(
self,
loss_a: nn.Module,
loss_b: nn.Module
):
"""
Args:
loss_a: Loss function to use for the first epoch
loss_b: Loss function to use for the remaining epochs
"""
super().__init__()
self.loss_a = loss_a
self.loss_b = loss_b
self.max_batch_idx = -1
def forward(self, logits: Tensor, y: Tensor, batch_idx: int):
if batch_idx > self.max_batch_idx:
self.max_batch_idx = batch_idx
return self.loss_a(logits, y)
else:
yhat = torch.sigmoid(logits)
return self.loss_b(yhat, y)