-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
120 lines (87 loc) · 3.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import random
import os
import math
import numpy as np
import pandas as pd
from torch_geometric.data import Data, Batch
from sklearn.metrics import roc_auc_score, average_precision_score
def init_params(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
#torch.use_deterministic_algorithms(True)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def split(y):
train_ratio = 0.1
val_ratio = 0.1
test_ratio = 0.8
N = len(y)
train_num = int(N * train_ratio)
val_num = int(N * (train_ratio + val_ratio))
idx = np.arange(N)
np.random.shuffle(idx)
train_idx = idx[:train_num]
val_idx = idx[train_num:val_num]
test_idx = idx[val_num:]
train_idx = torch.tensor(train_idx)
val_idx = torch.tensor(val_idx)
test_idx = torch.tensor(test_idx)
return train_idx, val_idx, test_idx
def collate_basis(graphs, period):
graph_list = []
for g in graphs:
num_nodes = g.num_nodes
e = g.e
u = g.u.view(num_nodes, num_nodes)
period_term = torch.arange(period, device=u.device, dtype=torch.float32)
period_e = e.unsqueeze(1) * period_term
fourier_e = torch.cat([torch.sin(period_e), torch.cos(period_e)], dim=-1)
equ = u @ fourier_e
new_g = Data()
new_g.num_nodes = g.num_nodes
new_g.x = g.x
new_g.y = g.y
new_g.pos = equ
new_g.edge_index = g.edge_index
new_g.edge_attr = g.edge_attr
graph_list.append(new_g)
batched_graph = Batch.from_data_list(graph_list)
return batched_graph
def collate_basis_sign(graphs, period):
spa_graph_list = []
spe_graph_list = []
for g in graphs:
num_nodes = g.num_nodes
spa_g = Data()
spa_g.num_nodes = g.num_nodes
spa_g.x = g.x
spa_g.y = g.y
spa_g.edge_index = g.edge_index
spa_g.edge_attr = g.edge_attr
e = g.e
period_term = torch.arange(period, device=e.device, dtype=torch.float32)
period_e = e.unsqueeze(1) * period_term
fourier_e = torch.cat([torch.sin(period_e), torch.cos(period_e)], dim=-1)
udj = g.u.view(num_nodes, num_nodes)
row, col = udj.nonzero().t()
spe_g = Data()
spe_g.num_nodes = g.num_nodes
spe_g.x = fourier_e
spe_g.edge_index = torch.stack([col, row], dim=0)
spe_g.edge_attr = udj[row, col] # to align with the source-to-target direction
spa_graph_list.append(spa_g)
spe_graph_list.append(spe_g)
spa_batched_graph = Batch.from_data_list(spa_graph_list)
spe_batched_graph = Batch.from_data_list(spe_graph_list)
return spa_batched_graph, spe_batched_graph