diff --git a/talos/__init__.py b/talos/__init__.py index 0fb5b79..ec2e730 100644 --- a/talos/__init__.py +++ b/talos/__init__.py @@ -9,4 +9,5 @@ import talos.layers import talos.networks import talos.ops +import talos.optimizers import talos.recurrent diff --git a/talos/__version__.py b/talos/__version__.py index b6773aa..7195647 100644 --- a/talos/__version__.py +++ b/talos/__version__.py @@ -1,4 +1,4 @@ __title__ = 'talos' -__version__ = '1.5.1' +__version__ = '1.6.0' __description__ = 'Powerful Neural Network Builder' __author__ = 'Jsaon' diff --git a/talos/optimizers/__init__.py b/talos/optimizers/__init__.py index 7a5016c..cdc718a 100644 --- a/talos/optimizers/__init__.py +++ b/talos/optimizers/__init__.py @@ -1 +1,4 @@ +from .gradient_clipping import GradientClipping +from .look_ahead import LookAhead from .radam import RAdamOptimizer +from .weight_decay import WeightDecay diff --git a/talos/optimizers/gradient_clipping.py b/talos/optimizers/gradient_clipping.py new file mode 100644 index 0000000..9f3e330 --- /dev/null +++ b/talos/optimizers/gradient_clipping.py @@ -0,0 +1,45 @@ +import tensorflow as tf + + +class GradientClipping(tf.train.Optimizer): + + _ALLOWED_CLIP_BY = {'value', 'norm'} + + def __init__( + self, + optimizer, + value: float, + clip_by: str = 'value', + use_locking: bool = False, + name: str = 'GradientClipping', + ): + super().__init__(use_locking, name) + self.optimizer = optimizer + if clip_by not in self._ALLOWED_CLIP_BY: + raise ValueError(f"`clip_by` should be in {self._ALLOWED_CLIP_BY}! Found {clip_by}") + if value <= 0.: + raise ValueError("`value` should > 0.!") + + self.value = value + self.value_tensor = tf.convert_to_tensor(value) + self.clip_by = clip_by + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + processed_gvs = [ + (self._process_grad(g), v) for g, v in grads_and_vars + if g is not None + ] + return self.optimizer.apply_gradients( + processed_gvs, + global_step=global_step, + name=name, + ) + + def _process_grad(self, grad): + value = tf.cast(self.value_tensor, grad.dtype.base_dtype) + if self.clip_by == 'value': + return tf.clip_by_value(grad, -value, value) + elif self.clip_by == 'norm': + return tf.clip_by_norm(grad, value) + else: + raise AssertionError("Invalid `clip_by` should be raised in `__init__`!") diff --git a/talos/optimizers/look_ahead.py b/talos/optimizers/look_ahead.py index 22fa66c..4b32223 100644 --- a/talos/optimizers/look_ahead.py +++ b/talos/optimizers/look_ahead.py @@ -10,7 +10,10 @@ def __init__( optimizer: tf.train.Optimizer, alpha: float = 0.5, explore_steps: int = 5, + use_locking: bool = False, + name: str = 'LookAhead', ): + super().__init__(use_locking, name) self.optimizer = optimizer self.alpha = alpha self.explore_steps = explore_steps @@ -43,6 +46,9 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None): def _slow_fast_updates(self, var_list): with tf.control_dependencies([self.ema.apply(var_list)]): # update slow return tf.group(*[ - var.assign(self.ema.average(var)) # synchronize fast by slow + var.assign( + self.ema.average(var), + use_locking=self._use_locking, + ) # synchronize fast by slow for var in var_list ]) diff --git a/talos/optimizers/radam.py b/talos/optimizers/radam.py index c68f12f..52ecc64 100644 --- a/talos/optimizers/radam.py +++ b/talos/optimizers/radam.py @@ -7,6 +7,7 @@ class RAdamOptimizer(tf.train.AdamOptimizer): + '''Reference: https://arxiv.org/abs/1908.03265v1''' # Add-On: create steps variable. diff --git a/talos/optimizers/tests/test_gradient_clipping.py b/talos/optimizers/tests/test_gradient_clipping.py new file mode 100644 index 0000000..aa97041 --- /dev/null +++ b/talos/optimizers/tests/test_gradient_clipping.py @@ -0,0 +1,44 @@ +import numpy as np +import tensorflow as tf + +from ..gradient_clipping import GradientClipping + + +def test_clip_value(sess): + lr, value = 0.2, 0.1 + x_val = 1. + optimizer = GradientClipping( + tf.train.GradientDescentOptimizer(lr), + value, + clip_by='value', + ) + x = tf.Variable(x_val) + y = 0.5 * x # dy/dx = 0.5 + + train_op = optimizer.minimize(y) + sess.run(tf.variables_initializer([x])) + sess.run(train_op) + np.testing.assert_almost_equal( + sess.run(x), + x_val - lr * np.minimum(value, 0.5), + ) + + +def test_clip_norm(sess): + lr, value = 0.2, 0.5 + x_val = np.array([3., 4.]) + optimizer = GradientClipping( + tf.train.GradientDescentOptimizer(lr), + value, + clip_by='norm', + ) + x = tf.Variable(x_val) + y = tf.nn.l2_loss(x) # dy/dx = x + + train_op = optimizer.minimize(y) + sess.run(tf.variables_initializer([x])) + sess.run(train_op) + np.testing.assert_array_almost_equal( + sess.run(x), + x_val - lr * x_val * np.minimum(value / np.linalg.norm(x_val), 1.), + ) diff --git a/talos/optimizers/tests/test_weight_decay.py b/talos/optimizers/tests/test_weight_decay.py new file mode 100644 index 0000000..510044f --- /dev/null +++ b/talos/optimizers/tests/test_weight_decay.py @@ -0,0 +1,25 @@ +import numpy as np +import tensorflow as tf + +from ..weight_decay import WeightDecay + + +def test_weight_decay(sess): + lr, decay_rate = 0.2, 0.1 + x_val, z_val = 2., 1. + optimizer = WeightDecay( + tf.train.GradientDescentOptimizer(lr), + decay_rate=decay_rate, + ) + x = tf.Variable(x_val) + z = tf.Variable(z_val) + y = tf.pow(x, 3) # dy/dx = 3x^2 + train_op = optimizer.minimize(y, var_list=[x]) + + sess.run(tf.variables_initializer([x, z])) + sess.run(train_op) + np.testing.assert_almost_equal( + sess.run(x), + x_val * (1. - decay_rate) - lr * 3 * (x_val ** 2), + ) + np.testing.assert_almost_equal(sess.run(z), z_val) # keep since it's not updated diff --git a/talos/optimizers/weight_decay.py b/talos/optimizers/weight_decay.py new file mode 100644 index 0000000..7675a5f --- /dev/null +++ b/talos/optimizers/weight_decay.py @@ -0,0 +1,42 @@ +import tensorflow as tf + + +class WeightDecay(tf.train.Optimizer): + + '''Reference: https://arxiv.org/pdf/1711.05101.pdf''' + + def __init__( + self, + optimizer, + decay_rate: float, + use_locking: bool = False, + name: str = 'WeightDecay', + ): + super().__init__(use_locking, name) + self.optimizer = optimizer + self.decay_rate = decay_rate + self.decay_rate_tensor = tf.convert_to_tensor(decay_rate) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + var_list = [v for g, v in grads_and_vars if g is not None] + + decay_value = [ + tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) * v + for v in var_list + ] + with tf.control_dependencies(decay_value): # cache the value before descent + grad_descent_op = self.optimizer.apply_gradients( + grads_and_vars, + global_step=global_step, + ) + + with tf.control_dependencies([grad_descent_op]): # guarantee compute before decay. + decay_op = tf.group( + *[ + v.assign_sub(d_v, use_locking=self._use_locking) + for v, d_v in zip(var_list, decay_value) + ], + name=name, + ) + + return decay_op