forked from cool-xuan/msflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
84 lines (74 loc) · 3.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os, random
import numpy as np
import torch
import argparse
import wandb
from train import train
def init_seeds(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parsing_args(c):
parser = argparse.ArgumentParser(description='msflow')
parser.add_argument('--dataset', default='mvtec', type=str,
choices=['mvtec', 'visa'], help='dataset name')
parser.add_argument('--mode', default='train', type=str,
help='train or test.')
parser.add_argument('--amp_enable', action='store_true', default=False,
help='use amp or not.')
parser.add_argument('--wandb_enable', action='store_true', default=False,
help='use wandb for result logging or not.')
parser.add_argument('--resume', action='store_true', default=False,
help='resume training or not.')
parser.add_argument('--eval_ckpt', default='', type=str,
help='checkpoint path for evaluation.')
parser.add_argument('--class-names', default=['all'], type=str, nargs='+',
help='class names for training')
parser.add_argument('--lr', default=1e-4, type=float,
help='learning rate')
parser.add_argument('--batch-size', default=8, type=int,
help='train batch size')
parser.add_argument('--meta-epochs', default=25, type=int,
help='number of meta epochs to train')
parser.add_argument('--sub-epochs', default=4, type=int,
help='number of sub epochs to train')
parser.add_argument('--extractor', default='wide_resnet50_2', type=str,
help='feature extractor')
parser.add_argument('--pool-type', default='avg', type=str,
help='pool type for extracted feature maps')
parser.add_argument('--parallel-blocks', default=[2, 5, 8], type=int, metavar='L', nargs='+',
help='number of flow blocks used in parallel flows.')
parser.add_argument('--pro-eval', action='store_true', default=False,
help='evaluate the pro score or not.')
parser.add_argument('--pro-eval-interval', default=4, type=int,
help='interval for pro evaluation.')
args = parser.parse_args()
for k, v in vars(args).items():
setattr(c, k, v)
if c.dataset == 'mvtec':
from datasets import MVTEC_CLASS_NAMES
setattr(c, 'data_path', './data/MVTec')
if c.class_names == ['all']:
setattr(c, 'class_names', MVTEC_CLASS_NAMES)
elif c.dataset == 'visa':
from datasets import VISA_CLASS_NAMES
setattr(c, 'data_path', './data/VisA_pytorch/1cls')
if c.class_names == ['all']:
setattr(c, 'class_names', VISA_CLASS_NAMES)
c.input_size = (256, 256) if c.class_name == 'transistor' else (512, 512)
return c
def main(c):
c = parsing_args(c)
init_seeds(seed=c.seed)
c.version_name = 'msflow_{}_{}pool_pl{}'.format(c.extractor, c.pool_type, "".join([str(x) for x in c.parallel_blocks]))
print(c.class_names)
for class_name in c.class_names:
c.class_name = class_name
print('-+'*5, class_name, '+-'*5)
c.ckpt_dir = os.path.join(c.work_dir, c.version_name, c.dataset, c.class_name)
train(c)
if __name__ == '__main__':
import default as c
main(c)