-
Notifications
You must be signed in to change notification settings - Fork 0
/
GAN.py
193 lines (158 loc) · 7.46 KB
/
GAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#!/usr/bin/env python
# Generative Adversarial Networks (GAN) example in PyTorch. Tested with PyTorch 0.4.1, Python 3.6.7 (Nov 2018)
# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from matplotlib import pyplot as plt
# Data params
data_mean = 4
data_stddev = 1.25
# ### Uncomment only one of these to define what data is actually sent to the Discriminator
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
#(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)
print("Using data [%s]" % (name))
# ##### DATA: Target data and generator input data
def get_distribution_sampler(mu, sigma):
return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))) # Gaussian
def get_generator_input_sampler():
return lambda m, n: torch.rand(m, n) # Uniform-dist data into generator, _NOT_ Gaussian
# ##### MODELS: Generator model and discriminator model
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, f):
super(Generator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
self.f = f
def forward(self, x):
x = self.map1(x)
x = self.f(x)
x = self.map2(x)
x = self.f(x)
x = self.map3(x)
return x
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, f):
super(Discriminator, self).__init__()
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
self.f = f
def forward(self, x):
x = self.f(self.map1(x))
x = self.f(self.map2(x))
return F.sigmoid(self.map3(x))
def extract(v):
return v.data.storage().tolist()
def stats(d):
return [np.mean(d), np.std(d)]
def get_moments(d):
# Return the first 4 moments of the data provided
mean = torch.mean(d)
diffs = d - mean
var = torch.mean(torch.pow(diffs, 2.0))
std = torch.pow(var, 0.5)
zscores = diffs / std
skews = torch.mean(torch.pow(zscores, 3.0))
kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0 # excess kurtosis, should be 0 for Gaussian
final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))
return final
def decorate_with_diffs(data, exponent, remove_raw_data=False):
mean = torch.mean(data.data, 1, keepdim=True)
mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
diffs = torch.pow(data - Variable(mean_broadcast), exponent)
if remove_raw_data:
return torch.cat([diffs], 1)
else:
return torch.cat([data, diffs], 1)
def train():
# Model parameters
g_input_size = 1 # Random noise dimension coming into generator, per output vector
g_hidden_size = 5*2 # Generator complexity
g_output_size = 1 # Size of generated output vector
d_input_size = 500 # Minibatch size - cardinality of distributions
d_hidden_size = 10*2 # Discriminator complexity
d_output_size = 1 # Single dimension for 'real' vs. 'fake' classification
minibatch_size = d_input_size
d_learning_rate = 1e-3*2
g_learning_rate = 1e-3*0.5
d_sgd_momentum = 0.9
g_sgd_momentum = 0.9
num_epochs = 1000
print_interval = 100
d_steps = 20*3
g_steps = 20*3
dfe, dre, ge = 0, 0, 0
d_real_data, d_fake_data, g_fake_data = None, None, None
discriminator_activation_function = torch.sigmoid
generator_activation_function = torch.tanh
# generator_activation_function = torch.relu
d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size,
hidden_size=g_hidden_size,
output_size=g_output_size,
f=generator_activation_function)
D = Discriminator(input_size=d_input_func(d_input_size),
hidden_size=d_hidden_size,
output_size=d_output_size,
f=discriminator_activation_function)
criterion = nn.BCELoss() # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=d_sgd_momentum)
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=g_sgd_momentum)
for epoch in range(num_epochs):
for d_index in range(d_steps):
# 1. Train D on real+fake
D.zero_grad()
# 1A: Train D on real
d_real_data = Variable(d_sampler(d_input_size))
d_real_decision = D(preprocess(d_real_data))
d_real_error = criterion(d_real_decision, Variable(torch.ones([1]))) # ones = true
d_real_error.backward() # compute/store gradients, but don't change params
# 1B: Train D on fake
d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels
d_fake_decision = D(preprocess(d_fake_data.t()))
d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1]))) # zeros = fake
d_fake_error.backward()
d_optimizer.step() # Only optimizes D's parameters; changes based on stored gradients from backward()
dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]
for g_index in range(g_steps):
# 2. Train G on D's response (but DO NOT train D on these labels)
G.zero_grad()
gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
g_fake_data = G(gen_input)
dg_fake_decision = D(preprocess(g_fake_data.t()))
g_error = criterion(dg_fake_decision, Variable(torch.ones([1]))) # Train G to pretend it's genuine
g_error.backward()
g_optimizer.step() # Only optimizes G's parameters
ge = extract(g_error)[0]
if epoch % print_interval == 0:
print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s), Fake Dist (%s) " %
(epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))
if epoch % 5 == 0:
print("Plotting the generated distribution...")
values = extract(g_fake_data)
plt.hist(values, range=(-2,10),bins=50)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Histogram of Generated Distribution')
plt.grid(True)
plt.pause(0.1)
if 1:
print("Plotting the generated distribution...")
values = extract(g_fake_data)
print(" Values: %s" % (str(values)))
plt.hist(values, bins=50)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Histogram of Generated Distribution')
plt.grid(True)
plt.show()
train()