diff --git a/.gitignore b/.gitignore index 7bbc71c09..f82bd7f85 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +*.jpg +*.weights +*.h5 + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/convert.py b/convert.py index c987669a2..4d8404f23 100644 --- a/convert.py +++ b/convert.py @@ -29,7 +29,11 @@ '--plot_model', help='Plot generated Keras model and save as image.', action='store_true') - +parser.add_argument( + '-n', + '--not_fixed_input', + help='Set input layer\'s width and height to None.', + action='store_true') def unique_config_sections(config_file): """Convert all config sections to have unique names. @@ -76,9 +80,12 @@ def _main(args): cfg_parser.read_file(unique_config_file) print('Creating Keras model.') - image_height = int(cfg_parser['net_0']['height']) - image_width = int(cfg_parser['net_0']['width']) - input_layer = Input(shape=(image_height, image_width, 3)) + if args.not_fixed_input: + input_layer = Input(shape=(None, None, 3)) + else: + image_height = int(cfg_parser['net_0']['height']) + image_width = int(cfg_parser['net_0']['width']) + input_layer = Input(shape=(image_height, image_width, 3)) prev_layer = input_layer all_layers = [] diff --git a/yolo.py b/yolo.py index f76bb327f..3187e8152 100644 --- a/yolo.py +++ b/yolo.py @@ -7,16 +7,16 @@ import colorsys import os import random -import time -import cv2 +from timeit import time +from timeit import default_timer as timer ### to calculate FPS + import numpy as np from keras import backend as K from keras.models import load_model -from PIL import Image, ImageDraw, ImageFont -from timeit import time -from timeit import default_timer as timer ### to calculate FPS +from PIL import Image, ImageFont, ImageDraw from yolo3.model import yolo_eval +from yolo3.utils import letterbox_image class YOLO(object): def __init__(self): @@ -53,6 +53,7 @@ def generate(self): print('{} model, anchors, and classes loaded.'.format(model_path)) self.model_image_size = self.yolo_model.layers[0].input_shape[1:3] + self.is_fixed_size = self.model_image_size != (None, None) # Generate colors for drawing bounding boxes. hsv_tuples = [(x / len(self.class_names), 1., 1.) @@ -66,15 +67,22 @@ def generate(self): random.seed(None) # Reset seed to default. # Generate output tensor targets for filtered bounding boxes. - # TODO: Wrap these backend operations with Keras layers. self.input_image_shape = K.placeholder(shape=(2, )) - boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors, len(self.class_names), self.input_image_shape, score_threshold=self.score, iou_threshold=self.iou) + boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors, + len(self.class_names), self.input_image_shape, + score_threshold=self.score, iou_threshold=self.iou) return boxes, scores, classes def detect_image(self, image): start = time.time() - resized_image = image.resize(tuple(reversed(self.model_image_size)), Image.BICUBIC) - image_data = np.array(resized_image, dtype='float32') + + if self.is_fixed_size: + boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size))) + else: + new_image_size = (image.width - (image.width % 32), + image.height - (image.height % 32)) + boxed_image = letterbox_image(image, new_image_size) + image_data = np.array(boxed_image, dtype='float32') print(image_data.shape) image_data /= 255. @@ -90,7 +98,8 @@ def detect_image(self, image): print('Found {} boxes for {}'.format(len(out_boxes), 'img')) - font = ImageFont.truetype(font='font/FiraMono-Medium.otf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32')) + font = ImageFont.truetype(font='font/FiraMono-Medium.otf', + size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32')) thickness = (image.size[0] + image.size[1]) // 300 for i, c in reversed(list(enumerate(out_classes))): @@ -133,10 +142,11 @@ def close_session(self): self.sess.close() -def detect_video(yolo,video_path): - vid = cv2.VideoCapture(video_path) ### TODO: will video path other than 0 be used? +def detect_video(yolo, video_path): + import cv2 + vid = cv2.VideoCapture(video_path) if not vid.isOpened(): - raise IOError("Couldn't open webcam") + raise IOError("Couldn't open webcam or video") accum_time = 0 curr_fps = 0 fps = "FPS: ??" @@ -155,10 +165,10 @@ def detect_video(yolo,video_path): accum_time = accum_time - 1 fps = "FPS: " + str(curr_fps) curr_fps = 0 - cv2.putText(result, text=fps, org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.50, - color=(255, 0, 0), thickness=2) + cv2.putText(result, text=fps, org=(3, 15), fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.50, color=(255, 0, 0), thickness=2) cv2.namedWindow("result", cv2.WINDOW_NORMAL) - cv2.imshow("result",result) + cv2.imshow("result", result) if cv2.waitKey(1) & 0xFF == ord('q'): break yolo.close_session() diff --git a/yolo3/model.py b/yolo3/model.py index 8c8770ea4..b3fa6f1d1 100644 --- a/yolo3/model.py +++ b/yolo3/model.py @@ -68,36 +68,35 @@ def yolo_body(inputs, num_anchors, num_classes): """Create YOLO_V3 model CNN body in Keras.""" darknet = Model(inputs, darknet_body(inputs)) x, y1 = make_last_layers(darknet.output, 512, num_anchors*(num_classes+5)) - + x = compose( DarknetConv2D_BN_Leaky(256, (1,1)), UpSampling2D(2))(x) x = Concatenate()([x,darknet.layers[148].output]) x, y2 = make_last_layers(x, 256, num_anchors*(num_classes+5)) - + x = compose( DarknetConv2D_BN_Leaky(128, (1,1)), UpSampling2D(2))(x) x = Concatenate()([x,darknet.layers[89].output]) x, y3 = make_last_layers(x, 128, num_anchors*(num_classes+5)) - + return Model(inputs, [y1,y2,y3]) -def yolo_head(feats, anchors, num_classes, n): +def yolo_head(feats, anchors, num_classes, input_shape): """Convert final layer features to bounding box parameters.""" num_anchors = len(anchors) # Reshape to batch, height, width, num_anchors, box_params. anchors_tensor = K.reshape(K.constant(anchors), [1, 1, 1, num_anchors, 2]) - conv_dims = K.shape(feats)[1:3] # assuming channels last - # In YOLO the height index is the inner most iteration. - conv_height_index = K.arange(0, stop=conv_dims[0]) - conv_width_index = K.arange(0, stop=conv_dims[1]) - conv_height_index = K.tile(conv_height_index, [conv_dims[1]]) + conv_dims = K.shape(feats)[1:3] + conv_height_index = K.arange(0, stop=conv_dims[1]) + conv_width_index = K.arange(0, stop=conv_dims[0]) + conv_height_index = K.tile(conv_height_index, [conv_dims[0]]) conv_width_index = K.tile( - K.expand_dims(conv_width_index, 0), [conv_dims[0], 1]) + K.expand_dims(conv_width_index, 0), [conv_dims[1], 1]) conv_width_index = K.flatten(K.transpose(conv_width_index)) conv_index = K.transpose(K.stack([conv_height_index, conv_width_index])) conv_index = K.reshape(conv_index, [1, conv_dims[0], conv_dims[1], 1, 2]) @@ -105,7 +104,7 @@ def yolo_head(feats, anchors, num_classes, n): feats = K.reshape( feats, [-1, conv_dims[0], conv_dims[1], num_anchors, num_classes + 5]) - conv_dims = K.cast(K.reshape(conv_dims, [1, 1, 1, 1, 2]), K.dtype(feats)) + conv_dims = K.cast(conv_dims[::-1], K.dtype(feats)) box_xy = K.sigmoid(feats[..., :2]) box_wh = K.exp(feats[..., 2:4]) @@ -116,30 +115,42 @@ def yolo_head(feats, anchors, num_classes, n): # Note: YOLO iterates over height index before width index. # TODO: It works with +1, don't know why. box_xy = (box_xy + conv_index + 1) / conv_dims - # TODO: Input layer size - box_wh = box_wh * anchors_tensor / conv_dims / {0:32, 1:16, 2:8}[n] + box_wh = box_wh * anchors_tensor / K.cast(input_shape[::-1], K.dtype(box_wh)) return box_xy, box_wh, box_confidence, box_class_probs -def yolo_boxes_to_corners(box_xy, box_wh): - """Convert YOLO box predictions to bounding box corners.""" - box_mins = box_xy - (box_wh / 2.) - box_maxes = box_xy + (box_wh / 2.) - - return K.concatenate([ - box_mins[..., 1:2], # y_min - box_mins[..., 0:1], # x_min - box_maxes[..., 1:2], # y_max - box_maxes[..., 0:1] # x_max +def yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape): + '''Get corrected boxes''' + box_yx = box_xy[..., ::-1] + box_hw = box_wh[..., ::-1] + input_shape = K.cast(input_shape, K.dtype(box_yx)) + image_shape = K.cast(image_shape, K.dtype(box_yx)) + new_shape = K.round(image_shape * K.min(input_shape/image_shape)) + offset = (input_shape-new_shape)/2./input_shape + scale = input_shape/new_shape + box_yx = (box_yx - offset) * scale + box_hw *= scale + + box_mins = box_yx - (box_hw / 2.) + box_maxes = box_yx + (box_hw / 2.) + boxes = K.concatenate([ + box_mins[..., 0:1], # y_min + box_mins[..., 1:2], # x_min + box_maxes[..., 0:1], # y_max + box_maxes[..., 1:2] # x_max ]) + # Scale boxes back to original image shape. + boxes *= K.concatenate([image_shape, image_shape]) + return boxes -def yolo_boxes_and_scores(feats, anchors, num_classes, n): +def yolo_boxes_and_scores(feats, anchors, num_classes, input_shape, image_shape): '''Process Conv layer output''' - box_xy, box_wh, box_confidence, box_class_probs = yolo_head(feats, anchors, num_classes, n) - boxes = yolo_boxes_to_corners(box_xy, box_wh) + box_xy, box_wh, box_confidence, box_class_probs = yolo_head(feats, + anchors, num_classes, input_shape) + boxes = yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape) boxes = K.reshape(boxes, [-1, 4]) box_scores = box_confidence * box_class_probs box_scores = K.reshape(box_scores, [-1, num_classes]) @@ -150,24 +161,20 @@ def yolo_eval(yolo_outputs, anchors, num_classes, image_shape, - max_boxes=10, + max_boxes=20, score_threshold=.6, iou_threshold=.5): - """Evaluate YOLO model on given input batch and return filtered boxes.""" + """Evaluate YOLO model on given input and return filtered boxes.""" + input_shape = K.shape(yolo_outputs[0])[1:3] * 32 for i in range(0,3): - _boxes, _box_scores = yolo_boxes_and_scores(yolo_outputs[i], anchors[6-3*i:9-3*i], num_classes, i) + _boxes, _box_scores = yolo_boxes_and_scores(yolo_outputs[i], + anchors[6-3*i:9-3*i], num_classes, input_shape, image_shape) if i==0: - boxes, box_scores= _boxes, _box_scores + boxes, box_scores = _boxes, _box_scores else: boxes = K.concatenate([boxes,_boxes], axis=0) box_scores = K.concatenate([box_scores,_box_scores], axis=0) - # Scale boxes back to original image shape. - height = image_shape[0] - width = image_shape[1] - image_dims = K.stack([height, width, height, width]) - image_dims = K.reshape(image_dims, [1, 4]) - boxes = boxes * image_dims mask = box_scores >= score_threshold max_boxes_tensor = K.constant(max_boxes, dtype='int32') @@ -175,13 +182,11 @@ def yolo_eval(yolo_outputs, # TODO: use keras backend instead of tf. class_boxes = tf.boolean_mask(boxes, mask[:, i]) class_box_scores = tf.boolean_mask(box_scores[:, i], mask[:, i]) - # TODO: 13*13 + 26*26 + 52*52 - classes = K.constant(i, shape=(3549,), dtype='int32') nms_index = tf.image.non_max_suppression( class_boxes, class_box_scores, max_boxes_tensor, iou_threshold=iou_threshold) class_boxes = K.gather(class_boxes, nms_index) class_box_scores = K.gather(class_box_scores, nms_index) - classes = K.gather(classes, nms_index) + classes = K.ones_like(class_box_scores, 'int32') * i if i==0: boxes_, scores_, classes_ = class_boxes, class_box_scores, classes else: diff --git a/yolo3/utils.py b/yolo3/utils.py index 194f3ac89..7d4138aab 100644 --- a/yolo3/utils.py +++ b/yolo3/utils.py @@ -2,6 +2,7 @@ from functools import reduce +from PIL import Image def compose(*funcs): """Compose arbitrarily many functions, evaluated left to right. @@ -13,3 +14,15 @@ def compose(*funcs): return reduce(lambda f, g: lambda *a, **kw: g(f(*a, **kw)), funcs) else: raise ValueError('Composition of empty sequence not supported.') + +def letterbox_image(image, size): + '''resize image with unchanged aspect ratio using padding''' + image_w, image_h = image.size + w, h = size + new_w = int(image_w * min(w/image_w, h/image_h)) + new_h = int(image_h * min(w/image_w, h/image_h)) + resized_image = image.resize((new_w,new_h), Image.BICUBIC) + + boxed_image = Image.new('RGB', size, (128,128,128)) + boxed_image.paste(resized_image, ((w-new_w)//2,(h-new_h)//2)) + return boxed_image diff --git a/yolo_video b/yolo_video.py similarity index 61% rename from yolo_video rename to yolo_video.py index 1cfc3c664..b07123ea6 100644 --- a/yolo_video +++ b/yolo_video.py @@ -6,6 +6,4 @@ if __name__ == '__main__': video_path='path2your-video' - yolo = YOLO() - #detect_img(yolo) - detect_video(yolo,video_path) + detect_video(YOLO(), video_path)