-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
90 lines (76 loc) · 3.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import argparse
import torch
from ogb.linkproppred import PygLinkPropPredDataset, Evaluator
from torch.optim import optimizer
import torch.optim as optim
from torch_geometric.data import DataLoader
from gnn_stack import GNNStack
from train import train
from link_predictor import LinkPredictor
from evaluate import test
from utils import print_and_log
def main():
parser = argparse.ArgumentParser(description="Script to train link prediction in offline graph setting")
parser.add_argument('--epochs', type=int, default=300,
help="Number of epochs for training")
parser.add_argument('--lr', type=float, default=3e-3,
help="Learning rate training")
parser.add_argument('--node_dim', type=int, default=256,
help='Embedding dimension for nodes')
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--batch_size', type=int, default=64 * 1024)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--exp_dir', type=str, default=None,
help="Path to exp dir for model checkpoints and experiment logs")
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optim_wd = 0
epochs = args.epochs
hidden_dim = args.hidden_channels
dropout = args.dropout
num_layers = args.num_layers
lr = args.lr
node_emb_dim = args.node_dim
batch_size = args.batch_size
exp_dir = args.exp_dir
if exp_dir is None:
exp_dir = "./experiments"
dir = f"offline.epochs:{epochs}.lr{lr}.layers:{num_layers}" \
f".hidden_dim:{hidden_dim}.node_dim:{node_emb_dim}.init_batch_size:{batch_size}"
exp_dir = os.path.join(exp_dir, dir)
model_dir = os.path.join(exp_dir, 'checkpoints')
logs_dir = os.path.join(exp_dir, 'logs')
os.makedirs(exp_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)
logfile_path = os.path.join(logs_dir, 'log.txt')
logfile = open(logfile_path, "a" if os.path.isfile(logfile_path) else "w", buffering=1)
# Download and process data at './dataset/ogbl-ddi/'
dataset = PygLinkPropPredDataset(name="ogbl-ddi", root='./dataset/')
split_edge = dataset.get_edge_split()
pos_train_edge = split_edge['train']['edge'].to(device)
graph = dataset[0]
edge_index = graph.edge_index.to(device)
evaluator = Evaluator(name='ogbl-ddi')
# Create embedding, model, and optimizer
emb = torch.nn.Embedding(graph.num_nodes, node_emb_dim).to(device)
model = GNNStack(node_emb_dim, hidden_dim, hidden_dim, num_layers, dropout, emb=True).to(device)
link_predictor = LinkPredictor(hidden_dim, hidden_dim, 1, num_layers + 1, dropout).to(device)
optimizer = optim.Adam(
list(model.parameters()) + list(link_predictor.parameters()) + list(emb.parameters()),
lr=lr, weight_decay=optim_wd
)
for e in range(epochs):
loss = train(model, link_predictor, emb.weight, edge_index, pos_train_edge, batch_size, optimizer)
print_and_log(logfile, f"Epoch {e + 1}: loss: {round(loss, 5)}")
if (e + 1) % 10 == 0:
torch.save(model.state_dict(), os.path.join(model_dir, f"model_{e + 1}.pt"))
torch.save(emb.state_dict(), os.path.join(model_dir, f"emb_{e + 1}.pt"))
torch.save(link_predictor.state_dict(), os.path.join(model_dir, f"link_pred_{e + 1}.pt"))
result = test(model, link_predictor, emb.weight, edge_index, split_edge, batch_size, evaluator)
print_and_log(logfile, f"{result}")
logfile.close()
if __name__ == "__main__":
main()