-
Notifications
You must be signed in to change notification settings - Fork 22
/
layers.py
134 lines (112 loc) · 5.57 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from setting import (
STEPS,
DT,
SIMWIN,
ALPHA,
VTH,
TAU,
)
class SpikeAct(torch.autograd.Function):
""" 定义脉冲激活函数,并根据论文公式进行梯度的近似。
Implementation of the spiking activation function with an approximation of gradient.
"""
alpha = ALPHA
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
# if input = u > Vth then output = 1
output = torch.gt(input, 0)
return output.float()
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
# hu is an approximate func of df/du
hu = abs(input) < SpikeAct.alpha
hu = hu.float() / (2 * SpikeAct.alpha)
return grad_input * hu
def state_update(u_t_n1, o_t_n1, W_mul_o_t1_n):
u_t1_n1 = TAU * u_t_n1 * (1 - o_t_n1) + W_mul_o_t1_n
o_t1_n1 = SpikeAct.apply(u_t1_n1 - VTH)
return u_t1_n1, o_t1_n1
class tdLayer(nn.Module):
"""将普通的层转换到时间域上。输入张量需要额外带有时间维,此处时间维在数据的最后一维上。前传时,对该时间维中的每一个时间步的数据都执行一次普通层的前传。
Converts a common layer to the time domain. The input tensor needs to have an additional time dimension, which in this case is on the last dimension of the data. When forwarding, a normal layer forward is performed for each time step of the data in that time dimension.
Args:
layer (nn.Module): 需要转换的层。
The layer needs to convert.
bn (nn.Module): 如果需要加入BN,则将BN层一起当做参数传入。
If batch-normalization is needed, the BN layer should be passed in together as a parameter.
"""
def __init__(self, layer, bn=None, steps=STEPS):
super(tdLayer, self).__init__()
self.layer = layer
self.bn = bn
self.steps = steps
def forward(self, x):
x_ = torch.zeros(self.layer(x[..., 0]).shape + (self.steps,), device=x.device)
for step in range(self.steps):
x_[..., step] = self.layer(x[..., step])
if self.bn is not None:
x_ = self.bn(x_)
return x_
class LIFSpike(nn.Module):
"""对带有时间维度的张量进行一次LIF神经元的发放模拟,可以视为一个激活函数,用法类似ReLU。
Generates spikes based on LIF module. It can be considered as an activation function and is used similar to ReLU. The input tensor needs to have an additional time dimension, which in this case is on the last dimension of the data.
"""
def __init__(self, steps=STEPS):
super(LIFSpike, self).__init__()
self.steps = steps
def forward(self, x):
u = torch.zeros(x.shape[:-1] , device=x.device)
out = torch.zeros(x.shape, device=x.device)
for step in range(self.steps):
u, out[..., step] = state_update(u, out[..., max(step-1, 0)], x[..., step])
return out
class tdBatchNorm(nn.BatchNorm2d):
"""tdBN的实现。相关论文链接:https://arxiv.org/pdf/2011.05280。具体是在BN时,也在时间域上作平均;并且在最后的系数中引入了alpha变量以及Vth。
Implementation of tdBN. Link to related paper: https://arxiv.org/pdf/2011.05280. In short it is averaged over the time domain as well when doing BN.
Args:
num_features (int): same with nn.BatchNorm2d
eps (float): same with nn.BatchNorm2d
momentum (float): same with nn.BatchNorm2d
alpha (float): an addtional parameter which may change in resblock.
affine (bool): same with nn.BatchNorm2d
track_running_stats (bool): same with nn.BatchNorm2d
"""
def __init__(self, num_features, eps=1e-05, momentum=0.1, alpha=1, affine=True, track_running_stats=True):
super(tdBatchNorm, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
self.alpha = alpha
def forward(self, input):
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
mean = input.mean([0, 2, 3, 4])
# use biased var in train
var = input.var([0, 2, 3, 4], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
input = self.alpha * VTH * (input - mean[None, :, None, None, None]) / (torch.sqrt(var[None, :, None, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None, None] + self.bias[None, :, None, None, None]
return input