forked from qqwweee/keras-yolo3
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add training support, close qqwweee#6
- Loading branch information
Showing
3 changed files
with
213 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
""" | ||
Retrain the YOLO model for your own dataset. | ||
""" | ||
import os | ||
|
||
import numpy as np | ||
from PIL import Image | ||
from keras.layers import Input, Lambda | ||
from keras.models import load_model, Model | ||
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping | ||
|
||
from yolo3.model import preprocess_true_boxes, yolo_body, yolo_loss | ||
from yolo3.utils import letterbox_image | ||
|
||
# Default anchor boxes | ||
YOLO_ANCHORS = np.array(((10,13), (16,30), (33,23), (30,61), | ||
(62,45), (59,119), (116,90), (156,198), (373,326))) | ||
|
||
def _main(): | ||
annotation_path = 'train.txt' | ||
data_path = 'train.npz' | ||
output_path = 'model_data/my_yolo.h5' | ||
log_dir = 'logs/000/' | ||
classes_path = 'model_data/voc_classes.txt' | ||
anchors_path = 'model_data/yolo_anchors.txt' | ||
class_names = get_classes(classes_path) | ||
anchors = get_anchors(anchors_path) | ||
|
||
input_shape = (416,416) # multiple of 32 | ||
image_data, box_data = get_training_data(annotation_path, data_path, | ||
input_shape, max_boxes=100, load_previous=True) | ||
y_true = preprocess_true_boxes(box_data, input_shape, anchors, len(class_names)) | ||
|
||
infer_model, model = create_model(input_shape, anchors, len(class_names), | ||
load_pretrained=True, freeze_body=True) | ||
|
||
train(model, image_data/255., y_true, log_dir=log_dir) | ||
|
||
infer_model.save(output_path) | ||
|
||
|
||
|
||
def get_classes(classes_path): | ||
'''loads the classes''' | ||
with open(classes_path) as f: | ||
class_names = f.readlines() | ||
class_names = [c.strip() for c in class_names] | ||
return class_names | ||
|
||
def get_anchors(anchors_path): | ||
'''loads the anchors from a file''' | ||
if os.path.isfile(anchors_path): | ||
with open(anchors_path) as f: | ||
anchors = f.readline() | ||
anchors = [float(x) for x in anchors.split(',')] | ||
return np.array(anchors).reshape(-1, 2) | ||
else: | ||
Warning("Could not open anchors file, using default.") | ||
return YOLO_ANCHORS | ||
|
||
def get_training_data(annotation_path, data_path, input_shape, max_boxes=100, load_previous=True): | ||
'''processes the data into standard shape | ||
annotation row format: image_file_path box1 box2 ... boxN | ||
box format: x_min,y_min,x_max,y_max,class_index (no space) | ||
''' | ||
if load_previous==True and os.path.isfile(data_path): | ||
data = np.load(data_path) | ||
print('Loading training data from ' + data_path) | ||
return data['image_data'], data['box_data'] | ||
image_data = [] | ||
box_data = [] | ||
with open(annotation_path) as f: | ||
for line in f.readlines(): | ||
line = line.split(' ') | ||
filename = line[0] | ||
image = Image.open(filename) | ||
boxed_image = letterbox_image(image, tuple(reversed(input_shape))) | ||
image_data.append(np.array(boxed_image,dtype='uint8')) | ||
|
||
boxes = np.zeros((max_boxes,5), dtype='int32') | ||
for i, box in enumerate(line[1:]): | ||
if i < max_boxes: | ||
boxes[i] = np.array(list(map(int,box.split(',')))) | ||
else: | ||
break | ||
image_size = np.array(image.size) | ||
input_size = np.array(input_shape[::-1]) | ||
new_size = (image_size * np.min(input_size/image_size)).astype('int32') | ||
boxes[:i+1, 0:2] = (boxes[:i+1, 0:2]*new_size/image_size + (input_size-new_size)/2).astype('int32') | ||
boxes[:i+1, 2:4] = (boxes[:i+1, 2:4]*new_size/image_size + (input_size-new_size)/2).astype('int32') | ||
box_data.append(boxes) | ||
image_data = np.array(image_data) | ||
box_data = np.array(box_data) | ||
np.savez(data_path, image_data=image_data, box_data=box_data) | ||
print('Saving training data into ' + data_path) | ||
return image_data, box_data | ||
|
||
|
||
def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze_body=True): | ||
'''create the training model''' | ||
image_input = Input(shape=(None, None, 3)) | ||
h, w = input_shape | ||
num_anchors = len(anchors)//3 | ||
y_true = [Input(shape=(h//32, w//32, num_anchors, num_classes+5)), | ||
Input(shape=(h//16, w//16, num_anchors, num_classes+5)), | ||
Input(shape=(h//8, w//8, num_anchors, num_classes+5))] | ||
|
||
model_body = yolo_body(image_input, num_anchors, num_classes) | ||
|
||
if load_pretrained: | ||
weights_path = os.path.join('model_data', 'yolo_weights.h5') | ||
if not os.path.exists(weights_path): | ||
print("CREATING WEIGHTS FILE" + weights_path) | ||
yolo_path = os.path.join('model_data', 'yolo.h5') | ||
model_body = load_model(yolo_path, compile=False) | ||
model_body.save_weights(weights_path) | ||
model_body.load_weights(weights_path, by_name=True, skip_mismatch=True) | ||
if freeze_body: | ||
# Do not freeze 3 output layers. | ||
for i in range(len(model_body.layers)-3): | ||
model_body.layers[i].trainable = False | ||
|
||
model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss', | ||
arguments={'anchors': anchors, 'num_classes': num_classes})( | ||
[*model_body.output, *y_true]) | ||
model = Model([model_body.input, *y_true], model_loss) | ||
|
||
return model_body, model | ||
|
||
def train(model, image_data, y_true, log_dir='logs/'): | ||
'''retrain/fine-tune the model''' | ||
model.compile(optimizer='adam', loss={ | ||
# use custom yolo_loss Lambda layer. | ||
'yolo_loss': lambda y_true, y_pred: y_pred}) | ||
|
||
logging = TensorBoard(log_dir=log_dir) | ||
checkpoint = ModelCheckpoint(log_dir + "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5", | ||
monitor='val_loss', save_weights_only=True, save_best_only=True) | ||
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1, mode='auto') | ||
|
||
model.fit([image_data, *y_true], | ||
np.zeros(len(image_data)), | ||
validation_split=.1, | ||
batch_size=32, | ||
epochs=30, | ||
callbacks=[logging, checkpoint, early_stopping]) | ||
model.save_weights(log_dir + 'trained_weights.h5') | ||
# Further training. | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import xml.etree.ElementTree as ET | ||
from os import getcwd | ||
|
||
sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')] | ||
|
||
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] | ||
|
||
|
||
def convert_annotation(year, image_id, list_file): | ||
in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id)) | ||
tree=ET.parse(in_file) | ||
root = tree.getroot() | ||
|
||
for obj in root.iter('object'): | ||
difficult = obj.find('difficult').text | ||
cls = obj.find('name').text | ||
if cls not in classes or int(difficult)==1: | ||
continue | ||
cls_id = classes.index(cls) | ||
xmlbox = obj.find('bndbox') | ||
b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text)) | ||
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id)) | ||
|
||
wd = getcwd() | ||
|
||
for year, image_set in sets: | ||
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split() | ||
list_file = open('%s_%s.txt'%(year, image_set), 'w') | ||
for image_id in image_ids: | ||
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg'%(wd, year, image_id)) | ||
convert_annotation(year, image_id, list_file) | ||
list_file.write('\n') | ||
list_file.close() | ||
|