-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhubconf.py
73 lines (55 loc) · 2.43 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Optional list of dependencies required by the package
dependencies = ['torch', 'torchvision']
import torch
import torchvision
model_urls = {
"resnet50_1m": "https://cornell.box.com/shared/static/r36nd2o0w5ch6ujuaxj0mtasxaqg0l5t.pth",
"resnet50_250m": "https://cornell.box.com/shared/static/y210vs3iktungg7wrf72ibzl87jrojna.pth",
}
def _resnet50(model_arch: str, pretrained: bool = False, **kwargs):
"""
Args:
model_arch (str): specify which model file to download.
progress (bool): If True, displays a progress bar of the download to stderr.
"""
# Create a torchvision resnet50 with randomly initialized weights.
model = torchvision.models.resnet50(pretrained=False, **kwargs)
# Get the model before the global aver-pooling layer.
model = torch.nn.Sequential(*list(model.children())[:-2])
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(
model_urls[model_arch], progress=True)
)
return model
def resnet50_1m(pretrained: bool = False, **kwargs):
"""
Constructs a ResNet-50 model pre-trained on 1.2M visual engagement
data in `"Exploring Visual Engagement Signals for Representation Learning"
<https://arxiv.org/abs/2104.07767>`_
This is a torchvision-like model. Given a batch of image tensors with size
``(B, 3, 224, 224)``, this model computes spatial image features of size
``(B, 2048, 7, 7)``, where B = batch size.
Args:
progress (bool): If True, displays a progress bar of the download to stderr.
"""
model = _resnet50(
model_arch="resnet50_1m", pretrained=pretrained, **kwargs
)
return model
def resnet50_250m(pretrained: bool = False, **kwargs):
"""
Constructs a ResNet-50 model pre-trained on 250M visual engagement
data in `"Exploring Visual Engagement Signals for Representation Learning"
<https://arxiv.org/abs/2104.07767>`_
This is a torchvision-like model. Given a batch of image tensors with size
``(B, 3, 224, 224)``, this model computes spatial image features of size
``(B, 2048, 7, 7)``, where B = batch size.
Args:
progress (bool): If True, displays a progress bar of the download to stderr.
"""
model = _resnet50(
model_arch="resnet50_250m", pretrained=pretrained, **kwargs
)
return model