forked from graphdeeplearning/benchmarking-gnns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_superpixels_graph_classification.py
111 lines (92 loc) · 3.45 KB
/
train_superpixels_graph_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Utility functions for training one epoch
and evaluating one epoch
"""
import torch
import torch.nn as nn
import math
from train.metrics import accuracy_MNIST_CIFAR as accuracy
"""
For GCNs
"""
def train_epoch_sparse(model, optimizer, device, data_loader, epoch):
model.train()
epoch_loss = 0
epoch_train_acc = 0
nb_data = 0
gpu_mem = 0
for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
batch_x = batch_graphs.ndata['feat'].to(device) # num x feat
batch_e = batch_graphs.edata['feat'].to(device)
batch_labels = batch_labels.to(device)
optimizer.zero_grad()
batch_scores = model.forward(batch_graphs, batch_x, batch_e)
loss = model.loss(batch_scores, batch_labels)
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_train_acc += accuracy(batch_scores, batch_labels)
nb_data += batch_labels.size(0)
epoch_loss /= (iter + 1)
epoch_train_acc /= nb_data
return epoch_loss, epoch_train_acc, optimizer
def evaluate_network_sparse(model, device, data_loader, epoch):
model.eval()
epoch_test_loss = 0
epoch_test_acc = 0
nb_data = 0
with torch.no_grad():
for iter, (batch_graphs, batch_labels) in enumerate(data_loader):
batch_x = batch_graphs.ndata['feat'].to(device)
batch_e = batch_graphs.edata['feat'].to(device)
batch_labels = batch_labels.to(device)
batch_scores = model.forward(batch_graphs, batch_x, batch_e)
loss = model.loss(batch_scores, batch_labels)
epoch_test_loss += loss.detach().item()
epoch_test_acc += accuracy(batch_scores, batch_labels)
nb_data += batch_labels.size(0)
epoch_test_loss /= (iter + 1)
epoch_test_acc /= nb_data
return epoch_test_loss, epoch_test_acc
"""
For WL-GNNs
"""
def train_epoch_dense(model, optimizer, device, data_loader, epoch, batch_size):
model.train()
epoch_loss = 0
epoch_train_acc = 0
nb_data = 0
gpu_mem = 0
optimizer.zero_grad()
for iter, (x_with_node_feat, labels) in enumerate(data_loader):
x_with_node_feat = x_with_node_feat.to(device)
labels = labels.to(device)
scores = model.forward(x_with_node_feat)
loss = model.loss(scores, labels)
loss.backward()
if not (iter%batch_size):
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.detach().item()
epoch_train_acc += accuracy(scores, labels)
nb_data += labels.size(0)
epoch_loss /= (iter + 1)
epoch_train_acc /= nb_data
return epoch_loss, epoch_train_acc, optimizer
def evaluate_network_dense(model, device, data_loader, epoch):
model.eval()
epoch_test_loss = 0
epoch_test_acc = 0
nb_data = 0
with torch.no_grad():
for iter, (x_with_node_feat, labels) in enumerate(data_loader):
x_with_node_feat = x_with_node_feat.to(device)
labels = labels.to(device)
scores = model.forward(x_with_node_feat)
loss = model.loss(scores, labels)
epoch_test_loss += loss.detach().item()
epoch_test_acc += accuracy(scores, labels)
nb_data += labels.size(0)
epoch_test_loss /= (iter + 1)
epoch_test_acc /= nb_data
return epoch_test_loss, epoch_test_acc