-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
102 lines (53 loc) · 1.84 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import math
import torch
import torch.nn as nn
import models
import utils
from .models import register
@register('classifier')
class Classifier(nn.Module):
def __init__(self, encoder, encoder_args,
classifier, classifier_args):
super().__init__()
self.encoder = models.make(encoder, **encoder_args)
classifier_args['in_dim'] = self.encoder.out_dim
self.classifier = models.make(classifier, **classifier_args)
def forward(self, x):
x = self.encoder(x)
x = self.classifier(x)
return x
@register('linear-classifier')
class LinearClassifier(nn.Module):
def __init__(self, in_dim, n_classes):
super().__init__()
self.linear = nn.Linear(in_dim, n_classes)
def forward(self, x):
return self.linear(x)
@register('nn-classifier')
class NNClassifier(nn.Module):
def __init__(self, in_dim, n_classes, metric='cos', temp=None):
super().__init__()
self.proto = nn.Parameter(torch.empty(n_classes, in_dim))
nn.init.kaiming_uniform_(self.proto, a=math.sqrt(5))
if temp is None:
if metric == 'cos':
temp = nn.Parameter(torch.tensor(10.))
else:
temp = 1.0
self.metric = metric
self.temp = temp
def forward(self, x):
return utils.compute_logits(x, self.proto, self.metric, self.temp)
@register('fine-tuning-classifier')
class FineTuningClassifier(nn.Module):
def __init__(self, in_dim, n_classes):
super().__init__()
self.encoder = None
self.linear = nn.Linear(in_dim, n_classes)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
h = self.encoder(x)
h = self.avgpool(h)
h = h.view(h.size(0), -1)
h = self.linear(h)
return h