forked from JDAI-CV/FaceX-Zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSST_Prototype.py
79 lines (72 loc) · 3.1 KB
/
SST_Prototype.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
@author: Hang Du, Jun Wang
@date: 20201020
@contact: [email protected]
"""
import torch
from torch.nn import Module
import math
import random
import torch.nn.functional as F
class SST_Prototype(Module):
"""Implementation for "Semi-Siamese Training for Shallow Face Learning".
"""
def __init__(self, feat_dim=512, queue_size=16384, scale=30.0, loss_type='softmax', margin=0.0):
super(SST_Prototype, self).__init__()
self.queue_size = queue_size
self.feat_dim = feat_dim
self.scale = scale
self.margin = margin
self.loss_type = loss_type
# initialize the prototype queue
self.register_buffer('queue', torch.rand(feat_dim,queue_size).uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5))
self.queue = F.normalize(self.queue, p=2, dim=0) # normalize the initial queue.
self.index = 0
self.label_list = [-1] * queue_size
def add_margin(self, cos_theta, label, batch_size):
cos_theta = cos_theta.clamp(-1, 1)
# additive cosine margin
if self.loss_type == 'am_softmax':
cos_theta_m = cos_theta[torch.arange(0, batch_size), label].view(-1, 1) - self.margin
cos_theta.scatter_(1, label.data.view(-1, 1), cos_theta_m)
# additive angurlar margin
elif self.loss_type == 'arc_softmax':
gt = cos_theta[torch.arange(0, batch_size), label].view(-1, 1)
sin_theta = torch.sqrt(1.0 - torch.pow(gt, 2))
cos_theta_m = gt * math.cos(self.margin) - sin_theta * math.sin(self.margin)
cos_theta.scatter_(1, label.data.view(-1, 1), cos_theta_m)
return cos_theta
def compute_theta(self, p, g, label, batch_size):
queue = self.queue.clone()
queue[:,self.index:self.index+batch_size] = g.transpose(0,1)
cos_theta = torch.mm(p, queue.detach())
cos_theta = self.add_margin(cos_theta, label,batch_size)
return cos_theta
def update_queue(self, g, cur_ids, batch_size):
with torch.no_grad():
self.queue[:,self.index:self.index+batch_size] = g.transpose(0,1)
for image_id in range(batch_size):
self.label_list[self.index + image_id] = cur_ids[image_id].item()
self.index = (self.index + batch_size) % self.queue_size
def get_id_set(self):
id_set = set()
for label in self.label_list:
if label != -1:
id_set.add(label)
return id_set
def forward(self, p1, g2, p2, g1, cur_ids):
batch_size = p1.shape[0]
label = (torch.LongTensor([range(batch_size)]) + self.index)
label = label.squeeze().cuda()
g1 = g1.detach()
g2 = g2.detach()
output1 = self.compute_theta(p1, g2, label, batch_size)
output2 = self.compute_theta(p2, g1, label, batch_size)
output1 *= self.scale
output2 *= self.scale
if random.random() > 0.5:
self.update_queue(g1, cur_ids, batch_size)
else:
self.update_queue(g2, cur_ids, batch_size)
id_set = self.get_id_set()
return output1, output2, label, id_set