Skip to content

Commit

Permalink
Do not scale learning rate of optimizer. (sql-machine-learning#2312)
Browse files Browse the repository at this point in the history
* Don't scale learning rate

* Remove unused variables

* Add an ut for training_process in AllReduceTrainer

* Pre-commit

* Fix unit test
  • Loading branch information
workingloong authored Sep 29, 2020
1 parent 8e9fa0a commit 0daf37b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 15 deletions.
7 changes: 7 additions & 0 deletions elasticdl/python/tests/allreduce_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def setUp(self):
model.loss = loss
self._trainer = AllReduceTrainer(master_client, "", model)

def test_training_process(self):
self._trainer.init_horovod_if_needed()
features = tf.constant([[0.5], [0.6], [0.7]])
labels = tf.constant([[1.0], [0.0], [1.0]])
loss = self._trainer._training_process(features, labels)
self.assertIsNotNone(loss)

def test_train_minibatch(self):
self._trainer.init_horovod_if_needed()
features = tf.constant([[0.5], [0.6], [0.7]])
Expand Down
16 changes: 1 addition & 15 deletions elasticdl/python/worker/allreduce_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,9 @@ def __init__(self, master_client, master_addr, model):
self._rendezvous_id = None
self._need_broadcast = True
self._var_created = False
self._world_size = None
self._set_optimizer(model.optimizer)
self._optimizer = model.optimizer
self._set_horovod_env()

def _set_optimizer(self, optimizer):
self._optimizer = optimizer
self._lr = optimizer._hyper["learning_rate"]
self._optimizer.lr = self._get_learning_rate

def _get_learning_rate(self):
scaler = 1 if self._world_size is None else self._world_size
lr = self._lr
if callable(lr):
lr = lr()
return lr * scaler

@tf.function
def _training_process(self, features, labels):
with tf.GradientTape() as tape:
Expand Down Expand Up @@ -143,7 +130,6 @@ def init_horovod_if_needed(self):
hvd.shutdown()
hvd.init()
os.environ[HorovodEnv.ELASTIC] = str(1)
self._world_size = hvd.size()
self._rendezvous_id = rank_response.rendezvous_id
self._need_broadcast = True

Expand Down

0 comments on commit 0daf37b

Please sign in to comment.