Skip to content

Commit

Permalink
Merge pull request #78 from Yoctol/bidirectional-xl
Browse files Browse the repository at this point in the history
Bidirectional xl
  • Loading branch information
noobOriented authored Apr 26, 2019
2 parents e5c60de + 7c7465d commit 04a8af1
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 24 deletions.
15 changes: 6 additions & 9 deletions talos/compounds/attention/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ def call(
mask = tf.cast(mask, inputs.dtype) # shape (N, T)

if state is not None:
concated = tf.concat(
[tf.stop_gradient(state), inputs],
axis=1,
)
concated = tf.concat([state, inputs], axis=1)
if state_mask is not None:
if mask is None:
raise TypeError("Invalid input!")
Expand Down Expand Up @@ -211,15 +208,15 @@ def _mask_logits(self, logits, mask):
triu_tensor = tf.constant(
np.triu(
np.full([q_length, kv_length], _LARGE_BIAS),
k=q_length + 1,
k=kv_length - q_length + 1,
)[:, np.newaxis], # shape (T, 1, t), 1 to broadcast on heads
dtype=logits.dtype,
)
self._computed_triu[(q_length, kv_length)] = triu_tensor
# example if (q_length, kv_length) = (3, 6)
# [[0, 0, 0, 0, 1e4, 1e4],
# [0, 0, 0, 0, 0 , 1e4],
# [0, 0, 0, 0, 0 , 0]]
# example if (q_length, kv_length) = (3, 5)
# [[0, 0, 0, 1e4, 1e4],
# [0, 0, 0, 0, 1e4],
# [0, 0, 0, 0, 0]]

logits -= triu_tensor

Expand Down
21 changes: 13 additions & 8 deletions talos/compounds/attention/tests/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ def cell():

@pytest.fixture(scope='module')
def state(inputs):
return tf.placeholder(dtype=inputs.dtype, shape=inputs.shape.as_list())
maxlen, dim = inputs.shape.as_list()[1:]
return tf.placeholder(dtype=inputs.dtype, shape=[None, maxlen - 1, dim])


@pytest.fixture(scope='module')
def state_mask(mask):
return tf.placeholder(dtype=mask.dtype, shape=mask.shape.as_list())
def state_mask(state, mask):
state_maxlen = state.shape[1].value
return tf.placeholder(dtype=mask.dtype, shape=[None, state_maxlen])


def test_output_shape(cell, inputs, state):
Expand All @@ -29,12 +31,13 @@ def test_output_shape(cell, inputs, state):

def test_mask_gradients(inputs, state, mask, state_mask, cell, sess):
maxlen, channel = inputs.shape.as_list()[1:]
state_maxlen = state.shape[1].value

outputs = cell(inputs, state, mask=mask, state_mask=state_mask)
grads = tf.gradients(outputs, inputs)[0] # same shape as inputs

mask_val = np.random.choice(2, size=[5, maxlen]).astype(np.bool)
state_mask_val = np.random.choice(2, size=[5, maxlen]).astype(np.bool)
state_mask_val = np.random.choice(2, size=[5, state_maxlen]).astype(np.bool)
mask_val[:, :2] = True # to make sure at least 2 True
state_mask_val[:, :2] = True

Expand All @@ -43,7 +46,7 @@ def test_mask_gradients(inputs, state, mask, state_mask, cell, sess):
grads,
feed_dict={
inputs: np.random.rand(5, maxlen, channel),
state: np.random.rand(5, maxlen, channel),
state: np.random.rand(5, state_maxlen, channel),
mask: mask_val,
state_mask: state_mask_val,
},
Expand All @@ -57,11 +60,12 @@ def test_mask_gradients(inputs, state, mask, state_mask, cell, sess):
def test_forward_mask_gradients(inputs, state, sess):
layer = RelativeAttentionCell(units=3, output_dim=10, heads=5, use_forward_mask=True)
maxlen, channel = inputs.shape.as_list()[1:]
state_maxlen = state.shape[1].value

outputs = layer(inputs, state=state)
grads_list = tf.stack([
tf.gradients(outputs[:, t], inputs)[0]
for t in range(maxlen)
for t in range(outputs.shape[1].value)
], axis=1) # every elements have same shape as inputs
# shape (N, T, T, U)

Expand All @@ -70,10 +74,11 @@ def test_forward_mask_gradients(inputs, state, sess):
grads_list,
feed_dict={
inputs: np.random.rand(5, maxlen, channel),
state: np.random.rand(5, maxlen, channel),
state: np.random.rand(5, state_maxlen, channel),
},
)
assert np.equal(
grad_list_val != 0., # shape (N, T, T, U)
np.tril(np.ones([maxlen, maxlen], dtype=np.bool))[:, :, np.newaxis], # shape (T, T, 1)
np.tril(np.ones([maxlen, maxlen], dtype=np.bool))[:, :, np.newaxis],
# shape (T, T', 1)
).all(), grad_list_val != 0
63 changes: 60 additions & 3 deletions talos/compounds/tests/test_transformer_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from ..transformer_xl import TransformerXL


@pytest.fixture(scope='module')
def layer():
return TransformerXL(block_size=2, units=3, heads=5)
@pytest.fixture(scope='module', params=[False, True])
def layer(request):
bidirectional = request.param
return TransformerXL(block_size=2, units=3, heads=5, bidirectional=bidirectional)


