From 0637208f21df656d48022df588a6348b162cc702 Mon Sep 17 00:00:00 2001 From: noobOriented Date: Wed, 16 Oct 2019 11:48:49 +0800 Subject: [PATCH 1/5] extract weight decay value to function --- talos/optimizers/weight_decay.py | 33 +++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/talos/optimizers/weight_decay.py b/talos/optimizers/weight_decay.py index 4ad0fb3..4f16490 100644 --- a/talos/optimizers/weight_decay.py +++ b/talos/optimizers/weight_decay.py @@ -22,21 +22,7 @@ def __init__( self.variable_filter = variable_filter def apply_gradients(self, grads_and_vars, global_step=None, name=None): - if self.variable_filter is None: - def need_decay(var): - return True - elif hasattr(self.variable_filter, '__contains__'): - def need_decay(var): - return var in self.variable_filter - else: - need_decay = self.variable_filter - - var_list = [v for g, v in grads_and_vars if g is not None and need_decay(v)] - - decay_value = [ - tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) * v - for v in var_list - ] + var_list, decay_value = self._get_decay_pairs(grads_and_vars) with tf.control_dependencies(decay_value): # cache the value before descent grad_descent_op = self.optimizer.apply_gradients( grads_and_vars, @@ -53,3 +39,20 @@ def need_decay(var): ) return decay_op + + def _get_decay_pairs(self, grads_and_vars): + if self.variable_filter is None: + def need_decay(var): + return True + elif hasattr(self.variable_filter, '__contains__'): + def need_decay(var): + return var in self.variable_filter + else: + need_decay = self.variable_filter + + var_list = [v for g, v in grads_and_vars if g is not None and need_decay(v)] + decay_value = [ + tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) * v + for v in var_list + ] + return var_list, decay_value From ae0925c43cdcc0d97c37ffac2c8025b4f4f0b8c4 Mon Sep 17 00:00:00 2001 From: noobOriented Date: Wed, 16 Oct 2019 12:07:19 +0800 Subject: [PATCH 2/5] impl sparse_update --- talos/optimizers/weight_decay.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/talos/optimizers/weight_decay.py b/talos/optimizers/weight_decay.py index 4f16490..7364c9e 100644 --- a/talos/optimizers/weight_decay.py +++ b/talos/optimizers/weight_decay.py @@ -14,12 +14,14 @@ def __init__( use_locking: bool = False, name: str = 'WeightDecay', variable_filter: Union[Container[tf.Variable], Callable[[tf.Variable], bool]] = None, + sparse_update: bool = True, ): super().__init__(use_locking, name) self.optimizer = optimizer self.decay_rate = decay_rate self.decay_rate_tensor = tf.convert_to_tensor(decay_rate) self.variable_filter = variable_filter + self.sparse_update = sparse_update def apply_gradients(self, grads_and_vars, global_step=None, name=None): var_list, decay_value = self._get_decay_pairs(grads_and_vars) @@ -50,9 +52,20 @@ def need_decay(var): else: need_decay = self.variable_filter - var_list = [v for g, v in grads_and_vars if g is not None and need_decay(v)] - decay_value = [ - tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) * v - for v in var_list - ] - return var_list, decay_value + var_list, decay_list = [], [] + for g, v in grads_and_vars: + if g is None or not need_decay(v): + continue + var_list.append(v) + if self.sparse_update and isinstance(g, tf.IndexedSlices): + decay_value = tf.IndexedSlices( + values=tf.gather(v, g.indices), + indices=g.indices, + dense_shape=g.dense_shape, + ) + else: + decay_value = v + rate = tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) + decay_list.append(tf.math.scalar_mul(rate, decay_value)) + + return var_list, decay_list From 145906105a8f7a20dbf70468dbe996a05a1d2a58 Mon Sep 17 00:00:00 2001 From: noobOriented Date: Wed, 16 Oct 2019 13:37:29 +0800 Subject: [PATCH 3/5] test sparse weight decay --- talos/optimizers/tests/test_weight_decay.py | 28 +++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/talos/optimizers/tests/test_weight_decay.py b/talos/optimizers/tests/test_weight_decay.py index fbe6e4e..fccc21e 100644 --- a/talos/optimizers/tests/test_weight_decay.py +++ b/talos/optimizers/tests/test_weight_decay.py @@ -52,3 +52,31 @@ def test_weight_decay_with_filter(var_filter, sess): sess.run(z), z_val - lr, ) # doesn't decay since it's not in filter + + +def test_sparse_weight_decay(sess): + lr, decay_rate = 0.2, 0.1 + E_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + x = tf.constant([[0, 1, 1]]) + E = tf.Variable(E_val, dtype=tf.float32, name='E') + + optimizer = WeightDecay( + tf.train.GradientDescentOptimizer(lr), + decay_rate=decay_rate, + sparse_update=True, + ) + e = tf.nn.embedding_lookup(E, x) + y = tf.pow(e, 3) # dy/de = 3e^2 + train_op = optimizer.minimize(y, var_list=[E]) + + sess.run(E.initializer) + sess.run(train_op) + np.testing.assert_array_almost_equal( + sess.run(E), + [ + E_val[0] * (1 - decay_rate) - lr * (3 * E_val[0] ** 2), # occurrence 1 + E_val[1] * (1 - 2 * decay_rate) - 2 * lr * (3 * E_val[1] ** 2), # occurrence 2 + E_val[2], + ], + decimal=4, + ) From 675a39b77e591d864d1d57a33bfa3162dbe6517b Mon Sep 17 00:00:00 2001 From: noobOriented Date: Wed, 16 Oct 2019 13:39:10 +0800 Subject: [PATCH 4/5] version 1.6.2 --- talos/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/talos/__version__.py b/talos/__version__.py index 5ba1e57..c3e029b 100644 --- a/talos/__version__.py +++ b/talos/__version__.py @@ -1,4 +1,4 @@ __title__ = 'talos' -__version__ = '1.6.1' +__version__ = '1.6.2' __description__ = 'Powerful Neural Network Builder' __author__ = 'Jsaon' From 7b9d73b30ded893c5c60d1e41e9da7c13fbd1446 Mon Sep 17 00:00:00 2001 From: noobOriented Date: Wed, 16 Oct 2019 13:43:20 +0800 Subject: [PATCH 5/5] add test case for sparse_update=False --- talos/optimizers/tests/test_weight_decay.py | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/talos/optimizers/tests/test_weight_decay.py b/talos/optimizers/tests/test_weight_decay.py index fccc21e..2889c5e 100644 --- a/talos/optimizers/tests/test_weight_decay.py +++ b/talos/optimizers/tests/test_weight_decay.py @@ -54,7 +54,8 @@ def test_weight_decay_with_filter(var_filter, sess): ) # doesn't decay since it's not in filter -def test_sparse_weight_decay(sess): +@pytest.mark.parametrize('sparse_update', [True, False]) +def test_sparse_weight_decay(sparse_update, sess): lr, decay_rate = 0.2, 0.1 E_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) x = tf.constant([[0, 1, 1]]) @@ -63,7 +64,7 @@ def test_sparse_weight_decay(sess): optimizer = WeightDecay( tf.train.GradientDescentOptimizer(lr), decay_rate=decay_rate, - sparse_update=True, + sparse_update=sparse_update, ) e = tf.nn.embedding_lookup(E, x) y = tf.pow(e, 3) # dy/de = 3e^2 @@ -71,12 +72,17 @@ def test_sparse_weight_decay(sess): sess.run(E.initializer) sess.run(train_op) - np.testing.assert_array_almost_equal( - sess.run(E), - [ + if sparse_update: + expected_E_val = [ E_val[0] * (1 - decay_rate) - lr * (3 * E_val[0] ** 2), # occurrence 1 E_val[1] * (1 - 2 * decay_rate) - 2 * lr * (3 * E_val[1] ** 2), # occurrence 2 E_val[2], - ], - decimal=4, - ) + ] + else: + expected_E_val = [ + E_val[0] * (1 - decay_rate) - lr * (3 * E_val[0] ** 2), # occurrence 1 + E_val[1] * (1 - decay_rate) - 2 * lr * (3 * E_val[1] ** 2), # occurrence 2 + E_val[2] * (1 - decay_rate), + ] + + np.testing.assert_array_almost_equal(sess.run(E), expected_E_val, decimal=4)