-
Notifications
You must be signed in to change notification settings - Fork 323
/
Copy pathdeit.py
525 lines (464 loc) · 19.9 KB
/
deit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DeiT in Paddle
A Paddle Implementation of Data Efficient Image Transformer (DeiT) as described in:
"Training data-efficient image transformers & distillation through attention"
- Paper Link: https://arxiv.org/abs/2012.12877
"""
import paddle
import paddle.nn as nn
from droppath import DropPath
class Identity(nn.Layer):
""" Identity layer
The output of this layer is the input without any change.
This layer is used to avoid using 'if' condition in methods such as forward
"""
def forward(self, x):
return x
class PatchEmbedding(nn.Layer):
"""Patch Embedding
Apply patch embedding (which is implemented using Conv2D) on input data.
Attributes:
image_size: image size
patch_size: patch size
num_patches: num of patches
patch_embddings: patch embed operation (Conv2D)
"""
def __init__(self,
image_size=224,
patch_size=16,
in_channels=3,
embed_dim=768):
super().__init__()
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = (image_size // patch_size) * (image_size // patch_size)
self.patch_embedding = nn.Conv2D(in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size)
def forward(self, x):
x = self.patch_embedding(x)
x = x.flatten(2) # [B, C, H, W] -> [B, C, h*w]
x = x.transpose([0, 2, 1]) # [B, C, h*w] -> [B, h*w, C] = [B, N, C]
return x
class Attention(nn.Layer):
""" Attention module
Attention module for ViT, here q, k, v are assumed the same.
The qkv mappings are stored as one single param.
Attributes:
num_heads: number of heads
attn_head_size: feature dim of single head
all_head_size: feature dim of all heads
qkv: a nn.Linear for q, k, v mapping
scales: 1 / sqrt(single_head_feature_dim)
out: projection of multi-head attention
attn_dropout: dropout for attention
proj_dropout: final dropout before output
softmax: softmax op for attention
"""
def __init__(self,
embed_dim,
num_heads,
attn_head_size=None,
qkv_bias=True,
dropout=0.,
attention_dropout=0.):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
if attn_head_size is not None:
self.attn_head_size = attn_head_size
else:
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.attn_head_size = embed_dim // num_heads
self.all_head_size = self.attn_head_size * num_heads
w_attr_1, b_attr_1 = self._init_weights()
self.qkv = nn.Linear(embed_dim,
self.all_head_size * 3, # weights for q, k, and v
weight_attr=w_attr_1,
bias_attr=b_attr_1 if qkv_bias else False)
self.scales = self.attn_head_size ** -0.5
w_attr_2, b_attr_2 = self._init_weights()
self.out = nn.Linear(self.all_head_size,
embed_dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.attn_dropout = nn.Dropout(attention_dropout)
self.proj_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(axis=-1)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def transpose_multihead(self, x):
"""[B, N, C] -> [B, N, n_heads, head_dim] -> [B, n_heads, N, head_dim]"""
new_shape = x.shape[:-1] + [self.num_heads, self.attn_head_size]
x = x.reshape(new_shape) # [B, N, C] -> [B, N, n_heads, head_dim]
x = x.transpose([0, 2, 1, 3]) # [B, N, n_heads, head_dim] -> [B, n_heads, N, head_dim]
return x
def forward(self, x):
qkv = self.qkv(x).chunk(3, axis=-1)
q, k, v = map(self.transpose_multihead, qkv)
q = q * self.scales
attn = paddle.matmul(q, k, transpose_y=True) # [B, n_heads, N, N]
attn = self.softmax(attn)
attn = self.attn_dropout(attn)
z = paddle.matmul(attn, v) # [B, n_heads, N, head_dim]
z = z.transpose([0, 2, 1, 3]) # [B, N, n_heads, head_dim]
new_shape = z.shape[:-2] + [self.all_head_size]
z = z.reshape(new_shape) # [B, N, all_head_size]
z = self.out(z)
z = self.proj_dropout(z)
return z
class Mlp(nn.Layer):
""" MLP module
Impl using nn.Linear and activation is GELU, dropout is applied.
Ops: fc -> act -> dropout -> fc -> dropout
Attributes:
fc1: nn.Linear
fc2: nn.Linear
act: GELU
dropout: dropout after fc
"""
def __init__(self,
embed_dim,
mlp_ratio,
dropout=0.):
super().__init__()
w_attr_1, b_attr_1 = self._init_weights()
self.fc1 = nn.Linear(embed_dim,
int(embed_dim * mlp_ratio),
weight_attr=w_attr_1,
bias_attr=b_attr_1)
w_attr_2, b_attr_2 = self._init_weights()
self.fc2 = nn.Linear(int(embed_dim * mlp_ratio),
embed_dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def _init_weights(self):
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.TruncatedNormal(std=0.2))
bias_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerLayer(nn.Layer):
"""Transformer Layer
Transformer layer contains attention, norm, mlp and residual
Attributes:
embed_dim: transformer feature dim
attn_norm: nn.LayerNorm before attention
mlp_norm: nn.LayerNorm before mlp
mlp: mlp modual
attn: attention modual
"""
def __init__(self,
embed_dim,
num_heads,
attn_head_size=None,
qkv_bias=True,
mlp_ratio=4.,
dropout=0.,
attention_dropout=0.,
droppath=0.):
super().__init__()
w_attr_1, b_attr_1 = self._init_weights()
self.attn_norm = nn.LayerNorm(embed_dim,
weight_attr=w_attr_1,
bias_attr=b_attr_1,
epsilon=1e-6)
self.attn = Attention(embed_dim,
num_heads,
attn_head_size,
qkv_bias,
dropout,
attention_dropout)
self.drop_path = DropPath(droppath) if droppath > 0. else Identity()
w_attr_2, b_attr_2 = self._init_weights()
self.mlp_norm = nn.LayerNorm(embed_dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2,
epsilon=1e-6)
self.mlp = Mlp(embed_dim, mlp_ratio, dropout)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
h = x
x = self.attn_norm(x)
x = self.attn(x)
x = self.drop_path(x)
x = x + h
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = self.drop_path(x)
x = x + h
return x
class Encoder(nn.Layer):
"""Transformer encoder
Encoder encoder contains a list of TransformerLayer, and a LayerNorm.
Attributes:
layers: nn.LayerList contains multiple EncoderLayers
encoder_norm: nn.LayerNorm which is applied after last encoder layer
"""
def __init__(self,
embed_dim,
num_heads,
depth,
attn_head_size=None,
qkv_bias=True,
mlp_ratio=4.0,
dropout=0.,
attention_dropout=0.,
droppath=0.):
super().__init__()
# stochatic depth decay
depth_decay = [x.item() for x in paddle.linspace(0, droppath, depth)]
layer_list = []
for i in range(depth):
layer_list.append(TransformerLayer(embed_dim,
num_heads,
attn_head_size,
qkv_bias,
mlp_ratio,
dropout,
attention_dropout,
depth_decay[i]))
self.layers = nn.LayerList(layer_list)
w_attr_1, b_attr_1 = self._init_weights()
self.encoder_norm = nn.LayerNorm(embed_dim,
weight_attr=w_attr_1,
bias_attr=b_attr_1,
epsilon=1e-6)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
for layer in self.layers:
x = layer(x)
x = self.encoder_norm(x)
return x
class VisionTransformer(nn.Layer):
"""ViT transformer
ViT Transformer, classifier is a single Linear layer for finetune,
For training from scratch, two layer mlp should be used.
Classification is done using cls_token.
Args:
image_size: int, input image size, default: 224
patch_size: int, patch size, default: 16
in_channels: int, input image channels, default: 3
num_classes: int, number of classes for classification, default: 1000
embed_dim: int, embedding dimension (patch embed out dim), default: 768
depth: int, number ot transformer blocks, default: 12
num_heads: int, number of attention heads, default: 12
attn_head_size: int, dim of head, if none, set to embed_dim // num_heads, default: None
mlp_ratio: float, ratio of mlp hidden dim to embed dim(mlp in dim), default: 4.0
qkv_bias: bool, If True, enable qkv(nn.Linear) layer with bias, default: True
dropout: float, dropout rate for linear layers, default: 0.
attention_dropout: float, dropout rate for attention layers default: 0.
droppath: float, droppath rate for droppath layers, default: 0.
representation_size: int, set representation layer (pre-logits) if set, default: None
"""
def __init__(self,
image_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
attn_head_size=None,
mlp_ratio=4,
qkv_bias=True,
dropout=0.,
attention_dropout=0.,
droppath=0.,
representation_size=None):
super().__init__()
# create patch embedding
self.patch_embedding = PatchEmbedding(image_size,
patch_size,
in_channels,
embed_dim)
# create posision embedding
self.position_embedding = paddle.create_parameter(
shape=[1, 1 + self.patch_embedding.num_patches, embed_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
# create cls token
self.cls_token = paddle.create_parameter(
shape=[1, 1, embed_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
self.pos_dropout = nn.Dropout(dropout)
# create multi head self-attention layers
self.encoder = Encoder(embed_dim,
num_heads,
depth,
attn_head_size,
qkv_bias,
mlp_ratio,
dropout,
attention_dropout,
droppath)
# pre-logits
if representation_size is not None:
self.num_features = representation_size
w_attr_1, b_attr_1 = self._init_weights()
self.pre_logits = nn.Sequential(
nn.Linear(embed_dim,
representation_size,
weight_attr=w_attr_1,
bias_attr=b_attr_1),
nn.ReLU())
else:
self.pre_logits = Identity()
# classifier head
w_attr_2, b_attr_2 = self._init_weights()
self.classifier = nn.Linear(embed_dim,
num_classes,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
def _init_weights(self):
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward_features(self, x):
x = self.patch_embedding(x)
cls_tokens = self.cls_token.expand((x.shape[0], -1, -1))
x = paddle.concat((cls_tokens, x), axis=1)
x = x + self.position_embedding
x = self.pos_dropout(x)
x = self.encoder(x)
x = self.pre_logits(x[:, 0]) # cls_token only
return x
def forward(self, x):
x = self.forward_features(x)
logits = self.classifier(x)
return logits
class DistilledVisionTransformer(VisionTransformer):
"""Distilled ViT transformer (DeiT)
Args:
image_size: int, input image size, default: 224
patch_size: int, patch size, default: 16
in_channels: int, input image channels, default: 3
num_classes: int, number of classes for classification, default: 1000
embed_dim: int, embedding dimension (patch embed out dim), default: 768
depth: int, number ot transformer blocks, default: 12
num_heads: int, number of attention heads, default: 12
attn_head_size: int, dim of head, if none, set to embed_dim // num_heads, default: None
mlp_ratio: float, ratio of mlp hidden dim to embed dim(mlp in dim), default: 4.0
qkv_bias: bool, If True, enable qkv(nn.Linear) layer with bias, default: True
dropout: float, dropout rate for linear layers, default: 0.
attention_dropout: float, dropout rate for attention layers default: 0.
droppath: float, droppath rate for droppath layers, default: 0.
"""
def __init__(self,
image_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
attn_head_size=None,
mlp_ratio=4,
qkv_bias=True,
dropout=0.,
attention_dropout=0.,
droppath=0.):
super().__init__(image_size, patch_size, in_channels, num_classes, embed_dim, depth,
num_heads, attn_head_size, mlp_ratio, qkv_bias, dropout, attention_dropout,
droppath, None)
# overwrite posision embedding
self.position_embedding = paddle.create_parameter(
shape=[1, 2 + self.patch_embedding.num_patches, embed_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
# create distill token
self.dist_token = paddle.create_parameter(
shape=[1, 1, embed_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
# distill classifier head
w_attr_1, b_attr_1 = self._init_weights()
self.classifier_dist = nn.Linear(embed_dim,
num_classes,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
def forward_features(self, x):
x = self.patch_embedding(x)
cls_tokens = self.cls_token.expand((x.shape[0], -1, -1))
dist_tokens = self.dist_token.expand((x.shape[0], -1, -1))
x = paddle.concat((cls_tokens, dist_tokens, x), axis=1)
x = x + self.position_embedding
x = self.pos_dropout(x)
x = self.encoder(x)
return x[:, 0], x[:, 1]
def forward(self, x):
x = self.forward_features(x)
logits = self.classifier(x[0])
logits_dist = self.classifier_dist(x[1])
if self.training:
return logits, logits_dist
return (logits + logits_dist) / 2
def build_vit(config):
"""build vit model from config, this is same as ViT"""
model = VisionTransformer(image_size=config.DATA.IMAGE_SIZE,
patch_size=config.MODEL.PATCH_SIZE,
in_channels=config.DATA.IMAGE_CHANNELS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.EMBED_DIM,
depth=config.MODEL.DEPTH,
num_heads=config.MODEL.NUM_HEADS,
attn_head_size=config.MODEL.ATTN_HEAD_SIZE,
mlp_ratio=config.MODEL.MLP_RATIO,
qkv_bias=config.MODEL.QKV_BIAS,
dropout=config.MODEL.DROPOUT,
attention_dropout=config.MODEL.ATTENTION_DROPOUT,
droppath=config.MODEL.DROPPATH,
representation_size=None)
return model
def build_deit(config):
"""build deit model from config"""
model = DistilledVisionTransformer(
image_size=config.DATA.IMAGE_SIZE,
patch_size=config.MODEL.PATCH_SIZE,
in_channels=config.DATA.IMAGE_CHANNELS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.EMBED_DIM,
depth=config.MODEL.DEPTH,
num_heads=config.MODEL.NUM_HEADS,
attn_head_size=config.MODEL.ATTN_HEAD_SIZE,
mlp_ratio=config.MODEL.MLP_RATIO,
qkv_bias=config.MODEL.QKV_BIAS,
dropout=config.MODEL.DROPOUT,
attention_dropout=config.MODEL.ATTENTION_DROPOUT,
droppath=config.MODEL.DROPPATH)
return model