-
Notifications
You must be signed in to change notification settings - Fork 28
/
train.py
96 lines (75 loc) · 4.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
# coding: utf-8
import argparse
import importlib
from v_diffusion import make_beta_schedule
from moving_average import init_ema_model
from torch.utils.tensorboard import SummaryWriter
from train_utils import *
def make_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--module", help="Model module.", type=str, required=True)
parser.add_argument("--name", help="Experiment name. Data will be saved to ./checkpoints/<name>/<dname>/.", type=str, required=True)
parser.add_argument("--dname", help="Distillation name. Data will be saved to ./checkpoints/<name>/<dname>/.", type=str, required=True)
parser.add_argument("--checkpoint_to_continue", help="Path to checkpoint.", type=str, default="")
parser.add_argument("--num_timesteps", help="Num diffusion steps.", type=int, default=1024)
parser.add_argument("--num_iters", help="Num iterations.", type=int, default=100000)
parser.add_argument("--batch_size", help="Batch size.", type=int, default=1)
parser.add_argument("--lr", help="Learning rate.", type=float, default=5e-5)
parser.add_argument("--scheduler", help="Learning rate scheduler.", type=str, default="StrategyConstantLR")
parser.add_argument("--diffusion", help="Diffusion model.", type=str, default="GaussianDiffusion")
parser.add_argument("--log_interval", help="Log interval in minutes.", type=int, default=15)
parser.add_argument("--ckpt_interval", help="Checkpoints saving interval in minutes.", type=int, default=30)
parser.add_argument("--num_workers", type=int, default=-1)
return parser
def train_model(args, make_model, make_dataset):
if args.num_workers == -1:
args.num_workers = args.batch_size * 2
# print(args)
print(' '.join(f'{k}={v}' for k, v in vars(args).items()))
device = torch.device("cuda")
train_dataset = test_dataset = InfinityDataset(make_dataset(), args.num_iters)
len(train_dataset), len(test_dataset)
img, anno = train_dataset[0]
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
teacher_ema = make_model().to(device)
checkpoints_dir = os.path.join("checkpoints", args.name, args.dname)
if not os.path.exists(checkpoints_dir):
os.makedirs(checkpoints_dir)
def make_sheduler():
M = importlib.import_module("train_utils")
D = getattr(M, args.scheduler)
return D()
scheduler = make_sheduler()
def make_diffusion(model, n_timestep, time_scale, device):
betas = make_beta_schedule("cosine", cosine_s=8e-3, n_timestep=n_timestep).to(device)
M = importlib.import_module("v_diffusion")
D = getattr(M, args.diffusion)
return D(model, betas, time_scale=time_scale)
teacher = make_model().to(device)
teacher_ema = make_model().to(device)
if args.checkpoint_to_continue != "":
ckpt = torch.load(args.checkpoint_to_continue)
teacher.load_state_dict(ckpt["G"])
teacher_ema.load_state_dict(ckpt["G"])
del ckpt
print("Continue training...")
else:
print("Training new model...")
init_ema_model(teacher, teacher_ema)
tensorboard = SummaryWriter(os.path.join(checkpoints_dir, "tensorboard"))
teacher_diffusion = make_diffusion(teacher, args.num_timesteps, 1, device)
teacher_ema_diffusion = make_diffusion(teacher, args.num_timesteps, 1, device)
image_size = teacher.image_size
on_iter = make_iter_callback(teacher_ema_diffusion, device, checkpoints_dir, image_size, tensorboard, args.log_interval, args.ckpt_interval, False)
diffusion_train = DiffusionTrain(scheduler)
diffusion_train.train(train_loader, teacher_diffusion, teacher_ema, args.lr, device, make_extra_args=make_condition, on_iter=on_iter)
print("Finished.")
if __name__ == "__main__":
parser = make_argument_parser()
args = parser.parse_args()
M = importlib.import_module(args.module)
make_model = getattr(M, "make_model")
make_dataset = getattr(M, "make_dataset")
train_model(args, make_model, make_dataset)