-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathretrainUnariesNet.py
160 lines (130 loc) · 5.46 KB
/
retrainUnariesNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import pickle
import numpy as np
import os
os.environ["THEANO_FLAGS"] = "mode=FAST_RUN,device=gpu2,floatX=float32"
from PIL import Image
import json
from scipy import ndimage
import math
import random
from UnariesNet import unariesNet
import MyConfig
class unaryNet2(unariesNet):
def __init__(self):
unariesNet.__init__(self)
self.trainImgsPath = MyConfig.trainImgPath
self.trainLabelsPath = MyConfig.trainLabelPath
#Path save params
self.path_save_params = MyConfig.unaries_params_path
self.train_logs_path = MyConfig.unaries_train_log
self.jsonFile = MyConfig.jsonFile
def loadRGBimg(self, dataPath, imgName):
rgb = np.asarray(Image.open(dataPath + imgName))[:, :, 0:3]
H, W = np.shape(rgb)[0:2]
self.imgH = H
self.imgW = W
rgb_theano = rgb.transpose((2, 0, 1))
rgb_theano = rgb_theano.reshape((1, 3, H, W))
return rgb_theano
def getIoU(self, boxA, boxB):
#find intersection box
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
# compute the area of intersection rectangle
interArea = max(0,xB - xA + 1) * max(0,yB - yA + 1)
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
iou = interArea / float(boxAArea + boxBArea - interArea)
return iou
def getBoxes(self, boxList):
boxs = np.array(boxList)
# get positive boxes
ROI_t_num = boxs.shape[0]
label_t = [True] * ROI_t_num
#generate a mask for negative boxs
gt_mask = np.zeros((self.imgH, self.imgW))
for box in boxList:
bbox = np.ones(( box[3], box[2] ))
gt_mask[ box[1]:box[1]+box[3], box[0]: box[0]+box[2] ] = bbox
#get negative boxs
fboxs = []
minW = 8
minH = 8
while( len(fboxs)<ROI_t_num ):
x0 = random.randint(0, self.imgW - minW )
y0 = random.randint(0, self.imgH - minH)
x1 = random.randint(x0+minW, self.imgW )
y1 = random.randint(y0+minH, self.imgH)
#check whether the randomly generated box can be a false sample
falseBox = True
for box in boxList:
Iou = self.getIoU([box[0],box[1],box[0]+box[2],box[1]+box[3]],[x0,y0,x1,y1])
if Iou > 0.6:
falseBox = False
break
if falseBox:
fboxs.append([x0,y0,x1,y1])
ROI_f_num = len(fboxs)
label_f = [False] * ROI_f_num
#gather all boxes and labels
rois_np = np.zeros((ROI_t_num+ROI_f_num, 5)).astype(np.single)
if ROI_t_num > 0:
rois_np[:ROI_t_num, 1:3] = boxs[:, 0:2]
rois_np[:ROI_t_num, 3] = boxs[:, 0] + boxs[:, 2]
rois_np[:ROI_t_num, 4] = boxs[:, 1] + boxs[:, 3]
if ROI_f_num > 0:
rois_np[ROI_t_num:, 1:5] = fboxs
bbox_label = label_t + label_f
return rois_np/4, bbox_label
def load_batch_train(self, img, boxes):
x = self.loadRGBimg(self.trainImgsPath, img)
roi, label = self.getBoxes(boxes)
#self.visualize_batch(x,roi,label)
return x, roi, label
def train(self, resume_epoch=0, fine_tune=True):
if resume_epoch == 0:
f_logs = open(self.train_logs_path, 'w')
f_logs.close()
else:
prev_epoch = resume_epoch-1
params_to_load = pickle.load(open(self.path_save_params + 'params_Unaries_%d.pickle' % prev_epoch))
self.setParams(params_to_load)
if fine_tune:
params_VGG = pickle.load(open(self.path_save_params + 'params_VGG_%d.pickle' % prev_epoch))
self.mNet.setParams(params_VGG)
with open(MyConfig.unaries_boxlist) as read_file:
u_boxList = json.load(read_file)
with open(MyConfig.unaries_imgList) as read_file:
u_imgList = json.load(read_file)
for epoch in range(resume_epoch, MyConfig.u_epochs):
costs = []
for img, boxes in zip(u_imgList, u_boxList):
print 'Epoch', epoch, ' img=', img, ', num of bbox = ', len(boxes)
x, rois_np, labels = self.load_batch_train(img, boxes)
# self.visualize_batch(x,rois_np,labels)
p_out_train, loss = self.train_func(x, rois_np, labels)
print 'Loss Unaries', loss
costs.append(loss)
# x_out_test = test_func(rgb_theano,rois_np)
# Save params
if epoch % 2 == 0:
params_to_save = self.getParams()
pickle.dump(params_to_save, open(self.path_save_params + 'params_Unaries_%d.pickle' % epoch, 'wb'))
if fine_tune:
params_VGG = self.mNet.getParams()
pickle.dump(params_VGG, open(self.path_save_params + "params_VGG_%d.pickle" % epoch, 'wb'))
av_cost = np.mean(costs)
f_logs = open(self.train_logs_path, 'a')
f_logs.write('%f' % (av_cost) + '\n')
f_logs.close()
def checkPath(path):
if not os.path.exists(path):
os.makedirs(path)
def main():
unaryNet = unaryNet2()
checkPath(unaryNet.path_save_params)
unaryNet.train()
if __name__ =="__main__":
main()