diff --git a/talos/__version__.py b/talos/__version__.py index c3e029b..c966ab5 100644 --- a/talos/__version__.py +++ b/talos/__version__.py @@ -1,4 +1,4 @@ __title__ = 'talos' -__version__ = '1.6.2' +__version__ = '1.6.3' __description__ = 'Powerful Neural Network Builder' __author__ = 'Jsaon' diff --git a/talos/optimizers/spectral_norm.py b/talos/optimizers/spectral_norm.py new file mode 100644 index 0000000..b9a8326 --- /dev/null +++ b/talos/optimizers/spectral_norm.py @@ -0,0 +1,99 @@ +from typing import Callable, Container, Union + +import tensorflow as tf + + +class SpectralWeightDecay(tf.train.Optimizer): + ''' + References: + 1. Decouple Weight Decay https://arxiv.org/abs/1711.05101 + 2. Spectral Regularization https://arxiv.org/abs/1705.10941 + ''' + + def __init__( + self, + optimizer, + decay_rate: float, + use_locking: bool = False, + name: str = 'SpectralWeightDecay', + variable_filter: Union[Container[tf.Variable], Callable[[tf.Variable], bool]] = None, + ): + super().__init__(use_locking, name) + self.optimizer = optimizer + self.decay_rate = decay_rate + self.decay_rate_tensor = tf.convert_to_tensor(decay_rate) + self.variable_filter = variable_filter + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + var_list, decay_value, update_list = self._get_decay_trips(grads_and_vars) + with tf.control_dependencies(decay_value): # cache the value before descent + grad_descent_op = self.optimizer.apply_gradients( + grads_and_vars, + global_step=global_step, + ) + + # guarantee compute before decay. + with tf.control_dependencies([grad_descent_op]): + decay_op = tf.group( + *[ + v.assign_sub(d_v, use_locking=self._use_locking) + for v, d_v in zip(var_list, decay_value) + ], + *update_list, + name=name, + ) + + return decay_op + + def _get_decay_trips(self, grads_and_vars): + if self.variable_filter is None: + def need_decay(var): + return 'kernel' in v.name and v.shape.ndims >= 2 + elif hasattr(self.variable_filter, '__contains__'): + def need_decay(var): + return var in self.variable_filter + else: + need_decay = self.variable_filter + + var_list, decay_list, update_list = [], [], [] + for g, v in grads_and_vars: + if g is None or not need_decay(v): + continue + if v.shape.ndims < 2: + raise ValueError("Can't apply spectral norm on variable with rank < 2!") + decay_value, update_u = self._build_spectral_norm_variables(v) + rate = tf.cast(self.decay_rate_tensor, dtype=v.dtype.base_dtype) + var_list.append(v) + decay_list.append(rate * decay_value) + update_list.append(update_u) + + return var_list, decay_list, update_list + + def _build_spectral_norm_variables(self, kernel): + kernel_matrix = to_rank2(kernel) # shape (U, V) + u = self._get_or_make_slot_with_initializer( + kernel, + initializer=tf.keras.initializers.lecun_normal(), # unit vector + shape=kernel_matrix.shape[:1], + dtype=kernel_matrix.dtype, + slot_name='u', + op_name=self._name, + ) # shape (U) + v = tf.nn.l2_normalize(tf.linalg.matvec(kernel_matrix, u, transpose_a=True)) # shape (V) + Wv = tf.linalg.matvec(kernel_matrix, v) # shape (U) + # NOTE + # sigma = u^T W v -> dsigma / dW = uv^T + # 0.5 dsigma^2 / dW = sigma u v^T = (sigma u) v^T = Wv v^T + decay_value = Wv[:, tf.newaxis] * v # shape (U, V) + if kernel.shape.ndims > 2: + decay_value = tf.reshape(decay_value, kernel.shape) + + new_u = tf.nn.l2_normalize(Wv) # shape (U) + update_u = tf.assign(u, new_u) + return decay_value, update_u + + +def to_rank2(tensor: tf.Tensor): + if tensor.shape.ndims > 2: + return tf.reshape(tensor, [-1, tensor.shape[-1].value]) + return tensor diff --git a/talos/optimizers/tests/test_spectral_norm.py b/talos/optimizers/tests/test_spectral_norm.py new file mode 100644 index 0000000..ba6f823 --- /dev/null +++ b/talos/optimizers/tests/test_spectral_norm.py @@ -0,0 +1,56 @@ +import pytest + +import numpy as np +import tensorflow as tf + +from ..spectral_norm import SpectralWeightDecay + + +def test_spectral_weight_decay_apply_low_rank_by_default(sess): + lr, decay_rate = 0.2, 0.1 + x_val = 2. + optimizer = SpectralWeightDecay( + tf.train.GradientDescentOptimizer(lr), + decay_rate=decay_rate, + ) + x = tf.Variable(x_val, name='x') # rank 0 + y = tf.pow(x, 3) # dy/dx = 3x^2 + train_op = optimizer.minimize(y, var_list=[x]) + + sess.run(tf.variables_initializer([x])) + sess.run(train_op) + np.testing.assert_almost_equal( + sess.run(x), + x_val - lr * 3 * (x_val ** 2), + ) + + +@pytest.mark.parametrize('shape', [ + [3, 4], + [3, 4, 5], +]) +def test_spectral_weight_decay(shape, sess): + lr, decay_rate = 0.2, 0.1 + optimizer = SpectralWeightDecay( + tf.train.GradientDescentOptimizer(lr), + decay_rate=decay_rate, + ) + + W = tf.Variable(np.random.rand(*shape), name='kernel') + y = tf.reduce_sum(W) # dy/dx = 1 + train_op = optimizer.minimize(y, var_list=[W]) + u = optimizer.get_slot(W, 'u') + + assert u.shape.as_list() == [np.prod(shape[:-1])] + + sess.run(tf.variables_initializer([W, u])) + W_val, u_val = sess.run([W, u]) + v_val = W_val.reshape([-1, shape[-1]]).T @ u_val + v_val /= np.linalg.norm(v_val) + decay_val = decay_rate * np.expand_dims(W_val @ v_val, -1) * v_val + + sess.run(train_op) + np.testing.assert_almost_equal( + sess.run(W), + W_val - decay_val - lr * 1., + ) diff --git a/talos/optimizers/weight_decay.py b/talos/optimizers/weight_decay.py index 7364c9e..88932ff 100644 --- a/talos/optimizers/weight_decay.py +++ b/talos/optimizers/weight_decay.py @@ -14,7 +14,7 @@ def __init__( use_locking: bool = False, name: str = 'WeightDecay', variable_filter: Union[Container[tf.Variable], Callable[[tf.Variable], bool]] = None, - sparse_update: bool = True, + sparse_update: bool = False, ): super().__init__(use_locking, name) self.optimizer = optimizer diff --git a/talos/spectral_norm/spectral_norm.py b/talos/spectral_norm/spectral_norm.py index 1ab2804..ff0b0e7 100644 --- a/talos/spectral_norm/spectral_norm.py +++ b/talos/spectral_norm/spectral_norm.py @@ -2,11 +2,7 @@ from typing import Set import tensorflow as tf -from tensorflow.python.keras.layers.cudnn_recurrent import ( - _CuDNNRNN, - CuDNNGRU, - CuDNNLSTM, -) +from tensorflow.python.keras.layers.cudnn_recurrent import _CuDNNRNN _WEIGHTS_VARIABLE_NAME = "kernel" @@ -36,21 +32,7 @@ def add_spectral_norm(layer: tf.layers.Layer): def add_spectral_norm_for_layer( layer: tf.layers.Layer, kernel_name: Set[str] = None, - ): - if isinstance(layer, (CuDNNGRU, tf.keras.layers.GRUCell)): - weight_split = 3 - elif isinstance(layer, (CuDNNLSTM, tf.keras.layers.LSTMCell)): - weight_split = 4 - else: - weight_split = 1 - - _add_spectral_norm_for_layer(layer, kernel_name, weight_split=weight_split) - - -def _add_spectral_norm_for_layer( - layer: tf.layers.Layer, - kernel_name: Set[str] = None, - weight_split: int = 1, + lipschitz: float = 1., ): if layer.built: raise ValueError("Can't add spectral norm on built layer!") @@ -72,41 +54,18 @@ def new_add_weight(self, name=None, shape=None, **kwargs): if len(shape) < 2: raise ValueError("Can't apply spectral norm on variable rank < 2!") - kernel_matrix = to_rank2(kernel) # shape (U, V) - if weight_split > 1: - assert shape[1] % weight_split == 0 - split_kernel = tf.split(kernel_matrix, weight_split, axis=1) - sn_list = [] - for i, sub_kernel in enumerate(split_kernel): - sn_val, update_u = _build_spectral_norm_variables( - f"{name}_{i}", - sub_kernel, - original_add_weight, - ) - sn_list.append( - tf.fill([sub_kernel.shape[1].value], value=sn_val), - ) # shape (V // split) - self.add_update(update_u) - - spectral_norm = tf.concat(sn_list, axis=0) # shape (V) - else: - spectral_norm, update_u = _build_spectral_norm_variables( - name, kernel_matrix, original_add_weight, - ) # shape () - self.add_update(update_u) - - normed_kernel = tf.truediv( - kernel, - spectral_norm + tf.keras.backend.epsilon(), - name=f'{name}_sn', - ) + spectral_norm, update_u = _build_spectral_norm_variables(name, kernel, original_add_weight) + self.add_update(update_u) + + scale = lipschitz / (spectral_norm + tf.keras.backend.epsilon()) + normed_kernel = tf.multiply(kernel, scale, name=f'{name}_sn') return normed_kernel layer.add_weight = types.MethodType(new_add_weight, layer) -def _build_spectral_norm_variables(name, kernel, add_weight_func): - assert kernel.shape.ndims == 2 +def _build_spectral_norm_variables(name, kernel, add_weight_func=tf.get_variable): + kernel = to_rank2(kernel) # shape (U, V) u_vector = add_weight_func( name=f'{name}/left_singular_vector', shape=(kernel.shape[0].value, ), @@ -119,12 +78,12 @@ def _build_spectral_norm_variables(name, kernel, add_weight_func): tf.nn.l2_normalize(tf.linalg.matvec(kernel, u_vector, transpose_a=True)), name=f'{name}/new_right_singular_vector', ) # shape (V) - unnormed_new_u = tf.linalg.matvec(kernel, new_v) # shape (U) + Wv = tf.linalg.matvec(kernel, new_v) # shape (U) new_u = tf.stop_gradient( - tf.nn.l2_normalize(unnormed_new_u), + tf.nn.l2_normalize(Wv), name=f'{name}/new_left_singular_vector', - ) - spectral_norm = tf.reduce_sum(new_u * unnormed_new_u, name=f'{name}/singular_value') + ) # shape (U) + spectral_norm = tf.tensordot(new_u, Wv, axes=1, name=f'{name}/singular_value') update_u = tf.assign(u_vector, new_u, name=f'{name}/power_iter') return spectral_norm, update_u diff --git a/talos/spectral_norm/tests/test_spectral_norm.py b/talos/spectral_norm/tests/test_spectral_norm.py index 3e15f51..d3bf008 100644 --- a/talos/spectral_norm/tests/test_spectral_norm.py +++ b/talos/spectral_norm/tests/test_spectral_norm.py @@ -82,7 +82,7 @@ def test_add_spectral_norm(layer, inputs, sess): u_vector_list = layer.non_trainable_variables # Since norm come from division - assert all([kernel.op.type == 'RealDiv' for kernel in kernel_list]) + assert all([kernel.op.type == 'Mul' for kernel in kernel_list]) sess.run(tf.variables_initializer(layer.variables))