forked from google-research/bigbird
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencoder.py
457 lines (399 loc) · 18.8 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
# Copyright 2021 The BigBird Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BigBird Encoder Layers."""
from bigbird.core import attention
from bigbird.core import recompute_grad
from bigbird.core import utils
import tensorflow.compat.v2 as tf
class PrenormEncoderLayer(tf.keras.layers.Layer):
"""Encoder layer of a transformer in Pegasus style.
The layer_norm is taken before self-attention.
"""
def __init__(self,
attention_type,
hidden_size=768,
intermediate_size=3072,
intermediate_act_fn=utils.gelu,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.1,
initializer_range=0.02,
num_attention_heads=12,
num_rand_blocks=3,
seq_length=1024,
block_size=64,
use_bias=True,
seed=None,
name=None):
"""Constructor of an encoder layer of a transformer in Pegasus style.
Args:
attention_type: Type of attention, needs to be one of ['original_full',
'simulated_sparse', 'block_sparse'].
hidden_size: (optional) int. Size of hidden dimension.
intermediate_size: (optional) int. Size of intermediate dimension.
intermediate_act_fn: optional) Activation function for intermediate layer.
attention_probs_dropout_prob: (optional) float. Dropout probability of the
attention probabilities.
hidden_dropout_prob: (optional) float. Dropout probability of the
attention.
initializer_range: (optional) float. Range of the weight initializer.
num_attention_heads: (optional) int. Number of attention heads.
num_rand_blocks: (optional) int. Number of random chunks per row.
seq_length: (optional) int. length of sequence.
block_size: (optional) int. size of block in sequence.
use_bias: (optional) bool. Whether key/query/value uses a bias vector.
seed: (Optional) int. Reandom seed for generating random mask.
name: The name scope of this layer.
"""
super(PrenormEncoderLayer, self).__init__(name=name)
with tf.compat.v1.variable_scope(name):
attention_head_size = hidden_size // num_attention_heads
with tf.compat.v1.variable_scope("attention"):
# Pre-Normalization layer
with tf.compat.v1.variable_scope("self"):
self.first_layer_norm = utils.NormLayer(hidden_size)
# Self-Attention layer
self.attn_layer = attention.MultiHeadedAttentionLayer(
attention_type, num_attention_heads, attention_head_size,
num_rand_blocks, seq_length, seq_length, block_size, block_size,
attention_probs_dropout_prob, initializer_range, use_bias,
seed, name="self")
# Feedforward layer
with tf.compat.v1.variable_scope("output"):
self.projection_layer = utils.Dense3dProjLayer(
num_attention_heads, attention_head_size,
utils.create_initializer(initializer_range), None,
"dense", use_bias)
# Dropout
self.attention_dropout = recompute_grad.RecomputingDropout(
hidden_dropout_prob)
with tf.compat.v1.variable_scope("intermediate"):
# Normalization layer
self.second_layer_norm = utils.NormLayer(hidden_size)
# Feedforward layer
self.expand_layer = utils.Dense2dLayer(
hidden_size, intermediate_size,
utils.create_initializer(initializer_range),
intermediate_act_fn, "dense")
with tf.compat.v1.variable_scope("output"):
# Feedforward layer
self.contract_layer = utils.Dense2dLayer(
intermediate_size, hidden_size,
utils.create_initializer(initializer_range),
None, "dense")
# Dropout
self.output_dropout = recompute_grad.RecomputingDropout(
hidden_dropout_prob)
def call(self,
layer_input,
attention_mask=None,
band_mask=None,
from_mask=None,
to_mask=None,
input_blocked_mask=None,
training=None):
"""Implements a encoder layer of a transformer in Pegasus style.
Args:
layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
attention_mask: (optional) float32 Tensor of shape [batch_size,
seq_length, seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions
in the mask that are 0, and will be unchanged for positions that are 1.
band_mask: (optional) float32 Tensor of shape [batch_size, 1,
seq_length//block_size-4, block_size, 3*block_size].
The values should be 1 or 0. The attention scores will effectively be
set to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
from_mask: (optional) float32 Tensor of shape [batch_size, 1,
seq_length, 1]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions
in the mask that are 0, and will be unchanged for positions that are 1.
to_mask: (optional) float32 Tensor of shape [batch_size, 1, 1,
seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions
in the mask that are 0, and will be unchanged for positions that are 1.
input_blocked_mask: (optional) float32 Tensor of shape [batch_size,
seq_length//block_size, block_size]. Same as from/to_mask, just
reshaped.
training: Boolean indicating whether the call is training or inference.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size].
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
NotImplementedError: For unknown attention type.
"""
# self-attention
normalized_layer_input = self.first_layer_norm(layer_input)
attention_output = self.attn_layer(
normalized_layer_input, normalized_layer_input, [
attention_mask, band_mask, from_mask, to_mask, input_blocked_mask,
input_blocked_mask
], training=training)
# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
attention_output = self.projection_layer(attention_output)
attention_output = self.attention_dropout(attention_output,
training=training)
attention_output = attention_output + layer_input
# The activation is only applied to the "intermediate" hidden layer.
normalized_attention_output = self.second_layer_norm(attention_output)
intermediate_output = self.expand_layer(normalized_attention_output)
# Down-project back to `hidden_size` then add the residual.
layer_output = self.contract_layer(intermediate_output)
layer_output = self.output_dropout(layer_output, training=training)
layer_output = layer_output + attention_output
return layer_output
class PostnormEncoderLayer(tf.keras.layers.Layer):
"""Encoder layer of a transformer in BERT style.
The layer_norm is taken after self-attention.
"""
def __init__(self,
attention_type,
hidden_size=768,
intermediate_size=3072,
intermediate_act_fn=utils.gelu,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.1,
initializer_range=0.02,
num_attention_heads=12,
num_rand_blocks=3,
seq_length=1024,
block_size=64,
use_bias=True,
seed=None,
name=None):
"""Constructor of an encoder layer of a transformer in BERT style.
Args:
attention_type: Type of attention, needs to be one of ['original_full',
'simulated_sparse', 'block_sparse'].
hidden_size: (optional) int. Size of hidden dimension.
intermediate_size: (optional) int. Size of intermediate dimension.
intermediate_act_fn: optional) Activation function for intermediate layer.
attention_probs_dropout_prob: (optional) float. Dropout probability of the
attention probabilities.
hidden_dropout_prob: (optional) float. Dropout probability of the
attention.
initializer_range: (optional) float. Range of the weight initializer.
num_attention_heads: (optional) int. Number of attention heads.
num_rand_blocks: (optional) int. Number of random chunks per row.
seq_length: (optional) int. length of sequence.
block_size: (optional) int. size of block in sequence.
use_bias: (optional) bool. Whether key/query/value uses a bias vector.
seed: (Optional) int. Reandom seed for generating random mask.
name: The name scope of this layer.
"""
super(PostnormEncoderLayer, self).__init__(name=name)
with tf.compat.v1.variable_scope(name):
attention_head_size = hidden_size // num_attention_heads
with tf.compat.v1.variable_scope("attention"):
# Self-Attention layer
self.attn_layer = attention.MultiHeadedAttentionLayer(
attention_type, num_attention_heads, attention_head_size,
num_rand_blocks, seq_length, seq_length, block_size, block_size,
attention_probs_dropout_prob, initializer_range, use_bias,
seed, name="self")
with tf.compat.v1.variable_scope("output"):
# Feedforward layer
self.projection_layer = utils.Dense3dProjLayer(
num_attention_heads, attention_head_size,
utils.create_initializer(initializer_range), None,
"dense", use_bias)
# Post-Normalization layer
self.first_layer_norm = utils.NormLayer(hidden_size)
# Dropout
self.attention_dropout = recompute_grad.RecomputingDropout(
hidden_dropout_prob)
with tf.compat.v1.variable_scope("intermediate"):
# Feedforward layer
self.expand_layer = utils.Dense2dLayer(
hidden_size, intermediate_size,
utils.create_initializer(initializer_range),
intermediate_act_fn, "dense")
with tf.compat.v1.variable_scope("output"):
# Feedforward layer
self.contract_layer = utils.Dense2dLayer(
intermediate_size, hidden_size,
utils.create_initializer(initializer_range),
None, "dense")
# Normalization layer
self.second_layer_norm = utils.NormLayer(hidden_size)
# Dropout
self.output_dropout = recompute_grad.RecomputingDropout(
hidden_dropout_prob)
def call(self,
layer_input,
attention_mask=None,
band_mask=None,
from_mask=None,
to_mask=None,
input_blocked_mask=None,
training=None):
"""Implements a encoder layer of a transformer in BERT style.
Args:
layer_input: float Tensor of shape [batch_size, seq_length, hidden_size].
attention_mask: (optional) float32 Tensor of shape [batch_size,
seq_length, seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions
in the mask that are 0, and will be unchanged for positions that are 1.
band_mask: (optional) float32 Tensor of shape [batch_size, 1,
seq_length//block_size-4, block_size, 3*block_size].
The values should be 1 or 0. The attention scores will effectively be
set to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
from_mask: (optional) float32 Tensor of shape [batch_size, 1,
seq_length, 1]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions
in the mask that are 0, and will be unchanged for positions that are 1.
to_mask: (optional) float32 Tensor of shape [batch_size, 1, 1,
seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions
in the mask that are 0, and will be unchanged for positions that are 1.
input_blocked_mask: (optional) float32 Tensor of shape [batch_size,
seq_length//block_size, block_size]. Same as from/to_mask, just
reshaped.
training: Boolean indicating whether the call is training or inference.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size].
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
NotImplementedError: For unknown attention type.
"""
# self-attention
attention_output = self.attn_layer(
layer_input, layer_input, [
attention_mask, band_mask, from_mask, to_mask, input_blocked_mask,
input_blocked_mask
], training=training)
# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
attention_output = self.projection_layer(attention_output)
attention_output = self.attention_dropout(attention_output,
training=training)
attention_output = self.first_layer_norm(attention_output + layer_input)
# The activation is only applied to the "intermediate" hidden layer.
intermediate_output = self.expand_layer(attention_output)
# Down-project back to `hidden_size` then add the residual.
layer_output = self.contract_layer(intermediate_output)
layer_output = self.output_dropout(layer_output, training=training)
layer_output = self.second_layer_norm(layer_output + attention_output)
return layer_output
def add_gradient_recomputation(original_class):
"""Creats a subclass which enables gradient checkpointing."""
class RecomputeLayer(original_class):
"""Transformer layer that recomputes the forward pass during backprop."""
def call(self,
layer_input,
attention_mask=None,
band_mask=None,
from_mask=None,
to_mask=None,
input_blocked_mask=None,
training=None):
def f(layer_input, attention_mask, band_mask,
from_mask, to_mask, input_blocked_mask):
x = super(RecomputeLayer, self).call(
layer_input, attention_mask, band_mask, from_mask, to_mask,
input_blocked_mask, training=training)
return x
f = recompute_grad.recompute_grad(f)
return f(layer_input, attention_mask, band_mask,
from_mask, to_mask, input_blocked_mask)
return RecomputeLayer
class EncoderStack(tf.keras.layers.Layer):
"""Transformer encoder stack."""
def __init__(self, params):
name = "encoder"
super(EncoderStack, self).__init__(name=name)
self.params = params
if params["norm_type"] == "prenorm":
encoder_class = PrenormEncoderLayer
elif params["norm_type"] == "postnorm":
encoder_class = PostnormEncoderLayer
else:
raise NotImplementedError(
"Norm type {} is not implemented".format(params["norm_type"]))
if params["use_gradient_checkpointing"]:
encoder_class = add_gradient_recomputation(encoder_class)
with tf.compat.v1.variable_scope(name):
# Encoder layers
self.encoder_layers = [
encoder_class( # pylint: disable=g-complex-comprehension
self.params["attention_type"],
self.params["hidden_size"],
self.params["intermediate_size"],
utils.get_activation(self.params["hidden_act"]),
self.params["attention_probs_dropout_prob"],
self.params["hidden_dropout_prob"],
self.params["initializer_range"],
self.params["num_attention_heads"],
self.params["num_rand_blocks"],
self.params["max_encoder_length"],
self.params["block_size"],
self.params["use_bias"],
seed=layer_idx,
name="layer_%d" % layer_idx)
for layer_idx in range(self.params["num_hidden_layers"])
]
# Normalization layer
self.layer_norm = utils.NormLayer(self.params["hidden_size"])
def call(self,
encoder_inputs,
encoder_inputs_mask,
training=None):
"""Return the output of the decoder layer stacks.
Args:
encoder_inputs: tensor with shape
[batch_size, input_length, hidden_size]
encoder_inputs_mask: Mask for enccoder input. [batch_size, input_length]
training: Boolean indicating whether the call is training or inference.
Returns:
Finaly layer encoder output. float tensor with shape
[batch_size, input_length, hidden_size]
"""
if self.params["attention_type"] == "block_sparse":
# reshape and cast for blocking
encoder_length = self.params["max_encoder_length"]
encoder_block_size = self.params["block_size"]
encoder_inputs_mask = tf.cast(encoder_inputs_mask, tf.float32)
blocked_encoder_mask = tf.reshape(
encoder_inputs_mask,
(-1, encoder_length//encoder_block_size, encoder_block_size))
encoder_from_mask = tf.reshape(encoder_inputs_mask,
(-1, 1, encoder_length, 1))
encoder_to_mask = tf.reshape(encoder_inputs_mask,
(-1, 1, 1, encoder_length))
# create band padding
band_mask = attention.create_band_mask_from_inputs(
blocked_encoder_mask, blocked_encoder_mask)
# For unused masks 0 instead of None for compatilibity with recompute_grad
attention_mask = 0.0
else:
# For unused masks 0 instead of None for compatilibity with recompute_grad
blocked_encoder_mask = 0.0
encoder_to_mask = 0.0
encoder_from_mask = 0.0
band_mask = 0.0
encoder_inputs_mask = tf.cast(encoder_inputs_mask, tf.float32)
attention_mask = attention.create_attention_mask_from_input_mask(
encoder_inputs_mask, encoder_inputs_mask)
if self.params["norm_type"] == "postnorm":
encoder_inputs = self.layer_norm(encoder_inputs)
layer_output = encoder_inputs
for layer in self.encoder_layers:
layer_output = layer(
layer_output, attention_mask, band_mask,
encoder_from_mask, encoder_to_mask, blocked_encoder_mask,
training=training)
if self.params["norm_type"] == "prenorm":
layer_output = self.layer_norm(layer_output)
return layer_output