Skip to content

Commit

Permalink
add training support, close qqwweee#6
Browse files Browse the repository at this point in the history
  • Loading branch information
qqwweee committed Apr 29, 2018
1 parent b4895a6 commit fda7b44
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 8 deletions.
34 changes: 26 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,38 @@

A Keras implementation of YOLOv3 (Tensorflow backend) inspired by [allanzelener/YAD2K](https://github.com/allanzelener/YAD2K).

Working On to Train...

---

## Quick Start

- Download YOLOv3 weights from [YOLO website](http://pjreddie.com/darknet/yolo/).
- Convert the Darknet YOLO model to a Keras model.
- Run YOLO detection.
1. Download YOLOv3 weights from [YOLO website](http://pjreddie.com/darknet/yolo/).
2. Convert the Darknet YOLO model to a Keras model.
3. Run YOLO detection.

```
wget https://pjreddie.com/media/files/yolov3.weights
python convert.py yolov3.cfg yolov3.weights model_data/yolo.h5
python yolo.py
or
python yolo_video.py
python yolo.py OR python yolo_video.py
```

---

## Training

1. Generate your own annotation file.

One row for one image;

Row format: image_file_path box1 box2 ... boxN;

Box format: x_min,y_min,x_max,y_max,class_id (no space).

For VOC dataset, try `python voc_annotation.py`

2. Generate your own class names file.

3. Modify train.py and start training.

`python train.py`

You will get the trained model model_data/my_yolo.h5.
153 changes: 153 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Retrain the YOLO model for your own dataset.
"""
import os

import numpy as np
from PIL import Image
from keras.layers import Input, Lambda
from keras.models import load_model, Model
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping

from yolo3.model import preprocess_true_boxes, yolo_body, yolo_loss
from yolo3.utils import letterbox_image

# Default anchor boxes
YOLO_ANCHORS = np.array(((10,13), (16,30), (33,23), (30,61),
(62,45), (59,119), (116,90), (156,198), (373,326)))

def _main():
annotation_path = 'train.txt'
data_path = 'train.npz'
output_path = 'model_data/my_yolo.h5'
log_dir = 'logs/000/'
classes_path = 'model_data/voc_classes.txt'
anchors_path = 'model_data/yolo_anchors.txt'
class_names = get_classes(classes_path)
anchors = get_anchors(anchors_path)

input_shape = (416,416) # multiple of 32
image_data, box_data = get_training_data(annotation_path, data_path,
input_shape, max_boxes=100, load_previous=True)
y_true = preprocess_true_boxes(box_data, input_shape, anchors, len(class_names))

infer_model, model = create_model(input_shape, anchors, len(class_names),
load_pretrained=True, freeze_body=True)

train(model, image_data/255., y_true, log_dir=log_dir)

infer_model.save(output_path)



def get_classes(classes_path):
'''loads the classes'''
with open(classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names

def get_anchors(anchors_path):
'''loads the anchors from a file'''
if os.path.isfile(anchors_path):
with open(anchors_path) as f:
anchors = f.readline()
anchors = [float(x) for x in anchors.split(',')]
return np.array(anchors).reshape(-1, 2)
else:
Warning("Could not open anchors file, using default.")
return YOLO_ANCHORS

def get_training_data(annotation_path, data_path, input_shape, max_boxes=100, load_previous=True):
'''processes the data into standard shape
annotation row format: image_file_path box1 box2 ... boxN
box format: x_min,y_min,x_max,y_max,class_index (no space)
'''
if load_previous==True and os.path.isfile(data_path):
data = np.load(data_path)
print('Loading training data from ' + data_path)
return data['image_data'], data['box_data']
image_data = []
box_data = []
with open(annotation_path) as f:
for line in f.readlines():
line = line.split(' ')
filename = line[0]
image = Image.open(filename)
boxed_image = letterbox_image(image, tuple(reversed(input_shape)))
image_data.append(np.array(boxed_image,dtype='uint8'))

boxes = np.zeros((max_boxes,5), dtype='int32')
for i, box in enumerate(line[1:]):
if i < max_boxes:
boxes[i] = np.array(list(map(int,box.split(','))))
else:
break
image_size = np.array(image.size)
input_size = np.array(input_shape[::-1])
new_size = (image_size * np.min(input_size/image_size)).astype('int32')
boxes[:i+1, 0:2] = (boxes[:i+1, 0:2]*new_size/image_size + (input_size-new_size)/2).astype('int32')
boxes[:i+1, 2:4] = (boxes[:i+1, 2:4]*new_size/image_size + (input_size-new_size)/2).astype('int32')
box_data.append(boxes)
image_data = np.array(image_data)
box_data = np.array(box_data)
np.savez(data_path, image_data=image_data, box_data=box_data)
print('Saving training data into ' + data_path)
return image_data, box_data


def create_model(input_shape, anchors, num_classes, load_pretrained=True, freeze_body=True):
'''create the training model'''
image_input = Input(shape=(None, None, 3))
h, w = input_shape
num_anchors = len(anchors)//3
y_true = [Input(shape=(h//32, w//32, num_anchors, num_classes+5)),
Input(shape=(h//16, w//16, num_anchors, num_classes+5)),
Input(shape=(h//8, w//8, num_anchors, num_classes+5))]

model_body = yolo_body(image_input, num_anchors, num_classes)

if load_pretrained:
weights_path = os.path.join('model_data', 'yolo_weights.h5')
if not os.path.exists(weights_path):
print("CREATING WEIGHTS FILE" + weights_path)
yolo_path = os.path.join('model_data', 'yolo.h5')
model_body = load_model(yolo_path, compile=False)
model_body.save_weights(weights_path)
model_body.load_weights(weights_path, by_name=True, skip_mismatch=True)
if freeze_body:
# Do not freeze 3 output layers.
for i in range(len(model_body.layers)-3):
model_body.layers[i].trainable = False

model_loss = Lambda(yolo_loss, output_shape=(1,), name='yolo_loss',
arguments={'anchors': anchors, 'num_classes': num_classes})(
[*model_body.output, *y_true])
model = Model([model_body.input, *y_true], model_loss)

return model_body, model

def train(model, image_data, y_true, log_dir='logs/'):
'''retrain/fine-tune the model'''
model.compile(optimizer='adam', loss={
# use custom yolo_loss Lambda layer.
'yolo_loss': lambda y_true, y_pred: y_pred})

logging = TensorBoard(log_dir=log_dir)
checkpoint = ModelCheckpoint(log_dir + "ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5",
monitor='val_loss', save_weights_only=True, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1, mode='auto')

model.fit([image_data, *y_true],
np.zeros(len(image_data)),
validation_split=.1,
batch_size=32,
epochs=30,
callbacks=[logging, checkpoint, early_stopping])
model.save_weights(log_dir + 'trained_weights.h5')
# Further training.



if __name__ == '__main__':
_main()
34 changes: 34 additions & 0 deletions voc_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import xml.etree.ElementTree as ET
from os import getcwd

sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]

classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]


def convert_annotation(year, image_id, list_file):
in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
tree=ET.parse(in_file)
root = tree.getroot()

for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text))
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))

wd = getcwd()

for year, image_set in sets:
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w')
for image_id in image_ids:
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg'%(wd, year, image_id))
convert_annotation(year, image_id, list_file)
list_file.write('\n')
list_file.close()

0 comments on commit fda7b44

Please sign in to comment.