From c018d838d3b3ee2a4770f60fdfb44e754e8290bd Mon Sep 17 00:00:00 2001 From: Ivan Capalija Date: Sun, 30 Oct 2016 00:53:50 +0200 Subject: [PATCH] Stop adding new ops to the graph object with every snapshot and lr change. --- lib/fast_rcnn/train.py | 29 ++++++++++++++++------------- lib/networks/VGGnet_train.py | 17 +++++++++++++++-- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/lib/fast_rcnn/train.py b/lib/fast_rcnn/train.py index d9792381..c80159cc 100644 --- a/lib/fast_rcnn/train.py +++ b/lib/fast_rcnn/train.py @@ -52,13 +52,14 @@ def snapshot(self, sess, iter): weights = tf.get_variable("weights") biases = tf.get_variable("biases") - orig_0 = weights.eval() - orig_1 = biases.eval() + orig_0 = weights.eval() + orig_1 = biases.eval() - # scale and shift with bbox reg unnormalization; then save snapshot - weights_shape = weights.get_shape().as_list() - sess.run(weights.assign(orig_0 * np.tile(self.bbox_stds, (weights_shape[0],1)))) - sess.run(biases.assign(orig_1 * self.bbox_stds + self.bbox_means)) + # scale and shift with bbox reg unnormalization; then save snapshot + weights_shape = weights.get_shape().as_list() + + sess.run(net.bbox_weights_assign, feed_dict={net.bbox_weights: orig_0 * np.tile(self.bbox_stds, (weights_shape[0], 1))}) + sess.run(net.bbox_bias_assign, feed_dict={net.bbox_biases: orig_1 * self.bbox_stds + self.bbox_means}) if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) @@ -73,9 +74,10 @@ def snapshot(self, sess, iter): print 'Wrote snapshot to: {:s}'.format(filename) if cfg.TRAIN.BBOX_REG and net.layers.has_key('bbox_pred'): - # restore net to original state - sess.run(weights.assign(orig_0)) - sess.run(biases.assign(orig_1)) + with tf.variable_scope('bbox_pred', reuse=True): + # restore net to original state + sess.run(net.bbox_weights_assign, feed_dict={net.bbox_weights: orig_0}) + sess.run(net.bbox_bias_assign, feed_dict={net.bbox_biases: orig_1}) def train_model(self, sess, max_iters): @@ -123,6 +125,8 @@ def train_model(self, sess, max_iters): # optimizer lr = tf.Variable(cfg.TRAIN.LEARNING_RATE, trainable=False) + lr_placeholder = tf.placeholder(tf.float32) + lr_assign = lr.assign(lr_placeholder) momentum = cfg.TRAIN.MOMENTUM train_op = tf.train.MomentumOptimizer(lr, momentum).minimize(loss) @@ -137,10 +141,9 @@ def train_model(self, sess, max_iters): timer = Timer() for iter in range(max_iters): # learning rate - if iter >= cfg.TRAIN.STEPSIZE: - sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE * cfg.TRAIN.GAMMA)) - else: - sess.run(tf.assign(lr, cfg.TRAIN.LEARNING_RATE)) + if (iter+1) % cfg.TRAIN.STEPSIZE == 0: + new_lr = lr.eval() * cfg.TRAIN.GAMMA + sess.run(lr_assign, feed_dict={lr_placeholder: new_lr}) # get one batch blobs = data_layer.forward() diff --git a/lib/networks/VGGnet_train.py b/lib/networks/VGGnet_train.py index 8c8cb6a1..afb973ce 100644 --- a/lib/networks/VGGnet_train.py +++ b/lib/networks/VGGnet_train.py @@ -17,8 +17,21 @@ def __init__(self, trainable=True): self.keep_prob = tf.placeholder(tf.float32) self.layers = dict({'data':self.data, 'im_info':self.im_info, 'gt_boxes':self.gt_boxes}) self.trainable = trainable + + # setup self.setup() + # create ops and placeholders for bbox normalization process + with tf.variable_scope('bbox_pred', reuse=True): + weights = tf.get_variable("weights") + biases = tf.get_variable("biases") + + self.bbox_weights = tf.placeholder(weights.dtype, shape=weights.get_shape()) + self.bbox_biases = tf.placeholder(biases.dtype, shape=biases.get_shape()) + + self.bbox_weights_assign = weights.assign(self.bbox_weights) + self.bbox_bias_assign = biases.assign(self.bbox_biases) + def setup(self): (self.feed('data') .conv(3, 3, 64, 1, 1, name='conv1_1', trainable=False) @@ -42,7 +55,7 @@ def setup(self): (self.feed('conv5_3') .conv(3,3,512,1,1,name='rpn_conv/3x3') .conv(1,1,len(anchor_scales)*3*2 ,1 , 1, padding='VALID', relu = False, name='rpn_cls_score')) - + (self.feed('rpn_cls_score','gt_boxes','im_info','data') .anchor_target_layer(_feat_stride, anchor_scales, name = 'rpn-data' )) @@ -66,7 +79,7 @@ def setup(self): .proposal_target_layer(n_classes,name = 'roi-data')) - #========= RCNN ============ + #========= RCNN ============ (self.feed('conv5_3', 'roi-data') .roi_pool(7, 7, 1.0/16, name='pool_5') .fc(4096, name='fc6')