diff --git a/talos/__version__.py b/talos/__version__.py index 7195647..5ba1e57 100644 --- a/talos/__version__.py +++ b/talos/__version__.py @@ -1,4 +1,4 @@ __title__ = 'talos' -__version__ = '1.6.0' +__version__ = '1.6.1' __description__ = 'Powerful Neural Network Builder' __author__ = 'Jsaon' diff --git a/talos/optimizers/tests/test_weight_decay.py b/talos/optimizers/tests/test_weight_decay.py index 510044f..fbe6e4e 100644 --- a/talos/optimizers/tests/test_weight_decay.py +++ b/talos/optimizers/tests/test_weight_decay.py @@ -1,3 +1,5 @@ +import pytest + import numpy as np import tensorflow as tf @@ -14,7 +16,7 @@ def test_weight_decay(sess): x = tf.Variable(x_val) z = tf.Variable(z_val) y = tf.pow(x, 3) # dy/dx = 3x^2 - train_op = optimizer.minimize(y, var_list=[x]) + train_op = optimizer.minimize(y, var_list=[x, z]) sess.run(tf.variables_initializer([x, z])) sess.run(train_op) @@ -23,3 +25,30 @@ def test_weight_decay(sess): x_val * (1. - decay_rate) - lr * 3 * (x_val ** 2), ) np.testing.assert_almost_equal(sess.run(z), z_val) # keep since it's not updated + + +@pytest.mark.parametrize('var_filter', ['collection', 'callable']) +def test_weight_decay_with_filter(var_filter, sess): + lr, decay_rate = 0.2, 0.1 + x_val, z_val = 2., 1. + x = tf.Variable(x_val, name='x') + z = tf.Variable(z_val, name='z') + + optimizer = WeightDecay( + tf.train.GradientDescentOptimizer(lr), + decay_rate=decay_rate, + variable_filter={x} if var_filter == 'collection' else lambda v: 'x' in v.name, + ) + y = tf.pow(x, 3) + z # dy/dx = 3x^2, dy/dz = 1 + train_op = optimizer.minimize(y, var_list=[x, z]) + + sess.run(tf.variables_initializer([x, z])) + sess.run(train_op) + np.testing.assert_almost_equal( + sess.run(x), + x_val * (1. - decay_rate) - lr * 3 * (x_val ** 2), + ) + np.testing.assert_almost_equal( + sess.run(z), + z_val - lr, + ) # doesn't decay since it's not in filter diff --git a/talos/optimizers/weight_decay.py b/talos/optimizers/weight_decay.py index 7675a5f..4ad0fb3 100644 --- a/talos/optimizers/weight_decay.py +++ b/talos/optimizers/weight_decay.py @@ -1,3 +1,5 @@ +from typing import Callable, Container, Union + import tensorflow as tf @@ -11,14 +13,25 @@ def __init__( decay_rate: float, use_locking: bool = False, name: str = 'WeightDecay', + variable_filter: Union[Container[tf.Variable], Callable[[tf.Variable], bool]] = None, ): super().__init__(use_locking, name) self.optimizer = optimizer self.decay_rate = decay_rate self.decay_rate_tensor = tf.convert_to_tensor(decay_rate) + self.variable_filter = variable_filter def apply_gradients(self, grads_and_vars, global_step=None, name=None): - var_list = [v for g, v in grads_and_vars if g is not None] + if self.variable_filter is None: + def need_decay(var): + return True + elif hasattr(self.variable_filter, '__contains__'): + def need_decay(var): + return var in self.variable_filter + else: + need_decay = self.variable_filter + + var_list = [v for g, v in grads_and_vars if g is not None and need_decay(v)] decay_value = [ tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) * v