forked from JDAI-CV/FaceX-Zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCurricularFace.py
50 lines (44 loc) · 1.86 KB
/
CurricularFace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
@author: Jun Wang
@date: 20201126
@contact: [email protected]
"""
# based on
# https://github.com/HuangYG123/CurricularFace/blob/master/head/metrics.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math
class CurricularFace(nn.Module):
"""Implementation for "CurricularFace: Adaptive Curriculum Learning Loss for Deep Face Recognition".
"""
def __init__(self, feat_dim, num_class, m = 0.5, s = 64.):
super(CurricularFace, self).__init__()
self.m = m
self.s = s
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.threshold = math.cos(math.pi - m)
self.mm = math.sin(math.pi - m) * m
self.kernel = Parameter(torch.Tensor(feat_dim, num_class))
self.register_buffer('t', torch.zeros(1))
nn.init.normal_(self.kernel, std=0.01)
def forward(self, feats, labels):
kernel_norm = F.normalize(self.kernel, dim=0)
cos_theta = torch.mm(feats, kernel_norm)
cos_theta = cos_theta.clamp(-1, 1) # for numerical stability
with torch.no_grad():
origin_cos = cos_theta.clone()
target_logit = cos_theta[torch.arange(0, feats.size(0)), labels].view(-1, 1)
sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m #cos(target+margin)
mask = cos_theta > cos_theta_m
final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)
hard_example = cos_theta[mask]
with torch.no_grad():
self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
cos_theta[mask] = hard_example * (self.t + hard_example)
cos_theta.scatter_(1, labels.view(-1, 1).long(), final_target_logit)
output = cos_theta * self.s
return output