-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
420 lines (364 loc) · 16.1 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import torch
import torch.nn as nn
import numpy as np
from torch.utils import data
import argparse
import os
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from PIL import ImageFile
from tensorboardX import SummaryWriter
from torchvision import transforms
from tqdm import tqdm
from apex import amp
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def mean_variance_norm(feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
def _calc_feat_flatten_mean_std(feat):
# takes 3D feat (C, H, W), return mean and std of array within channels
assert (feat.size()[0] == 3)
assert (isinstance(feat, torch.FloatTensor))
feat_flatten = feat.view(3, -1)
mean = feat_flatten.mean(dim=-1, keepdim=True)
std = feat_flatten.std(dim=-1, keepdim=True)
return feat_flatten, mean, std
decoder = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 256, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 128, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 64, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 3, (3, 3)),
)
vgg = nn.Sequential(
nn.Conv2d(3, 3, (1, 1)),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(3, 64, (3, 3)),
nn.ReLU(), # relu1-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(), # relu1-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 128, (3, 3)),
nn.ReLU(), # relu2-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(), # relu2-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 256, (3, 3)),
nn.ReLU(), # relu3-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 512, (3, 3)),
nn.ReLU(), # relu4-1, this is the last layer used
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU() # relu5-4
)
# %%
class SANet(nn.Module):
def __init__(self, in_planes):
super(SANet, self).__init__()
self.f = nn.Conv2d(in_planes, in_planes, (1, 1))
self.g = nn.Conv2d(in_planes, in_planes, (1, 1))
self.h = nn.Conv2d(in_planes, in_planes, (1, 1))
self.sm = nn.Softmax(dim=-1)
self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1))
def forward(self, content, style):
F = self.f(mean_variance_norm(content))
G = self.g(mean_variance_norm(style))
H = self.h(style)
b, c, h, w = F.size()
F = F.view(b, -1, w * h).permute(0, 2, 1)
b, c, h, w = G.size()
G = G.view(b, -1, w * h)
S = torch.bmm(F, G)
S = self.sm(S)
b, c, h, w = H.size()
H = H.view(b, -1, w * h)
O = torch.bmm(H, S.permute(0, 2, 1))
b, c, h, w = content.size()
O = O.view(b, c, h, w)
O = self.out_conv(O)
O += content
return O
class Transform(nn.Module):
def __init__(self, in_planes):
super(Transform, self).__init__()
self.sanet4_1 = SANet(in_planes=in_planes)
self.sanet5_1 = SANet(in_planes=in_planes)
self.upsample5_1 = nn.Upsample(scale_factor=2, mode='nearest')
self.merge_conv_pad = nn.ReflectionPad2d((1, 1, 1, 1))
self.merge_conv = nn.Conv2d(in_planes, in_planes, (3, 3))
def forward(self, content4_1, style4_1, content5_1, style5_1):
return self.merge_conv(self.merge_conv_pad(
self.sanet4_1(content4_1, style4_1) + self.upsample5_1(self.sanet5_1(content5_1, style5_1))))
class Net(nn.Module):
def __init__(self, encoder, decoder, start_iter):
super(Net, self).__init__()
enc_layers = list(encoder.children())
self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
# transform
self.transform = Transform(in_planes=512)
self.decoder = decoder
if (start_iter > 0):
self.transform.load_state_dict(torch.load('transformer_iter_' + str(start_iter) + '.pth'))
self.decoder.load_state_dict(torch.load('decoder_iter_' + str(start_iter) + '.pth'))
self.mse_loss = nn.MSELoss()
# fix the encoder
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
for param in getattr(self, name).parameters():
param.requires_grad = False
# extract relu1_1, relu2_1, relu3_1, relu4_1, relu5_1 from input image
def encode_with_intermediate(self, input):
results = [input]
for i in range(5):
func = getattr(self, 'enc_{:d}'.format(i + 1))
results.append(func(results[-1]))
return results[1:]
def calc_content_loss(self, input, target, norm=False):
if (norm == False):
return self.mse_loss(input, target)
else:
return self.mse_loss(mean_variance_norm(input), mean_variance_norm(target))
def calc_style_loss(self, input, target):
input_mean, input_std = calc_mean_std(input)
target_mean, target_std = calc_mean_std(target)
return self.mse_loss(input_mean, target_mean) + \
self.mse_loss(input_std, target_std)
def forward(self, content, style):
style_feats = self.encode_with_intermediate(style)
content_feats = self.encode_with_intermediate(content)
stylized = self.transform(content_feats[3], style_feats[3], content_feats[4], style_feats[4])
g_t = self.decoder(stylized)
g_t_feats = self.encode_with_intermediate(g_t)
loss_c = self.calc_content_loss(g_t_feats[3], content_feats[3], norm=True) + self.calc_content_loss(
g_t_feats[4], content_feats[4], norm=True)
loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
for i in range(1, 5):
loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
"""IDENTITY LOSSES"""
Icc = self.decoder(self.transform(content_feats[3], content_feats[3], content_feats[4], content_feats[4]))
Iss = self.decoder(self.transform(style_feats[3], style_feats[3], style_feats[4], style_feats[4]))
l_identity1 = self.calc_content_loss(Icc, content) + self.calc_content_loss(Iss, style)
Fcc = self.encode_with_intermediate(Icc)
Fss = self.encode_with_intermediate(Iss)
l_identity2 = self.calc_content_loss(Fcc[0], content_feats[0]) + self.calc_content_loss(Fss[0], style_feats[0])
for i in range(1, 5):
l_identity2 += self.calc_content_loss(Fcc[i], content_feats[i]) + self.calc_content_loss(Fss[i],
style_feats[i])
return loss_c, loss_s, l_identity1, l_identity2
def InfiniteSampler(n):
# i = 0
i = n - 1
order = np.random.permutation(n)
while True:
yield order[i]
i += 1
if i >= n:
np.random.seed()
order = np.random.permutation(n)
i = 0
class InfiniteSamplerWrapper(data.sampler.Sampler):
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(InfiniteSampler(self.num_samples))
def __len__(self):
return 2 ** 31
# cudnn.benchmark = True
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
def train_transform():
transform_list = [
transforms.Resize(size=(1024, 1024)),
transforms.RandomCrop(512),
transforms.ToTensor()
]
return transforms.Compose(transform_list)
class FlatFolderDataset(data.Dataset):
def __init__(self, root, transform):
super(FlatFolderDataset, self).__init__()
self.root = root
self.paths = os.listdir(self.root)
self.transform = transform
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(os.path.join(self.root, path)).convert('RGB')
img = self.transform(img)
return img
def __len__(self):
return len(self.paths)
def name(self):
return 'FlatFolderDataset'
def adjust_learning_rate(optimizer, iteration_count):
"""Imitating the original implementation"""
lr = args.lr / (1.0 + args.lr_decay * iteration_count)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
if __name__ == "__main__":
opt_level = 'O1'
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content_dir', type=str, default='./train2014',
help='Directory path to a batch of content images')
parser.add_argument('--style_dir', type=str, default='./train',
help='Directory path to a batch of style images')
parser.add_argument('--vgg_pth', type=str, default='vgg_normalised.pth')
# training options
parser.add_argument('--save_dir', default='./experiments',
help='Directory to save the model')
parser.add_argument('--log_dir', default='./logs',
help='Directory to save the log')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--lr_decay', type=float, default=5e-5)
parser.add_argument('--max_iter', type=int, default=500000)
parser.add_argument('--batch_size', type=int, default=5)
parser.add_argument('--style_weight', type=float, default=2.0)
parser.add_argument('--content_weight', type=float, default=3.0)
parser.add_argument('--n_threads', type=int, default=16)
parser.add_argument('--save_model_interval', type=int, default=1000)
parser.add_argument('--start_iter', type=float, default=0)
args = parser.parse_args('')
args.content_dir = r'data/src_img'
args.style_dir = r'data/alien_img'
args.batch_size = 2
device = torch.device('cuda')
decoder = decoder
vg = vgg
vg.load_state_dict(torch.load(args.vgg_pth))
vg = nn.Sequential(*list(vg.children())[:44])
network = Net(vg, decoder, args.start_iter)
network.train()
network.to(device)
content_tf = train_transform()
style_tf = train_transform()
content_dataset = FlatFolderDataset(args.content_dir, content_tf)
style_dataset = FlatFolderDataset(args.style_dir, style_tf)
content_iter = iter(data.DataLoader(
content_dataset, batch_size=args.batch_size,
sampler=InfiniteSamplerWrapper(content_dataset),
num_workers=args.n_threads))
style_iter = iter(data.DataLoader(
style_dataset, batch_size=args.batch_size,
sampler=InfiniteSamplerWrapper(style_dataset),
num_workers=args.n_threads))
optimizer = torch.optim.Adam([
{'params': network.decoder.parameters()},
{'params': network.transform.parameters()}], lr=args.lr)
# network, optimizer = amp.initialize(network, optimizer, opt_level=opt_level)
writer = SummaryWriter(comment=f'LR_{args.lr}_BS_{args.batch_size}')
if (args.start_iter > 0):
optimizer.load_state_dict(torch.load('optimizer_iter_' + str(args.start_iter) + '.pth'))
global_step = 0
with tqdm(total=args.max_iter - args.start_iter,
desc=global_step/(args.max_iter - args.start_iter),
unit='batch') as pbar:
for i in range(args.start_iter, args.max_iter):
adjust_learning_rate(optimizer, iteration_count=i)
content_images = next(content_iter).to(device)
style_images = next(style_iter).to(device)
loss_c, loss_s, l_identity1, l_identity2 = network(content_images, style_images)
loss_c = args.content_weight * loss_c
loss_s = args.style_weight * loss_s
loss = loss_c + loss_s + l_identity1 + l_identity2 * 50
pbar.set_postfix(**{'loss (batch)': loss.item()})
writer.add_scalar('Loss/train', loss.item(), global_step)
global_step += 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
if global_step % 1000 == 0:
for tag, value in network.named_parameters():
tag = tag.replace('.', '/')
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
state_dict = decoder.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(torch.device('cpu'))
torch.save(state_dict,
'{:s}/decoder_iter_{:d}.pth'.format(args.save_dir,
i + 1))
state_dict = network.transform.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(torch.device('cpu'))
torch.save(state_dict,
'{:s}/transformer_iter_{:d}.pth'.format(args.save_dir,
i + 1))
state_dict = optimizer.state_dict()
torch.save(state_dict,
'{:s}/optimizer_iter_{:d}.pth'.format(args.save_dir,
i + 1))