Skip to content

Commit

Permalink
Merge pull request #116 from Yoctol/sparse-weight-decay
Browse files Browse the repository at this point in the history
Sparse weight decay
  • Loading branch information
noobOriented authored Oct 16, 2019
2 parents f22957d + 7b9d73b commit 6f11aa8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 16 deletions.
2 changes: 1 addition & 1 deletion talos/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__title__ = 'talos'
__version__ = '1.6.1'
__version__ = '1.6.2'
__description__ = 'Powerful Neural Network Builder'
__author__ = 'Jsaon'
34 changes: 34 additions & 0 deletions talos/optimizers/tests/test_weight_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,37 @@ def test_weight_decay_with_filter(var_filter, sess):
sess.run(z),
z_val - lr,
) # doesn't decay since it's not in filter


@pytest.mark.parametrize('sparse_update', [True, False])
def test_sparse_weight_decay(sparse_update, sess):
lr, decay_rate = 0.2, 0.1
E_val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
x = tf.constant([[0, 1, 1]])
E = tf.Variable(E_val, dtype=tf.float32, name='E')

optimizer = WeightDecay(
tf.train.GradientDescentOptimizer(lr),
decay_rate=decay_rate,
sparse_update=sparse_update,
)
e = tf.nn.embedding_lookup(E, x)
y = tf.pow(e, 3) # dy/de = 3e^2
train_op = optimizer.minimize(y, var_list=[E])

sess.run(E.initializer)
sess.run(train_op)
if sparse_update:
expected_E_val = [
E_val[0] * (1 - decay_rate) - lr * (3 * E_val[0] ** 2), # occurrence 1
E_val[1] * (1 - 2 * decay_rate) - 2 * lr * (3 * E_val[1] ** 2), # occurrence 2
E_val[2],
]
else:
expected_E_val = [
E_val[0] * (1 - decay_rate) - lr * (3 * E_val[0] ** 2), # occurrence 1
E_val[1] * (1 - decay_rate) - 2 * lr * (3 * E_val[1] ** 2), # occurrence 2
E_val[2] * (1 - decay_rate),
]

np.testing.assert_array_almost_equal(sess.run(E), expected_E_val, decimal=4)
46 changes: 31 additions & 15 deletions talos/optimizers/weight_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,17 @@ def __init__(
use_locking: bool = False,
name: str = 'WeightDecay',
variable_filter: Union[Container[tf.Variable], Callable[[tf.Variable], bool]] = None,
sparse_update: bool = True,
):
super().__init__(use_locking, name)
self.optimizer = optimizer
self.decay_rate = decay_rate
self.decay_rate_tensor = tf.convert_to_tensor(decay_rate)
self.variable_filter = variable_filter
self.sparse_update = sparse_update

def apply_gradients(self, grads_and_vars, global_step=None, name=None):
if self.variable_filter is None:
def need_decay(var):
return True
elif hasattr(self.variable_filter, '__contains__'):
def need_decay(var):
return var in self.variable_filter
else:
need_decay = self.variable_filter

var_list = [v for g, v in grads_and_vars if g is not None and need_decay(v)]

decay_value = [
tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) * v
for v in var_list
]
var_list, decay_value = self._get_decay_pairs(grads_and_vars)
with tf.control_dependencies(decay_value): # cache the value before descent
grad_descent_op = self.optimizer.apply_gradients(
grads_and_vars,
Expand All @@ -53,3 +41,31 @@ def need_decay(var):
)

return decay_op

def _get_decay_pairs(self, grads_and_vars):
if self.variable_filter is None:
def need_decay(var):
return True
elif hasattr(self.variable_filter, '__contains__'):
def need_decay(var):
return var in self.variable_filter
else:
need_decay = self.variable_filter

var_list, decay_list = [], []
for g, v in grads_and_vars:
if g is None or not need_decay(v):
continue
var_list.append(v)
if self.sparse_update and isinstance(g, tf.IndexedSlices):
decay_value = tf.IndexedSlices(
values=tf.gather(v, g.indices),
indices=g.indices,
dense_shape=g.dense_shape,
)
else:
decay_value = v
rate = tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype)
decay_list.append(tf.math.scalar_mul(rate, decay_value))

return var_list, decay_list

0 comments on commit 6f11aa8

Please sign in to comment.