diff --git a/README.md b/README.md index d8ace78..d1a989a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ This repository contains codes of the reimplementation of [SSD: Single Shot Mult There are already some TensorFlow based SSD reimplementation codes on GitHub, the main special features of this repo inlcude: -- state of the art performance(77.8%mAP) when training from VGG-16 pre-trained model (SSD300-VGG16). +- state of the art performance(77.4%mAP) when training from VGG-16 pre-trained model (SSD300-VGG16). - the model is trained using TensorFlow high level API [tf.estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator). Although TensorFlow provides many APIs, the Estimator API is highly recommended to yield scalable, high-performance models. - all codes were writen by pure TensorFlow ops (no numpy operation) to ensure the performance and portability. - using ssd augmentation pipeline discribed in the original paper. @@ -65,15 +65,15 @@ All the codes was tested under TensorFlow 1.6, Python 3.5, Ubuntu 16.04 with CUD ## Results (VOC07 Metric) -This implementation(SSD300-VGG16) yield **mAP 77.8%** on PASCAL VOC 2007 test dataset(the original performance described in the paper is 77.2%mAP), the details are as follows: +This implementation(SSD300-VGG16) yield **mAP 77.4%** on PASCAL VOC 2007 test dataset(the original performance described in the paper is 77.2%mAP), the details are as follows: | sofa | bird | pottedplant | bus | diningtable | cow | bottle | horse | aeroplane | motorbike |:-------|:-----:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:|:-------:| -| 79.6 | 76.0 | 52.8 | 85.9 | 76.9 | 83.5 | 49.9 | 86.0 | 82.9 | 81.0 | +| 78.8 | 76.3 | 53.3 | 86.2 | 77.7 | 83.0 | 52.7 | 85.5 | 82.3 | 82.2 | | **sheep** | **train** | **boat** | **bicycle** | **chair** | **cat** | **tvmonitor** | **person** | **car** | **dog** | -| 81.6 | 86.2 | 71.8 | 84.2 | 60.2 | 87.8 | 76.7 | 80.5 | 85.5 | 86.2 | +| 77.2 | 87.3 | 69.7 | 83.3 | 59.0 | 88.2 | 74.6 | 79.6 | 84.8 | 85.1 | -You can download the trained model(VOC07+12 Train) from [GoogleDrive](https://drive.google.com/open?id=1yeYcfcOURcZ4DaElEn9C2xY1NymGzG5W) for further research. +You can download the trained model(VOC07+12 Train) from [GoogleDrive](https://drive.google.com/open?id=1sr3khWzrXZtcS5mmkQDL00y07Rj7erW5) for further research. For Chinese friends, you can also download both the trained model and pre-trained vgg16 weights from [BaiduYun Drive](https://pan.baidu.com/s/1kRhZd4p-N46JFpVkMgU3fg), access code: **tg64**. diff --git a/dataset/convert_tfrecords.py b/dataset/convert_tfrecords.py index 4ce3ad3..0cb86f5 100644 --- a/dataset/convert_tfrecords.py +++ b/dataset/convert_tfrecords.py @@ -43,13 +43,13 @@ | |->Annotations/ | |->... ''' -tf.app.flags.DEFINE_string('dataset_directory', '/media/rs/7A0EE8880EE83EAF/Detections/PASCAL/VOC', +tf.app.flags.DEFINE_string('dataset_directory', './dataset/VOC', 'All datas directory') tf.app.flags.DEFINE_string('train_splits', 'VOC2007, VOC2012', 'Comma-separated list of the training data sub-directory') tf.app.flags.DEFINE_string('validation_splits', 'VOC2007TEST', 'Comma-separated list of the validation data sub-directory') -tf.app.flags.DEFINE_string('output_directory', '/media/rs/7A0EE8880EE83EAF/Detections/SSD/dataset/tfrecords', +tf.app.flags.DEFINE_string('output_directory', './dataset/tfrecords', 'Output data directory') tf.app.flags.DEFINE_integer('train_shards', 16, 'Number of shards in training TFRecord files.') @@ -228,7 +228,7 @@ def _find_image_bounding_boxes(directory, cur_record): difficult = [] truncated = [] for obj in root.findall('object'): - label = obj.find('name').text + label = obj.find('name').text.strip() labels.append(int(dataset_common.VOC_LABELS[label][0])) labels_text.append(label.encode('ascii')) @@ -245,10 +245,10 @@ def _find_image_bounding_boxes(directory, cur_record): truncated.append(0) bbox = obj.find('bndbox') - bboxes.append((float(bbox.find('ymin').text) / shape[0], - float(bbox.find('xmin').text) / shape[1], - float(bbox.find('ymax').text) / shape[0], - float(bbox.find('xmax').text) / shape[1] + bboxes.append((float(bbox.find('ymin').text) - 1., + float(bbox.find('xmin').text) - 1., + float(bbox.find('ymax').text) - 1., + float(bbox.find('xmax').text) - 1. )) return bboxes, labels, labels_text, difficult, truncated diff --git a/dataset/dataset_common.py b/dataset/dataset_common.py index 046dcca..64bfd7c 100644 --- a/dataset/dataset_common.py +++ b/dataset/dataset_common.py @@ -221,17 +221,17 @@ def slim_get_batch(num_classes, batch_size, split_name, file_pattern, num_reader gbboxes_raw = tf.boolean_mask(gbboxes_raw, isdifficult_mask) # Pre-processing image, labels and bboxes. - + tensors_to_batch = [] if is_training: image, glabels, gbboxes = image_preprocessing_fn(org_image, glabels_raw, gbboxes_raw) + gt_targets, gt_labels, gt_scores = anchor_encoder(glabels, gbboxes) + tensors_to_batch = [image, filename, shape, gt_targets, gt_labels, gt_scores] else: - image = image_preprocessing_fn(org_image, glabels_raw, gbboxes_raw) - glabels, gbboxes = glabels_raw, gbboxes_raw - - gt_targets, gt_labels, gt_scores = anchor_encoder(glabels, gbboxes) + image, output_shape = image_preprocessing_fn(org_image, glabels_raw, gbboxes_raw) + tensors_to_batch = [image, filename, shape, output_shape] - return tf.train.batch([image, filename, shape, gt_targets, gt_labels, gt_scores], - dynamic_pad=False, + return tf.train.batch(tensors_to_batch, + dynamic_pad=(not is_training), batch_size=batch_size, allow_smaller_final_batch=(not is_training), num_threads=num_preprocessing_threads, diff --git a/dataset/dataset_inspect.py b/dataset/dataset_inspect.py index a94e6a6..b23fa5a 100644 --- a/dataset/dataset_inspect.py +++ b/dataset/dataset_inspect.py @@ -31,5 +31,5 @@ def count_split_examples(split_path, file_prefix='.tfrecord'): return num_samples if __name__ == '__main__': - print('train:', count_split_examples('/media/rs/7A0EE8880EE83EAF/Detections/SSD/dataset/tfrecords', 'train-?????-of-?????')) - print('val:', count_split_examples('/media/rs/7A0EE8880EE83EAF/Detections/SSD/dataset/tfrecords', 'val-?????-of-?????')) + print('train:', count_split_examples('./dataset/tfrecords', 'train-?????-of-?????')) + print('val:', count_split_examples('./dataset/tfrecords', 'val-?????-of-?????')) diff --git a/eval_ssd.py b/eval_ssd.py index 722b21c..912d749 100644 --- a/eval_ssd.py +++ b/eval_ssd.py @@ -22,6 +22,7 @@ import tensorflow as tf import numpy as np +from scipy.misc import imread, imsave, imshow, imresize from net import ssd_net @@ -29,6 +30,8 @@ from preprocessing import ssd_preprocessing from utility import anchor_manipulator from utility import scaffolds +from utility import bbox_util +from utility import draw_toolbox # hardware related configuration tf.app.flags.DEFINE_integer( @@ -54,44 +57,29 @@ tf.app.flags.DEFINE_integer( 'log_every_n_steps', 10, 'The frequency with which logs are printed.') -tf.app.flags.DEFINE_integer( - 'save_summary_steps', 500, - 'The frequency with which summaries are saved, in seconds.') # model related configuration -tf.app.flags.DEFINE_integer( - 'train_image_size', 300, - 'The size of the input image for the model to use.') -tf.app.flags.DEFINE_integer( - 'train_epochs', 1, - 'The number of epochs to use for training.') tf.app.flags.DEFINE_integer( 'batch_size', 1, 'Batch size for training and evaluation.') +tf.app.flags.DEFINE_integer( + 'train_image_size', 300, + 'The size of the input image for the model to use.') tf.app.flags.DEFINE_string( 'data_format', 'channels_last', # 'channels_first' or 'channels_last' 'A flag to override the data format used in the model. channels_first ' 'provides a performance boost on GPU but is not always compatible ' 'with CPU. If left unspecified, the data format will be chosen ' 'automatically based on whether TensorFlow was built for CPU or GPU.') -tf.app.flags.DEFINE_float( - 'negative_ratio', 3., 'Negative ratio in the loss function.') -tf.app.flags.DEFINE_float( - 'match_threshold', 0.5, 'Matching threshold in the loss function.') -tf.app.flags.DEFINE_float( - 'neg_threshold', 0.5, 'Matching threshold for the negtive examples in the loss function.') tf.app.flags.DEFINE_float( 'select_threshold', 0.01, 'Class-specific confidence score threshold for selecting a box.') tf.app.flags.DEFINE_float( - 'min_size', 0.03, 'The min size of bboxes to keep.') + 'min_size', 4., 'The min size of bboxes to keep.') tf.app.flags.DEFINE_float( 'nms_threshold', 0.45, 'Matching threshold in NMS algorithm.') tf.app.flags.DEFINE_integer( 'nms_topk', 200, 'Number of total object to keep after NMS.') tf.app.flags.DEFINE_integer( 'keep_topk', 400, 'Number of total object to keep for each image before nms.') -# optimizer related configuration -tf.app.flags.DEFINE_float( - 'weight_decay', 5e-4, 'The weight decay on the model weights.') # checkpoint related configuration tf.app.flags.DEFINE_string( 'checkpoint_path', './model', @@ -115,166 +103,82 @@ def get_checkpoint(): return checkpoint_path +def save_image_with_bbox(image, labels_, scores_, bboxes_): + if not hasattr(save_image_with_bbox, "counter"): + save_image_with_bbox.counter = 0 # it doesn't exist yet, so initialize it + save_image_with_bbox.counter += 1 + + img_to_draw = np.copy(image).astype(np.uint8) + + img_to_draw = draw_toolbox.bboxes_draw_on_img(img_to_draw, labels_, scores_, bboxes_, thickness=2) + imsave(os.path.join('./debug/{}.jpg').format(save_image_with_bbox.counter), img_to_draw) + return save_image_with_bbox.counter + # couldn't find better way to pass params from input_fn to model_fn # some tensors used by model_fn must be created in input_fn to ensure they are in the same graph # but when we put these tensors to labels's dict, the replicate_model_fn will split them into each GPU # the problem is that they shouldn't be splited global_anchor_info = dict() -def input_pipeline(dataset_pattern='train-*', is_training=True, batch_size=FLAGS.batch_size): +def input_pipeline(dataset_pattern='val-*', is_training=True, batch_size=FLAGS.batch_size): def input_fn(): - out_shape = [FLAGS.train_image_size] * 2 - anchor_creator = anchor_manipulator.AnchorCreator(out_shape, - layers_shapes = [(38, 38), (19, 19), (10, 10), (5, 5), (3, 3), (1, 1)], - anchor_scales = [(0.1,), (0.2,), (0.375,), (0.55,), (0.725,), (0.9,)], - extra_anchor_scales = [(0.1414,), (0.2739,), (0.4541,), (0.6315,), (0.8078,), (0.9836,)], - anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)], - layer_steps = [8, 16, 32, 64, 100, 300]) - all_anchors, all_num_anchors_depth, all_num_anchors_spatial = anchor_creator.get_all_anchors() - - num_anchors_per_layer = [] - for ind in range(len(all_anchors)): - num_anchors_per_layer.append(all_num_anchors_depth[ind] * all_num_anchors_spatial[ind]) - - anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(allowed_borders = [1.0] * 6, - positive_threshold = FLAGS.match_threshold, - ignore_threshold = FLAGS.neg_threshold, - prior_scaling=[0.1, 0.1, 0.2, 0.2]) - - image_preprocessing_fn = lambda image_, labels_, bboxes_ : ssd_preprocessing.preprocess_image(image_, labels_, bboxes_, out_shape, is_training=is_training, data_format=FLAGS.data_format, output_rgb=False) - anchor_encoder_fn = lambda glabels_, gbboxes_: anchor_encoder_decoder.encode_all_anchors(glabels_, gbboxes_, all_anchors, all_num_anchors_depth, all_num_anchors_spatial) - - image, filename, shape, loc_targets, cls_targets, match_scores = dataset_common.slim_get_batch(FLAGS.num_classes, - batch_size, - ('train' if is_training else 'val'), - os.path.join(FLAGS.data_dir, dataset_pattern), - FLAGS.num_readers, - FLAGS.num_preprocessing_threads, - image_preprocessing_fn, - anchor_encoder_fn, - num_epochs=FLAGS.train_epochs, - is_training=is_training) - global global_anchor_info - global_anchor_info = {'decode_fn': lambda pred : anchor_encoder_decoder.decode_all_anchors(pred, num_anchors_per_layer), - 'num_anchors_per_layer': num_anchors_per_layer, - 'all_num_anchors_depth': all_num_anchors_depth } - - return {'image': image, 'filename': filename, 'shape': shape, 'loc_targets': loc_targets, 'cls_targets': cls_targets, 'match_scores': match_scores}, None + assert batch_size==1, 'We only support single batch when evaluation.' + target_shape = [FLAGS.train_image_size] * 2 + image_preprocessing_fn = lambda image_, labels_, bboxes_ : ssd_preprocessing.preprocess_image(image_, labels_, bboxes_, target_shape, is_training=is_training, data_format=FLAGS.data_format, output_rgb=False) + + image, filename, shape, output_shape = dataset_common.slim_get_batch(FLAGS.num_classes, + batch_size, + ('train' if is_training else 'val'), + os.path.join(FLAGS.data_dir, dataset_pattern), + FLAGS.num_readers, + FLAGS.num_preprocessing_threads, + image_preprocessing_fn, + None, + num_epochs=1, + is_training=is_training) + + return {'image': image, 'filename': filename, 'shape': shape, 'output_shape': output_shape}, None return input_fn -def modified_smooth_l1(bbox_pred, bbox_targets, bbox_inside_weights=1., bbox_outside_weights=1., sigma=1.): - """ - ResultLoss = outside_weights * SmoothL1(inside_weights * (bbox_pred - bbox_targets)) - SmoothL1(x) = 0.5 * (sigma * x)^2, if |x| < 1 / sigma^2 - |x| - 0.5 / sigma^2, otherwise - """ - with tf.name_scope('smooth_l1', [bbox_pred, bbox_targets]): - sigma2 = sigma * sigma - - inside_mul = tf.multiply(bbox_inside_weights, tf.subtract(bbox_pred, bbox_targets)) - - smooth_l1_sign = tf.cast(tf.less(tf.abs(inside_mul), 1.0 / sigma2), tf.float32) - smooth_l1_option1 = tf.multiply(tf.multiply(inside_mul, inside_mul), 0.5 * sigma2) - smooth_l1_option2 = tf.subtract(tf.abs(inside_mul), 0.5 / sigma2) - smooth_l1_result = tf.add(tf.multiply(smooth_l1_option1, smooth_l1_sign), - tf.multiply(smooth_l1_option2, tf.abs(tf.subtract(smooth_l1_sign, 1.0)))) - - outside_mul = tf.multiply(bbox_outside_weights, smooth_l1_result) - - return outside_mul - -def select_bboxes(scores_pred, bboxes_pred, num_classes, select_threshold): - selected_bboxes = {} - selected_scores = {} - with tf.name_scope('select_bboxes', [scores_pred, bboxes_pred]): - for class_ind in range(1, num_classes): - class_scores = scores_pred[:, class_ind] - select_mask = class_scores > select_threshold - - select_mask = tf.cast(select_mask, tf.float32) - selected_bboxes[class_ind] = tf.multiply(bboxes_pred, tf.expand_dims(select_mask, axis=-1)) - selected_scores[class_ind] = tf.multiply(class_scores, select_mask) - - return selected_bboxes, selected_scores - -def clip_bboxes(ymin, xmin, ymax, xmax, name): - with tf.name_scope(name, 'clip_bboxes', [ymin, xmin, ymax, xmax]): - ymin = tf.maximum(ymin, 0.) - xmin = tf.maximum(xmin, 0.) - ymax = tf.minimum(ymax, 1.) - xmax = tf.minimum(xmax, 1.) - - ymin = tf.minimum(ymin, ymax) - xmin = tf.minimum(xmin, xmax) - - return ymin, xmin, ymax, xmax - -def filter_bboxes(scores_pred, ymin, xmin, ymax, xmax, min_size, name): - with tf.name_scope(name, 'filter_bboxes', [scores_pred, ymin, xmin, ymax, xmax]): - width = xmax - xmin - height = ymax - ymin - - filter_mask = tf.logical_and(width > min_size, height > min_size) - - filter_mask = tf.cast(filter_mask, tf.float32) - return tf.multiply(ymin, filter_mask), tf.multiply(xmin, filter_mask), \ - tf.multiply(ymax, filter_mask), tf.multiply(xmax, filter_mask), tf.multiply(scores_pred, filter_mask) - -def sort_bboxes(scores_pred, ymin, xmin, ymax, xmax, keep_topk, name): - with tf.name_scope(name, 'sort_bboxes', [scores_pred, ymin, xmin, ymax, xmax]): - cur_bboxes = tf.shape(scores_pred)[0] - scores, idxes = tf.nn.top_k(scores_pred, k=tf.minimum(keep_topk, cur_bboxes), sorted=True) - - ymin, xmin, ymax, xmax = tf.gather(ymin, idxes), tf.gather(xmin, idxes), tf.gather(ymax, idxes), tf.gather(xmax, idxes) - - paddings_scores = tf.expand_dims(tf.stack([0, tf.maximum(keep_topk-cur_bboxes, 0)], axis=0), axis=0) - - return tf.pad(ymin, paddings_scores, "CONSTANT"), tf.pad(xmin, paddings_scores, "CONSTANT"),\ - tf.pad(ymax, paddings_scores, "CONSTANT"), tf.pad(xmax, paddings_scores, "CONSTANT"),\ - tf.pad(scores, paddings_scores, "CONSTANT") - -def nms_bboxes(scores_pred, bboxes_pred, nms_topk, nms_threshold, name): - with tf.name_scope(name, 'nms_bboxes', [scores_pred, bboxes_pred]): - idxes = tf.image.non_max_suppression(bboxes_pred, scores_pred, nms_topk, nms_threshold) - return tf.gather(scores_pred, idxes), tf.gather(bboxes_pred, idxes) - -def parse_by_class(cls_pred, bboxes_pred, num_classes, select_threshold, min_size, keep_topk, nms_topk, nms_threshold): - with tf.name_scope('select_bboxes', [cls_pred, bboxes_pred]): - scores_pred = tf.nn.softmax(cls_pred) - selected_bboxes, selected_scores = select_bboxes(scores_pred, bboxes_pred, num_classes, select_threshold) - for class_ind in range(1, num_classes): - ymin, xmin, ymax, xmax = tf.unstack(selected_bboxes[class_ind], 4, axis=-1) - #ymin, xmin, ymax, xmax = tf.split(selected_bboxes[class_ind], 4, axis=-1) - #ymin, xmin, ymax, xmax = tf.squeeze(ymin), tf.squeeze(xmin), tf.squeeze(ymax), tf.squeeze(xmax) - ymin, xmin, ymax, xmax = clip_bboxes(ymin, xmin, ymax, xmax, 'clip_bboxes_{}'.format(class_ind)) - ymin, xmin, ymax, xmax, selected_scores[class_ind] = filter_bboxes(selected_scores[class_ind], - ymin, xmin, ymax, xmax, min_size, 'filter_bboxes_{}'.format(class_ind)) - ymin, xmin, ymax, xmax, selected_scores[class_ind] = sort_bboxes(selected_scores[class_ind], - ymin, xmin, ymax, xmax, keep_topk, 'sort_bboxes_{}'.format(class_ind)) - selected_bboxes[class_ind] = tf.stack([ymin, xmin, ymax, xmax], axis=-1) - selected_scores[class_ind], selected_bboxes[class_ind] = nms_bboxes(selected_scores[class_ind], selected_bboxes[class_ind], nms_topk, nms_threshold, 'nms_bboxes_{}'.format(class_ind)) - - return selected_bboxes, selected_scores - def ssd_model_fn(features, labels, mode, params): """model_fn for SSD to be used with our Estimator.""" filename = features['filename'] + filename = tf.identity(filename, name='filename') shape = features['shape'] - loc_targets = features['loc_targets'] - cls_targets = features['cls_targets'] - match_scores = features['match_scores'] + output_shape = features['output_shape'] features = features['image'] - global global_anchor_info - decode_fn = global_anchor_info['decode_fn'] - num_anchors_per_layer = global_anchor_info['num_anchors_per_layer'] - all_num_anchors_depth = global_anchor_info['all_num_anchors_depth'] + anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(positive_threshold=None, ignore_threshold=None, prior_scaling=[0.1, 0.1, 0.2, 0.2]) + all_anchor_scales = [(30.,), (60.,), (112.5,), (165.,), (217.5,), (270.,)] + all_extra_scales = [(42.43,), (82.17,), (136.23,), (189.45,), (242.34,), (295.08,)] + all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)] with tf.variable_scope(params['model_scope'], default_name=None, values=[features], reuse=tf.AUTO_REUSE): backbone = ssd_net.VGG16Backbone(params['data_format']) + # forward features feature_layers = backbone.forward(features, training=(mode == tf.estimator.ModeKeys.TRAIN)) - #print(feature_layers) - location_pred, cls_pred = ssd_net.multibox_head(feature_layers, params['num_classes'], all_num_anchors_depth, data_format=params['data_format']) + # generate anchors according to the feature map size + with tf.device('/cpu:0'): + if params['data_format'] == 'channels_first': + all_layer_shapes = [tf.shape(feat)[2:] for feat in feature_layers] + else: + all_layer_shapes = [tf.shape(feat)[1:3] for feat in feature_layers] + all_layer_strides = [8, 16, 32, 64, 100, 300] + total_layers = len(all_layer_shapes) + anchors_height = list() + anchors_width = list() + anchors_depth = list() + for ind in range(total_layers): + _anchors_height, _anchors_width, _anchor_depth = anchor_encoder_decoder.get_anchors_width_height(all_anchor_scales[ind], all_extra_scales[ind], all_anchor_ratios[ind], name='get_anchors_width_height{}'.format(ind)) + anchors_height.append(_anchors_height) + anchors_width.append(_anchors_width) + anchors_depth.append(_anchor_depth) + anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, _ = anchor_encoder_decoder.get_all_anchors(tf.squeeze(output_shape, axis=0), + anchors_height, anchors_width, anchors_depth, + [0.5] * total_layers, all_layer_shapes, all_layer_strides, + [0.] * total_layers, [False] * total_layers) + # generate predictions based on anchors + location_pred, cls_pred = ssd_net.multibox_head(feature_layers, params['num_classes'], anchors_depth, data_format=params['data_format']) if params['data_format'] == 'channels_first': cls_pred = [tf.transpose(pred, [0, 2, 3, 1]) for pred in cls_pred] location_pred = [tf.transpose(pred, [0, 2, 3, 1]) for pred in location_pred] @@ -287,87 +191,40 @@ def ssd_model_fn(features, labels, mode, params): cls_pred = tf.reshape(cls_pred, [-1, params['num_classes']]) location_pred = tf.reshape(location_pred, [-1, 4]) - + # decode predictions with tf.device('/cpu:0'): - bboxes_pred = decode_fn(location_pred) - bboxes_pred = tf.concat(bboxes_pred, axis=0) - selected_bboxes, selected_scores = parse_by_class(cls_pred, bboxes_pred, + bboxes_pred = anchor_encoder_decoder.decode_anchors(location_pred, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax) + selected_bboxes, selected_scores = bbox_util.parse_by_class(tf.squeeze(output_shape, axis=0), cls_pred, bboxes_pred, params['num_classes'], params['select_threshold'], params['min_size'], params['keep_topk'], params['nms_topk'], params['nms_threshold']) - predictions = {'filename': filename, 'shape': shape } + labels_list = [] + scores_list = [] + bboxes_list = [] + for k, v in selected_scores.items(): + labels_list.append(tf.ones_like(v, tf.int32) * k) + scores_list.append(v) + bboxes_list.append(selected_bboxes[k]) + all_labels = tf.concat(labels_list, axis=0) + all_scores = tf.concat(scores_list, axis=0) + all_bboxes = tf.concat(bboxes_list, axis=0) + save_image_op = tf.py_func(save_image_with_bbox, + [ssd_preprocessing.unwhiten_image(tf.squeeze(features, axis=0), output_rgb=False), + all_labels * tf.to_int32(all_scores > 0.3), + all_scores, + all_bboxes], + tf.int64, stateful=True) + tf.identity(save_image_op, name='save_image_op') + predictions = {'filename': filename, 'shape': shape, 'output_shape': output_shape } for class_ind in range(1, params['num_classes']): predictions['scores_{}'.format(class_ind)] = tf.expand_dims(selected_scores[class_ind], axis=0) predictions['bboxes_{}'.format(class_ind)] = tf.expand_dims(selected_bboxes[class_ind], axis=0) - flaten_cls_targets = tf.reshape(cls_targets, [-1]) - flaten_match_scores = tf.reshape(match_scores, [-1]) - flaten_loc_targets = tf.reshape(loc_targets, [-1, 4]) - - # each positive examples has one label - positive_mask = flaten_cls_targets > 0 - n_positives = tf.count_nonzero(positive_mask) - - batch_n_positives = tf.count_nonzero(cls_targets, -1) - - batch_negtive_mask = tf.equal(cls_targets, 0)#tf.logical_and(tf.equal(cls_targets, 0), match_scores > 0.) - batch_n_negtives = tf.count_nonzero(batch_negtive_mask, -1) - - batch_n_neg_select = tf.cast(params['negative_ratio'] * tf.cast(batch_n_positives, tf.float32), tf.int32) - batch_n_neg_select = tf.minimum(batch_n_neg_select, tf.cast(batch_n_negtives, tf.int32)) - - # hard negative mining for classification - predictions_for_bg = tf.nn.softmax(tf.reshape(cls_pred, [tf.shape(features)[0], -1, params['num_classes']]))[:, :, 0] - prob_for_negtives = tf.where(batch_negtive_mask, - 0. - predictions_for_bg, - # ignore all the positives - 0. - tf.ones_like(predictions_for_bg)) - topk_prob_for_bg, _ = tf.nn.top_k(prob_for_negtives, k=tf.shape(prob_for_negtives)[1]) - score_at_k = tf.gather_nd(topk_prob_for_bg, tf.stack([tf.range(tf.shape(features)[0]), batch_n_neg_select - 1], axis=-1)) - - selected_neg_mask = prob_for_negtives >= tf.expand_dims(score_at_k, axis=-1) - - # include both selected negtive and all positive examples - final_mask = tf.stop_gradient(tf.logical_or(tf.reshape(tf.logical_and(batch_negtive_mask, selected_neg_mask), [-1]), positive_mask)) - total_examples = tf.count_nonzero(final_mask) - - cls_pred = tf.boolean_mask(cls_pred, final_mask) - location_pred = tf.boolean_mask(location_pred, tf.stop_gradient(positive_mask)) - flaten_cls_targets = tf.boolean_mask(tf.clip_by_value(flaten_cls_targets, 0, params['num_classes']), final_mask) - flaten_loc_targets = tf.stop_gradient(tf.boolean_mask(flaten_loc_targets, positive_mask)) - - # Calculate loss, which includes softmax cross entropy and L2 regularization. - #cross_entropy = (params['negative_ratio'] + 1.) * tf.cond(n_positives > 0, lambda: tf.losses.sparse_softmax_cross_entropy(labels=glabels, logits=cls_pred), lambda: 0.) - cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=flaten_cls_targets, logits=cls_pred) * (params['negative_ratio'] + 1.) - # Create a tensor named cross_entropy for logging purposes. - tf.identity(cross_entropy, name='cross_entropy_loss') - tf.summary.scalar('cross_entropy_loss', cross_entropy) - - #loc_loss = tf.cond(n_positives > 0, lambda: modified_smooth_l1(location_pred, tf.stop_gradient(flaten_loc_targets), sigma=1.), lambda: tf.zeros_like(location_pred)) - loc_loss = modified_smooth_l1(location_pred, flaten_loc_targets, sigma=1.) - loc_loss = tf.reduce_mean(tf.reduce_sum(loc_loss, axis=-1), name='location_loss') - tf.summary.scalar('location_loss', loc_loss) - tf.losses.add_loss(loc_loss) - - # Add weight decay to the loss. We exclude the batch norm variables because - # doing so leads to a small improvement in accuracy. - total_loss = tf.add(cross_entropy, loc_loss, name='total_loss') - - cls_accuracy = tf.metrics.accuracy(flaten_cls_targets, tf.argmax(cls_pred, axis=-1)) - - # Create a tensor named train_accuracy for logging purposes. - tf.identity(cls_accuracy[1], name='cls_accuracy') - tf.summary.scalar('cls_accuracy', cls_accuracy[1]) - - summary_hook = tf.train.SummarySaverHook(save_steps=params['save_summary_steps'], - output_dir=params['summary_dir'], - summary_op=tf.summary.merge_all()) if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, - prediction_hooks=[summary_hook], - loss=None, train_op=None) + prediction_hooks=None, loss=None, train_op=None) else: raise ValueError('This script only support "PREDICT" mode!') @@ -385,13 +242,13 @@ def main(_): run_config = tf.estimator.RunConfig().replace( save_checkpoints_secs=None).replace( save_checkpoints_steps=None).replace( - save_summary_steps=FLAGS.save_summary_steps).replace( + save_summary_steps=None).replace( keep_checkpoint_max=5).replace( log_step_count_steps=FLAGS.log_every_n_steps).replace( session_config=config) summary_dir = os.path.join(FLAGS.model_dir, 'predict') - + tf.gfile.MakeDirs(summary_dir) ssd_detector = tf.estimator.Estimator( model_fn=ssd_model_fn, model_dir=FLAGS.model_dir, config=run_config, params={ @@ -403,22 +260,13 @@ def main(_): 'data_format': FLAGS.data_format, 'batch_size': FLAGS.batch_size, 'model_scope': FLAGS.model_scope, - 'save_summary_steps': FLAGS.save_summary_steps, - 'summary_dir': summary_dir, 'num_classes': FLAGS.num_classes, - 'negative_ratio': FLAGS.negative_ratio, - 'match_threshold': FLAGS.match_threshold, - 'neg_threshold': FLAGS.neg_threshold, - 'weight_decay': FLAGS.weight_decay, }) tensors_to_log = { - 'ce': 'cross_entropy_loss', - 'loc': 'location_loss', - 'loss': 'total_loss', - 'acc': 'cls_accuracy', + 'cur_image': 'filename', + 'cur_ind': 'save_image_op' } - logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=FLAGS.log_every_n_steps, - formatter=lambda dicts: (', '.join(['%s=%.6f' % (k, v) for k, v in dicts.items()]))) + logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=FLAGS.log_every_n_steps) print('Starting a predict cycle.') pred_results = ssd_detector.predict(input_fn=input_pipeline(dataset_pattern='val-*', is_training=False, batch_size=FLAGS.batch_size), @@ -427,20 +275,21 @@ def main(_): det_results = list(pred_results) #print(list(det_results)) - #[{'bboxes_1': array([[0. , 0. , 0.28459054, 0.5679505 ], [0.3158835 , 0.34792888, 0.7312541 , 1. ]], dtype=float32), 'scores_17': array([0.01333667, 0.01152573], dtype=float32), 'filename': b'000703.jpg', 'shape': array([334, 500, 3])}] + #[{'bboxes_1': array([[0. , 0. , 284.59054, 567.9505 ], [31.58835 , 34.792888, 73.12541 , 100. ]], dtype=float32), 'scores_17': array([0.01333667, 0.01152573], dtype=float32), 'filename': b'000703.jpg', 'shape': array([334, 500, 3])}] for class_ind in range(1, FLAGS.num_classes): with open(os.path.join(summary_dir, 'results_{}.txt'.format(class_ind)), 'wt') as f: for image_ind, pred in enumerate(det_results): filename = pred['filename'] shape = pred['shape'] + output_shape = pred['output_shape'] scores = pred['scores_{}'.format(class_ind)] bboxes = pred['bboxes_{}'.format(class_ind)] - bboxes[:, 0] = (bboxes[:, 0] * shape[0]).astype(np.int32, copy=False) + 1 - bboxes[:, 1] = (bboxes[:, 1] * shape[1]).astype(np.int32, copy=False) + 1 - bboxes[:, 2] = (bboxes[:, 2] * shape[0]).astype(np.int32, copy=False) + 1 - bboxes[:, 3] = (bboxes[:, 3] * shape[1]).astype(np.int32, copy=False) + 1 + bboxes[:, 0] = bboxes[:, 0] * shape[0] / output_shape[0] + bboxes[:, 1] = bboxes[:, 1] * shape[1] / output_shape[1] + bboxes[:, 2] = bboxes[:, 2] * shape[0] / output_shape[0] + bboxes[:, 3] = bboxes[:, 3] * shape[1] / output_shape[1] - valid_mask = np.logical_and((bboxes[:, 2] - bboxes[:, 0] > 0), (bboxes[:, 3] - bboxes[:, 1] > 0)) + valid_mask = np.logical_and((bboxes[:, 2] - bboxes[:, 0] > 1.), (bboxes[:, 3] - bboxes[:, 1] > 1.)) for det_ind in range(valid_mask.shape[0]): if not valid_mask[det_ind]: @@ -453,4 +302,5 @@ def main(_): if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) + tf.gfile.MakeDirs('./debug') tf.app.run() diff --git a/preprocessing/preprocessing_unittest.py b/preprocessing/preprocessing_unittest.py index 92e4167..3cd5184 100644 --- a/preprocessing/preprocessing_unittest.py +++ b/preprocessing/preprocessing_unittest.py @@ -102,7 +102,7 @@ def slim_get_split(file_pattern='{}_????'): return save_image_op if __name__ == '__main__': - save_image_op = slim_get_split('/media/rs/7A0EE8880EE83EAF/Detections/SSD/dataset/tfrecords/*') + save_image_op = slim_get_split('./dataset/tfrecords/*') # Create the graph, etc. init_op = tf.group([tf.local_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()]) diff --git a/preprocessing/ssd_preprocessing.py b/preprocessing/ssd_preprocessing.py index 3ab8dcc..b79f571 100644 --- a/preprocessing/ssd_preprocessing.py +++ b/preprocessing/ssd_preprocessing.py @@ -27,6 +27,21 @@ More information can be obtained from the VGG website: www.robots.ox.ac.uk/~vgg/research/very_deep/ """ +# ============================================================================== +# Copyright 2018 Changan Wang + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= from __future__ import absolute_import from __future__ import division @@ -156,7 +171,7 @@ def condition(index, sampled_width, sampled_height, width, height): def body(index, sampled_width, sampled_height, width, height): sampled_width = tf.random_uniform([1], minval=0.3, maxval=0.999, dtype=tf.float32)[0] * width - sampled_height = tf.random_uniform([1], minval=0.3, maxval=0.999, dtype=tf.float32)[0] *height + sampled_height = tf.random_uniform([1], minval=0.3, maxval=0.999, dtype=tf.float32)[0] * height return index+1, sampled_width, sampled_height, width, height @@ -171,30 +186,31 @@ def jaccard_with_anchors(roi, bboxes): int_xmin = tf.maximum(roi[1], bboxes[:, 1]) int_ymax = tf.minimum(roi[2], bboxes[:, 2]) int_xmax = tf.minimum(roi[3], bboxes[:, 3]) - h = tf.maximum(int_ymax - int_ymin, 0.) - w = tf.maximum(int_xmax - int_xmin, 0.) + h = tf.maximum(int_ymax - int_ymin + 1., 0.) + w = tf.maximum(int_xmax - int_xmin + 1., 0.) inter_vol = h * w - union_vol = (roi[3] - roi[1]) * (roi[2] - roi[0]) + ((bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1]) - inter_vol) + union_vol = (roi[3] - roi[1] + 1.) * (roi[2] - roi[0] + 1.) + ((bboxes[:, 2] - bboxes[:, 0] + 1.) * (bboxes[:, 3] - bboxes[:, 1] + 1.) - inter_vol) jaccard = tf.div(inter_vol, union_vol) return jaccard def areas(bboxes): with tf.name_scope('bboxes_areas'): - vol = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + vol = (bboxes[:, 3] - bboxes[:, 1] + 1.) * (bboxes[:, 2] - bboxes[:, 0] + 1.) return vol def check_roi_center(width, height, labels, bboxes): with tf.name_scope('check_roi_center'): index = 0 max_attempt = 20 - roi = [0., 0., 0., 0.] - float_width = tf.cast(width, tf.float32) - float_height = tf.cast(height, tf.float32) + float_width = tf.to_float(width) + float_height = tf.to_float(height) + roi = [0., 0., float_height - 1., float_width - 1.] + mask = tf.cast(tf.zeros_like(labels, dtype=tf.uint8), tf.bool) center_x, center_y = (bboxes[:, 1] + bboxes[:, 3]) / 2, (bboxes[:, 0] + bboxes[:, 2]) / 2 def condition(index, roi, mask): - return tf.logical_or(tf.logical_and(tf.reduce_sum(tf.cast(mask, tf.int32)) < 1, + return tf.logical_or(tf.logical_and(tf.reduce_sum(tf.to_int32(mask)) < 1, tf.less(index, max_attempt)), tf.less(index, 1)) @@ -204,10 +220,7 @@ def body(index, roi, mask): x = tf.random_uniform([], minval=0, maxval=width - sampled_width, dtype=tf.int32) y = tf.random_uniform([], minval=0, maxval=height - sampled_height, dtype=tf.int32) - roi = [tf.cast(y, tf.float32) / float_height, - tf.cast(x, tf.float32) / float_width, - tf.cast(y + sampled_height, tf.float32) / float_height, - tf.cast(x + sampled_width, tf.float32) / float_width] + roi = [tf.to_float(y), tf.to_float(x), tf.to_float(y + sampled_height), tf.to_float(x + sampled_width)] mask_min = tf.logical_and(tf.greater(center_y, roi[0]), tf.greater(center_x, roi[1])) mask_max = tf.logical_and(tf.less(center_y, roi[2]), tf.less(center_x, roi[3])) @@ -225,12 +238,15 @@ def check_roi_overlap(width, height, labels, bboxes, min_iou): with tf.name_scope('check_roi_overlap'): index = 0 max_attempt = 50 - roi = [0., 0., 1., 1.] + float_width = tf.to_float(width) + float_height = tf.to_float(height) + roi = [0., 0., float_height - 1., float_width - 1.] + mask_labels = labels mask_bboxes = bboxes def condition(index, roi, mask_labels, mask_bboxes): - return tf.logical_or(tf.logical_or(tf.logical_and(tf.reduce_sum(tf.cast(jaccard_with_anchors(roi, mask_bboxes) < min_iou, tf.int32)) > 0, + return tf.logical_or(tf.logical_or(tf.logical_and(tf.reduce_sum(tf.to_int32(jaccard_with_anchors(roi, mask_bboxes) < min_iou)) > 0, tf.less(index, max_attempt)), tf.less(index, 1)), tf.less(tf.shape(mask_labels)[0], 1)) @@ -242,11 +258,8 @@ def body(index, roi, mask_labels, mask_bboxes): [index, roi, mask_labels, mask_bboxes] = tf.while_loop(condition, body, [index, roi, mask_labels, mask_bboxes], parallel_iterations=16, back_prop=False, swap_memory=True) return tf.cond(tf.greater(tf.shape(mask_labels)[0], 0), - lambda : (tf.cast([roi[0] * tf.cast(height, tf.float32), - roi[1] * tf.cast(width, tf.float32), - (roi[2] - roi[0]) * tf.cast(height, tf.float32), - (roi[3] - roi[1]) * tf.cast(width, tf.float32)], tf.int32), mask_labels, mask_bboxes), - lambda : (tf.cast([0, 0, height, width], tf.int32), labels, bboxes)) + lambda : (tf.to_int32([roi[0], roi[1], roi[2] - roi[0] + 1., roi[3] - roi[1] + 1.]), mask_labels, mask_bboxes), + lambda : (tf.to_int32([0., 0., float_height, float_width]), labels, bboxes)) def sample_patch(image, labels, bboxes, min_iou): @@ -255,27 +268,21 @@ def sample_patch(image, labels, bboxes, min_iou): roi_slice_range, mask_labels, mask_bboxes = check_roi_overlap(width, height, labels, bboxes, min_iou) - scale = tf.cast(tf.stack([height, width, height, width]), mask_bboxes.dtype) - mask_bboxes = mask_bboxes * scale - # Add offset. offset = tf.cast(tf.stack([roi_slice_range[0], roi_slice_range[1], roi_slice_range[0], roi_slice_range[1]]), mask_bboxes.dtype) mask_bboxes = mask_bboxes - offset cliped_ymin = tf.maximum(0., mask_bboxes[:, 0]) cliped_xmin = tf.maximum(0., mask_bboxes[:, 1]) - cliped_ymax = tf.minimum(tf.cast(roi_slice_range[2], tf.float32), mask_bboxes[:, 2]) - cliped_xmax = tf.minimum(tf.cast(roi_slice_range[3], tf.float32), mask_bboxes[:, 3]) + cliped_ymax = tf.minimum(tf.to_float(roi_slice_range[2]) - 1., mask_bboxes[:, 2]) + cliped_xmax = tf.minimum(tf.to_float(roi_slice_range[3]) - 1., mask_bboxes[:, 3]) mask_bboxes = tf.stack([cliped_ymin, cliped_xmin, cliped_ymax, cliped_xmax], axis=-1) - # Rescale to target dimension. - scale = tf.cast(tf.stack([roi_slice_range[2], roi_slice_range[3], - roi_slice_range[2], roi_slice_range[3]]), mask_bboxes.dtype) return tf.cond(tf.logical_or(tf.less(roi_slice_range[2], 1), tf.less(roi_slice_range[3], 1)), lambda: (image, labels, bboxes), lambda: (tf.slice(image, [roi_slice_range[0], roi_slice_range[1], 0], [roi_slice_range[2], roi_slice_range[3], -1]), - mask_labels, mask_bboxes / scale)) + mask_labels, mask_bboxes)) with tf.name_scope('ssd_random_sample_patch'): image = tf.convert_to_tensor(image, name='image') @@ -310,53 +317,21 @@ def ssd_random_expand(image, bboxes, ratio=2., name=None): tf.pad(image[:, :, 1], paddings, "CONSTANT", constant_values = mean_color_of_image[1]), tf.pad(image[:, :, 2], paddings, "CONSTANT", constant_values = mean_color_of_image[2])], axis=-1) - scale = tf.cast(tf.stack([height, width, height, width]), bboxes.dtype) - absolute_bboxes = bboxes * scale + tf.cast(tf.stack([y, x, y, x]), bboxes.dtype) - - return big_canvas, absolute_bboxes / tf.cast(tf.stack([canvas_height, canvas_width, canvas_height, canvas_width]), bboxes.dtype) - -# def ssd_random_sample_patch_wrapper(image, labels, bboxes): -# with tf.name_scope('ssd_random_sample_patch_wrapper'): -# orgi_image, orgi_labels, orgi_bboxes = image, labels, bboxes -# def check_bboxes(bboxes): -# areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) -# return tf.logical_and(tf.logical_and(areas < 0.9, areas > 0.001), -# tf.logical_and((bboxes[:, 3] - bboxes[:, 1]) > 0.025, (bboxes[:, 2] - bboxes[:, 0]) > 0.025)) - -# index = 0 -# max_attempt = 3 -# def condition(index, image, labels, bboxes): -# return tf.logical_or(tf.logical_and(tf.reduce_sum(tf.cast(check_bboxes(bboxes), tf.int64)) < 1, tf.less(index, max_attempt)), tf.less(index, 1)) - -# def body(index, image, labels, bboxes): -# image, bboxes = tf.cond(tf.random_uniform([], minval=0., maxval=1., dtype=tf.float32) < 0.5, -# lambda: (image, bboxes), -# lambda: ssd_random_expand(image, bboxes, tf.random_uniform([1], minval=1.1, maxval=4., dtype=tf.float32)[0])) -# # Distort image and bounding boxes. -# random_sample_image, labels, bboxes = ssd_random_sample_patch(image, labels, bboxes, ratio_list=[-0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.]) -# random_sample_image.set_shape([None, None, 3]) -# return index+1, random_sample_image, labels, bboxes - -# [index, image, labels, bboxes] = tf.while_loop(condition, body, [index, orgi_image, orgi_labels, orgi_bboxes], parallel_iterations=4, back_prop=False, swap_memory=True) - -# valid_mask = check_bboxes(bboxes) -# labels, bboxes = tf.boolean_mask(labels, valid_mask), tf.boolean_mask(bboxes, valid_mask) -# return tf.cond(tf.less(index, max_attempt), -# lambda : (image, labels, bboxes), -# lambda : (orgi_image, orgi_labels, orgi_bboxes)) + return big_canvas, bboxes + tf.cast(tf.stack([y, x, y, x]), bboxes.dtype) def ssd_random_sample_patch_wrapper(image, labels, bboxes): with tf.name_scope('ssd_random_sample_patch_wrapper'): orgi_image, orgi_labels, orgi_bboxes = image, labels, bboxes - def check_bboxes(bboxes): - areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) - return tf.logical_and(tf.logical_and(areas < 0.9, areas > 0.001), - tf.logical_and((bboxes[:, 3] - bboxes[:, 1]) > 0.025, (bboxes[:, 2] - bboxes[:, 0]) > 0.025)) + def check_bboxes(image, bboxes): + image_shape = tf.shape(image) + areas = (bboxes[:, 3] - bboxes[:, 1] + 1.) * (bboxes[:, 2] - bboxes[:, 0] + 1.) + return tf.logical_and(tf.logical_and(areas > 64., areas < tf.to_float(image_shape[0] * image_shape[1]) * 0.96), + tf.logical_and((bboxes[:, 3] - bboxes[:, 1] + 1.) > 5., (bboxes[:, 2] - bboxes[:, 0] + 1.) > 5.)) index = 0 max_attempt = 3 def condition(index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes): - return tf.logical_or(tf.logical_and(tf.reduce_sum(tf.cast(check_bboxes(bboxes), tf.int64)) < 1, tf.less(index, max_attempt)), tf.less(index, 1)) + return tf.logical_or(tf.logical_and(tf.reduce_sum(tf.cast(check_bboxes(image, bboxes), tf.int64)) < 1, tf.less(index, max_attempt)), tf.less(index, 1)) def body(index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes): image, bboxes = tf.cond(tf.random_uniform([], minval=0., maxval=1., dtype=tf.float32) < 0.5, @@ -369,7 +344,7 @@ def body(index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes): [index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes] = tf.while_loop(condition, body, [index, image, labels, bboxes, orgi_image, orgi_labels, orgi_bboxes], parallel_iterations=4, back_prop=False, swap_memory=True) - valid_mask = check_bboxes(bboxes) + valid_mask = check_bboxes(image, bboxes) labels, bboxes = tf.boolean_mask(labels, valid_mask), tf.boolean_mask(bboxes, valid_mask) return tf.cond(tf.less(index, max_attempt), lambda : (image, labels, bboxes), @@ -407,8 +382,11 @@ def _mean_image_subtraction(image, means): channels[i] -= means[i] return tf.concat(axis=2, values=channels) -def unwhiten_image(image): +def unwhiten_image(image, output_rgb=True): means=[_R_MEAN, _G_MEAN, _B_MEAN] + if not output_rgb: + image_channels = tf.unstack(image, axis=-1, name='split_bgr') + image = tf.stack([image_channels[2], image_channels[1], image_channels[0]], axis=-1, name='merge_rgb') num_channels = image.get_shape().as_list()[-1] channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image) for i in range(num_channels): @@ -417,13 +395,15 @@ def unwhiten_image(image): def random_flip_left_right(image, bboxes): with tf.name_scope('random_flip_left_right'): + float_width = tf.to_float(_ImageDimensions(image, rank=3)[1]) + uniform_random = tf.random_uniform([], 0, 1.0) mirror_cond = tf.less(uniform_random, .5) # Flip image. result = tf.cond(mirror_cond, lambda: tf.image.flip_left_right(image), lambda: image) # Flip bboxes. - mirror_bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3], - bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1) + mirror_bboxes = tf.stack([bboxes[:, 0], float_width - 1. - bboxes[:, 3], + bboxes[:, 2], float_width - 1. - bboxes[:, 1]], axis=-1) bboxes = tf.cond(mirror_cond, lambda: mirror_bboxes, lambda: bboxes) return result, bboxes @@ -463,6 +443,13 @@ def preprocess_for_train(image, labels, bboxes, out_shape, data_format='channels # Randomly flip the image horizontally. random_sample_flip_image, bboxes = random_flip_left_right(random_sample_image, bboxes) # Rescale to VGG input scale. + height, width, _ = _ImageDimensions(random_sample_flip_image, rank=3) + float_height, float_width = tf.to_float(height), tf.to_float(width) + ymin, xmin, ymax, xmax = tf.unstack(bboxes, 4, axis=-1) + target_height, target_width = tf.to_float(out_shape[0]), tf.to_float(out_shape[1]) + ymin, ymax = ymin * target_height / float_height, ymax * target_height / float_height + xmin, xmax = xmin * target_width / float_width, xmax * target_width / float_width + bboxes = tf.stack([ymin, xmin, ymax, xmax], axis=-1) random_sample_flip_resized_image = tf.image.resize_images(random_sample_flip_image, out_shape, method=tf.image.ResizeMethod.BILINEAR, align_corners=False) random_sample_flip_resized_image.set_shape([None, None, 3]) @@ -489,8 +476,12 @@ def preprocess_for_eval(image, out_shape, data_format='channels_first', scope='s """ with tf.name_scope(scope, 'ssd_preprocessing_eval', [image]): image = tf.to_float(image) - image = tf.image.resize_images(image, out_shape, method=tf.image.ResizeMethod.BILINEAR, align_corners=False) - image.set_shape(out_shape + [3]) + if out_shape is not None: + image = tf.image.resize_images(image, out_shape, method=tf.image.ResizeMethod.BILINEAR, align_corners=False) + image.set_shape(out_shape + [3]) + + height, width, _ = _ImageDimensions(image, rank=3) + output_shape = tf.stack([height, width]) image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) if not output_rgb: @@ -499,7 +490,7 @@ def preprocess_for_eval(image, out_shape, data_format='channels_first', scope='s # Image data format. if data_format == 'channels_first': image = tf.transpose(image, perm=(2, 0, 1)) - return image + return image, output_shape def preprocess_image(image, labels, bboxes, out_shape, is_training=False, data_format='channels_first', output_rgb=True): """Preprocesses the given image. diff --git a/simple_ssd_demo.py b/simple_ssd_demo.py index cb64da6..2e9b102 100644 --- a/simple_ssd_demo.py +++ b/simple_ssd_demo.py @@ -29,6 +29,7 @@ from preprocessing import ssd_preprocessing from utility import anchor_manipulator from utility import draw_toolbox +from utility import bbox_util # scaffold related configuration tf.app.flags.DEFINE_integer( @@ -46,7 +47,7 @@ tf.app.flags.DEFINE_float( 'select_threshold', 0.2, 'Class-specific confidence score threshold for selecting a box.') tf.app.flags.DEFINE_float( - 'min_size', 0.03, 'The min size of bboxes to keep.') + 'min_size', 4., 'The min size of bboxes to keep.') tf.app.flags.DEFINE_float( 'nms_threshold', 0.45, 'Matching threshold in NMS algorithm.') tf.app.flags.DEFINE_integer( @@ -72,78 +73,6 @@ def get_checkpoint(): return checkpoint_path -def select_bboxes(scores_pred, bboxes_pred, num_classes, select_threshold): - selected_bboxes = {} - selected_scores = {} - with tf.name_scope('select_bboxes', [scores_pred, bboxes_pred]): - for class_ind in range(1, num_classes): - class_scores = scores_pred[:, class_ind] - - select_mask = class_scores > select_threshold - select_mask = tf.cast(select_mask, tf.float32) - selected_bboxes[class_ind] = tf.multiply(bboxes_pred, tf.expand_dims(select_mask, axis=-1)) - selected_scores[class_ind] = tf.multiply(class_scores, select_mask) - - return selected_bboxes, selected_scores - -def clip_bboxes(ymin, xmin, ymax, xmax, name): - with tf.name_scope(name, 'clip_bboxes', [ymin, xmin, ymax, xmax]): - ymin = tf.maximum(ymin, 0.) - xmin = tf.maximum(xmin, 0.) - ymax = tf.minimum(ymax, 1.) - xmax = tf.minimum(xmax, 1.) - - ymin = tf.minimum(ymin, ymax) - xmin = tf.minimum(xmin, xmax) - - return ymin, xmin, ymax, xmax - -def filter_bboxes(scores_pred, ymin, xmin, ymax, xmax, min_size, name): - with tf.name_scope(name, 'filter_bboxes', [scores_pred, ymin, xmin, ymax, xmax]): - width = xmax - xmin - height = ymax - ymin - - filter_mask = tf.logical_and(width > min_size, height > min_size) - - filter_mask = tf.cast(filter_mask, tf.float32) - return tf.multiply(ymin, filter_mask), tf.multiply(xmin, filter_mask), \ - tf.multiply(ymax, filter_mask), tf.multiply(xmax, filter_mask), tf.multiply(scores_pred, filter_mask) - -def sort_bboxes(scores_pred, ymin, xmin, ymax, xmax, keep_topk, name): - with tf.name_scope(name, 'sort_bboxes', [scores_pred, ymin, xmin, ymax, xmax]): - cur_bboxes = tf.shape(scores_pred)[0] - scores, idxes = tf.nn.top_k(scores_pred, k=tf.minimum(keep_topk, cur_bboxes), sorted=True) - - ymin, xmin, ymax, xmax = tf.gather(ymin, idxes), tf.gather(xmin, idxes), tf.gather(ymax, idxes), tf.gather(xmax, idxes) - - paddings_scores = tf.expand_dims(tf.stack([0, tf.maximum(keep_topk-cur_bboxes, 0)], axis=0), axis=0) - - return tf.pad(ymin, paddings_scores, "CONSTANT"), tf.pad(xmin, paddings_scores, "CONSTANT"),\ - tf.pad(ymax, paddings_scores, "CONSTANT"), tf.pad(xmax, paddings_scores, "CONSTANT"),\ - tf.pad(scores, paddings_scores, "CONSTANT") - -def nms_bboxes(scores_pred, bboxes_pred, nms_topk, nms_threshold, name): - with tf.name_scope(name, 'nms_bboxes', [scores_pred, bboxes_pred]): - idxes = tf.image.non_max_suppression(bboxes_pred, scores_pred, nms_topk, nms_threshold) - return tf.gather(scores_pred, idxes), tf.gather(bboxes_pred, idxes) - -def parse_by_class(cls_pred, bboxes_pred, num_classes, select_threshold, min_size, keep_topk, nms_topk, nms_threshold): - with tf.name_scope('select_bboxes', [cls_pred, bboxes_pred]): - scores_pred = tf.nn.softmax(cls_pred) - selected_bboxes, selected_scores = select_bboxes(scores_pred, bboxes_pred, num_classes, select_threshold) - for class_ind in range(1, num_classes): - ymin, xmin, ymax, xmax = tf.unstack(selected_bboxes[class_ind], 4, axis=-1) - #ymin, xmin, ymax, xmax = tf.squeeze(ymin), tf.squeeze(xmin), tf.squeeze(ymax), tf.squeeze(xmax) - ymin, xmin, ymax, xmax = clip_bboxes(ymin, xmin, ymax, xmax, 'clip_bboxes_{}'.format(class_ind)) - ymin, xmin, ymax, xmax, selected_scores[class_ind] = filter_bboxes(selected_scores[class_ind], - ymin, xmin, ymax, xmax, min_size, 'filter_bboxes_{}'.format(class_ind)) - ymin, xmin, ymax, xmax, selected_scores[class_ind] = sort_bboxes(selected_scores[class_ind], - ymin, xmin, ymax, xmax, keep_topk, 'sort_bboxes_{}'.format(class_ind)) - selected_bboxes[class_ind] = tf.stack([ymin, xmin, ymax, xmax], axis=-1) - selected_scores[class_ind], selected_bboxes[class_ind] = nms_bboxes(selected_scores[class_ind], selected_bboxes[class_ind], nms_topk, nms_threshold, 'nms_bboxes_{}'.format(class_ind)) - - return selected_bboxes, selected_scores - def main(_): with tf.Graph().as_default(): out_shape = [FLAGS.train_image_size] * 2 @@ -151,28 +80,39 @@ def main(_): image_input = tf.placeholder(tf.uint8, shape=(None, None, 3)) shape_input = tf.placeholder(tf.int32, shape=(2,)) - features = ssd_preprocessing.preprocess_for_eval(image_input, out_shape, data_format=FLAGS.data_format, output_rgb=False) + features, output_shape = ssd_preprocessing.preprocess_for_eval(image_input, out_shape, data_format=FLAGS.data_format, output_rgb=False) features = tf.expand_dims(features, axis=0) + output_shape = tf.expand_dims(output_shape, axis=0) - anchor_creator = anchor_manipulator.AnchorCreator(out_shape, - layers_shapes = [(38, 38), (19, 19), (10, 10), (5, 5), (3, 3), (1, 1)], - anchor_scales = [(0.1,), (0.2,), (0.375,), (0.55,), (0.725,), (0.9,)], - extra_anchor_scales = [(0.1414,), (0.2739,), (0.4541,), (0.6315,), (0.8078,), (0.9836,)], - anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)], - layer_steps = [8, 16, 32, 64, 100, 300]) - all_anchors, all_num_anchors_depth, all_num_anchors_spatial = anchor_creator.get_all_anchors() - - anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(allowed_borders = [1.0] * 6, - positive_threshold = None, - ignore_threshold = None, - prior_scaling=[0.1, 0.1, 0.2, 0.2]) - - decode_fn = lambda pred : anchor_encoder_decoder.ext_decode_all_anchors(pred, all_anchors, all_num_anchors_depth, all_num_anchors_spatial) + all_anchor_scales = [(30.,), (60.,), (112.5,), (165.,), (217.5,), (270.,)] + all_extra_scales = [(42.43,), (82.17,), (136.23,), (189.45,), (242.34,), (295.08,)] + all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)] with tf.variable_scope(FLAGS.model_scope, default_name=None, values=[features], reuse=tf.AUTO_REUSE): backbone = ssd_net.VGG16Backbone(FLAGS.data_format) feature_layers = backbone.forward(features, training=False) - location_pred, cls_pred = ssd_net.multibox_head(feature_layers, FLAGS.num_classes, all_num_anchors_depth, data_format=FLAGS.data_format) + with tf.device('/cpu:0'): + anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(positive_threshold=None, ignore_threshold=None, prior_scaling=[0.1, 0.1, 0.2, 0.2]) + + if FLAGS.data_format == 'channels_first': + all_layer_shapes = [tf.shape(feat)[2:] for feat in feature_layers] + else: + all_layer_shapes = [tf.shape(feat)[1:3] for feat in feature_layers] + all_layer_strides = [8, 16, 32, 64, 100, 300] + total_layers = len(all_layer_shapes) + anchors_height = list() + anchors_width = list() + anchors_depth = list() + for ind in range(total_layers): + _anchors_height, _anchors_width, _anchor_depth = anchor_encoder_decoder.get_anchors_width_height(all_anchor_scales[ind], all_extra_scales[ind], all_anchor_ratios[ind], name='get_anchors_width_height{}'.format(ind)) + anchors_height.append(_anchors_height) + anchors_width.append(_anchors_width) + anchors_depth.append(_anchor_depth) + anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, _ = anchor_encoder_decoder.get_all_anchors(tf.squeeze(output_shape, axis=0), + anchors_height, anchors_width, anchors_depth, + [0.5] * total_layers, all_layer_shapes, all_layer_strides, + [0.] * total_layers, [False] * total_layers) + location_pred, cls_pred = ssd_net.multibox_head(feature_layers, FLAGS.num_classes, anchors_depth, data_format=FLAGS.data_format) if FLAGS.data_format == 'channels_first': cls_pred = [tf.transpose(pred, [0, 2, 3, 1]) for pred in cls_pred] location_pred = [tf.transpose(pred, [0, 2, 3, 1]) for pred in location_pred] @@ -184,9 +124,8 @@ def main(_): location_pred = tf.concat(location_pred, axis=0) with tf.device('/cpu:0'): - bboxes_pred = decode_fn(location_pred) - bboxes_pred = tf.concat(bboxes_pred, axis=0) - selected_bboxes, selected_scores = parse_by_class(cls_pred, bboxes_pred, + bboxes_pred = anchor_encoder_decoder.decode_anchors(location_pred, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax) + selected_bboxes, selected_scores = bbox_util.parse_by_class(tf.squeeze(output_shape, axis=0), cls_pred, bboxes_pred, FLAGS.num_classes, FLAGS.select_threshold, FLAGS.min_size, FLAGS.keep_topk, FLAGS.nms_topk, FLAGS.nms_threshold) @@ -209,7 +148,11 @@ def main(_): saver.restore(sess, get_checkpoint()) np_image = imread('./demo/test.jpg') - labels_, scores_, bboxes_ = sess.run([all_labels, all_scores, all_bboxes], feed_dict = {image_input : np_image, shape_input : np_image.shape[:-1]}) + labels_, scores_, bboxes_, output_shape_ = sess.run([all_labels, all_scores, all_bboxes, output_shape], feed_dict = {image_input : np_image, shape_input : np_image.shape[:-1]}) + bboxes_[:, 0] = bboxes_[:, 0] * np_image.shape[0] / output_shape_[0, 0] + bboxes_[:, 1] = bboxes_[:, 1] * np_image.shape[1] / output_shape_[0, 1] + bboxes_[:, 2] = bboxes_[:, 2] * np_image.shape[0] / output_shape_[0, 0] + bboxes_[:, 3] = bboxes_[:, 3] * np_image.shape[1] / output_shape_[0, 1] img_to_draw = draw_toolbox.bboxes_draw_on_img(np_image, labels_, scores_, bboxes_, thickness=2) imsave('./demo/test_out.jpg', img_to_draw) diff --git a/train_ssd.py b/train_ssd.py index aaab83c..fcf12fa 100644 --- a/train_ssd.py +++ b/train_ssd.py @@ -56,8 +56,11 @@ 'save_summary_steps', 500, 'The frequency with which summaries are saved, in seconds.') tf.app.flags.DEFINE_integer( - 'save_checkpoints_secs', 7200, + 'save_checkpoints_secs', 7200, # not used 'The frequency with which the model is saved, in seconds.') +tf.app.flags.DEFINE_integer( + 'save_checkpoints_steps', 20000, + 'The frequency with which the model is saved, in steps.') # model related configuration tf.app.flags.DEFINE_integer( 'train_image_size', 300, @@ -165,26 +168,37 @@ def get_init_fn(): def input_pipeline(dataset_pattern='train-*', is_training=True, batch_size=FLAGS.batch_size): def input_fn(): - out_shape = [FLAGS.train_image_size] * 2 - anchor_creator = anchor_manipulator.AnchorCreator(out_shape, - layers_shapes = [(38, 38), (19, 19), (10, 10), (5, 5), (3, 3), (1, 1)], - anchor_scales = [(0.1,), (0.2,), (0.375,), (0.55,), (0.725,), (0.9,)], - extra_anchor_scales = [(0.1414,), (0.2739,), (0.4541,), (0.6315,), (0.8078,), (0.9836,)], - anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)], - layer_steps = [8, 16, 32, 64, 100, 300]) - all_anchors, all_num_anchors_depth, all_num_anchors_spatial = anchor_creator.get_all_anchors() - - num_anchors_per_layer = [] - for ind in range(len(all_anchors)): - num_anchors_per_layer.append(all_num_anchors_depth[ind] * all_num_anchors_spatial[ind]) - - anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(allowed_borders = [1.0] * 6, - positive_threshold = FLAGS.match_threshold, - ignore_threshold = FLAGS.neg_threshold, - prior_scaling=[0.1, 0.1, 0.2, 0.2]) - - image_preprocessing_fn = lambda image_, labels_, bboxes_ : ssd_preprocessing.preprocess_image(image_, labels_, bboxes_, out_shape, is_training=is_training, data_format=FLAGS.data_format, output_rgb=False) - anchor_encoder_fn = lambda glabels_, gbboxes_: anchor_encoder_decoder.encode_all_anchors(glabels_, gbboxes_, all_anchors, all_num_anchors_depth, all_num_anchors_spatial) + target_shape = [FLAGS.train_image_size] * 2 + + anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(positive_threshold = FLAGS.match_threshold, + ignore_threshold = FLAGS.neg_threshold, + prior_scaling=[0.1, 0.1, 0.2, 0.2]) + + all_anchor_scales = [(30.,), (60.,), (112.5,), (165.,), (217.5,), (270.,)] + all_extra_scales = [(42.43,), (82.17,), (136.23,), (189.45,), (242.34,), (295.08,)] + all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)] + all_layer_shapes = [(38, 38), (19, 19), (10, 10), (5, 5), (3, 3), (1, 1)] + all_layer_strides = [8, 16, 32, 64, 100, 300] + total_layers = len(all_layer_shapes) + anchors_height = list() + anchors_width = list() + anchors_depth = list() + for ind in range(total_layers): + _anchors_height, _anchors_width, _anchor_depth = anchor_encoder_decoder.get_anchors_width_height(all_anchor_scales[ind], all_extra_scales[ind], all_anchor_ratios[ind], name='get_anchors_width_height{}'.format(ind)) + anchors_height.append(_anchors_height) + anchors_width.append(_anchors_width) + anchors_depth.append(_anchor_depth) + anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, inside_mask = anchor_encoder_decoder.get_all_anchors(target_shape, anchors_height, anchors_width, anchors_depth, + [0.5] * total_layers, all_layer_shapes, all_layer_strides, + [FLAGS.train_image_size * 1.] * total_layers, [False] * total_layers) + + num_anchors_per_layer = list() + for ind, layer_shape in enumerate(all_layer_shapes): + _, _num_anchors_per_layer = anchor_encoder_decoder.get_anchors_count(anchors_depth[ind], layer_shape, name='get_anchor_count{}'.format(ind)) + num_anchors_per_layer.append(_num_anchors_per_layer) + + image_preprocessing_fn = lambda image_, labels_, bboxes_ : ssd_preprocessing.preprocess_image(image_, labels_, bboxes_, target_shape, is_training=is_training, data_format=FLAGS.data_format, output_rgb=False) + anchor_encoder_fn = lambda glabels_, gbboxes_: anchor_encoder_decoder.encode_anchors(glabels_, gbboxes_, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, inside_mask) image, _, shape, loc_targets, cls_targets, match_scores = dataset_common.slim_get_batch(FLAGS.num_classes, batch_size, @@ -197,9 +211,9 @@ def input_fn(): num_epochs=FLAGS.train_epochs, is_training=is_training) global global_anchor_info - global_anchor_info = {'decode_fn': lambda pred : anchor_encoder_decoder.decode_all_anchors(pred, num_anchors_per_layer), + global_anchor_info = {'decode_fn': lambda pred : anchor_encoder_decoder.batch_decode_anchors(pred, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax), 'num_anchors_per_layer': num_anchors_per_layer, - 'all_num_anchors_depth': all_num_anchors_depth } + 'all_num_anchors_depth': anchors_depth } return image, {'shape': shape, 'loc_targets': loc_targets, 'cls_targets': cls_targets, 'match_scores': match_scores} return input_fn @@ -288,12 +302,9 @@ def ssd_model_fn(features, labels, mode, params): with tf.control_dependencies([cls_pred, location_pred]): with tf.name_scope('post_forward'): #bboxes_pred = decode_fn(location_pred) - bboxes_pred = tf.map_fn(lambda _preds : decode_fn(_preds), - tf.reshape(location_pred, [tf.shape(features)[0], -1, 4]), - dtype=[tf.float32] * len(num_anchors_per_layer), back_prop=False) + bboxes_pred = decode_fn(tf.reshape(location_pred, [tf.shape(features)[0], -1, 4])) #cls_targets = tf.Print(cls_targets, [tf.shape(bboxes_pred[0]),tf.shape(bboxes_pred[1]),tf.shape(bboxes_pred[2]),tf.shape(bboxes_pred[3])]) - bboxes_pred = [tf.reshape(preds, [-1, 4]) for preds in bboxes_pred] - bboxes_pred = tf.concat(bboxes_pred, axis=0) + bboxes_pred = tf.reshape(bboxes_pred, [-1, 4]) flaten_cls_targets = tf.reshape(cls_targets, [-1]) flaten_match_scores = tf.reshape(match_scores, [-1]) @@ -308,8 +319,8 @@ def ssd_model_fn(features, labels, mode, params): batch_negtive_mask = tf.equal(cls_targets, 0)#tf.logical_and(tf.equal(cls_targets, 0), match_scores > 0.) batch_n_negtives = tf.count_nonzero(batch_negtive_mask, -1) - batch_n_neg_select = tf.cast(params['negative_ratio'] * tf.cast(batch_n_positives, tf.float32), tf.int32) - batch_n_neg_select = tf.minimum(batch_n_neg_select, tf.cast(batch_n_negtives, tf.int32)) + batch_n_neg_select = tf.to_int32(params['negative_ratio'] * tf.to_float(batch_n_positives)) + batch_n_neg_select = tf.minimum(batch_n_neg_select, tf.to_int32(batch_n_negtives)) # hard negative mining for classification predictions_for_bg = tf.nn.softmax(tf.reshape(cls_pred, [tf.shape(features)[0], -1, params['num_classes']]))[:, :, 0] @@ -416,8 +427,8 @@ def main(_): # Set up a RunConfig to only save checkpoints once per training cycle. run_config = tf.estimator.RunConfig().replace( - save_checkpoints_secs=FLAGS.save_checkpoints_secs).replace( - save_checkpoints_steps=None).replace( + save_checkpoints_secs=None).replace( + save_checkpoints_steps=FLAGS.save_checkpoints_steps).replace( save_summary_steps=FLAGS.save_summary_steps).replace( keep_checkpoint_max=5).replace( tf_random_seed=FLAGS.tf_random_seed).replace( diff --git a/utility/anchor_manipulator.py b/utility/anchor_manipulator.py index a5d14c7..24069ec 100644 --- a/utility/anchor_manipulator.py +++ b/utility/anchor_manipulator.py @@ -19,10 +19,11 @@ from tensorflow.contrib.image.python.ops import image_ops + def areas(gt_bboxes): with tf.name_scope('bboxes_areas', [gt_bboxes]): ymin, xmin, ymax, xmax = tf.split(gt_bboxes, 4, axis=1) - return (xmax - xmin) * (ymax - ymin) + return (xmax - xmin + 1.) * (ymax - ymin + 1.) def intersection(gt_bboxes, default_bboxes): with tf.name_scope('bboxes_intersection', [gt_bboxes, default_bboxes]): @@ -35,66 +36,71 @@ def intersection(gt_bboxes, default_bboxes): int_xmin = tf.maximum(xmin, gt_xmin) int_ymax = tf.minimum(ymax, gt_ymax) int_xmax = tf.minimum(xmax, gt_xmax) - h = tf.maximum(int_ymax - int_ymin, 0.) - w = tf.maximum(int_xmax - int_xmin, 0.) + h = tf.maximum(int_ymax - int_ymin + 1., 0.) + w = tf.maximum(int_xmax - int_xmin + 1., 0.) return h * w def iou_matrix(gt_bboxes, default_bboxes): with tf.name_scope('iou_matrix', [gt_bboxes, default_bboxes]): inter_vol = intersection(gt_bboxes, default_bboxes) # broadcast - union_vol = areas(gt_bboxes) + tf.transpose(areas(default_bboxes), perm=[1, 0]) - inter_vol + areas_gt = areas(gt_bboxes) + union_vol = areas_gt + tf.transpose(areas(default_bboxes), perm=[1, 0]) - inter_vol - return tf.where(tf.equal(union_vol, 0.0), - tf.zeros_like(inter_vol), tf.truediv(inter_vol, union_vol)) + #areas_gt = tf.Print(areas_gt, [areas_gt], summarize=100) + return tf.where(tf.equal(union_vol, 0.0), tf.zeros_like(inter_vol), tf.truediv(inter_vol, union_vol)) def do_dual_max_match(overlap_matrix, low_thres, high_thres, ignore_between=True, gt_max_first=True): - ''' - overlap_matrix: num_gt * num_anchors + '''do_dual_max_match, but using the transpoed overlap matrix, this may be faster due to the cache friendly + + Args: + overlap_matrix: num_anchors * num_gt ''' with tf.name_scope('dual_max_match', [overlap_matrix]): # first match from anchors' side - anchors_to_gt = tf.argmax(overlap_matrix, axis=0) + anchors_to_gt = tf.argmax(overlap_matrix, axis=1) # the matching degree - match_values = tf.reduce_max(overlap_matrix, axis=0) + match_values = tf.reduce_max(overlap_matrix, axis=1) #positive_mask = tf.greater(match_values, high_thres) less_mask = tf.less(match_values, low_thres) between_mask = tf.logical_and(tf.less(match_values, high_thres), tf.greater_equal(match_values, low_thres)) negative_mask = less_mask if ignore_between else between_mask ignore_mask = between_mask if ignore_between else less_mask + # comment following two lines + # over_pos_mask = tf.greater(match_values, 0.7) + # ignore_mask = tf.logical_or(ignore_mask, over_pos_mask) # fill all negative positions with -1, all ignore positions is -2 match_indices = tf.where(negative_mask, -1 * tf.ones_like(anchors_to_gt), anchors_to_gt) match_indices = tf.where(ignore_mask, -2 * tf.ones_like(match_indices), match_indices) # negtive values has no effect in tf.one_hot, that means all zeros along that axis # so all positive match positions in anchors_to_gt_mask is 1, all others are 0 - anchors_to_gt_mask = tf.one_hot(tf.clip_by_value(match_indices, -1, tf.cast(tf.shape(overlap_matrix)[0], tf.int64)), - tf.shape(overlap_matrix)[0], on_value=1, off_value=0, axis=0, dtype=tf.int32) + anchors_to_gt_mask = tf.one_hot(tf.clip_by_value(match_indices, -1, tf.cast(tf.shape(overlap_matrix)[1], tf.int64)), + tf.shape(overlap_matrix)[1], on_value=1, off_value=0, axis=1, dtype=tf.int32) # match from ground truth's side - gt_to_anchors = tf.argmax(overlap_matrix, axis=1) + gt_to_anchors = tf.argmax(overlap_matrix, axis=0) + gt_to_anchors_overlap = tf.reduce_max(overlap_matrix, axis=0, keepdims=True) - if gt_max_first: - # the max match from ground truth's side has higher priority - left_gt_to_anchors_mask = tf.one_hot(gt_to_anchors, tf.shape(overlap_matrix)[1], on_value=1, off_value=0, axis=1, dtype=tf.int32) - else: + #gt_to_anchors = tf.Print(gt_to_anchors, [tf.equal(overlap_matrix, gt_to_anchors_overlap)], message='gt_to_anchors_indices:', summarize=100) + # the max match from ground truth's side has higher priority + left_gt_to_anchors_mask = tf.equal(overlap_matrix, gt_to_anchors_overlap)#tf.one_hot(gt_to_anchors, tf.shape(overlap_matrix)[0], on_value=True, off_value=False, axis=0, dtype=tf.bool) + if not gt_max_first: # the max match from anchors' side has higher priority # use match result from ground truth's side only when the the matching degree from anchors' side is lower than position threshold - left_gt_to_anchors_mask = tf.cast(tf.logical_and(tf.reduce_max(anchors_to_gt_mask, axis=1, keep_dims=True) < 1, - tf.one_hot(gt_to_anchors, tf.shape(overlap_matrix)[1], - on_value=True, off_value=False, axis=1, dtype=tf.bool) - ), tf.int64) + left_gt_to_anchors_mask = tf.logical_and(tf.reduce_max(anchors_to_gt_mask, axis=0, keep_dims=True) < 1, left_gt_to_anchors_mask) # can not use left_gt_to_anchors_mask here, because there are many ground truthes match to one anchor, we should pick the highest one even when we are merging matching from ground truth side + left_gt_to_anchors_mask = tf.to_int64(left_gt_to_anchors_mask) left_gt_to_anchors_scores = overlap_matrix * tf.to_float(left_gt_to_anchors_mask) # merge matching results from ground truth's side with the original matching results from anchors' side # then select all the overlap score of those matching pairs - selected_scores = tf.gather_nd(overlap_matrix, tf.stack([tf.where(tf.reduce_max(left_gt_to_anchors_mask, axis=0) > 0, - tf.argmax(left_gt_to_anchors_scores, axis=0), - anchors_to_gt), - tf.range(tf.cast(tf.shape(overlap_matrix)[1], tf.int64))], axis=1)) + selected_scores = tf.gather_nd(overlap_matrix, tf.stack([tf.range(tf.cast(tf.shape(overlap_matrix)[0], tf.int64)), + tf.where(tf.reduce_max(left_gt_to_anchors_mask, axis=1) > 0, + tf.argmax(left_gt_to_anchors_scores, axis=1), + anchors_to_gt)], axis=1)) # return the matching results for both foreground anchors and background anchors, also with overlap scores - return tf.where(tf.reduce_max(left_gt_to_anchors_mask, axis=0) > 0, - tf.argmax(left_gt_to_anchors_scores, axis=0), + return tf.where(tf.reduce_max(left_gt_to_anchors_mask, axis=1) > 0, + tf.argmax(left_gt_to_anchors_scores, axis=1), match_indices), selected_scores # def save_anchors(bboxes, labels, anchors_point): @@ -108,83 +114,187 @@ def do_dual_max_match(overlap_matrix, low_thres, high_thres, ignore_between=True # return save_image_with_bbox.counter class AnchorEncoder(object): - def __init__(self, allowed_borders, positive_threshold, ignore_threshold, prior_scaling, clip=False): + def __init__(self, positive_threshold, ignore_threshold, prior_scaling): super(AnchorEncoder, self).__init__() self._all_anchors = None - self._allowed_borders = allowed_borders self._positive_threshold = positive_threshold self._ignore_threshold = ignore_threshold self._prior_scaling = prior_scaling - self._clip = clip def center2point(self, center_y, center_x, height, width): - return center_y - height / 2., center_x - width / 2., center_y + height / 2., center_x + width / 2., + with tf.name_scope('center2point'): + return center_y - (height - 1.) / 2., center_x - (width - 1.) / 2., center_y + (height - 1.) / 2., center_x + (width - 1.) / 2., def point2center(self, ymin, xmin, ymax, xmax): - height, width = (ymax - ymin), (xmax - xmin) - return ymin + height / 2., xmin + width / 2., height, width - - def encode_all_anchors(self, labels, bboxes, all_anchors, all_num_anchors_depth, all_num_anchors_spatial, debug=False): - # y, x, h, w are all in range [0, 1] relative to the original image size - # shape info: - # y_on_image, x_on_image: layers_shapes[0] * layers_shapes[1] - # h_on_image, w_on_image: num_anchors - assert (len(all_num_anchors_depth)==len(all_num_anchors_spatial)) and (len(all_num_anchors_depth)==len(all_anchors)), 'inconsist num layers for anchors.' - with tf.name_scope('encode_all_anchors'): - num_layers = len(all_num_anchors_depth) - list_anchors_ymin = [] - list_anchors_xmin = [] - list_anchors_ymax = [] - list_anchors_xmax = [] - tiled_allowed_borders = [] - for ind, anchor in enumerate(all_anchors): - anchors_ymin_, anchors_xmin_, anchors_ymax_, anchors_xmax_ = self.center2point(anchor[0], anchor[1], anchor[2], anchor[3]) - - list_anchors_ymin.append(tf.reshape(anchors_ymin_, [-1])) - list_anchors_xmin.append(tf.reshape(anchors_xmin_, [-1])) - list_anchors_ymax.append(tf.reshape(anchors_ymax_, [-1])) - list_anchors_xmax.append(tf.reshape(anchors_xmax_, [-1])) - - tiled_allowed_borders.extend([self._allowed_borders[ind]] * all_num_anchors_depth[ind] * all_num_anchors_spatial[ind]) - - anchors_ymin = tf.concat(list_anchors_ymin, 0, name='concat_ymin') - anchors_xmin = tf.concat(list_anchors_xmin, 0, name='concat_xmin') - anchors_ymax = tf.concat(list_anchors_ymax, 0, name='concat_ymax') - anchors_xmax = tf.concat(list_anchors_xmax, 0, name='concat_xmax') - - if self._clip: - anchors_ymin = tf.clip_by_value(anchors_ymin, 0., 1.) - anchors_xmin = tf.clip_by_value(anchors_xmin, 0., 1.) - anchors_ymax = tf.clip_by_value(anchors_ymax, 0., 1.) - anchors_xmax = tf.clip_by_value(anchors_xmax, 0., 1.) - - anchor_allowed_borders = tf.stack(tiled_allowed_borders, 0, name='concat_allowed_borders') - - inside_mask = tf.logical_and(tf.logical_and(anchors_ymin > -anchor_allowed_borders * 1., - anchors_xmin > -anchor_allowed_borders * 1.), - tf.logical_and(anchors_ymax < (1. + anchor_allowed_borders * 1.), - anchors_xmax < (1. + anchor_allowed_borders * 1.))) - - anchors_point = tf.stack([anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax], axis=-1) - - # save_anchors_op = tf.py_func(save_anchors, - # [bboxes, - # labels, - # anchors_point], - # tf.int64, stateful=True) - - # with tf.control_dependencies([save_anchors_op]): - overlap_matrix = iou_matrix(bboxes, anchors_point) * tf.cast(tf.expand_dims(inside_mask, 0), tf.float32) + with tf.name_scope('point2center'): + height, width = (ymax - ymin + 1.), (xmax - xmin + 1.) + return (ymin + ymax) / 2., (xmin + xmax) / 2., height, width + + def get_anchors_width_height(self, anchor_scale, extra_anchor_scale, anchor_ratio, name=None): + '''get_anchors_width_height + + Given scales and ratios, generate anchors along depth (you should use absolute scale in the input image) + Args: + anchor_scale: base scale of the window size used to transform anchors, each scale should have every ratio in 'anchor_ratio' + extra_anchor_scale: base scale of the window size used to transform anchors, each scale should have ratio of 1:1 + anchor_ratio: all ratios of anchors for each scale in 'anchor_scale' + ''' + with tf.name_scope(name, 'get_anchors_width_height'): + all_num_anchors_depth = len(anchor_scale) * len(anchor_ratio) + len(extra_anchor_scale) + + list_h_on_image = [] + list_w_on_image = [] + + # for square anchors + for _, scale in enumerate(extra_anchor_scale): + list_h_on_image.append(scale) + list_w_on_image.append(scale) + # for other aspect ratio anchors + for scale_index, scale in enumerate(anchor_scale): + for ratio_index, ratio in enumerate(anchor_ratio): + list_h_on_image.append(scale / math.sqrt(ratio)) + list_w_on_image.append(scale * math.sqrt(ratio)) + # shape info: + # y_on_image, x_on_image: layers_shapes[0] * layers_shapes[1] + # h_on_image, w_on_image: num_anchors_along_depth + return tf.constant(list_h_on_image, dtype=tf.float32), tf.constant(list_w_on_image, dtype=tf.float32), all_num_anchors_depth + + def generate_anchors_by_offset(self, anchors_height, anchors_width, anchor_depth, image_shape, layer_shape, feat_stride, offset=0.5, name=None): + '''generate_anchors_by_offset + + Given anchor width and height, generate tiled anchors across the 'layer_shape' + Args: + anchors_height, anchors_width, anchor_depth: generate by the above function 'get_anchors_width_height' + image_shape: the input image size, since we will generate anchors in absolute coordinates, [height, width] + layer_shape: the size of layer on which we will tile the anchors, [height, width] + feat_stride: the strides from input image to the layer on which we will generate anchors + offset: the offset (height offset and width offset) in in the feature map when we tile anchors, should be either single scalar or a list of scalar + ''' + with tf.name_scope(name, 'generate_anchors'): + image_height, image_width, feat_stride = tf.to_float(image_shape[0]), tf.to_float(image_shape[1]), tf.to_float(feat_stride) + + x_on_layer, y_on_layer = tf.meshgrid(tf.range(layer_shape[1]), tf.range(layer_shape[0])) + + if isinstance(offset, list): + tf.logging.info('{}: Using seperate offset: height: {}, width: {}.'.format(name, offset[0], offset[1])) + offset_h = offset[0] + offset_w = offset[1] + else: + offset_h = offset + offset_w = offset + y_on_image = (tf.to_float(y_on_layer) + offset_h) * feat_stride + x_on_image = (tf.to_float(x_on_layer) + offset_w) * feat_stride + + anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax = self.center2point(tf.expand_dims(y_on_image, axis=-1), + tf.expand_dims(x_on_image, axis=-1), + anchors_height, anchors_width) + + anchors_ymin = tf.reshape(anchors_ymin, [-1, anchor_depth]) + anchors_xmin = tf.reshape(anchors_xmin, [-1, anchor_depth]) + anchors_ymax = tf.reshape(anchors_ymax, [-1, anchor_depth]) + anchors_xmax = tf.reshape(anchors_xmax, [-1, anchor_depth]) + + return anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax + + def get_anchors_count(self, anchors_depth, layer_shape, name=None): + '''get_anchors_count + + Return the total anchors on specific layer + Args: + anchor_depth: generate by the above function 'get_anchors_width_height' + layer_shape: the size of layer on which we will tile the anchors, [height, width] + ''' + with tf.name_scope(name, 'get_anchors_count'): + all_num_anchors_spatial = layer_shape[0] * layer_shape[1] + all_num_anchors = all_num_anchors_spatial * anchors_depth + return all_num_anchors_spatial, all_num_anchors + + def get_all_anchors(self, image_shape, anchors_height, anchors_width, anchors_depth, anchors_offsets, layer_shapes, feat_strides, allowed_borders, should_clips, name=None): + '''get_all_anchors + + Return the all anchors from all layers + Args: + image_shape: the input image size, since we will generate anchors in absolute coordinates, [height, width] + anchors_height: list, each of which is generated by the above function 'get_anchors_width_height' + anchors_width: list, each of which is generated by the above function 'get_anchors_width_height' + anchors_depth: list, each of which is generated by the above function 'get_anchors_width_height' + anchors_offsets: list, each of which will be used by 'generate_anchors_by_offset' + layer_shapes: list, each of which will be used by 'generate_anchors_by_offset' + feat_strides: list, each of which will be used by 'generate_anchors_by_offset' + allowed_borders: list, each of which is the border margin to clip border anchors for each layer + should_clips: list, each of which indicate that if we should clip anchors to image border for each layer + ''' + with tf.name_scope(name, 'get_all_anchors'): + image_height, image_width = tf.to_float(image_shape[0]), tf.to_float(image_shape[1]) + + anchors_ymin = [] + anchors_xmin = [] + anchors_ymax = [] + anchors_xmax = [] + anchor_allowed_borders = [] + for ind, anchor_depth in enumerate(anchors_depth): + with tf.name_scope('generate_anchors_{}'.format(ind)): + _anchors_ymin, _anchors_xmin, _anchors_ymax, _anchors_xmax = self.generate_anchors_by_offset(anchors_height[ind], anchors_width[ind], anchor_depth, image_shape, layer_shapes[ind], feat_strides[ind], offset=anchors_offsets[ind]) + + if should_clips[ind]: + _anchors_ymin = tf.clip_by_value(_anchors_ymin, 0., image_height - 1.) + _anchors_xmin = tf.clip_by_value(_anchors_xmin, 0., image_width - 1.) + _anchors_ymax = tf.clip_by_value(_anchors_ymax, 0., image_height - 1.) + _anchors_xmax = tf.clip_by_value(_anchors_xmax, 0., image_width - 1.) + + _anchors_ymin = tf.reshape(_anchors_ymin, [-1]) + _anchors_xmin = tf.reshape(_anchors_xmin, [-1]) + _anchors_ymax = tf.reshape(_anchors_ymax, [-1]) + _anchors_xmax = tf.reshape(_anchors_xmax, [-1]) + + anchors_ymin.append(_anchors_ymin) + anchors_xmin.append(_anchors_xmin) + anchors_ymax.append(_anchors_ymax) + anchors_xmax.append(_anchors_xmax) + anchor_allowed_borders.append(tf.ones_like(_anchors_ymin, dtype=tf.float32) * allowed_borders[ind]) + + # anchors_ymin = tf.reshape(tf.concat(anchors_ymin, axis=0), [-1]) + # anchors_xmin = tf.reshape(tf.concat(anchors_xmin, axis=0), [-1]) + # anchors_ymax = tf.reshape(tf.concat(anchors_ymax, axis=0), [-1]) + # anchors_xmax = tf.reshape(tf.concat(anchors_xmax, axis=0), [-1]) + # anchor_allowed_borders = tf.reshape(tf.concat(anchor_allowed_borders, axis=0), [-1]) + anchors_ymin = tf.concat(anchors_ymin, axis=0) + anchors_xmin = tf.concat(anchors_xmin, axis=0) + anchors_ymax = tf.concat(anchors_ymax, axis=0) + anchors_xmax = tf.concat(anchors_xmax, axis=0) + anchor_allowed_borders = tf.concat(anchor_allowed_borders, axis=0) + + inside_mask = tf.logical_and(tf.logical_and(anchors_ymin > -anchor_allowed_borders, + anchors_xmin > -anchor_allowed_borders), + tf.logical_and(anchors_ymax < (image_height - 1. + anchor_allowed_borders), + anchors_xmax < (image_width - 1. + anchor_allowed_borders))) + + return anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, inside_mask + + def encode_anchors(self, labels, bboxes, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, inside_mask, debug=False): + '''encode anchors with ground truth on the fly + + We generate prediction targets for all locations of the rpn feature map, so this routine is called when the final rpn feature map has been generated, so there is a performance bottleneck here but we have no idea to fix this because of we must perform multi-scale training. Maybe this needs to be placed on CPU, leave this problem to later + + Args: + bboxes: [num_bboxes, 4] in [ymin, xmin, ymax, xmax] format + anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, inside_mask: generate by 'get_all_anchors' + ''' + with tf.name_scope('encode_anchors'): + all_anchors = tf.stack([anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax], axis=-1) + overlap_matrix = iou_matrix(all_anchors, bboxes) * tf.cast(tf.expand_dims(inside_mask, 1), tf.float32) + #overlap_matrix = tf.Print(overlap_matrix, [tf.shape(overlap_matrix)]) + #matched_gt, gt_scores = custom_op.small_mining_match(overlap_matrix, 0., self._rpn_negtive_threshold, self._rpn_positive_threshold, 5, 0.35) matched_gt, gt_scores = do_dual_max_match(overlap_matrix, self._ignore_threshold, self._positive_threshold) # get all positive matching positions matched_gt_mask = matched_gt > -1 matched_indices = tf.clip_by_value(matched_gt, 0, tf.int64.max) - # the labels here maybe chaos at those non-positive positions + gt_labels = tf.gather(labels, matched_indices) # filter the invalid labels - gt_labels = gt_labels * tf.cast(matched_gt_mask, tf.int64) + gt_labels = gt_labels * tf.to_int64(matched_gt_mask) # set those ignored positions to -1 - gt_labels = gt_labels + (-1 * tf.cast(matched_gt < -1, tf.int64)) + gt_labels = gt_labels + (-1 * tf.to_int64(matched_gt < -1)) gt_ymin, gt_xmin, gt_ymax, gt_xmax = tf.unstack(tf.gather(bboxes, matched_indices), 4, axis=-1) @@ -203,131 +313,43 @@ def encode_all_anchors(self, labels, bboxes, all_anchors, all_num_anchors_depth, else: gt_targets = tf.stack([gt_cy, gt_cx, gt_h, gt_w], axis=-1) # set all targets of non-positive positions to 0 - gt_targets = tf.expand_dims(tf.cast(matched_gt_mask, tf.float32), -1) * gt_targets - self._all_anchors = (anchor_cy, anchor_cx, anchor_h, anchor_w) + gt_targets = tf.expand_dims(tf.to_float(matched_gt_mask), -1) * gt_targets + return gt_targets, gt_labels, gt_scores - # return a list, of which each is: - # shape: [feature_h, feature_w, num_anchors, 4] - # order: ymin, xmin, ymax, xmax - def decode_all_anchors(self, pred_location, num_anchors_per_layer): - assert self._all_anchors is not None, 'no anchors to decode.' - with tf.name_scope('decode_all_anchors', [pred_location]): - anchor_cy, anchor_cx, anchor_h, anchor_w = self._all_anchors + def batch_decode_anchors(self, pred_location, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax): + '''batch_decode_anchors - pred_h = tf.exp(pred_location[:, -2] * self._prior_scaling[2]) * anchor_h - pred_w = tf.exp(pred_location[:, -1] * self._prior_scaling[3]) * anchor_w - pred_cy = pred_location[:, 0] * self._prior_scaling[0] * anchor_h + anchor_cy - pred_cx = pred_location[:, 1] * self._prior_scaling[1] * anchor_w + anchor_cx + Args: + pred_location: [batch, num_preds, 4] in yxhw format + all_anchors: generate by 'get_all_anchors' + ''' + with tf.name_scope('decode_rpn', [pred_location, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax]): + anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax = tf.expand_dims(anchors_ymin, axis=0), \ + tf.expand_dims(anchors_xmin, axis=0),\ + tf.expand_dims(anchors_ymax, axis=0),\ + tf.expand_dims(anchors_xmax, axis=0) + anchor_cy, anchor_cx, anchor_h, anchor_w = self.point2center(anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax) - return tf.split(tf.stack(self.center2point(pred_cy, pred_cx, pred_h, pred_w), axis=-1), num_anchors_per_layer, axis=0) - - def ext_decode_all_anchors(self, pred_location, all_anchors, all_num_anchors_depth, all_num_anchors_spatial): - assert (len(all_num_anchors_depth)==len(all_num_anchors_spatial)) and (len(all_num_anchors_depth)==len(all_anchors)), 'inconsist num layers for anchors.' - with tf.name_scope('ext_decode_all_anchors', [pred_location]): - num_anchors_per_layer = [] - for ind in range(len(all_anchors)): - num_anchors_per_layer.append(all_num_anchors_depth[ind] * all_num_anchors_spatial[ind]) - - num_layers = len(all_num_anchors_depth) - list_anchors_ymin = [] - list_anchors_xmin = [] - list_anchors_ymax = [] - list_anchors_xmax = [] - tiled_allowed_borders = [] - for ind, anchor in enumerate(all_anchors): - anchors_ymin_, anchors_xmin_, anchors_ymax_, anchors_xmax_ = self.center2point(anchor[0], anchor[1], anchor[2], anchor[3]) - - list_anchors_ymin.append(tf.reshape(anchors_ymin_, [-1])) - list_anchors_xmin.append(tf.reshape(anchors_xmin_, [-1])) - list_anchors_ymax.append(tf.reshape(anchors_ymax_, [-1])) - list_anchors_xmax.append(tf.reshape(anchors_xmax_, [-1])) - - anchors_ymin = tf.concat(list_anchors_ymin, 0, name='concat_ymin') - anchors_xmin = tf.concat(list_anchors_xmin, 0, name='concat_xmin') - anchors_ymax = tf.concat(list_anchors_ymax, 0, name='concat_ymax') - anchors_xmax = tf.concat(list_anchors_xmax, 0, name='concat_xmax') + pred_h = tf.exp(pred_location[:, :, -2] * self._prior_scaling[2]) * anchor_h + pred_w = tf.exp(pred_location[:, :, -1] * self._prior_scaling[3]) * anchor_w + pred_cy = pred_location[:, :, 0] * self._prior_scaling[0] * anchor_h + anchor_cy + pred_cx = pred_location[:, :, 1] * self._prior_scaling[1] * anchor_w + anchor_cx + return tf.stack(self.center2point(pred_cy, pred_cx, pred_h, pred_w), axis=-1) + def decode_anchors(self, pred_location, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax): + '''decode_anchors + + Args: + pred_location: [num_preds, 4] in yxhw format + all_anchors: generate by 'get_all_anchors' + ''' + with tf.name_scope('decode_rpn', [pred_location, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax]): anchor_cy, anchor_cx, anchor_h, anchor_w = self.point2center(anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax) - pred_h = tf.exp(pred_location[:,-2] * self._prior_scaling[2]) * anchor_h + pred_h = tf.exp(pred_location[:, -2] * self._prior_scaling[2]) * anchor_h pred_w = tf.exp(pred_location[:, -1] * self._prior_scaling[3]) * anchor_w pred_cy = pred_location[:, 0] * self._prior_scaling[0] * anchor_h + anchor_cy pred_cx = pred_location[:, 1] * self._prior_scaling[1] * anchor_w + anchor_cx - return tf.split(tf.stack(self.center2point(pred_cy, pred_cx, pred_h, pred_w), axis=-1), num_anchors_per_layer, axis=0) - -class AnchorCreator(object): - def __init__(self, img_shape, layers_shapes, anchor_scales, extra_anchor_scales, anchor_ratios, layer_steps): - super(AnchorCreator, self).__init__() - # img_shape -> (height, width) - self._img_shape = img_shape - self._layers_shapes = layers_shapes - self._anchor_scales = anchor_scales - self._extra_anchor_scales = extra_anchor_scales - self._anchor_ratios = anchor_ratios - self._layer_steps = layer_steps - self._anchor_offset = [0.5] * len(self._layers_shapes) - - def get_layer_anchors(self, layer_shape, anchor_scale, extra_anchor_scale, anchor_ratio, layer_step, offset = 0.5): - ''' assume layer_shape[0] = 6, layer_shape[1] = 5 - x_on_layer = [[0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4], - [0, 1, 2, 3, 4]] - y_on_layer = [[0, 0, 0, 0, 0], - [1, 1, 1, 1, 1], - [2, 2, 2, 2, 2], - [3, 3, 3, 3, 3], - [4, 4, 4, 4, 4], - [5, 5, 5, 5, 5]] - ''' - with tf.name_scope('get_layer_anchors'): - x_on_layer, y_on_layer = tf.meshgrid(tf.range(layer_shape[1]), tf.range(layer_shape[0])) - - y_on_image = (tf.cast(y_on_layer, tf.float32) + offset) * layer_step / self._img_shape[0] - x_on_image = (tf.cast(x_on_layer, tf.float32) + offset) * layer_step / self._img_shape[1] - - num_anchors_along_depth = len(anchor_scale) * len(anchor_ratio) + len(extra_anchor_scale) - num_anchors_along_spatial = layer_shape[1] * layer_shape[0] - - list_h_on_image = [] - list_w_on_image = [] - - global_index = 0 - # for square anchors - for _, scale in enumerate(extra_anchor_scale): - list_h_on_image.append(scale) - list_w_on_image.append(scale) - global_index += 1 - # for other aspect ratio anchors - for scale_index, scale in enumerate(anchor_scale): - for ratio_index, ratio in enumerate(anchor_ratio): - list_h_on_image.append(scale / math.sqrt(ratio)) - list_w_on_image.append(scale * math.sqrt(ratio)) - global_index += 1 - # shape info: - # y_on_image, x_on_image: layers_shapes[0] * layers_shapes[1] - # h_on_image, w_on_image: num_anchors_along_depth - return tf.expand_dims(y_on_image, axis=-1), tf.expand_dims(x_on_image, axis=-1), \ - tf.constant(list_h_on_image, dtype=tf.float32), \ - tf.constant(list_w_on_image, dtype=tf.float32), num_anchors_along_depth, num_anchors_along_spatial - - def get_all_anchors(self): - all_anchors = [] - all_num_anchors_depth = [] - all_num_anchors_spatial = [] - for layer_index, layer_shape in enumerate(self._layers_shapes): - anchors_this_layer = self.get_layer_anchors(layer_shape, - self._anchor_scales[layer_index], - self._extra_anchor_scales[layer_index], - self._anchor_ratios[layer_index], - self._layer_steps[layer_index], - self._anchor_offset[layer_index]) - all_anchors.append(anchors_this_layer[:-2]) - all_num_anchors_depth.append(anchors_this_layer[-2]) - all_num_anchors_spatial.append(anchors_this_layer[-1]) - return all_anchors, all_num_anchors_depth, all_num_anchors_spatial - + return tf.stack(self.center2point(pred_cy, pred_cx, pred_h, pred_w), axis=-1) diff --git a/utility/anchor_manipulator_unittest.py b/utility/anchor_manipulator_unittest.py index bbacc64..38dfbfb 100644 --- a/utility/anchor_manipulator_unittest.py +++ b/utility/anchor_manipulator_unittest.py @@ -91,32 +91,42 @@ def slim_get_split(file_pattern='{}_????'): 'object/difficult']) image, glabels, gbboxes = ssd_preprocessing.preprocess_image(org_image, glabels_raw, gbboxes_raw, [300, 300], is_training=True, data_format='channels_last', output_rgb=True) - anchor_creator = anchor_manipulator.AnchorCreator([300] * 2, - layers_shapes = [(38, 38), (19, 19), (10, 10), (5, 5), (3, 3), (1, 1)], - anchor_scales = [(0.1,), (0.2,), (0.375,), (0.55,), (0.725,), (0.9,)], - extra_anchor_scales = [(0.1414,), (0.2739,), (0.4541,), (0.6315,), (0.8078,), (0.9836,)], - anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)], - layer_steps = [8, 16, 32, 64, 100, 300]) - - all_anchors, all_num_anchors_depth, all_num_anchors_spatial = anchor_creator.get_all_anchors() - - num_anchors_per_layer = [] - for ind in range(len(all_anchors)): - num_anchors_per_layer.append(all_num_anchors_depth[ind] * all_num_anchors_spatial[ind]) - - anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(allowed_borders=[1.0] * 6, - positive_threshold = 0.5, + anchor_encoder_decoder = anchor_manipulator.AnchorEncoder(positive_threshold = 0.5, ignore_threshold = 0.5, prior_scaling=[0.1, 0.1, 0.2, 0.2]) - gt_targets, gt_labels, gt_scores = anchor_encoder_decoder.encode_all_anchors(glabels, gbboxes, all_anchors, all_num_anchors_depth, all_num_anchors_spatial, True) + all_anchor_scales = [(30.,), (60.,), (112.5,), (165.,), (217.5,), (270.,)] + all_extra_scales = [(42.43,), (82.17,), (136.23,), (189.45,), (242.34,), (295.08,)] + all_anchor_ratios = [(2., .5), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., 3., .5, 0.3333), (2., .5), (2., .5)] + all_layer_shapes = [(38, 38), (19, 19), (10, 10), (5, 5), (3, 3), (1, 1)] + all_layer_strides = [8, 16, 32, 64, 100, 300] + total_layers = len(all_layer_shapes) + anchors_height = list() + anchors_width = list() + anchors_depth = list() + for ind in range(total_layers): + _anchors_height, _anchors_width, _anchor_depth = anchor_encoder_decoder.get_anchors_width_height(all_anchor_scales[ind], all_extra_scales[ind], all_anchor_ratios[ind]) + anchors_height.append(_anchors_height) + anchors_width.append(_anchors_width) + anchors_depth.append(_anchor_depth) + anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, inside_mask = anchor_encoder_decoder.get_all_anchors([300] * 2, anchors_height, anchors_width, anchors_depth, + [0.5] * total_layers, all_layer_shapes, all_layer_strides, + [300.] * total_layers, [False] * total_layers) + + gt_targets, gt_labels, gt_scores = anchor_encoder_decoder.encode_anchors(glabels, gbboxes, anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax, inside_mask, True) + + num_anchors_per_layer = list() + for ind, layer_shape in enumerate(all_layer_shapes): + _, _num_anchors_per_layer = anchor_encoder_decoder.get_anchors_count(anchors_depth[ind], layer_shape) + num_anchors_per_layer.append(_num_anchors_per_layer) - anchors = anchor_encoder_decoder._all_anchors # split by layers + all_anchors = tf.stack([anchors_ymin, anchors_xmin, anchors_ymax, anchors_xmax], axis=-1) + gt_targets, gt_labels, gt_scores, anchors = tf.split(gt_targets, num_anchors_per_layer, axis=0),\ tf.split(gt_labels, num_anchors_per_layer, axis=0),\ tf.split(gt_scores, num_anchors_per_layer, axis=0),\ - [tf.split(anchor, num_anchors_per_layer, axis=0) for anchor in anchors] + tf.split(all_anchors, num_anchors_per_layer, axis=0) save_image_op = tf.py_func(save_image_with_bbox, [ssd_preprocessing.unwhiten_image(image), @@ -127,7 +137,7 @@ def slim_get_split(file_pattern='{}_????'): return save_image_op if __name__ == '__main__': - save_image_op = slim_get_split('/media/rs/7A0EE8880EE83EAF/Detections/SSD/dataset/tfrecords/train*') + save_image_op = slim_get_split('./dataset/tfrecords/train*') # Create the graph, etc. init_op = tf.group([tf.local_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer()]) diff --git a/utility/bbox_util.py b/utility/bbox_util.py new file mode 100644 index 0000000..d0651e9 --- /dev/null +++ b/utility/bbox_util.py @@ -0,0 +1,119 @@ +# Copyright 2018 Changan Wang + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys + +import tensorflow as tf + +def select_bboxes(scores_pred, bboxes_pred, num_classes, select_threshold, name=None): + selected_bboxes = {} + selected_scores = {} + with tf.name_scope(name, 'select_bboxes', [scores_pred, bboxes_pred]): + for class_ind in range(1, num_classes): + class_scores = scores_pred[:, class_ind] + select_mask = class_scores > select_threshold + + select_mask = tf.to_float(select_mask) + selected_bboxes[class_ind] = tf.multiply(bboxes_pred, tf.expand_dims(select_mask, axis=-1)) + selected_scores[class_ind] = tf.multiply(class_scores, select_mask) + + return selected_bboxes, selected_scores + +def clip_bboxes(ymin, xmin, ymax, xmax, height, width, name=None): + with tf.name_scope(name, 'clip_bboxes', [ymin, xmin, ymax, xmax]): + ymin = tf.maximum(ymin, 0.) + xmin = tf.maximum(xmin, 0.) + ymax = tf.minimum(ymax, tf.to_float(height) - 1.) + xmax = tf.minimum(xmax, tf.to_float(width) - 1.) + + ymin = tf.minimum(ymin, ymax) + xmin = tf.minimum(xmin, xmax) + + return ymin, xmin, ymax, xmax + +def filter_bboxes(scores_pred, ymin, xmin, ymax, xmax, min_size, name=None): + with tf.name_scope(name, 'filter_bboxes', [scores_pred, ymin, xmin, ymax, xmax]): + width = xmax - xmin + 1. + height = ymax - ymin + 1. + + filter_mask = tf.logical_and(width > min_size + 1., height > min_size + 1.) + + filter_mask = tf.cast(filter_mask, tf.float32) + return tf.multiply(scores_pred, filter_mask), tf.multiply(ymin, filter_mask), \ + tf.multiply(xmin, filter_mask), tf.multiply(ymax, filter_mask), tf.multiply(xmax, filter_mask) + +def sort_bboxes(scores_pred, ymin, xmin, ymax, xmax, keep_topk, name=None): + with tf.name_scope(name, 'sort_bboxes', [scores_pred, ymin, xmin, ymax, xmax]): + cur_bboxes = tf.shape(scores_pred)[0] + scores, idxes = tf.nn.top_k(scores_pred, k=tf.minimum(keep_topk, cur_bboxes), sorted=True) + + ymin, xmin, ymax, xmax = tf.gather(ymin, idxes), tf.gather(xmin, idxes), tf.gather(ymax, idxes), tf.gather(xmax, idxes) + + paddings = tf.expand_dims(tf.stack([0, tf.maximum(keep_topk-cur_bboxes, 0)], axis=0), axis=0) + + return tf.pad(scores, paddings, "CONSTANT"), \ + tf.pad(ymin, paddings, "CONSTANT"), tf.pad(xmin, paddings, "CONSTANT"),\ + tf.pad(ymax, paddings, "CONSTANT"), tf.pad(xmax, paddings, "CONSTANT"),\ + + +def nms_bboxes(scores_pred, bboxes_pred, nms_topk, nms_threshold, name=None): + with tf.name_scope(name, 'nms_bboxes', [scores_pred, bboxes_pred]): + idxes = tf.image.non_max_suppression(bboxes_pred, scores_pred, nms_topk, nms_threshold) + return tf.gather(scores_pred, idxes), tf.gather(bboxes_pred, idxes) + +def nms_bboxes_with_padding(scores_pred, bboxes_pred, nms_topk, nms_threshold, name=None): + with tf.name_scope(name, 'nms_bboxes_with_padding', [scores_pred, bboxes_pred]): + idxes = tf.image.non_max_suppression(bboxes_pred, scores_pred, nms_topk, nms_threshold) + scores = tf.gather(scores_pred, idxes) + bboxes = tf.gather(bboxes_pred, idxes) + + nms_bboxes = tf.shape(idxes)[0] + scores_paddings = tf.expand_dims(tf.stack([0, tf.maximum(nms_topk - nms_bboxes, 0)], axis=0), axis=0) + bboxes_paddings = tf.stack([[0, 0], [tf.maximum(nms_topk - nms_bboxes, 0), 0]], axis=1) + + return tf.pad(scores, scores_paddings, "CONSTANT"), tf.pad(bboxes, bboxes_paddings, "CONSTANT") + +def bbox_point2center(bboxes, name=None): + with tf.name_scope(name, 'bbox_point2center', [bboxes]): + ymin, xmin, ymax, xmax = tf.unstack(bboxes, 4, axis=-1) + height, width = (ymax - ymin + 1.), (xmax - xmin + 1.) + return tf.stack([(ymin + ymax) / 2., (xmin + xmax) / 2., height, width], axis=-1) + +def bbox_center2point(bboxes, name=None): + with tf.name_scope(name, 'bbox_center2point', [bboxes]): + y, x, h, w = tf.unstack(bboxes, 4, axis=-1) + return tf.stack([y - (h - 1.) / 2., x - (w - 1.) / 2., y + (h - 1.) / 2., x + (w - 1.) / 2.], axis=-1) + +def parse_by_class(image_shape, cls_pred, bboxes_pred, num_classes, select_threshold, min_size, keep_topk, nms_topk, nms_threshold): + with tf.name_scope('select_bboxes', [cls_pred, bboxes_pred]): + scores_pred = tf.nn.softmax(cls_pred) + selected_bboxes, selected_scores = select_bboxes(scores_pred, bboxes_pred, num_classes, select_threshold) + for class_ind in range(1, num_classes): + ymin, xmin, ymax, xmax = tf.unstack(selected_bboxes[class_ind], 4, axis=-1) + #ymin, xmin, ymax, xmax = tf.split(selected_bboxes[class_ind], 4, axis=-1) + #ymin, xmin, ymax, xmax = tf.squeeze(ymin), tf.squeeze(xmin), tf.squeeze(ymax), tf.squeeze(xmax) + ymin, xmin, ymax, xmax = clip_bboxes(ymin, xmin, ymax, xmax, image_shape[0], image_shape[1], 'clip_bboxes_{}'.format(class_ind)) + selected_scores[class_ind], ymin, xmin, ymax, xmax = filter_bboxes(selected_scores[class_ind], + ymin, xmin, ymax, xmax, min_size, 'filter_bboxes_{}'.format(class_ind)) + selected_scores[class_ind], ymin, xmin, ymax, xmax = sort_bboxes(selected_scores[class_ind], + ymin, xmin, ymax, xmax, keep_topk, 'sort_bboxes_{}'.format(class_ind)) + selected_bboxes[class_ind] = tf.stack([ymin, xmin, ymax, xmax], axis=-1) + selected_scores[class_ind], selected_bboxes[class_ind] = nms_bboxes_with_padding(selected_scores[class_ind], selected_bboxes[class_ind], nms_topk, nms_threshold, 'nms_bboxes_{}'.format(class_ind)) + + return selected_bboxes, selected_scores diff --git a/utility/draw_toolbox.py b/utility/draw_toolbox.py index a72ae50..8d11f22 100644 --- a/utility/draw_toolbox.py +++ b/utility/draw_toolbox.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================= import cv2 -import matplotlib.cm as mpcm from dataset import dataset_common @@ -35,8 +34,9 @@ def colors_subselect(colors, num_classes=21): else: sub_colors.append([c for c in color]) return sub_colors +# import matplotlib.cm as mpcm +# colors = colors_subselect(mpcm.plasma.colors, num_classes=21) -colors = colors_subselect(mpcm.plasma.colors, num_classes=21) colors_tableau = [(255, 255, 255), (31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120), (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150), (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148), @@ -53,8 +53,8 @@ def bboxes_draw_on_img(img, classes, scores, bboxes, thickness=2): bbox = bboxes[i] color = colors_tableau[classes[i]] # Draw bounding boxes - p1 = (int(bbox[0] * shape[0]), int(bbox[1] * shape[1])) - p2 = (int(bbox[2] * shape[0]), int(bbox[3] * shape[1])) + p1 = (int(round(bbox[0])), int(round(bbox[1]))) + p2 = (int(round(bbox[2])), int(round(bbox[3]))) if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1): continue diff --git a/voc_eval.py b/voc_eval.py index d15d848..ebccaf9 100644 --- a/voc_eval.py +++ b/voc_eval.py @@ -34,7 +34,7 @@ ... ImageSets ''' -dataset_path = '/media/rs/7A0EE8880EE83EAF/Detections/PASCAL/VOC/VOC2007TEST' +dataset_path = './dataset/VOC/VOC2007TEST' # change above path according to your system settings pred_path = './logs/predict' pred_file = 'results_{}.txt' # from 1-num_classes