forked from mit-han-lab/proxylessnas
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_zoo.py
50 lines (38 loc) · 1.74 KB
/
model_zoo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from functools import partial
import json
import torch
from .utils import download_url
from .nas_modules import ProxylessNASNets
def proxyless_base(pretrained=True, net_config=None, net_weight=None):
assert net_config is not None, "Please input a network config"
net_config_path = download_url(net_config)
net_config_json = json.load(open(net_config_path, 'r'))
net = ProxylessNASNets.build_from_config(net_config_json)
if 'bn' in net_config_json:
net.set_bn_param(
bn_momentum=net_config_json['bn']['momentum'],
bn_eps=net_config_json['bn']['eps'])
else:
net.set_bn_param(bn_momentum=0.1, bn_eps=1e-3)
if pretrained:
assert net_weight is not None, "Please specify network weights"
init_path = download_url(net_weight)
init = torch.load(init_path, map_location='cpu')
net.load_state_dict(init['state_dict'])
return net
proxyless_cpu = partial(
proxyless_base,
net_config="https://hanlab.mit.edu/files/proxylessNAS/proxyless_cpu.config",
net_weight="https://hanlab.mit.edu/files/proxylessNAS/proxyless_cpu.pth")
proxyless_gpu = partial(
proxyless_base,
net_config="https://hanlab.mit.edu/files/proxylessNAS/proxyless_gpu.config",
net_weight="https://hanlab.mit.edu/files/proxylessNAS/proxyless_gpu.pth")
proxyless_mobile = partial(
proxyless_base,
net_config="https://hanlab.mit.edu/files/proxylessNAS/proxyless_mobile.config",
net_weight="https://hanlab.mit.edu/files/proxylessNAS/proxyless_mobile.pth")
proxyless_mobile_14 = partial(
proxyless_base,
net_config="https://hanlab.mit.edu/files/proxylessNAS/proxyless_mobile_14.config",
net_weight="https://hanlab.mit.edu/files/proxylessNAS/proxyless_mobile_14.pth")