This repository has been archived by the owner on Dec 5, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_class_weights.py
101 lines (79 loc) · 4.04 KB
/
get_class_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import os
from scipy.misc import imread
image_dir = "./dataset/trainannot"
image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith('.png')]
def ENet_weighting(image_files=image_files, num_classes=12):
'''
The custom class weighing function as seen in the ENet paper.
"We have used a custom class weighing scheme defined as w_class = 1 / ln(c + p_class ).
In contrast to the inverse class probability weighing, the weights are bounded
as the probability approaches 0. c is an additional hyper-parameter,
which we set to 1.02 (i.e. we restrict the class weights to be in the interval of [1, 50])."
INPUTS:
- image_files(list): a list of image_filenames which element can be read immediately
OUTPUTS:
- class_weights(list): a list of class weights where each index represents each class label and the element is the class weight for that label.
'''
label_to_frequency = {}
for i in range(num_classes):
label_to_frequency[i] = 0
for n,_ in enumerate(image_files):
image = imread(image_files[n])
# For each label in each image, sum up the frequency of the label and add it to label_to_frequency dict
for i in range(num_classes):
class_mask = np.equal(image, i)
class_mask = class_mask.astype(np.float32)
class_frequency = np.sum(class_mask)
label_to_frequency[i] += class_frequency
# Perform the weighing function label-wise and append the label's class weights to class_weights
class_weights = []
total_frequency = sum(label_to_frequency.values())
for label, frequency in label_to_frequency.items():
class_weight = 1 / np.log(1.02 + (frequency / total_frequency))
class_weights.append(class_weight)
# Set the last class_weight to 0.0
class_weights[-1] = 0.0
return class_weights
def median_frequency_balancing(image_files=image_files, num_classes=12):
'''
Perform median frequency balancing on the image files, given by the formula:
f = Median_freq_c / total_freq_c
where median_freq_c is the median frequency of the class for all pixels of C that appeared in images
and total_freq_c is the total number of pixels of c in the total pixels of the images where c appeared.
INPUTS:
- image_files(list): a list of image_filenames which element can be read immediately
- num_classes(int): the number of classes of pixels in all images
OUTPUTS:
- class_weights(list): a list of class weights where each index represents each class label and the element is the class weight for that label.
'''
# Initialize all the labels key with a list value
label_to_frequency_dict = {}
for i in range(num_classes):
label_to_frequency_dict[i] = []
for n,_ in enumerate(image_files):
image = imread(image_files[n])
# For each label in each image, sum up the frequency of the label and add it to label_to_frequency dict
for i in range(num_classes):
class_mask = np.equal(image, i)
class_mask = class_mask.astype(np.float32)
class_frequency = np.sum(class_mask)
if class_frequency != 0.0:
label_to_frequency_dict[i].append(class_frequency)
class_weights = []
# Get the total pixels to calculate total_frequency later
total_pixels = 0
for frequencies in label_to_frequency_dict.values():
total_pixels += sum(frequencies)
for i, j in label_to_frequency_dict.items():
j = sorted(j) # To obtain the median, we got to sort the frequencies
median_frequency = np.median(j) / sum(j)
total_frequency = sum(j) / total_pixels
median_frequency_balanced = median_frequency / total_frequency
class_weights.append(median_frequency_balanced)
# Set the last class_weight to 0.0 as it's the background class
class_weights[-1] = 0.0
return class_weights
if __name__ == '__main__':
median_frequency_balancing(image_files, num_classes=12)
ENet_weighting(image_files, num_classes=12)