Skip to content

Commit

Permalink
Merge branch 'master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
zxytim committed Aug 3, 2017
2 parents a7ac83f + 9a99451 commit 31d304d
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 15 deletions.
129 changes: 129 additions & 0 deletions data_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
'''
this file is modified from keras implemention of data process multi-threading,
see https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
'''
import time
import numpy as np
import threading
import multiprocessing
try:
import queue
except ImportError:
import Queue as queue


class GeneratorEnqueuer():
"""Builds a queue out of a data generator.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
# Arguments
generator: a generator function which endlessly yields data
use_multiprocessing: use multiprocessing if True, otherwise threading
wait_time: time to sleep in-between calls to `put()`
random_seed: Initial seed for workers,
will be incremented by one for each workers.
"""

def __init__(self, generator,
use_multiprocessing=False,
wait_time=0.05,
random_seed=None):
self.wait_time = wait_time
self._generator = generator
self._use_multiprocessing = use_multiprocessing
self._threads = []
self._stop_event = None
self.queue = None
self.random_seed = random_seed

def start(self, workers=1, max_queue_size=10):
"""Kicks off threads which add data from the generator into the queue.
# Arguments
workers: number of worker threads
max_queue_size: queue size
(when full, threads could block on `put()`)
"""

def data_generator_task():
while not self._stop_event.is_set():
try:
if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
generator_output = next(self._generator)
self.queue.put(generator_output)
else:
time.sleep(self.wait_time)
except Exception:
self._stop_event.set()
raise

try:
if self._use_multiprocessing:
self.queue = multiprocessing.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
else:
self.queue = queue.Queue()
self._stop_event = threading.Event()

for _ in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.random_seed)
thread = multiprocessing.Process(target=data_generator_task)
thread.daemon = True
if self.random_seed is not None:
self.random_seed += 1
else:
thread = threading.Thread(target=data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise

def is_running(self):
return self._stop_event is not None and not self._stop_event.is_set()

def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.
Should be called by the same thread which called `start()`.
# Arguments
timeout: maximum time to wait on `thread.join()`.
"""
if self.is_running():
self._stop_event.set()

for thread in self._threads:
if thread.is_alive():
if self._use_multiprocessing:
thread.terminate()
else:
thread.join(timeout)

if self._use_multiprocessing:
if self.queue is not None:
self.queue.close()

self._threads = []
self._stop_event = None
self.queue = None

def get(self):
"""Creates a generator to extract data from the queue.
Skip the data if it is `None`.
# Returns
A generator
"""
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
16 changes: 6 additions & 10 deletions icdar.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import tensorflow as tf

from data_util import GeneratorEnqueuer

tf.app.flags.DEFINE_string('training_data_path', '/data/ocr/icdar2015/',
'training dataset to use')
Expand Down Expand Up @@ -715,17 +716,14 @@ def generator(input_size=512, batch_size=32,
continue


def get_batch(num_workers=10, **kwargs):
from keras.engine.training import GeneratorEnqueuer
def get_batch(num_workers, **kwargs):
try:
enqueuer = GeneratorEnqueuer(generator(**kwargs),
use_multiprocessing=True) # , pickle_safe=True)
enqueuer.start(max_queue_size=12, workers=num_workers)
enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True)
enqueuer.start(max_queue_size=24, workers=num_workers)
generator_output = None
while True:
while enqueuer.is_running():
if not enqueuer.queue.empty():
# print self.enqueuer.queue.qsize()
generator_output = enqueuer.queue.get()
break
else:
Expand All @@ -737,8 +735,6 @@ def get_batch(num_workers=10, **kwargs):
enqueuer.stop()



if __name__ == '__main__':
gen = generator(input_size=512, batch_size=32, vis=True)
while True:
images, image_fns, score_maps, geo_maps, training_masks = next(gen)
print(len(images))
pass
4 changes: 2 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def loss(y_true_cls, y_pred_cls,
area_union = area_gt + area_pred - area_intersect
L_AABB = -tf.log((area_intersect + 1.0)/(area_union + 1.0))
L_theta = 1 - tf.cos(theta_pred - theta_gt)
tf.summary.scalar('geometry_AABB', tf.reduce_mean(L_AABB * y_true_cls))
tf.summary.scalar('geometry_theta', tf.reduce_mean(L_theta * y_true_cls))
tf.summary.scalar('geometry_AABB', tf.reduce_mean(L_AABB * y_true_cls * training_mask))
tf.summary.scalar('geometry_theta', tf.reduce_mean(L_theta * y_true_cls * training_mask))
L_g = L_AABB + 20 * L_theta

return tf.reduce_mean(L_g * y_true_cls * training_mask) + classification_loss
2 changes: 1 addition & 1 deletion multigpu_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tensorflow.contrib import slim

tf.app.flags.DEFINE_integer('input_size', 512, '')
tf.app.flags.DEFINE_integer('batch_size', 14, '')
tf.app.flags.DEFINE_integer('batch_size_per_gpu', 14, '')
tf.app.flags.DEFINE_integer('num_readers', 16, '')
tf.app.flags.DEFINE_float('learning_rate', 0.0001, '')
tf.app.flags.DEFINE_integer('max_steps', 100000, '')
Expand Down
4 changes: 2 additions & 2 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ If you want to train the model, you should provide the dataset path, in the data
and run

```
python multigpu_train.py --gpu_list=0 --input_size=512 --batch_size=14 --checkpoint_path=/tmp/east_icdar2015_resnet_v1_50_rbox/ \
python multigpu_train.py --gpu_list=0 --input_size=512 --batch_size_per_gpu=14 --checkpoint_path=/tmp/east_icdar2015_resnet_v1_50_rbox/ \
--text_scale=512 --training_data_path=/data/ocr/icdar2015/ --geometry=RBOX --learning_rate=0.0001 --num_readers=24 \
--pretrained_model_path=/tmp/resnet_v1_50.ckpt
```

If you have more than one gpu, you can pass gpu ids to gpu_list
If you have more than one gpu, you can pass gpu ids to gpu_list(like --gpu_list=0,1,2,3)

**Note: you should change the gt text file of icdar2015's filename to img_\*.txt instead of gt_img_\*.txt(or you can change the code in icdar.py), and some extra characters should be removed from the file.
See the examples in training_samples/**
Expand Down

0 comments on commit 31d304d

Please sign in to comment.