-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtexture_nets.py
79 lines (55 loc) · 2.26 KB
/
texture_nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
from .common import *
normalization = nn.BatchNorm2d
def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero'):
if pad == 'zero':
return nn.Conv2d(in_f, out_f, kernel_size, stride, padding=(kernel_size - 1) / 2, bias=bias)
elif pad == 'reflection':
layers = [nn.ReflectionPad2d((kernel_size - 1) / 2),
nn.Conv2d(in_f, out_f, kernel_size, stride, padding=0, bias=bias)]
return nn.Sequential(*layers)
def get_texture_nets(inp=3, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False, pad='zero', need_sigmoid=False, conv_num=8, upsample_mode='nearest'):
for i in range(len(ratios)):
j = i + 1
seq = nn.Sequential()
tmp = nn.AvgPool2d(ratios[i], ratios[i])
seq.add(tmp)
if fill_noise:
seq.add(GenNoise(inp))
seq.add(conv(inp, conv_num, 3, pad=pad))
seq.add(normalization(conv_num))
seq.add(act())
seq.add(conv(conv_num, conv_num, 3, pad=pad))
seq.add(normalization(conv_num))
seq.add(act())
seq.add(conv(conv_num, conv_num, 1, pad=pad))
seq.add(normalization(conv_num))
seq.add(act())
if i == 0:
seq.add(nn.Upsample(scale_factor=2, mode=upsample_mode))
cur = seq
else:
cur_temp = cur
cur = nn.Sequential()
# Batch norm before merging
seq.add(normalization(conv_num))
cur_temp.add(normalization(conv_num * (j - 1)))
cur.add(Concat(1, cur_temp, seq))
cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad))
cur.add(normalization(conv_num * j))
cur.add(act())
cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad))
cur.add(normalization(conv_num * j))
cur.add(act())
cur.add(conv(conv_num * j, conv_num * j, 1, pad=pad))
cur.add(normalization(conv_num * j))
cur.add(act())
if i == len(ratios) - 1:
cur.add(conv(conv_num * j, 3, 1, pad=pad))
else:
cur.add(nn.Upsample(scale_factor=2, mode=upsample_mode))
model = cur
if need_sigmoid:
model.add(nn.Sigmoid())
return model