forked from GitEventhandler/H2GCN-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
108 lines (97 loc) · 3.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import time
import argparse
import torch
import torch.nn.functional as F
import torch.optim as optim
import utils
from utils import accuracy, set_seed, load_data
from model import H2GCN
# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0, help='seed')
parser.add_argument('--epochs', type=int, default=500, help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--k', type=int, default=2, help='number of embedding rounds')
parser.add_argument('--wd', type=float, default=5e-4, help='weight decay value')
parser.add_argument('--hidden', type=int, default=64, help='embedding output dim')
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
parser.add_argument('--patience', type=int, default=50, help='patience for early stop')
parser.add_argument('--dataset', default='cora', help='dateset name')
parser.add_argument('--gpu', type=int, default=0, help='gpu id to use while training, set -1 to use cpu')
parser.add_argument('--split', type=str, default="DEFAULT", help='data split to use')
args = parser.parse_args()
def train():
model.train()
optimizer.zero_grad()
output = model(adj, features)
acc_train = accuracy(output[idx_train], labels[idx_train].to(device))
loss_train = F.nll_loss(output[idx_train], labels[idx_train].to(device))
loss_train.backward()
optimizer.step()
return loss_train.item(), acc_train.item()
def validate():
model.eval()
with torch.no_grad():
output = model(adj, features)
loss_val = F.nll_loss(output[idx_val], labels[idx_val].to(device))
acc_val = accuracy(output[idx_val], labels[idx_val].to(device))
return loss_val.item(), acc_val.item()
def test():
model.load_state_dict(torch.load(checkpoint_path))
model.eval()
with torch.no_grad():
output = model(adj, features)
loss_test = F.nll_loss(output[idx_test], labels[idx_test].to(device))
acc_test = accuracy(output[idx_test], labels[idx_test].to(device))
return loss_test.item(), acc_test.item()
def main():
begin_time = time.time()
tolerate = 0
best_loss = 1000
acc = 0
for epoch in range(args.epochs):
loss_train, acc_train = train()
loss_validate, acc_validate = validate()
if (epoch + 1) % 1 == 0:
print(
'Epoch {:03d}'.format(epoch + 1),
'|| train',
'loss : {:.3f}'.format(loss_train),
', accuracy : {:.2f}%'.format(acc_train * 100),
'|| val',
'loss : {:.3f}'.format(loss_validate),
', accuracy : {:.2f}%'.format(acc_validate * 100)
)
if loss_validate < best_loss:
best_loss = loss_validate
acc = acc_validate
torch.save(model.state_dict(), checkpoint_path)
tolerate = 0
else:
tolerate += 1
if tolerate == args.patience:
break
print("Train cost : {:.2f}s".format(time.time() - begin_time))
print("Test accuracy : {:.2f}%".format(test()[1] * 100), "on dataset", args.dataset)
if __name__ == '__main__':
set_seed(args.seed)
device = torch.device('cpu' if args.gpu == -1 else "cuda:%s" % args.gpu)
split_path = utils.root + '/splits/' + args.dataset + '_split_0.6_0.2_0.npz' if args.split == 'DEFAULT' else args.split
# load dataset
adj, features, labels, idx_train, idx_val, idx_test, feat_dim, class_dim = load_data(
args.dataset,
split_path
)
adj = adj.to(device)
features = features.to(device)
checkpoint_path = utils.root + '/ckpt/%s.pt' % args.dataset
if not os.path.exists(utils.root + '/ckpt'):
os.makedirs(utils.root + '/ckpt')
model = H2GCN(
feat_dim=feat_dim,
hidden_dim=args.hidden,
class_dim=class_dim,
).to(device)
optimizer = optim.Adam([{'params': model.params, 'weight_decay': args.wd}], lr=args.lr)
main()