forked from JDAI-CV/FaceX-Zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathArcFace.py
41 lines (38 loc) · 1.62 KB
/
ArcFace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
@author:Jun Wang
@date: 20201123
@contact: [email protected]
"""
import math
import torch
import torch.nn.functional as F
from torch.nn import Module, Parameter
class ArcFace(Module):
"""Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition"
"""
def __init__(self, feat_dim, num_class, margin_arc=0.35, margin_am=0.0, scale=32):
super(ArcFace, self).__init__()
self.weight = Parameter(torch.Tensor(feat_dim, num_class))
self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
self.margin_arc = margin_arc
self.margin_am = margin_am
self.scale = scale
self.cos_margin = math.cos(margin_arc)
self.sin_margin = math.sin(margin_arc)
self.min_cos_theta = math.cos(math.pi - margin_arc)
def forward(self, feats, labels):
kernel_norm = F.normalize(self.weight, dim=0)
cos_theta = torch.mm(feats, kernel_norm)
cos_theta = cos_theta.clamp(-1, 1)
sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2))
cos_theta_m = cos_theta * self.cos_margin - sin_theta * self.sin_margin
# 0 <= theta + m <= pi, ==> -m <= theta <= pi-m
# because 0<=theta<=pi, so, we just have to keep theta <= pi-m, that is cos_theta >= cos(pi-m)
cos_theta_m = torch.where(cos_theta > self.min_cos_theta, cos_theta_m, cos_theta-self.margin_am)
index = torch.zeros_like(cos_theta)
index.scatter_(1, labels.data.view(-1, 1), 1)
index = index.byte()
output = cos_theta * 1.0
output[index] = cos_theta_m[index]
output *= self.scale
return output