forked from tangshengeng/ProgressiveTransformersSLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembeddings.py
90 lines (77 loc) · 2.98 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import math
from torch import nn, Tensor
from helpers import freeze_params
import torch
class MaskedNorm(nn.Module):
"""
Original Code from:
https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8
"""
def __init__(self, norm_type, num_groups, num_features):
super().__init__()
self.norm_type = norm_type
if self.norm_type == "batch":
self.norm = nn.BatchNorm1d(num_features=num_features)
elif self.norm_type == "group":
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features)
elif self.norm_type == "layer":
self.norm = nn.LayerNorm(normalized_shape=num_features)
else:
raise ValueError("Unsupported Normalization Layer")
self.num_features = num_features
def forward(self, x: Tensor, mask: Tensor):
if self.training:
reshaped = x.reshape([-1, self.num_features])
reshaped_mask = mask.reshape([-1, 1]) > 0
selected = torch.masked_select(reshaped, reshaped_mask).reshape(
[-1, self.num_features]
)
batch_normed = self.norm(selected)
scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
return scattered.reshape([x.shape[0], -1, self.num_features])
else:
reshaped = x.reshape([-1, self.num_features])
batched_normed = self.norm(reshaped)
return batched_normed.reshape([x.shape[0], -1, self.num_features])
class Embeddings(nn.Module):
"""
Simple embeddings class
"""
# pylint: disable=unused-argument
def __init__(self,
embedding_dim: int = 64,
scale: bool = False,
vocab_size: int = 0,
padding_idx: int = 1,
freeze: bool = False,
**kwargs):
"""
Create new embeddings for the vocabulary.
Use scaling for the Transformer.
:param embedding_dim:
:param scale:
:param vocab_size:
:param padding_idx:
:param freeze: freeze the embeddings during training
"""
super(Embeddings, self).__init__()
self.embedding_dim = embedding_dim
self.scale = scale
self.vocab_size = vocab_size
self.lut = nn.Embedding(vocab_size, self.embedding_dim,
padding_idx=padding_idx)
if freeze:
freeze_params(self)
# pylint: disable=arguments-differ
def forward(self, x: Tensor) -> Tensor:
"""
Perform lookup for input `x` in the embedding table.
:param x: index in the vocabulary
:return: embedded representation for `x`
"""
if self.scale:
return self.lut(x) * math.sqrt(self.embedding_dim)
return self.lut(x)
def __repr__(self):
return "%s(embedding_dim=%d, vocab_size=%d)" % (
self.__class__.__name__, self.embedding_dim, self.vocab_size)