forked from JDAI-CV/FaceX-Zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAdaM_Softmax.py
40 lines (37 loc) · 1.46 KB
/
AdaM_Softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
@author: Hang Du, Jun Wang
@date: 20201128
@contact: [email protected]
"""
import torch
import torch.nn.functional as F
from torch.nn import Module, Parameter
class Adam_Softmax(Module):
"""Implementation for "AdaptiveFace: Adaptive Margin and Sampling for Face Recognition".
"""
def __init__(self, feat_dim, num_class, scale=30.0, lamda=70.0):
super(Adam_Softmax, self).__init__()
self.num_class = num_class
self.scale = scale
self.lamda = lamda
self.kernel = Parameter(torch.Tensor(feat_dim, num_class))
self.kernel.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.adam = Parameter(torch.Tensor(1, num_class))
self.adam.data.uniform_(0.3,0.4)
def forward(self, feats, labels):
kernel_norm = F.normalize(self.kernel, dim=0)
cos_theta = torch.mm(feats, kernel_norm)
cos_theta = cos_theta.clamp(-1,1)
# margin in [0,1] for cosface.
self.adam.data.clamp_(0,1)
margin = self.adam[:, labels].view(-1, 1)
cos_theta_m = cos_theta - margin
index = torch.zeros_like(cos_theta)
index.scatter_(1,labels.data.view(-1,1),1)
index = index.byte()
output = cos_theta * 1.0
output[index] = cos_theta_m[index]
output *= self.scale
#ensure the loss > 0
Lm = -1* torch.sum(self.adam, dim=1)/self.num_class + 1
return output, self.lamda*Lm