def test_output_shape(layer, inputs):
Expand Down Expand Up @@ -37,3 +38,59 @@ def test_mask_gradients(inputs, mask, layer, sess):
grads_val != 0.,
mask_val[:, :, np.newaxis],
).all()


@pytest.mark.parametrize('layer', [
TransformerXL(block_size=2, units=3, heads=2, use_forward_mask=True),
TransformerXL(block_size=2, units=3, heads=2, use_forward_mask=True, state_gradient=True),
TransformerXL(block_size=2, units=3, heads=2, bidirectional=True),
])
def test_blocklevel_gradients(layer, sess):
inputs = tf.random_normal([5, 5, 4])
maxlen, channel = inputs.shape.as_list()[1:]

outputs = layer(inputs)
grads_list = tf.stack([
tf.gradients(outputs[:, t], inputs)[0]
for t in range(maxlen)
], axis=1) # every elements have same shape as inputs
# shape (N, T, T, U)

sess.run(tf.variables_initializer(var_list=layer.variables))
grad_list_val = sess.run(grads_list)
attention_map = generate_attention_map(
maxlen,
layer.block_size,
bidirectional=layer.bidirectional,
forward_mask=layer.use_forward_mask,
state_gradient=layer.state_gradient,
)
assert np.equal(
grad_list_val != 0., # shape (N, T, T, U)
attention_map[:, :, np.newaxis], # shape (T, T, 1)
).all(), np.any(grad_list_val, axis=-1)[0]


def generate_attention_map(maxlen, block_size, bidirectional, forward_mask, state_gradient):
att_map = np.zeros([maxlen, maxlen], dtype=np.bool)
for t in range(0, maxlen, block_size):
# diagonal block
if forward_mask:
att_map[t: t + block_size, t: t + block_size] = np.tril(
np.ones(
[min(block_size, maxlen - t), min(block_size, maxlen - t)],
dtype=np.bool,
),
)
else:
att_map[t: t + block_size, t: t + block_size] = True

if state_gradient and t >= block_size:
att_map[t: t + block_size, t - block_size: t] = True # section from previous block

if bidirectional:
att_map = np.logical_or(
att_map,
att_map[::-1, ::-1],
)
return att_map
36 changes: 32 additions & 4 deletions talos/compounds/transformer_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(
block_size: int,
units: int,
heads: int,
state_gradient: bool = False,
bidirectional: bool = False,
activation: Union[str, Callable] = 'relu',
hidden_units: int = None,
dropout_rate: float = 0.1,
Expand All @@ -57,6 +59,8 @@ def __init__(
self.block_size = block_size
self.units = units
self.heads = heads
self.state_gradient = state_gradient
self.bidirectional = bidirectional
self.use_forward_mask = use_forward_mask
self._input_spec = tf.keras.layers.InputSpec(ndim=3)

Expand All @@ -72,6 +76,13 @@ def build(self, input_shape):
heads=self.heads,
use_forward_mask=self.use_forward_mask,
)
if self.bidirectional:
self.backward_cell = RelativeAttentionCell(
units=self.units,
output_dim=output_dim,
heads=self.heads,
use_forward_mask=self.use_forward_mask,
)
self.output_dense = tf.keras.layers.Dense(
units=output_dim,
use_bias=True,
Expand All @@ -92,7 +103,22 @@ def call(self, inputs, mask: tf.Tensor = None, training: tf.Tensor = None):
"""
# RelationAttention SubLayers
ln_inputs = self.ln_pre_cell(inputs) # layer norm
att_vec = self._block_wise_attention(ln_inputs, mask=mask)
att_vec = self._blockwise_attention(ln_inputs, mask=mask, cell=self.cell)
if self.bidirectional:
if mask is not None:
backward_mask = tf.reverse(mask, axis=[1])
else:
backward_mask = None
backward_vec = tf.reverse(
self._blockwise_attention(
tf.reverse(ln_inputs, axis=[1]),
mask=backward_mask,
cell=self.backward_cell,
),
axis=[1],
)
att_vec = att_vec + backward_vec

att_vec = self.dropout_cell(att_vec, training=training)

# Position-wise Feed Forward
Expand All @@ -105,7 +131,7 @@ def call(self, inputs, mask: tf.Tensor = None, training: tf.Tensor = None):
outputs *= tf.cast(mask, inputs.dtype)[:, :, tf.newaxis]
return outputs

def _block_wise_attention(self, inputs, mask):
def _blockwise_attention(self, inputs, mask, cell):
# split full_timesteps to blocks as possible
# with 1 < length <= self.block_size
maxlen = inputs.shape[1].value
Expand Down Expand Up @@ -133,7 +159,7 @@ def _block_wise_attention(self, inputs, mask):
if block_mask is not None:
block_output = tf.cond(
tf.reduce_any(block_mask),
lambda: self.cell(
lambda: cell(
block_input,
state=state,
mask=block_mask,
Expand All @@ -142,10 +168,12 @@ def _block_wise_attention(self, inputs, mask):
lambda: tf.zeros_like(block_input),
)
else:
block_output = self.cell(block_input, state=state)
block_output = cell(block_input, state=state)

output_list.append(block_output)
state = block_input
if not self.state_gradient:
state = tf.stop_gradient(state)
state_mask = block_mask

if len(output_list) > 1:
Expand Down

0 comments on commit 04a8af1

Please sign in to comment.