forked from JDAI-CV/FaceX-Zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAM_Softmax.py
31 lines (29 loc) · 1.03 KB
/
AM_Softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""
@author:Jun Wang
@date: 20201123
@contact: [email protected]
"""
import torch
import torch.nn.functional as F
from torch.nn import Module, Parameter
class AM_Softmax(Module):
"""Implementation for "Additive Margin Softmax for Face Verification"
"""
def __init__(self, feat_dim, num_class, margin=0.35, scale=32):
super(AM_Softmax, self).__init__()
self.weight = Parameter(torch.Tensor(feat_dim, num_class))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.margin = margin
self.scale = scale
def forward(self, feats, labels):
kernel_norm = F.normalize(self.weight, dim=0)
cos_theta = torch.mm(feats, kernel_norm)
cos_theta = cos_theta.clamp(-1, 1)
cos_theta_m = cos_theta - self.margin
index = torch.zeros_like(cos_theta)
index.scatter_(1, labels.data.view(-1, 1), 1)
index = index.byte()
output = cos_theta * 1.0
output[index] = cos_theta_m[index]
output *= self.scale
return output