Skip to content

Commit

Permalink
Merge pull request #113 from Yoctol/lookahead
Browse files Browse the repository at this point in the history
impl LookAhead wrapper
  • Loading branch information
noobOriented authored Aug 28, 2019
2 parents 76107bf + e0e0df3 commit 57d8d13
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 3 deletions.
48 changes: 48 additions & 0 deletions talos/optimizers/look_ahead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import tensorflow as tf


class LookAhead(tf.train.Optimizer):

'''Reference: https://arxiv.org/abs/1907.08610'''

def __init__(
self,
optimizer: tf.train.Optimizer,
alpha: float = 0.5,
explore_steps: int = 5,
):
self.optimizer = optimizer
self.alpha = alpha
self.explore_steps = explore_steps
self.ema = tf.train.ExponentialMovingAverage(
decay=1. - alpha,
name="LookAheadSlowVariables",
)

def apply_gradients(self, grads_and_vars, global_step=None, name=None):
if global_step is None:
global_step = tf.train.get_or_create_global_step() # initial 0

# global_step will be updated here
update_op = self.optimizer.apply_gradients(grads_and_vars, global_step=global_step)
var_list = [v for g, v in grads_and_vars if g is not None]

with tf.control_dependencies([update_op]):
finish_op = tf.cond(
tf.equal(
tf.mod(global_step, self.explore_steps),
0,
),
lambda: self._slow_fast_updates(var_list),
tf.no_op,
name=name,
)

return finish_op

def _slow_fast_updates(self, var_list):
with tf.control_dependencies([self.ema.apply(var_list)]): # update slow
return tf.group(*[
var.assign(self.ema.average(var)) # synchronize fast by slow
for var in var_list
])
1 change: 1 addition & 0 deletions talos/optimizers/radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


class RAdamOptimizer(tf.train.AdamOptimizer):
'''Reference: https://arxiv.org/abs/1908.03265v1'''

# Add-On: create steps variable.
def _create_slots(self, var_list):
Expand Down
41 changes: 41 additions & 0 deletions talos/optimizers/tests/test_look_ahead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import tensorflow as tf

from ..look_ahead import LookAhead


def test_look_ahead(sess):
alpha, lr = 0.2, 0.1
explore_steps = 5
slow_val, grad_val = 1., 2.
opt = LookAhead(
tf.train.GradientDescentOptimizer(lr),
alpha=alpha,
explore_steps=explore_steps,
)
with tf.variable_scope('test_look_ahead'):
x = tf.Variable(slow_val)
update_x = opt.minimize(grad_val * x) # constant grad

sess.run(tf.variables_initializer(
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='test_look_ahead'),
))

for _ in range(5):
fast_val = slow_val
for _ in range(explore_steps - 1):
sess.run(update_x)
fast_val -= lr * grad_val

np.testing.assert_almost_equal(sess.run(x), fast_val)

sess.run(update_x)
fast_val -= lr * grad_val

# step % explore_steps == 0, fast interpolates with slow
x_val = sess.run(x)
np.testing.assert_almost_equal(
x_val,
slow_val * (1 - alpha) + fast_val * alpha,
)
slow_val = x_val
7 changes: 4 additions & 3 deletions talos/optimizers/tests/test_radam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


def test_radam(sess):
radam_opt = RAdamOptimizer(0.1)
lr = 0.1
radam_opt = RAdamOptimizer(lr)
with tf.variable_scope('test_radam'):
x = tf.Variable(1.)
update_x = radam_opt.minimize(2 * x) # constant grad 2
Expand All @@ -17,12 +18,12 @@ def test_radam(sess):
sess.run(update_x)

x_val = sess.run(x)
np.testing.assert_almost_equal(x_val, 1. - 4 * 0.1 * 2) # without adaptive gradient
np.testing.assert_almost_equal(x_val, 1. - 4 * lr * 2) # without adaptive gradient

# N_sma > 4 now
rectifier, _ = sess.run([radam_opt.rectifier, update_x])
new_x_val = sess.run(x)
np.testing.assert_almost_equal(
new_x_val,
x_val - 0.1 * rectifier * 2 / 2, # with adaptive gradient: divide by sqrt(v)
x_val - lr * rectifier * 2 / 2, # with adaptive gradient: divide by sqrt(v)
)

0 comments on commit 57d8d13

Please sign in to comment.