-
Notifications
You must be signed in to change notification settings - Fork 253
/
Copy path(AAAI 2019)AGCRN.py
155 lines (134 loc) · 7.15 KB
/
(AAAI 2019)AGCRN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torch.nn.functional as F
import torch.nn as nn
# 论文:Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting
# 论文地址:https://arxiv.org/pdf/2007.02842
class AVWGCN(nn.Module):
def __init__(self, dim_in, dim_out, cheb_k, embed_dim):
super(AVWGCN, self).__init__()
self.cheb_k = cheb_k
self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out))
self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out))
def forward(self, x, node_embeddings):
#x shaped[B, N, C], node_embeddings shaped [N, D] -> supports shaped [N, N]
#output shape [B, N, C]
node_num = node_embeddings.shape[0]
supports = F.softmax(F.relu(torch.mm(node_embeddings, node_embeddings.transpose(0, 1))), dim=1)
support_set = [torch.eye(node_num).to(supports.device), supports]
#default cheb_k = 3
for k in range(2, self.cheb_k):
support_set.append(torch.matmul(2 * supports, support_set[-1]) - support_set[-2])
supports = torch.stack(support_set, dim=0)
weights = torch.einsum('nd,dkio->nkio', node_embeddings, self.weights_pool) #N, cheb_k, dim_in, dim_out
bias = torch.matmul(node_embeddings, self.bias_pool) #N, dim_out
x_g = torch.einsum("knm,bmc->bknc", supports, x) #B, cheb_k, N, dim_in
x_g = x_g.permute(0, 2, 1, 3) # B, N, cheb_k, dim_in
x_gconv = torch.einsum('bnki,nkio->bno', x_g, weights) + bias #b, N, dim_out
return x_gconv
class AGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super(AGCRNCell, self).__init__()
self.node_num = node_num
self.hidden_dim = dim_out
self.gate = AVWGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k, embed_dim)
self.update = AVWGCN(dim_in+self.hidden_dim, dim_out, cheb_k, embed_dim)
def forward(self, x, state, node_embeddings):
#x: B, num_nodes, input_dim
#state: B, num_nodes, hidden_dim
state = state.to(x.device)
input_and_state = torch.cat((x, state), dim=-1)
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z*state), dim=-1)
hc = torch.tanh(self.update(candidate, node_embeddings))
h = r*state + (1-r)*hc
return h
def init_hidden_state(self, batch_size):
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
class AVWDCRNN(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super(AVWDCRNN, self).__init__()
assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.'
self.node_num = node_num
self.input_dim = dim_in
self.num_layers = num_layers
self.dcrnn_cells = nn.ModuleList()
self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
for _ in range(1, num_layers):
self.dcrnn_cells.append(AGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim))
def forward(self, x, init_state, node_embeddings):
#shape of x: (B, T, N, D)
#shape of init_state: (num_layers, B, N, hidden_dim)
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
seq_length = x.shape[1]
current_inputs = x
output_hidden = []
for i in range(self.num_layers):
state = init_state[i]
inner_states = []
for t in range(seq_length):
state = self.dcrnn_cells[i](current_inputs[:, t, :, :], state, node_embeddings)
inner_states.append(state)
output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1)
#current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
#output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim)
#last_state: (B, N, hidden_dim)
return current_inputs, output_hidden
def init_hidden(self, batch_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size))
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
class AGCRN(nn.Module):
def __init__(self, args):
super(AGCRN, self).__init__()
self.num_node = args.num_nodes
self.input_dim = args.input_dim
self.hidden_dim = args.rnn_units
self.output_dim = args.output_dim
self.horizon = args.horizon
self.num_layers = args.num_layers
# self.default_graph = args.default_graph
self.node_embeddings = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
self.encoder = AVWDCRNN(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_k,
args.embed_dim, args.num_layers)
#predictor
self.end_conv = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
def forward(self, source, targets, teacher_forcing_ratio=0.5):
#source: B, T_1, N, D
#target: B, T_2, N, D
#supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1)
init_state = self.encoder.init_hidden(source.shape[0])
output, _ = self.encoder(source, init_state, self.node_embeddings) #B, T, N, hidden
output = output[:, -1:, :, :] #B, 1, N, hidden
#CNN based predictor
output = self.end_conv((output)) #B, T*C, N, 1
output = output.squeeze(-1).reshape(-1, self.horizon, self.output_dim, self.num_node)
output = output.permute(0, 1, 3, 2) #B, T, N, C
return output
if __name__ == '__main__':
class Args:
def __init__(self):
self.num_nodes = 10 # 假设图中有10个节点
self.input_dim = 1 # 每个节点的特征维度
self.rnn_units = 64 # RNN单元的数量
self.output_dim = 1 # 输出维度
self.horizon = 3 # 预测未来3个时间步
self.num_layers = 2 # 使用2层RNN
self.cheb_k = 3 # 切比雪夫多项式的阶数
self.embed_dim = 20 # 节点嵌入的维度
# 实例化参数
args = Args()
# 实例化模型
model = AGCRN(args)
# 创建一个虚拟的输入数据
input_tensor = torch.randn(1, 3, args.num_nodes, args.input_dim)
print("Input tensor size: ", input_tensor.size()) # 打印输入尺寸
# 创建虚拟的目标数据
target_tensor = torch.randn(1, args.horizon, args.num_nodes, args.output_dim)
print("Target tensor size:", target_tensor.size()) # 打印目标尺寸
# 将模型转换为训练模式并进行前向传播
model.train()
output = model(input_tensor, target_tensor)
print("Output size: ", output.size()) # 打印输出尺寸