Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mansour committed Sep 18, 2016
1 parent 30ea41c commit bd1fece
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 17 deletions.
39 changes: 37 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,37 @@
# ebgan
tensorflow implementation of Junbo et al's Energy-based generative adversarial network ( EBGAN ) paper.
# EBGAN
A tensorflow implementation of Junbo et al's Energy-based generative adversarial network ( EBGAN ) paper.
( See : [https://arxiv.org/pdf/1609.03126v2.pdf](https://arxiv.org/pdf/1609.03126v2.pdf) )

## Dependencies

1. tensorflow >= rc0.10
1. sugartensor >= 0.0.1

## Training the network

Execute
<pre><code>
python mnist_ebgan_train.py
</code></pre>
to train the network. You can see the result ckpt files and log files in the 'asset/train' directory.
Launch tensorboard --logdir asset/train/log to monitor training process.


## Generating image.

Execute
<pre><code>
python mnist_ebgan_generate.py
</code></pre>
to generate sample image. The 'sample.png' file will be generated in the 'asset/train' directory.

## Sample image.

This image was generated by EBGAN network.
<p align="center">
<img src="https://raw.githubusercontent.com/buriburisuri/ebgan/master/ebgan/png/sample.png" width="350"/>
</p>


# Authors
Namju Kim ([email protected]) at Jamonglabs Co., Ltd.
1 change: 0 additions & 1 deletion mnist_ebgan_generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import sugartensor as tf
import numpy as np
import matplotlib.pyplot as plt

# set log level to debug
Expand Down
32 changes: 18 additions & 14 deletions mnist_ebgan_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
# hyper parameters
#

batch_size = 32
z_dim = 50
margin = 1
batch_size = 32 # batch size
z_dim = 50 # noise dimension
margin = 1 # max-margin for hinge loss
pt_weight = 0.1 # PT regularizer's weight

#
# inputs
Expand All @@ -23,12 +24,6 @@
# input images
x = data.train.image

# generator labels ( all ones )
y = tf.ones(batch_size, dtype=tf.sg_floatx)

# discriminator labels ( half 1s, half 0s )
y_disc = tf.concat(0, [y, y * 0])

#
# create generator
#
Expand All @@ -45,7 +40,6 @@
.sg_upconv(dim=64)
.sg_upconv(dim=1, act='sigmoid', bn=False))


#
# create discriminator
#
Expand All @@ -59,17 +53,27 @@
.sg_upconv(dim=64)
.sg_upconv(dim=1, act='linear'))

#
# pull-away term ( PT ) regularizer
#

sample = gen.sg_flatten()
nom = tf.matmul(sample, tf.transpose(sample, perm=[1, 0]))
denom = tf.reduce_sum(tf.square(sample), reduction_indices=[1], keep_dims=True)
pt = tf.square(nom/denom)
pt -= tf.diag(tf.diag_part(pt))
pt = tf.reduce_sum(pt) / (batch_size * (batch_size - 1))

#
# loss & train ops
#

# squared errors
mse = tf.square(disc - xx)
mse_real, mse_fake = mse[:batch_size, :, :, :], mse[batch_size:, :, :, :]
# mean squared errors
mse = tf.reduce_mean(tf.square(disc - xx), reduction_indices=[1, 2, 3])
mse_real, mse_fake = mse[:batch_size], mse[batch_size:]

loss_disc = mse_real + tf.maximum(margin - mse_fake, 0) # discriminator loss
loss_gen = mse_fake # generator loss
loss_gen = mse_fake + pt * pt_weight # generator loss + PT regularizer

train_disc = tf.sg_optim(loss_disc, lr=0.001, category='discriminator') # discriminator train ops
train_gen = tf.sg_optim(loss_gen, lr=0.001, category='generator') # generator train ops
Expand Down
Binary file added png/sample.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit bd1fece

Please sign in to comment.