-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
26 lines (20 loc) · 823 Bytes
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules import Module
class GraphConvolution(Module):
def __init__(self, in_features, out_features):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(self.in_features, self.out_features))
self.bias = Parameter(torch.FloatTensor(self.out_features))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.uniform_(-stdv, stdv)
def forward(self, inp, adj):
x = torch.mm(inp, self.weight)
output = torch.mm(x, adj) + self.bias
return output