-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
273 lines (232 loc) · 11.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from __future__ import division
from __future__ import print_function
import os
import time
import tensorflow as tf
import numpy as np
import sklearn
from sklearn import metrics
from src.models import GraphHAT, LayerInfo
from src.sampler import UniformNeighborSampler, NodeMinibatchIterator
from src.utils import load_data
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# Set random seed
seed = 123
np.random.seed(seed)
tf.set_random_seed(seed)
# Settings
flags = tf.app.flags
FLAGS = flags.FLAGS
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
#core params..
flags.DEFINE_string('model', 'graphhat', 'model names')
flags.DEFINE_float('learning_rate', 0.001, 'initial learning rate')
flags.DEFINE_string('train_prefix', '', 'prefix identifying training data. must be specified.')
# left to default values in main experiments
flags.DEFINE_integer('epochs', 100, 'number of epochs to train.')
flags.DEFINE_float('dropout', 0.0, 'dropout rate (1 - keep probability).')
flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.')
flags.DEFINE_integer('max_degree', 128, 'maximum node degree.')
flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')
flags.DEFINE_integer('samples_2', 20, 'number of samples in layer 2')
flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')
flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')
flags.DEFINE_boolean('random_context', False, 'Whether to use random context or direct edges')
flags.DEFINE_integer('batch_size', 512, 'minibatch size.')
flags.DEFINE_boolean('sigmoid', False, 'whether to use sigmoid loss')
flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.')
#logging, saving, validation settings etc.
flags.DEFINE_string('base_log_dir', '.', 'base directory for logging and saving embeddings')
flags.DEFINE_integer('validate_iter', 5000, "how often to run a validation minibatch.")
flags.DEFINE_integer('validate_batch_size', 256, "how many nodes per validation sample.")
flags.DEFINE_integer('gpu', 1, "which gpu to use.")
flags.DEFINE_integer('print_every', 5, "How often to print training info.")
flags.DEFINE_integer('max_total_steps', 10**10, "Maximum total number of iterations")
os.environ["CUDA_VISIBLE_DEVICES"]=str(FLAGS.gpu)
GPU_MEM_FRACTION = 0.8
def calc_f1(y_true, y_pred):
if not FLAGS.sigmoid:
y_true = np.argmax(y_true, axis=1)
y_pred = np.argmax(y_pred, axis=1)
else:
y_pred[y_pred > 0.5] = 1
y_pred[y_pred <= 0.5] = 0
return metrics.f1_score(y_true, y_pred, average="micro"), metrics.f1_score(y_true, y_pred, average="macro"), metrics.accuracy_score(y_true, y_pred)
# Define model evaluation function
def evaluate(sess, model, minibatch_iter, size=None):
t_test = time.time()
feed_dict_val, labels = minibatch_iter.node_val_feed_dict(size)
node_outs_val = sess.run([model.preds, model.loss],
feed_dict=feed_dict_val)
mic, mac, acc = calc_f1(labels, node_outs_val[0])
return node_outs_val[1], mic, mac, acc, (time.time() - t_test)
def log_dir():
log_dir = FLAGS.base_log_dir + "/sup-" + FLAGS.train_prefix.split("/")[-2]
log_dir += "/{model:s}_{lr:0.4f}/".format(
model=FLAGS.model,
lr=FLAGS.learning_rate)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir
def incremental_evaluate(sess, model, minibatch_iter, size, test=False):
t_test = time.time()
finished = False
val_losses = []
val_preds = []
labels = []
iter_num = 0
finished = False
while not finished:
feed_dict_val, batch_labels, finished, _ = minibatch_iter.incremental_node_val_feed_dict(size, iter_num, test=test)
node_outs_val = sess.run([model.preds, model.loss],
feed_dict=feed_dict_val)
val_preds.append(node_outs_val[0])
labels.append(batch_labels)
val_losses.append(node_outs_val[1])
iter_num += 1
val_preds = np.vstack(val_preds)
labels = np.vstack(labels)
mic, mac, acc = calc_f1(labels, val_preds)
return np.mean(val_losses), mic, mac, acc, (time.time() - t_test)
def construct_placeholders(num_classes):
# Define placeholders
placeholders = {
'labels' : tf.placeholder(tf.float32, shape=(None, num_classes), name='labels'),
'batch' : tf.placeholder(tf.int32, shape=(None), name='batch1'),
'dropout': tf.placeholder_with_default(0., shape=(), name='dropout'),
'batch_size' : tf.placeholder(tf.int32, name='batch_size'),
}
return placeholders
def train(train_data, test_data=None):
G = train_data[0]
features = train_data[1]
id_map = train_data[2]
class_map = train_data[4]
if isinstance(list(class_map.values())[0], list):
num_classes = len(list(class_map.values())[0])
else:
num_classes = len(set(class_map.values()))
if not features is None:
# pad with dummy zero vector
features = np.vstack([features, np.zeros((features.shape[1],))])
context_pairs = train_data[3] if FLAGS.random_context else None
placeholders = construct_placeholders(num_classes)
minibatch = NodeMinibatchIterator(G,
id_map,
placeholders,
class_map,
num_classes,
batch_size=FLAGS.batch_size,
max_degree=FLAGS.max_degree,
context_pairs = context_pairs)
adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
if FLAGS.model == 'graphhat':
# Create model
sampler = UniformNeighborSampler(adj_info)
if FLAGS.samples_2 != 0:
layer_infos = [LayerInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
LayerInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
else:
layer_infos = [LayerInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1)]
model = GraphHAT(num_classes, placeholders,
features,
adj_info,
minibatch.deg,
layer_infos,
sigmoid_loss = FLAGS.sigmoid,
identity_dim = FLAGS.identity_dim,
logging=True)
else:
raise Exception('Error: model name unrecognized.')
config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
config.gpu_options.allow_growth = True
#config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
config.allow_soft_placement = True
# Initialize session
sess = tf.Session(config=config)
merged = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
# Init variables
sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
# Train model
total_steps = 0
avg_time = 0.0
epoch_val_costs = []
train_adj_info = tf.assign(adj_info, minibatch.adj)
val_adj_info = tf.assign(adj_info, minibatch.test_adj)
for epoch in range(FLAGS.epochs):
minibatch.shuffle()
iter = 0
print('Epoch: %04d' % (epoch + 1))
epoch_val_costs.append(0)
while not minibatch.end():
# Construct feed dictionary
feed_dict, labels = minibatch.next_minibatch_feed_dict()
feed_dict.update({placeholders['dropout']: FLAGS.dropout})
t = time.time()
# Training step
outs = sess.run([merged, model.opt_op, model.loss, model.preds], feed_dict=feed_dict)
train_cost = outs[2]
if iter % FLAGS.validate_iter == 0:
# Validation
sess.run(val_adj_info.op)
if FLAGS.validate_batch_size == -1:
val_cost, val_f1_mic, val_f1_mac, val_acc, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size)
else:
val_cost, val_f1_mic, val_f1_mac, val_acc, duration = evaluate(sess, model, minibatch, FLAGS.validate_batch_size)
sess.run(train_adj_info.op)
epoch_val_costs[-1] += val_cost
if total_steps % FLAGS.print_every == 0:
summary_writer.add_summary(outs[0], total_steps)
# Print results
avg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)
if total_steps % FLAGS.print_every == 0:
train_f1_mic, train_f1_mac, train_accuracy = calc_f1(labels, outs[-1])
print("Iter:", '%04d' % iter,
"train_loss=", "{:.5f}".format(train_cost),
"train_f1_mic=", "{:.5f}".format(train_f1_mic),
"train_f1_mac=", "{:.5f}".format(train_f1_mac),
"train_acc=", "{:.5f}".format(train_accuracy),
"val_loss=", "{:.5f}".format(val_cost),
"val_f1_mic=", "{:.5f}".format(val_f1_mic),
"val_f1_mac=", "{:.5f}".format(val_f1_mac),
"val_acc=", "{:.5f}".format(val_acc),
"time=", "{:.5f}".format(avg_time))
iter += 1
total_steps += 1
if total_steps > FLAGS.max_total_steps:
break
if total_steps > FLAGS.max_total_steps:
break
print("Optimization Finished!")
sess.run(val_adj_info.op)
val_cost, val_f1_mic, val_f1_mac, val_acc, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size)
print("Full validation stats:",
"loss=", "{:.5f}".format(val_cost),
"f1_micro=", "{:.5f}".format(val_f1_mic),
"f1_macro=", "{:.5f}".format(val_f1_mac),
"accuracy=", "{:.5f}".format(val_acc),
"time=", "{:.5f}".format(duration))
with open(log_dir() + "val_stats.txt", "w") as fp:
fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f} time={:.5f}".
format(val_cost, val_f1_mic, val_f1_mac, duration))
print("Writing test set stats to file (don't peak!)")
test_cost, test_f1_mic, test_f1_mac, test_acc, duration = incremental_evaluate(sess, model, minibatch, FLAGS.batch_size, test=True)
print("Full test stats:",
"loss=", "{:.5f}".format(test_cost),
"f1_micro=", "{:.5f}".format(test_f1_mic),
"f1_macro=", "{:.5f}".format(test_f1_mac),
"accuracy=", "{:.5f}".format(test_acc),
"time=", "{:.5f}".format(duration))
with open(log_dir() + "test_stats.txt", "w") as fp:
fp.write("loss={:.5f} f1_micro={:.5f} f1_macro={:.5f}".
format(test_cost, test_f1_mic, test_f1_mac))
def main(argv=None):
print("Loading training data..")
train_data = load_data(FLAGS.train_prefix)
print("Done loading training data..")
train(train_data)
if __name__ == '__main__':
tf.app.run()