-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
51 lines (40 loc) · 2.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import numpy as np
import copy
from tqdm import trange
from torch_geometric.data import DataLoader
from torch_geometric.utils import negative_sampling
def train(model, link_predictor, emb, edge_index, pos_train_edge, batch_size, optimizer):
"""
Runs offline training for model, link_predictor and node embeddings given the message
edges and supervision edges.
:param model: Torch Graph model used for updating node embeddings based on message passing
:param link_predictor: Torch model used for predicting whether edge exists or not
:param emb: (N, d) Initial node embeddings for all N nodes in graph
:param edge_index: (2, E) Edge index for all edges in the graph
:param pos_train_edge: (PE, 2) Positive edges used for training supervision loss
:param batch_size: Number of positive (and negative) supervision edges to sample per batch
:param optimizer: Torch Optimizer to update model parameters
:return: Average supervision loss over all positive (and correspondingly sampled negative) edges
"""
model.train()
link_predictor.train()
train_losses = []
for edge_id in DataLoader(range(pos_train_edge.shape[0]), batch_size, shuffle=True):
optimizer.zero_grad()
# Run message passing on the inital node embeddings to get updated embeddings
node_emb = model(emb, edge_index) # (N, d)
# Predict the class probabilities on the batch of positive edges using link_predictor
pos_edge = pos_train_edge[edge_id].T # (2, B)
pos_pred = link_predictor(node_emb[pos_edge[0]], node_emb[pos_edge[1]]) # (B, )
# Sample negative edges (same as number of positive edges) and predict class probabilities
neg_edge = negative_sampling(edge_index, num_nodes=emb.shape[0],
num_neg_samples=edge_id.shape[0], method='dense') # (Ne,2)
neg_pred = link_predictor(node_emb[neg_edge[0]], node_emb[neg_edge[1]]) # (Ne,)
# Compute the corresponding negative log likelihood loss on the positive and negative edges
loss = -torch.log(pos_pred + 1e-15).mean() - torch.log(1 - neg_pred + 1e-15).mean()
loss.backward()
optimizer.step()
train_losses.append(loss.item())
# print(loss.item())
return sum(train_losses) / len(train_losses)