forked from buriburisuri/ebgan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_ebgan_train.py
104 lines (77 loc) · 2.45 KB
/
mnist_ebgan_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
import sugartensor as tf
import numpy as np
# set log level to debug
tf.sg_verbosity(10)
#
# hyper parameters
#
batch_size = 32 # batch size
z_dim = 50 # noise dimension
margin = 1 # max-margin for hinge loss
pt_weight = 0.1 # PT regularizer's weight
#
# inputs
#
# MNIST input tensor ( with QueueRunner )
data = tf.sg_data.Mnist(batch_size=batch_size)
# input images
x = data.train.image
#
# create generator
#
# random uniform seed
z = tf.random_uniform((batch_size, z_dim))
with tf.sg_context(name='generator', size=4, stride=2, act='relu', bn=True):
# generator network
gen = (z.sg_dense(dim=1024)
.sg_dense(dim=7*7*128)
.sg_reshape(shape=(-1, 7, 7, 128))
.sg_upconv(dim=64)
.sg_upconv(dim=1, act='sigmoid', bn=False))
#
# create discriminator
#
# create real + fake image input
xx = tf.concat(0, [x, gen])
with tf.sg_context(name='discriminator', size=4, stride=2, act='leaky_relu'):
disc = (xx.sg_conv(dim=64)
.sg_conv(dim=128)
.sg_upconv(dim=64)
.sg_upconv(dim=1, act='linear'))
#
# pull-away term ( PT ) regularizer
#
sample = gen.sg_flatten()
nom = tf.matmul(sample, tf.transpose(sample, perm=[1, 0]))
denom = tf.reduce_sum(tf.square(sample), reduction_indices=[1], keep_dims=True)
pt = tf.square(nom/denom)
pt -= tf.diag(tf.diag_part(pt))
pt = tf.reduce_sum(pt) / (batch_size * (batch_size - 1))
#
# loss & train ops
#
# mean squared errors
mse = tf.reduce_mean(tf.square(disc - xx), reduction_indices=[1, 2, 3])
mse_real, mse_fake = mse[:batch_size], mse[batch_size:]
loss_disc = mse_real + tf.maximum(margin - mse_fake, 0) # discriminator loss
loss_gen = mse_fake + pt * pt_weight # generator loss + PT regularizer
train_disc = tf.sg_optim(loss_disc, lr=0.001, category='discriminator') # discriminator train ops
train_gen = tf.sg_optim(loss_gen, lr=0.001, category='generator') # generator train ops
#
# add summary
#
tf.sg_summary_loss(tf.identity(loss_disc, name='disc'))
tf.sg_summary_loss(tf.identity(loss_gen, name='gen'))
tf.sg_summary_image(gen)
#
# training
#
# def alternate training func
@tf.sg_train_func
def alt_train(sess, opt):
l_disc = sess.run([loss_disc, train_disc])[0] # training discriminator
l_gen = sess.run([loss_gen, train_gen])[0] # training generator
return np.mean(l_disc) + np.mean(l_gen)
# do training
alt_train(log_interval=10, max_ep=30, ep_size=data.train.num_batch, early_stop=False)