Skip to content

Commit

Permalink
Fix learning rate scheduler for the CNN model of cifar10 (sql-machine…
Browse files Browse the repository at this point in the history
…-learning#2296)

* Fix learning rate scheduler for the cifar10 model

* Add an unittest for error
  • Loading branch information
workingloong authored Sep 16, 2020
1 parent a2b920c commit 83e8424
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
3 changes: 1 addition & 2 deletions elasticdl/python/elasticdl/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K

from elasticdl.proto import elasticdl_pb2
from elasticdl.python.common.constants import Mode
Expand Down Expand Up @@ -151,4 +150,4 @@ def on_train_batch_begin(self, batch, logs=None):
raise ValueError(
'The output of the "schedule" function should be float.'
)
K.set_value(self.model.optimizer.lr, K.get_value(lr))
self.model.optimizer.lr = lr
14 changes: 14 additions & 0 deletions elasticdl/python/tests/callbacks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ class LearningRateSchedulerTest(unittest.TestCase):
def _schedule(self, model_version):
return 0.2 if model_version < 2 else 0.1

def test_raise_error(self):
def _schedule(model_version):
return 1 if model_version < 2 else 2

learning_rate_scheduler = LearningRateScheduler(_schedule)
model = tf.keras.Model()
learning_rate_scheduler.set_model(model)
with self.assertRaises(ValueError):
learning_rate_scheduler.on_train_batch_begin(batch=1)

model.optimizer = tf.optimizers.SGD(0.1)
with self.assertRaises(ValueError):
learning_rate_scheduler.on_train_batch_begin(batch=1)

def test_learning_rate_scheduler(self):
learning_rate_scheduler = LearningRateScheduler(self._schedule)
model = tf.keras.Model()
Expand Down
2 changes: 1 addition & 1 deletion model_zoo/cifar10/cifar10_functional_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _schedule(model_version):
else:
return 0.001

LearningRateScheduler(_schedule)
return [LearningRateScheduler(_schedule)]


def dataset_fn(dataset, mode, _):
Expand Down
2 changes: 1 addition & 1 deletion model_zoo/cifar10/cifar10_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _schedule(model_version):
else:
return 0.001

LearningRateScheduler(_schedule)
return [LearningRateScheduler(_schedule)]


def dataset_fn(dataset, mode, _):
Expand Down

0 comments on commit 83e8424

Please sign in to comment.