forked from devnkong/GOAT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvq.py
68 lines (52 loc) · 2.73 KB
/
vq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from re import X
import numpy as np
import torch
from torch import nn
from torch.autograd import Function, Variable
import torch.nn.functional as F
class VectorQuantizerEMA(nn.Module):
def __init__(
self,
num_embeddings,
embedding_dim,
decay=0.99
):
super(VectorQuantizerEMA, self).__init__()
self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self.register_buffer('_embedding', torch.randn(self._num_embeddings, self._embedding_dim*2))
self.register_buffer('_embedding_output', torch.randn(self._num_embeddings, self._embedding_dim*2))
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
self.register_buffer('_ema_w', torch.randn(self._num_embeddings, self._embedding_dim*2))
self._decay = decay
self.bn = torch.nn.BatchNorm1d(self._embedding_dim*2, affine=False)
def get_k(self) :
return self._embedding_output
def get_v(self) :
return self._embedding_output[:, :self._embedding_dim]
def update(self, x):
inputs_normalized = self.bn(x)
embedding_normalized = self._embedding
# Calculate distances
distances = (torch.sum(inputs_normalized ** 2, dim=1, keepdim=True)
+ torch.sum(embedding_normalized ** 2, dim=1)
- 2 * torch.matmul(inputs_normalized, embedding_normalized.t()))
# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=x.device)
encodings.scatter_(1, encoding_indices, 1)
# Use EMA to update the embedding vectors
if self.training:
self._ema_cluster_size.data = self._ema_cluster_size * self._decay + (1 - self._decay) * torch.sum(encodings, 0)
# Laplace smoothing of the cluster size
n = torch.sum(self._ema_cluster_size.data)
self._ema_cluster_size.data = (self._ema_cluster_size + 1e-5) / (n + self._num_embeddings * 1e-5) * n
# if torch.count_nonzero(self._ema_cluster_size) != self._ema_cluster_size.shape[0] :
# raise ValueError('Bad Init!')
dw = torch.matmul(encodings.t(), inputs_normalized)
self._ema_w.data = self._ema_w * self._decay + (1 - self._decay) * dw
self._embedding.data = self._ema_w / self._ema_cluster_size.unsqueeze(1)
running_std = torch.sqrt(self.bn.running_var + 1e-5).unsqueeze(dim=0)
running_mean = self.bn.running_mean.unsqueeze(dim=0)
self._embedding_output.data = self._embedding*running_std + running_mean
return encoding_indices