-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvit_models.py
358 lines (312 loc) · 15.7 KB
/
vit_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import torch.nn as nn
from fairseq.models import FairseqEncoder, register_model, FairseqEncoderDecoderModel, register_model_architecture
from fairseq.models.transformer import TransformerDecoder, Embedding, TransformerModel
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq import utils
# from timm.models.vision_transformer import HybridEmbed, PatchEmbed, Block
from timm.models.layers import trunc_normal_
import torch
from torch.hub import load_state_dict_from_url
from functools import partial
import logging
logger = logging.getLogger(__name__)
DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('ViT_TR')
class ViTTRModel(FairseqEncoderDecoderModel):
@staticmethod
def add_args(parser):
TransformerModel.add_args(parser)
# parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
# help='decoder embedding dimension')
parser.add_argument(
'--vit-img-size', type=int, metavar='N',
help='the image size of h and w (h=w) of the ViT'
)
parser.add_argument(
'--vit-patch-size', type=int, metavar='N',
help='the patch size of h and w (h=w) of the ViT'
)
parser.add_argument(
'--vit-dim', type=int, metavar='N',
help='the hidden size of the ViT'
)
parser.add_argument(
'--vit-depth', type=int, metavar='N',
help='the layer num of the ViT'
)
parser.add_argument(
'--vit-heads', type=int, metavar='N',
help='the head num of the ViT'
)
parser.add_argument(
'--vit-channels', type=int, metavar='N', default=3,
help='the input image channels of the ViT'
)
parser.add_argument(
'--vit-dropout', type=float, default=0.0,
help='the dropout ratio of the ViT'
)
parser.add_argument(
'--vit-atten-dropout', type=float, default=0.0,
help='the input embedding dropout ratio of the ViT'
)
parser.add_argument(
'--encoder-pretrained-url', type=str,
help='the pretrained parameter url for the ViT encoder'
)
@classmethod
def build_model(cls, args, task):
encoder = ViTTREncoder(
args = args,
dictionary = task.source_dictionary
)
if args.encoder_pretrained_url:
logger.info('load pretrianed encoder parameter from: {}'.format(args.encoder_pretrained_url))
encoder_state_dict = load_state_dict_from_url(args.encoder_pretrained_url)
encoder.load_state_dict(encoder_state_dict, strict=False)
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
decoder_embed_tokens = cls.build_embedding(
args, task.target_dictionary, args.decoder_embed_dim, args.decoder_embed_path
)
decoder = TransformerDecoder(
args = args,
dictionary=task.target_dictionary,
embed_tokens=decoder_embed_tokens,
no_encoder_attn=False
)
model = cls(encoder, decoder)
return model
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
def forward(self, imgs, prev_output_tokens, **kwargs):
encoder_out = self.encoder(imgs, **kwargs)
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
@register_model_architecture('ViT_TR', 'ViT_TR_base')
def ViT_TR_base(args):
# ViT Encoder vit_base_patch16_224
args.vit_img_size = getattr(args, "vit_img_size", 224)
args.resize_img_size = args.vit_img_size
args.vit_patch_size = getattr(args, "vit_patch_size", 16)
args.vit_dim = getattr(args, "vit_dim", 768)
args.vit_depth = getattr(args, "vit_depth", 12)
args.vit_heads = getattr(args, "vit_heads", 12)
args.encoder_pretrained_url = getattr(args, "encoder_pretrained_url",
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth")
# Transformer Decoder
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
@register_model_architecture('ViT_TR', 'ViT_TR_large')
def large_architecture(args):
# ViT Encoder vit_base_patch16_224
args.vit_img_size = getattr(args, "vit_img_size", 384)
args.resize_img_size = args.vit_img_size
args.vit_patch_size = getattr(args, "vit_patch_size", 16)
args.vit_dim = getattr(args, "vit_dim", 1024)
args.vit_depth = getattr(args, "vit_depth", 24)
args.vit_heads = getattr(args, "vit_heads", 16)
args.encoder_pretrained_url = getattr(args, "encoder_pretrained_url",
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth")
# Transformer Decoder
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
args.activation_fn = getattr(args, "activation_fn", "relu")
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
class ViTTREncoder(FairseqEncoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
img_size = args.vit_img_size
patch_size = args.vit_patch_size
in_chans = args.vit_channels
embed_dim = args.vit_dim
depth = args.vit_depth
num_heads = args.vit_heads
mlp_ratio=4.
qkv_bias=True
qk_scale=None
drop_rate = args.vit_dropout
attn_drop_rate = args.vit_atten_dropout
drop_path_rate=0.
hybrid_backbone=None
norm_layer=None
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x):
B = x.shape[0] # bs, num_patches, dim
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
encoder_embedding = x # bs, n + 1, dim
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x) # bs, n + 1, dim
return x, encoder_embedding
def forward(self, imgs):
x, encoder_embedding = self.forward_features(imgs) # bs, n + 1, dim
x = x.transpose(0, 1) # n + 1, bs, dim
encoder_padding_mask = torch.zeros(*x.shape[:2]).transpose(0, 1).to(imgs.device)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": [], # List[T x B x C]
"src_tokens": [],
"src_lengths": [],
}
def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
_encoder_out = encoder_out['encoder_out'][0]
_encoder_padding_mask = encoder_out['encoder_padding_mask'][0]
_encoder_embedding = encoder_out['encoder_embedding'][0]
return {
"encoder_out": [_encoder_out.index_select(1, new_order)],
"encoder_padding_mask": [_encoder_padding_mask.index_select(0, new_order)], # B x T
"encoder_embedding": [_encoder_padding_mask.index_select(0, new_order)], # B x T x C
"encoder_states": [],
"src_tokens": [],
"src_lengths": [],
}
if __name__ == '__main__':
pass