Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mansour committed Sep 18, 2016
1 parent 8180dae commit 30ea41c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/asset
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
17 changes: 10 additions & 7 deletions mnist_ebgan_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

batch_size = 32
z_dim = 50
margin = 10
margin = 1

#
# inputs
Expand Down Expand Up @@ -53,20 +53,23 @@
# create real + fake image input
xx = tf.concat(0, [x, gen])

with tf.sg_context(name='discriminator', size=4, stride=2, act='leaky_relu', bn=True):
with tf.sg_context(name='discriminator', size=4, stride=2, act='leaky_relu'):
disc = (xx.sg_conv(dim=64)
.sg_conv(dim=128)
.sg_upconv(dim=64)
.sg_upconv(dim=1, act='sigmoid', bn=False))
.sg_upconv(dim=1, act='linear'))


#
# loss & train ops
#

mse = tf.square(disc - xx) # squared error
loss_disc = mse[:batch_size, :, :, :] + tf.maximum(margin - mse[batch_size:, :, :, :], 0) # discriminator loss
loss_gen = mse[batch_size:, :, :, :] # generator loss
# squared errors
mse = tf.square(disc - xx)
mse_real, mse_fake = mse[:batch_size, :, :, :], mse[batch_size:, :, :, :]

loss_disc = mse_real + tf.maximum(margin - mse_fake, 0) # discriminator loss
loss_gen = mse_fake # generator loss

train_disc = tf.sg_optim(loss_disc, lr=0.001, category='discriminator') # discriminator train ops
train_gen = tf.sg_optim(loss_gen, lr=0.001, category='generator') # generator train ops
Expand All @@ -93,5 +96,5 @@ def alt_train(sess, opt):
return np.mean(l_disc) + np.mean(l_gen)

# do training
alt_train(log_interval=10, ep_max=100, ep_size=data.train.num_batch, early_stop=False)
alt_train(log_interval=10, ep_max=30, ep_size=data.train.num_batch, early_stop=False)

0 comments on commit 30ea41c

Please sign in to comment.