forked from jeya-maria-jose/KiU-Net-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
96 lines (69 loc) · 3.03 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _WeightedLoss
EPSILON = 1e-32
class LogNLLLoss(_WeightedLoss):
__constants__ = ['weight', 'reduction', 'ignore_index']
def __init__(self, weight=None, size_average=None, reduce=None, reduction=None,
ignore_index=-100):
super(LogNLLLoss, self).__init__(weight, size_average, reduce, reduction)
self.ignore_index = ignore_index
def forward(self, y_input, y_target):
# y_input = torch.log(y_input + EPSILON)
return cross_entropy(y_input, y_target, weight=self.weight,
ignore_index=self.ignore_index)
def classwise_iou(output, gt):
"""
Args:
output: torch.Tensor of shape (n_batch, n_classes, image.shape)
gt: torch.LongTensor of shape (n_batch, image.shape)
"""
dims = (0, *range(2, len(output.shape)))
gt = torch.zeros_like(output).scatter_(1, gt[:, None, :], 1)
intersection = output*gt
union = output + gt - intersection
classwise_iou = (intersection.sum(dim=dims).float() + EPSILON) / (union.sum(dim=dims) + EPSILON)
return classwise_iou
def classwise_f1(output, gt):
"""
Args:
output: torch.Tensor of shape (n_batch, n_classes, image.shape)
gt: torch.LongTensor of shape (n_batch, image.shape)
"""
epsilon = 1e-20
n_classes = output.shape[1]
output = torch.argmax(output, dim=1)
true_positives = torch.tensor([((output == i) * (gt == i)).sum() for i in range(n_classes)]).float()
selected = torch.tensor([(output == i).sum() for i in range(n_classes)]).float()
relevant = torch.tensor([(gt == i).sum() for i in range(n_classes)]).float()
precision = (true_positives + epsilon) / (selected + epsilon)
recall = (true_positives + epsilon) / (relevant + epsilon)
classwise_f1 = 2 * (precision * recall) / (precision + recall)
return classwise_f1
def make_weighted_metric(classwise_metric):
"""
Args:
classwise_metric: classwise metric like classwise_IOU or classwise_F1
"""
def weighted_metric(output, gt, weights=None):
# dimensions to sum over
dims = (0, *range(2, len(output.shape)))
# default weights
if weights == None:
weights = torch.ones(output.shape[1]) / output.shape[1]
else:
# creating tensor if needed
if len(weights) != output.shape[1]:
raise ValueError("The number of weights must match with the number of classes")
if not isinstance(weights, torch.Tensor):
weights = torch.tensor(weights)
# normalizing weights
weights /= torch.sum(weights)
classwise_scores = classwise_metric(output, gt).cpu()
return classwise_scores
return weighted_metric
jaccard_index = make_weighted_metric(classwise_iou)
f1_score = make_weighted_metric(classwise_f1)
if __name__ == '__main__':
output, gt = torch.zeros(3, 2, 5, 5), torch.zeros(3, 5, 5).long()
print(classwise_iou(output, gt))