-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
139 lines (108 loc) · 4.39 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from torch_geometric.utils import to_networkx, degree, to_scipy_sparse_matrix
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from scipy import sparse as sp
def convert_to_nodeDegreeFeatures(graphs):
graph_infos = []
maxdegree = 0
for i, graph in enumerate(graphs):
g = to_networkx(graph, to_undirected=True)
gdegree = max(dict(g.degree).values())
if gdegree > maxdegree:
maxdegree = gdegree
graph_infos.append((graph, g.degree, graph.num_nodes))
new_graphs = []
for i, tuple in enumerate(graph_infos):
idx, x = tuple[0].edge_index[0], tuple[0].x
deg = degree(idx, tuple[2], dtype=torch.long)
deg = F.one_hot(deg, num_classes=maxdegree + 1).to(torch.float)
new_graph = tuple[0].clone()
new_graph.__setitem__('x', deg)
new_graphs.append(new_graph)
return new_graphs
def get_maxDegree(graphs):
maxdegree = 0
for i, graph in enumerate(graphs):
g = to_networkx(graph, to_undirected=True)
gdegree = max(dict(g.degree).values())
if gdegree > maxdegree:
maxdegree = gdegree
return maxdegree
def split_data(graphs, train=None, test=None, shuffle=True, seed=None):
y = torch.cat([graph.y for graph in graphs])
graphs_tv, graphs_test = train_test_split(graphs, train_size=train, test_size=test, stratify=y, shuffle=shuffle, random_state=seed)
return graphs_tv, graphs_test
def get_numGraphLabels(dataset):
s = set()
for g in dataset:
s.add(g.y.item())
return len(s)
def _get_avg_nodes_edges(graphs):
numNodes = 0.
numEdges = 0.
numGraphs = len(graphs)
for g in graphs:
numNodes += g.num_nodes
numEdges += g.num_edges / 2. # undirected
return numNodes/numGraphs, numEdges/numGraphs
def get_stats(df, ds, graphs_train, graphs_val=None, graphs_test=None):
df.loc[ds, "#graphs_train"] = len(graphs_train)
avgNodes, avgEdges = _get_avg_nodes_edges(graphs_train)
df.loc[ds, 'avgNodes_train'] = avgNodes
df.loc[ds, 'avgEdges_train'] = avgEdges
if graphs_val:
df.loc[ds, '#graphs_val'] = len(graphs_val)
avgNodes, avgEdges = _get_avg_nodes_edges(graphs_val)
df.loc[ds, 'avgNodes_val'] = avgNodes
df.loc[ds, 'avgEdges_val'] = avgEdges
if graphs_test:
df.loc[ds, '#graphs_test'] = len(graphs_test)
avgNodes, avgEdges = _get_avg_nodes_edges(graphs_test)
df.loc[ds, 'avgNodes_test'] = avgNodes
df.loc[ds, 'avgEdges_test'] = avgEdges
return df
def init_structure_encoding(args, gs, type_init):
if type_init == 'rw':
for g in gs:
A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes)
D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy()
Dinv=sp.diags(D)
RW=A*Dinv
M=RW
SE_rw=[torch.from_numpy(M.diagonal()).float()]
M_power=M
for _ in range(args.n_rw-1):
M_power=M_power*M
SE_rw.append(torch.from_numpy(M_power.diagonal()).float())
SE_rw=torch.stack(SE_rw,dim=-1)
g['stc_enc'] = SE_rw
elif type_init == 'dg':
for g in gs:
# PE_degree
g_dg = (degree(g.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, args.n_dg)
SE_dg = torch.zeros([g.num_nodes, args.n_dg])
for i in range(len(g_dg)):
SE_dg[i,int(g_dg[i]-1)] = 1
g['stc_enc'] = SE_dg
elif type_init == 'rw_dg':
for g in gs:
# SE_rw
A = to_scipy_sparse_matrix(g.edge_index, num_nodes=g.num_nodes)
D = (degree(g.edge_index[0], num_nodes=g.num_nodes) ** -1.0).numpy()
Dinv=sp.diags(D)
RW=A*Dinv
M=RW
SE=[torch.from_numpy(M.diagonal()).float()]
M_power=M
for _ in range(args.n_rw-1):
M_power=M_power*M
SE.append(torch.from_numpy(M_power.diagonal()).float())
SE_rw=torch.stack(SE,dim=-1)
# PE_degree
g_dg = (degree(g.edge_index[0], num_nodes=g.num_nodes)).numpy().clip(1, args.n_dg)
SE_dg = torch.zeros([g.num_nodes, args.n_dg])
for i in range(len(g_dg)):
SE_dg[i,int(g_dg[i]-1)] = 1
g['stc_enc'] = torch.cat([SE_rw, SE_dg], dim=1)
return gs