-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #114 from Yoctol/more-optimizer-wrapper
More optimizer wrapper
- Loading branch information
Showing
9 changed files
with
169 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ | |
import talos.layers | ||
import talos.networks | ||
import talos.ops | ||
import talos.optimizers | ||
import talos.recurrent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__title__ = 'talos' | ||
__version__ = '1.5.1' | ||
__version__ = '1.6.0' | ||
__description__ = 'Powerful Neural Network Builder' | ||
__author__ = 'Jsaon' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
from .gradient_clipping import GradientClipping | ||
from .look_ahead import LookAhead | ||
from .radam import RAdamOptimizer | ||
from .weight_decay import WeightDecay |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import tensorflow as tf | ||
|
||
|
||
class GradientClipping(tf.train.Optimizer): | ||
|
||
_ALLOWED_CLIP_BY = {'value', 'norm'} | ||
|
||
def __init__( | ||
self, | ||
optimizer, | ||
value: float, | ||
clip_by: str = 'value', | ||
use_locking: bool = False, | ||
name: str = 'GradientClipping', | ||
): | ||
super().__init__(use_locking, name) | ||
self.optimizer = optimizer | ||
if clip_by not in self._ALLOWED_CLIP_BY: | ||
raise ValueError(f"`clip_by` should be in {self._ALLOWED_CLIP_BY}! Found {clip_by}") | ||
if value <= 0.: | ||
raise ValueError("`value` should > 0.!") | ||
|
||
self.value = value | ||
self.value_tensor = tf.convert_to_tensor(value) | ||
self.clip_by = clip_by | ||
|
||
def apply_gradients(self, grads_and_vars, global_step=None, name=None): | ||
processed_gvs = [ | ||
(self._process_grad(g), v) for g, v in grads_and_vars | ||
if g is not None | ||
] | ||
return self.optimizer.apply_gradients( | ||
processed_gvs, | ||
global_step=global_step, | ||
name=name, | ||
) | ||
|
||
def _process_grad(self, grad): | ||
value = tf.cast(self.value_tensor, grad.dtype.base_dtype) | ||
if self.clip_by == 'value': | ||
return tf.clip_by_value(grad, -value, value) | ||
elif self.clip_by == 'norm': | ||
return tf.clip_by_norm(grad, value) | ||
else: | ||
raise AssertionError("Invalid `clip_by` should be raised in `__init__`!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from ..gradient_clipping import GradientClipping | ||
|
||
|
||
def test_clip_value(sess): | ||
lr, value = 0.2, 0.1 | ||
x_val = 1. | ||
optimizer = GradientClipping( | ||
tf.train.GradientDescentOptimizer(lr), | ||
value, | ||
clip_by='value', | ||
) | ||
x = tf.Variable(x_val) | ||
y = 0.5 * x # dy/dx = 0.5 | ||
|
||
train_op = optimizer.minimize(y) | ||
sess.run(tf.variables_initializer([x])) | ||
sess.run(train_op) | ||
np.testing.assert_almost_equal( | ||
sess.run(x), | ||
x_val - lr * np.minimum(value, 0.5), | ||
) | ||
|
||
|
||
def test_clip_norm(sess): | ||
lr, value = 0.2, 0.5 | ||
x_val = np.array([3., 4.]) | ||
optimizer = GradientClipping( | ||
tf.train.GradientDescentOptimizer(lr), | ||
value, | ||
clip_by='norm', | ||
) | ||
x = tf.Variable(x_val) | ||
y = tf.nn.l2_loss(x) # dy/dx = x | ||
|
||
train_op = optimizer.minimize(y) | ||
sess.run(tf.variables_initializer([x])) | ||
sess.run(train_op) | ||
np.testing.assert_array_almost_equal( | ||
sess.run(x), | ||
x_val - lr * x_val * np.minimum(value / np.linalg.norm(x_val), 1.), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from ..weight_decay import WeightDecay | ||
|
||
|
||
def test_weight_decay(sess): | ||
lr, decay_rate = 0.2, 0.1 | ||
x_val, z_val = 2., 1. | ||
optimizer = WeightDecay( | ||
tf.train.GradientDescentOptimizer(lr), | ||
decay_rate=decay_rate, | ||
) | ||
x = tf.Variable(x_val) | ||
z = tf.Variable(z_val) | ||
y = tf.pow(x, 3) # dy/dx = 3x^2 | ||
train_op = optimizer.minimize(y, var_list=[x]) | ||
|
||
sess.run(tf.variables_initializer([x, z])) | ||
sess.run(train_op) | ||
np.testing.assert_almost_equal( | ||
sess.run(x), | ||
x_val * (1. - decay_rate) - lr * 3 * (x_val ** 2), | ||
) | ||
np.testing.assert_almost_equal(sess.run(z), z_val) # keep since it's not updated |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import tensorflow as tf | ||
|
||
|
||
class WeightDecay(tf.train.Optimizer): | ||
|
||
'''Reference: https://arxiv.org/pdf/1711.05101.pdf''' | ||
|
||
def __init__( | ||
self, | ||
optimizer, | ||
decay_rate: float, | ||
use_locking: bool = False, | ||
name: str = 'WeightDecay', | ||
): | ||
super().__init__(use_locking, name) | ||
self.optimizer = optimizer | ||
self.decay_rate = decay_rate | ||
self.decay_rate_tensor = tf.convert_to_tensor(decay_rate) | ||
|
||
def apply_gradients(self, grads_and_vars, global_step=None, name=None): | ||
var_list = [v for g, v in grads_and_vars if g is not None] | ||
|
||
decay_value = [ | ||
tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) * v | ||
for v in var_list | ||
] | ||
with tf.control_dependencies(decay_value): # cache the value before descent | ||
grad_descent_op = self.optimizer.apply_gradients( | ||
grads_and_vars, | ||
global_step=global_step, | ||
) | ||
|
||
with tf.control_dependencies([grad_descent_op]): # guarantee compute before decay. | ||
decay_op = tf.group( | ||
*[ | ||
v.assign_sub(d_v, use_locking=self._use_locking) | ||
for v, d_v in zip(var_list, decay_value) | ||
], | ||
name=name, | ||
) | ||
|
||
return decay_op |