Skip to content

Commit

Permalink
Merge pull request #84 from Yoctol/spectral-norm-on-model
Browse files Browse the repository at this point in the history
let add_spectral_norm work for subclass Model
  • Loading branch information
noobOriented authored May 6, 2019
2 parents 9b0f100 + b03555d commit 97d82ed
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
17 changes: 17 additions & 0 deletions talos/spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ def add_spectral_norm(layer: tf.layers.Layer):
if isinstance(layer, tf.keras.Sequential):
for sub_layer in layer.layers:
add_spectral_norm(sub_layer)
elif isinstance(layer, tf.keras.Model):
add_spectral_norm_for_model(layer)
elif isinstance(layer, tf.keras.layers.RNN):
add_spectral_norm(layer.cell)
elif isinstance(layer, tf.keras.layers.StackedRNNCells):
Expand Down Expand Up @@ -74,6 +76,21 @@ def new_add_weight(self, name=None, shape=None, **kwargs):
layer.add_weight = types.MethodType(new_add_weight, layer)


def add_spectral_norm_for_model(model: tf.keras.Model):
if model.built:
raise ValueError("Can't add spectral norm on built layer!")

original_build = model.build

# Very Very Evil HACK
def new_build(self, input_shape):
original_build(input_shape)
for sub_layer in self.layers:
add_spectral_norm(sub_layer)

model.build = types.MethodType(new_build, model)


def to_rank2(tensor: tf.Tensor):
if tensor.shape.ndims > 2:
return tf.reshape(tensor, [-1, tensor.shape[-1].value])
Expand Down
6 changes: 4 additions & 2 deletions talos/tests/test_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import tensorflow as tf

from talos.compounds import TransformerBlock
from talos.layers import (
Bidirectional,
Conv1D,
Expand All @@ -20,7 +21,7 @@
from ..spectral_norm import add_spectral_norm


@pytest.mark.parametrize('layer,inputs', [
@pytest.mark.parametrize('layer, inputs', [
(Dense(10), tf.zeros([3, 4])),
(Conv1D(filters=10, kernel_size=3), tf.zeros([3, 4, 5])),
(Conv2D(filters=10, kernel_size=3), tf.zeros([3, 4, 5, 5])),
Expand All @@ -37,6 +38,7 @@
LSTMCell(5),
])), tf.zeros([3, 4, 5])),
(Bidirectional(LSTM(10)), tf.zeros([3, 4, 5])),
(TransformerBlock(5, heads=4), tf.zeros([3, 4, 5])),
])
def test_spectral_norm_for_layer(layer, inputs, sess):
add_spectral_norm(layer)
Expand Down Expand Up @@ -70,7 +72,7 @@ def test_spectral_norm_for_layer(layer, inputs, sess):


def recursive_get_kernel_attributes(layer):
if isinstance(layer, tf.keras.Sequential):
if isinstance(layer, tf.keras.Model):
return list(chain.from_iterable([
recursive_get_kernel_attributes(layer)
for layer in layer.layers
Expand Down

0 comments on commit 97d82ed

Please sign in to comment.