-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathall_loss_aug.py
executable file
·145 lines (131 loc) · 8.78 KB
/
all_loss_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import os
import pickle
import random
import sys
class transition_loss_(torch.nn.Module):
def __init__(self):
super(transition_loss_, self).__init__()
self.zero = Variable(torch.zeros(1), requires_grad=False)
self.zero = self.zero.cuda()
def forward(self, log_y_alpha, log_y_beta, log_y_gamma, alpha_index, beta_index, gamma_index, label_weight = None):
if label_weight is None:
label_w = 1
else:
label_w = label_weight.sum() / label_weight[alpha_index * 16 + beta_index * 4 + gamma_index] / 64.0
return torch.max(self.zero, log_y_alpha[:, alpha_index] + log_y_beta[:, beta_index] - log_y_gamma[:, gamma_index]) * label_w
class transition_loss_not_(torch.nn.Module):
def __init__(self):
super(transition_loss_not_, self).__init__()
self.zero = Variable(torch.zeros(1), requires_grad=False)
self.one = Variable(torch.ones(1), requires_grad=False)
self.zero = self.zero.cuda()
self.one = self.one.cuda()
def forward(self, log_y_alpha, log_y_beta, log_y_gamma, alpha_index, beta_index, gamma_index):
very_small = 1e-8
log_not_y_gamma = (self.one - log_y_gamma.exp()).clamp(very_small).log()
return torch.max(self.zero, log_y_alpha[:, alpha_index] + log_y_beta[:, beta_index] - log_not_y_gamma[:, gamma_index])
class transitivity_loss_H_(torch.nn.Module):
def __init__(self):
super(transitivity_loss_H_, self).__init__()
def forward(self, alpha_logits, beta_logits, gamma_logits, label_weight_H = None):
log_y_alpha = nn.LogSoftmax(1)(alpha_logits[:, 0:4])
log_y_beta = nn.LogSoftmax(1)(beta_logits[:, 0:4])
log_y_gamma = nn.LogSoftmax(1)(gamma_logits[:, 0:4])
transition_loss = transition_loss_()
transition_loss_not = transition_loss_not_()
loss = transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 1, 1, label_weight_H)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 2, 2, label_weight_H)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 0, 0, label_weight_H)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 0, 0, 0, label_weight_H)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 0, 2, 0, label_weight_H)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 1, 1, 1, label_weight_H)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 1, 2, 1, label_weight_H)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 3, 3, label_weight_H)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 3, 2, 3, label_weight_H)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 0, 3, 2)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 0, 3, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 1, 3, 2)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 1, 3, 0)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 3, 0, 2)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 3, 0, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 3, 1, 2)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 3, 1, 0)
return loss
class transitivity_loss_T_(torch.nn.Module):
def __init__(self):
super(transitivity_loss_T_, self).__init__()
def forward(self, alpha_logits, beta_logits, gamma_logits, label_weight_T = None):
log_y_alpha = nn.LogSoftmax(1)(alpha_logits[:, 0:4])
log_y_beta = nn.LogSoftmax(1)(beta_logits[:, 0:4])
log_y_gamma = nn.LogSoftmax(1)(gamma_logits[:, 0:4])
transition_loss = transition_loss_()
transition_loss_not = transition_loss_not_()
loss = transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 0, 0, 0, label_weight_T)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 0, 2, 0, label_weight_T)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 1, 1, 1, label_weight_T)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 1, 2, 1, label_weight_T)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 0, 0, label_weight_T)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 1, 1, label_weight_T)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 2, 2, label_weight_T)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 3, 3, label_weight_T)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 3, 2, 3, label_weight_T)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 0, 3, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 0, 3, 2)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 1, 3, 0)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 1, 3, 2)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 3, 0, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 3, 0, 2)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 3, 1, 0)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 3, 1, 2)
return loss
class cross_category_loss_(torch.nn.Module):
def __init__(self):
super(cross_category_loss_, self).__init__()
def forward(self, alpha_logits, beta_logits, gamma_logits):
log_y_alpha = nn.LogSoftmax(1)(alpha_logits)
log_y_beta = nn.LogSoftmax(1)(beta_logits)
log_y_gamma = nn.LogSoftmax(1)(gamma_logits)
transition_loss = transition_loss_()
transition_loss_not = transition_loss_not_()
loss = transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 0, 4, 4)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 0, 4, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 0, 4, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 0, 6, 4)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 0, 6, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 0, 6, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 1, 5, 5)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 1, 5, 0)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 1, 5, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 1, 6, 5)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 1, 6, 0)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 1, 6, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 4, 4)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 2, 4, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 2, 4, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 5, 5)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 2, 5, 0)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 2, 5, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 6, 6)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 2, 7, 7)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 2, 7, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 4, 0, 4)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 4, 0, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 4, 0, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 4, 2, 4)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 4, 2, 1)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 4, 2, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 5, 1, 5)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 5, 1, 0)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 5, 1, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 5, 2, 5)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 5, 2, 0)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 5, 2, 2)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 6, 2, 6)
loss += transition_loss(log_y_alpha, log_y_beta, log_y_gamma, 7, 2, 7)
loss += transition_loss_not(log_y_alpha, log_y_beta, log_y_gamma, 7, 2, 2)
return loss