-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
140 lines (112 loc) · 5.65 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
'''
setup model and datasets
'''
import torch
import numpy as np
import random
from functools import partial
import warnings
import json
from visual_prompt import ExpansiveVisualPrompt, PadVisualPrompt, FixVisualPrompt, RandomVisualPrompt
from label_mapping import label_mapping_base, generate_label_mapping_by_frequency, generate_label_mapping_by_frequency_ordinary
__all__ = [
'set_seed'
,'setup_optimizer_and_prompt'
,'calculate_label_mapping'
,'obtain_label_mapping'
,'save_args'
,'load_args'
]
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def get_optimizer(parameters, optimizer, scheduler, lr, weight_decay, args):
if optimizer == 'sgd':
optimizer = torch.optim.SGD(parameters, lr=lr, momentum=args.momentum, weight_decay=weight_decay)
elif optimizer == 'adam':
optimizer = torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
else:
raise ValueError('optimizer should be one of [sgd, adam]')
if scheduler == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
elif scheduler == 'multistep':
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(args.epochs * _) for _ in args.decreasing_step], gamma=0.1)
else:
raise ValueError('scheduler should be one of [cosine, multistep]')
return optimizer, scheduler
def setup_optimizer_and_prompt(network, args):
device = args.device
normalize = args.normalize
visual_prompt = None
score_optimizer, score_scheduler = None, None
score_vp_optimizer, score_vp_scheduler = None, None
weight_vp_optimizer, weight_vp_scheduler = None, None
weight_params = [param for param in network.parameters() if not hasattr(param, 'is_score')]
weight_optimizer, weight_scheduler = get_optimizer(weight_params, args.weight_optimizer, args.weight_scheduler, args.weight_lr, args.weight_weight_decay, args)
if args.prune_method == 'vpns' or 'vp' in args.prune_mode:
if args.prompt_method == 'pad':
visual_prompt = PadVisualPrompt(args, normalize=normalize).to(device)
elif args.prompt_method == 'fix':
visual_prompt = FixVisualPrompt(args, normalize=normalize).to(device)
elif args.prompt_method == 'random':
visual_prompt = RandomVisualPrompt(args, normalize=normalize).to(device)
else:
raise ValueError("Prompt method should be one of [pad, fix, random]")
if args.prune_method in ('vpns', 'bip', 'hydra'):
score_vp_optimizer, score_vp_scheduler = get_optimizer(visual_prompt.parameters(), args.score_vp_optimizer, args.score_vp_scheduler, args.score_vp_lr, args.score_vp_weight_decay, args)
weight_vp_optimizer, weight_vp_scheduler = get_optimizer(visual_prompt.parameters(), args.weight_vp_optimizer, args.weight_vp_scheduler, args.weight_vp_lr, args.weight_vp_weight_decay, args)
if args.prune_method in ('vpns', 'bip', 'hydra'):
score_params = [param for param in network.parameters() if hasattr(param, 'is_score') and param.is_score]
score_optimizer, score_scheduler = get_optimizer(score_params, args.score_optimizer, args.score_scheduler, args.score_lr, args.score_weight_decay, args)
return visual_prompt, score_optimizer, score_scheduler, score_vp_optimizer, score_vp_scheduler, weight_optimizer, weight_scheduler, weight_vp_optimizer, weight_vp_scheduler
def calculate_label_mapping(visual_prompt, network, train_loader, args):
if visual_prompt:
if args.label_mapping_mode == 'rlm':
print('Random Label Mapping')
mapping_sequence = torch.randperm(1000)[:args.class_cnt]
label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
elif args.label_mapping_mode in ('flm', 'ilm'):
mapping_sequence = generate_label_mapping_by_frequency(visual_prompt, network, train_loader)
label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
else:
mapping_sequence = None
label_mapping = None
warnings.warn('No Label Mapping!')
else:
if args.label_mapping_mode == 'rlm':
print('Random Label Mapping')
mapping_sequence = torch.randperm(1000)[:10]
label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
elif args.label_mapping_mode in ('flm', 'ilm'):
mapping_sequence = generate_label_mapping_by_frequency_ordinary(network, train_loader)
label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
else:
mapping_sequence = None
label_mapping = None
warnings.warn('No Label Mapping!')
return label_mapping, mapping_sequence
def obtain_label_mapping(mapping_sequence):
label_mapping = partial(label_mapping_base, mapping_sequence=mapping_sequence)
return label_mapping
def save_args(args, file_path):
with open(file_path, 'w') as file:
json.dump(vars(args), file)
def load_args(file_path):
with open(file_path, 'r') as file:
load_args = json.load(file)
return load_args
def get_init_ckpt(args):
ckpt = None
if args.prune_mode == 'normal' and args.prune_method in ('random', 'imp', 'omp'):
ckpt = torch.load('ckpts/0best.pth')
return ckpt
def get_masks(mask):
masks = {}
for module in mask.modules:
for name, tensor in module.named_parameters():
if name in mask.masks:
masks[name] = mask.masks[name]
return masks