-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathyolov5-dhaka_ai-csv-generator.py
131 lines (114 loc) · 4.51 KB
/
yolov5-dhaka_ai-csv-generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import os
import glob
import csv
import time
import cv2
def get_class_names(path):
with open(path) as f:
content = f.readlines()
content = [x.strip() for x in content]
return content
def get_img_size(file):
img = cv2.imread(file)
assert img is not None, "Image cat not be read, path: "+file
height, width, _ = img.shape
return height, width
def check_badbox(file, img_height, img_width, x_min, y_min, x_max, y_max):
flag = False
if x_max > img_width:
print("Badbox (x_max > img_width) found in: "+file)
print("x_max: ", x_max)
print("img_width: ", img_width)
flag = True
if x_max < 0:
print("Badbox (x_max < 0) found in: "+file)
flag = True
if y_max > img_height:
print("Badbox (y_max > img_height) found in: "+file)
flag = True
if y_min < 0:
print("Badbox (y_min < 0) found in: "+file)
flag = True
return flag
def process(classes):
image_id = []
classname = []
score = []
xmin = []
xmax = []
ymax = []
ymin = []
height = []
width = []
for file in glob.glob(os.path.join(opt.label_dir, "*.txt")):
with open(file) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=' ')
for row in csv_reader:
basename = os.path.splitext(os.path.basename(file))[0]
image_id.append(basename + ".jpg")
classname.append(classes[int(row[0])])
x_center = float(row[1])
y_center = float(row[2])
box_width = float(row[3])
box_height = float(row[4])
h, w = get_img_size(os.path.join(
opt.image_dir, basename + ".jpg"))
x_center = x_center * w
y_center = y_center * h
box_width = box_width * w
box_height = box_height * h
x_min = x_center - box_width/2
x_max = x_center + box_width/2
y_min = y_center - box_height/2
y_max = y_center + box_height/2
if x_max > w:
x_max = w
if x_min < 0:
x_min = 0
if y_max > h:
y_max = h
if y_min < 0:
y_min = 0
xmin.append(x_min)
xmax.append(x_max)
ymin.append(y_min)
ymax.append(y_max)
height.append(h)
width.append(w)
check_badbox(basename + ".jpg", h, w,
x_min, y_min, x_max, y_max)
score.append(row[5])
save_path = 'submission_files/arafat_yolo-result_conf-{}_IOUthr-{}_{}_ac-0.0_epc-0.csv'.format(
opt.conf_thres, opt.iou_thres, time.strftime("%Y-%m-%d_%H-%M-%S"))
with open(save_path, mode='w') as result_file:
fieldnames = ['image_id', 'class', 'score', 'xmin',
'ymin', 'xmax', 'ymax', 'width', 'height']
result_file_writer = csv.writer(
result_file, delimiter=',', quotechar='"',
quoting=csv.QUOTE_MINIMAL)
result_file_writer.writerow(fieldnames)
for index in range(len(image_id)):
result_file_writer.writerow([image_id[index], classname[index],
score[index],
xmin[index], ymin[index],
xmax[index], ymax[index],
height[index], width[index]])
print("stored in: ", save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--image-dir', type=str,
help='source directory to read test images')
parser.add_argument('--label-dir', type=str,
help='source directory to read darknet labels')
parser.add_argument('--classname-file', type=str,
help='class name text file path')
parser.add_argument('--conf-thres', type=float,
default=0.25, help='object confidence threshold')
parser.add_argument('--iou-thres', type=float,
default=0.45, help='IOU threshold for NMS')
opt = parser.parse_args()
print(opt)
classes = get_class_names(opt.classname_file)
print(classes)
process(classes)