-
Notifications
You must be signed in to change notification settings - Fork 558
Writing Model
In this section, we take DeepWalk and GraphSage as examples to show how to implement a graph embedding learning model by using tf_euler
and tensorflow
.
First, import the required libraries:
import tensorflow as tf
import tf_euler
Using the encode/decode paradigm arXiv:1709.05584, a graph embedding learning model can be generally divided into three steps:
- Generate samples from node set;
- Encode the samples into embedding vectors;
- Decode the loss and evaluation metrics from the embedding vectors.
Next we write the graph embedding model in such a three-step paradigm.
Here we use tf_euler.layers
as meta-tools (you can also use tf.keras
, or other deep learning development kits) to implement the graph embedding learning model as a tf_euler.layers.Layer
, whose call
function receive a tf.Tensor
as the input nodes and output a tetrad representing the embeddings of the input nodes, the loss of the current Mini-batch, the evaluation metric name of the model, and the evaluation score of the current Mini-batch, respectively.
Putting DeepWalk into the above paradigm, its implementation can be divided into the following three steps:
- Generate a positive node by random walking from the source node, and sample negative nodes.
- Embed the source node, positive nodes, and negative nodes into vectors;
- Calculate the cross entropy loss and mrr from the embedding vectors of the source node, positive nodes, and negative nodes.
class DeepWalk(tf_euler.layers.Layer):
def __init__(self, node_type, edge_type, max_id, dim,
num_negs=8, walk_len=3, left_win_size=1, right_win_size=1):
super(DeepWalk, self).__init__()
self.node_type = node_type
self.edge_type = edge_type
self.max_id = max_id
self.num_negs = num_negs
self.walk_len = walk_len
self.left_win_size = left_win_size
self.right_win_size = right_win_size
self.target_encoder = tf_euler.layers.Embedding(max_id + 1, dim)
self.context_encoder = tf_euler.layers.Embedding(max_id + 1, dim)
def call(self, inputs):
src, pos, negs = self.sampler(inputs)
embedding = self.target_encoder(src)
embedding_pos = self.context_encoder(pos)
embedding_negs = self.context_encoder(negs)
loss, mrr = self.decoder(embedding, embedding_pos, embedding_negs)
embedding = self.target_encoder(inputs)
return (embedding, loss, 'mrr', mrr)
def sampler(self, inputs):
batch_size = tf.size(inputs)
path = tf_euler.random_walk(
inputs, [self.edge_type] * self.walk_len,
default_node=self.max_id + 1)
pair = tf_euler.gen_pair(path, self.left_win_size, self.right_win_size)
num_pairs = pair.shape[1]
src, pos = tf.split(pair, [1, 1], axis=-1)
negs = tf_euler.sample_node(batch_size * num_pairs * self.num_negs,
self.node_type)
src = tf.reshape(src, [batch_size * num_pairs, 1])
pos = tf.reshape(pos, [batch_size * num_pairs, 1])
negs = tf.reshape(negs, [batch_size * num_pairs, self.num_negs])
return src, pos, negs
def decoder(self, embedding, embedding_pos, embedding_negs):
logits = tf.matmul(embedding, embedding_pos, transpose_b=True)
neg_logits = tf.matmul(embedding, embedding_negs, transpose_b=True)
true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(logits), logits=logits)
negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(neg_logits), logits=neg_logits)
loss = tf.reduce_sum(true_xent) + tf.reduce_sum(negative_xent)
mrr = tf_euler.metrics.mrr_score(logits, neg_logits)
return loss, mrr
tf_euler
provides a series of operators (Euler-OP) to access the Euler graph engine in TensorFlow computation graph. Here we use tf_euler.random_walk
to get the path according to the configured edge type, then use tf_euler.gen_pair
to generate the pair of <source node, positive node>, and then use tf_euler.sample_node
to sample the negative nodes according to the configured node type.
We can use the tf_euler.sample_node
to perform node sampling on the whole graph to get mini-batch nodes for training:
tf_euler.initialize_embedded_graph('ppi') # 图数据目录
source = tf_euler.sample_node(128, tf_euler.ALL_NODE_TYPE)
source.set_shape([128])
model = DeepWalk(tf_euler.ALL_NODE_TYPE, [0, 1], 56944, 256)
_, loss, metric_name, metric = model(source)
global_step = tf.train.get_or_create_global_step()
train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss, global_step)
tf.logging.set_verbosity(tf.logging.INFO)
with tf.train.MonitoredTrainingSession(
hooks=[
tf.train.LoggingTensorHook({'step': global_step,
'loss': loss, metric_name: metric}, 100),
tf.train.StopAtStepHook(2000)
]) as sess:
while not sess.should_stop():
sess.run(train_op)
Running the above code could get the following output:
INFO:tensorflow:loss = 4804.9565, mrr = 0.3264798, step = 1
INFO:tensorflow:loss = 4770.668, mrr = 0.39584208, step = 101 (0.765 sec)
INFO:tensorflow:loss = 4713.837, mrr = 0.37533116, step = 201 (0.676 sec)
INFO:tensorflow:loss = 4120.8774, mrr = 0.42687973, step = 301 (0.653 sec)
INFO:tensorflow:loss = 3288.204, mrr = 0.439512, step = 401 (0.674 sec)
INFO:tensorflow:loss = 2826.6309, mrr = 0.46083882, step = 501 (0.662 sec)
INFO:tensorflow:loss = 2562.7861, mrr = 0.5067806, step = 601 (0.656 sec)
INFO:tensorflow:loss = 2336.0562, mrr = 0.55503833, step = 701 (0.670 sec)
INFO:tensorflow:loss = 2101.0967, mrr = 0.6194568, step = 801 (0.664 sec)
INFO:tensorflow:loss = 1984.6118, mrr = 0.65155166, step = 901 (0.647 sec)
INFO:tensorflow:loss = 1855.1826, mrr = 0.6955864, step = 1001 (0.621 sec)
INFO:tensorflow:loss = 1680.2745, mrr = 0.74010307, step = 1101 (0.648 sec)
INFO:tensorflow:loss = 1525.5436, mrr = 0.7830129, step = 1201 (0.628 sec)
INFO:tensorflow:loss = 1325.8943, mrr = 0.84210175, step = 1301 (0.672 sec)
INFO:tensorflow:loss = 1274.5737, mrr = 0.85022587, step = 1401 (0.689 sec)
INFO:tensorflow:loss = 1153.6146, mrr = 0.8824446, step = 1501 (0.645 sec)
INFO:tensorflow:loss = 1144.9847, mrr = 0.88094825, step = 1601 (0.645 sec)
INFO:tensorflow:loss = 961.09924, mrr = 0.92628604, step = 1701 (0.616 sec)
INFO:tensorflow:loss = 940.64496, mrr = 0.91833764, step = 1801 (0.634 sec)
INFO:tensorflow:loss = 888.75397, mrr = 0.946753, step = 1901 (0.656 sec)
GraphSage is an improved model of GCN which can be used to supervised learning on labeled graphs. GraphSage samples neighbors of the node and aggregates their features to get the embedding vector. Putting the supervised GraphSage into the above paradigm, its implementation can be divided into the following three steps:
- Get the label of the node.
- Perform multi-hop neighbor sampling for the node and use the node's feature/attribute as the original embedding vector. Conduct multi-layer aggregation for the original embedding vector to obtain the final embedding vector.
- Linearly classify the embedding vectors of the nodes to get sigmoid loss and f1 score.
Note, in each layer, GraphSage will aggregate the intermediate embedding vectors of each node and its neighbors to generate next-layer embedding vectors. Here is an example by using the Mean aggregator:
class MeanAggregator(tf_euler.layers.Layer):
def __init__(self, dim, activation=tf.nn.relu):
super(MeanAggregator, self).__init__()
self.self_layer = tf_euler.layers.Dense(
dim // 2, activation=activation, use_bias=False)
self.neigh_layer = tf_euler.layers.Dense(
dim // 2, activation=activation, use_bias=False)
def call(self, inputs):
self_embedding, neigh_embedding = inputs
agg_embedding = tf.reduce_mean(neigh_embedding, axis=1)
from_self = self.self_layer(self_embedding)
from_neighs = self.neigh_layer(agg_embedding)
return tf.concat([from_self, from_neighs], 1)
We use tf_euler.sample_fanout
to perform multi-hop neighbors sampling and then use tf_euler.get_dense_feature
to get the features of the nodes in each hop, and iteratively call the MeanAggregator
defined above for aggregation:
class SageEncoder(tf_euler.layers.Layer):
def __init__(self, metapath, fanouts, dim, feature_idx, feature_dim):
super(SageEncoder, self).__init__()
self.metapath = metapath
self.fanouts = fanouts
self.num_layers = len(metapath)
self.feature_idx = feature_idx
self.feature_dim = feature_dim
self.aggregators = []
for layer in range(self.num_layers):
activation = tf.nn.relu if layer < self.num_layers - 1 else None
self.aggregators.append(MeanAggregator(dim, activation=activation))
self.dims = [feature_dim] + [dim] * self.num_layers
def call(self, inputs):
samples = tf_euler.sample_fanout(inputs, self.metapath, self.fanouts)[0]
hidden = [
tf_euler.get_dense_feature(sample,
[self.feature_idx], [self.feature_dim])[0]
for sample in samples]
for layer in range(self.num_layers):
aggregator = self.aggregators[layer]
next_hidden = []
for hop in range(self.num_layers - layer):
neigh_shape = [-1, self.fanouts[hop], self.dims[layer]]
h = aggregator((hidden[hop], tf.reshape(hidden[hop + 1], neigh_shape)))
next_hidden.append(h)
hidden = next_hidden
return hidden[0]
Finally, we use tf_euler.get_dense_feature
to grab the labels of nodes from the graph and linearly classify the nodes by using the final-layer's embedding vectors:
class GraphSage(tf_euler.layers.Layer):
def __init__(self, label_idx, label_dim,
metapath, fanouts, dim, feature_idx, feature_dim):
super(GraphSage, self).__init__()
self.label_idx = label_idx
self.label_dim = label_dim
self.encoder = SageEncoder(metapath, fanouts, dim, feature_idx, feature_dim)
self.predict_layer = tf_euler.layers.Dense(label_dim)
def call(self, inputs):
nodes, labels = self.sampler(inputs)
embedding = self.encoder(nodes)
loss, f1 = self.decoder(embedding, labels)
return (embedding, loss, 'f1', f1)
def sampler(self, inputs):
labels = tf_euler.get_dense_feature(inputs, [self.label_idx],
[self.label_dim])[0]
return inputs, labels
def decoder(self, embedding, labels):
logits = self.predict_layer(embedding)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
predictions = tf.floor(tf.nn.sigmoid(logits) + 0.5)
f1 = tf_euler.metrics.f1_score(labels, predictions)
return tf.reduce_mean(loss), f1
We can train the model using the similar method in [DeepWalk] section:
tf_euler.initialize_embedded_graph('ppi')
source = tf_euler.sample_node(512, 0)
source.set_shape([512])
model = GraphSage(0, 121, [[0], [0]], [10, 10], 256, 1, 50)
_, loss, metric_name, metric = model(source)
global_step = tf.train.get_or_create_global_step()
train_op = tf.train.AdamOptimizer(0.01).minimize(loss, global_step)
tf.logging.set_verbosity(tf.logging.INFO)
with tf.train.MonitoredTrainingSession(
hooks=[
tf.train.LoggingTensorHook({'step': global_step,
'loss': loss, metric_name: metric}, 100),
tf.train.StopAtStepHook(2000)
]) as sess:
while not sess.should_stop():
sess.run(train_op)
Running the above code could get the following output:
INFO:tensorflow:f1 = 0.3850271, loss = 0.69317585, step = 1
INFO:tensorflow:f1 = 0.42160043, loss = 0.5167424, step = 101 (4.987 sec)
INFO:tensorflow:f1 = 0.4489097, loss = 0.5023754, step = 201 (4.788 sec)
INFO:tensorflow:f1 = 0.4701608, loss = 0.49763823, step = 301 (4.866 sec)
INFO:tensorflow:f1 = 0.4902702, loss = 0.48410782, step = 401 (4.809 sec)
INFO:tensorflow:f1 = 0.5044798, loss = 0.4730545, step = 501 (4.851 sec)
INFO:tensorflow:f1 = 0.5104125, loss = 0.4705497, step = 601 (4.866 sec)
INFO:tensorflow:f1 = 0.51712954, loss = 0.47582737, step = 701 (4.844 sec)
INFO:tensorflow:f1 = 0.5240817, loss = 0.46666723, step = 801 (4.871 sec)
INFO:tensorflow:f1 = 0.53172356, loss = 0.45738563, step = 901 (4.837 sec)
INFO:tensorflow:f1 = 0.53270173, loss = 0.4746988, step = 1001 (4.802 sec)
INFO:tensorflow:f1 = 0.53611106, loss = 0.46039847, step = 1101 (4.882 sec)
INFO:tensorflow:f1 = 0.5402253, loss = 0.46644467, step = 1201 (4.808 sec)
INFO:tensorflow:f1 = 0.5420937, loss = 0.47356603, step = 1301 (4.820 sec)
INFO:tensorflow:f1 = 0.5462865, loss = 0.45834514, step = 1401 (4.872 sec)
INFO:tensorflow:f1 = 0.5511238, loss = 0.45826617, step = 1501 (4.848 sec)
INFO:tensorflow:f1 = 0.5543519, loss = 0.4414709, step = 1601 (4.865 sec)
INFO:tensorflow:f1 = 0.5557352, loss = 0.4589582, step = 1701 (4.836 sec)
INFO:tensorflow:f1 = 0.5591235, loss = 0.45354822, step = 1801 (4.869 sec)
INFO:tensorflow:f1 = 0.56102884, loss = 0.44353116, step = 1901 (4.885 sec)