forked from graphdeeplearning/benchmarking-gnns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_CSL_graph_classification.py
129 lines (111 loc) · 4.29 KB
/
train_CSL_graph_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Utility functions for training one epoch
and evaluating one epoch
"""
import torch
import torch.nn as nn
import math
from train.metrics import accuracy_TU as accuracy
"""
For GCNs
"""
def train_epoch_sparse(model, optimizer, device, data_loader, epoch):
model.train()
epoch_loss = 0
epoch_train_acc = 0
nb_data = 0
gpu_mem = 0
for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
batch_x = batch_graphs.ndata['feat'].to(device) # num x feat
batch_e = batch_graphs.edata['feat'].to(device)
batch_labels = batch_labels.to(device)
optimizer.zero_grad()
try:
batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device)
sign_flip = torch.rand(batch_pos_enc.size(1)).to(device)
sign_flip[sign_flip>=0.5] = 1.0; sign_flip[sign_flip<0.5] = -1.0
batch_pos_enc = batch_pos_enc * sign_flip.unsqueeze(0)
batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_pos_enc)
except:
batch_scores = model.forward(batch_graphs, batch_x, batch_e)
loss = model.loss(batch_scores, batch_labels)
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_train_acc += accuracy(batch_scores, batch_labels)
nb_data += batch_labels.size(0)
epoch_loss /= (iter + 1)
epoch_train_acc /= nb_data
return epoch_loss, epoch_train_acc, optimizer
def evaluate_network_sparse(model, device, data_loader, epoch):
model.eval()
epoch_test_loss = 0
epoch_test_acc = 0
nb_data = 0
with torch.no_grad():
for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
batch_x = batch_graphs.ndata['feat'].to(device)
batch_e = batch_graphs.edata['feat'].to(device)
batch_labels = batch_labels.to(device)
try:
batch_pos_enc = batch_graphs.ndata['pos_enc'].to(device)
batch_scores = model.forward(batch_graphs, batch_x, batch_e, batch_pos_enc)
except:
batch_scores = model.forward(batch_graphs, batch_x, batch_e)
loss = model.loss(batch_scores, batch_labels)
epoch_test_loss += loss.detach().item()
epoch_test_acc += accuracy(batch_scores, batch_labels)
nb_data += batch_labels.size(0)
epoch_test_loss /= (iter + 1)
epoch_test_acc /= nb_data
return epoch_test_loss, epoch_test_acc
"""
For WL-GNNs
"""
def train_epoch_dense(model, optimizer, device, data_loader, epoch, batch_size):
model.train()
epoch_loss = 0
epoch_train_acc = 0
nb_data = 0
gpu_mem = 0
optimizer.zero_grad()
for iter, (x_with_node_feat, labels) in enumerate(data_loader):
x_with_node_feat = x_with_node_feat.to(device)
labels = labels.to(device)
scores = model.forward(x_with_node_feat)
loss = model.loss(scores, labels)
loss.backward()
if not (iter%batch_size):
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.detach().item()
epoch_train_acc += accuracy(scores, labels)
nb_data += labels.size(0)
epoch_loss /= (iter + 1)
epoch_train_acc /= nb_data
return epoch_loss, epoch_train_acc, optimizer
def evaluate_network_dense(model, device, data_loader, epoch):
model.eval()
epoch_test_loss = 0
epoch_test_acc = 0
nb_data = 0
with torch.no_grad():
for iter, (x_with_node_feat, labels) in enumerate(data_loader):
x_with_node_feat = x_with_node_feat.to(device)
labels = labels.to(device)
scores = model.forward(x_with_node_feat)
loss = model.loss(scores, labels)
epoch_test_loss += loss.detach().item()
epoch_test_acc += accuracy(scores, labels)
nb_data += labels.size(0)
epoch_test_loss /= (iter + 1)
epoch_test_acc /= nb_data
return epoch_test_loss, epoch_test_acc
def check_patience(all_losses, best_loss, best_epoch, curr_loss, curr_epoch, counter):
if curr_loss < best_loss:
counter = 0
best_loss = curr_loss
best_epoch = curr_epoch
else:
counter += 1
return best_loss, best_epoch, counter