-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatchnorm.py
62 lines (45 loc) · 2.13 KB
/
batchnorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import numpy as np
class BatchNorm1d:
def __init__(self, num_features, alpha=0.9):
self.alpha = alpha
self.eps = 1e-8
self.BW = np.ones((1, num_features))
self.Bb = np.zeros((1, num_features))
self.dLdBW = np.zeros((1, num_features))
self.dLdBb = np.zeros((1, num_features))
# Running mean and variance, updated during training, used during
# inference
self.running_M = np.zeros((1, num_features))
self.running_V = np.ones((1, num_features))
def forward(self, Z, eval=False):
"""
The eval parameter is to indicate whether we are in the
training phase of the problem or the inference phase.
So see what values you need to recompute when eval is False.
"""
self.Z = Z
self.N = Z.shape[0]
self.M = self.Z.sum(axis=0)/self.N
self.V = (((self.Z - self.M)**2).sum(axis=0))/self.N
if eval == False:
# training mode
self.NZ = np.subtract(self.Z,self.M)/np.sqrt(self.V+self.eps)
self.BZ = np.multiply(self.BW ,self.NZ)+self.Bb
self.running_M = self.alpha * self.running_M + ( 1-self.alpha) * self.M
self.running_V = self.alpha * self.running_V + (1-self.alpha) * self.V
else:
# inference mode
self.NZ = (self.Z - self.running_M) / np.sqrt(self.running_V + self.eps)
self.BZ = np.multiply(self.BW ,self.NZ)+self.Bb
return self.BZ
def backward(self, dLdBZ):
self.dLdBW = np.sum(dLdBZ*self.NZ,axis=0)
self.dLdBb = np.sum(dLdBZ,axis=0)
dLdNZ = dLdBZ * self.BW
dLdV = -0.5 * (np.sum(dLdNZ*(self.Z - self.M)*(self.V + self.eps)**(-1.5),axis=0))
t1 = -(2/self.N)*np.sum((self.Z -self.M),axis=0)
t2 = (self.V + self.eps)**(-1.5) * t1
dNZdM = -(self.V+self.eps)**(-0.5)-0.5 * (self.Z - self.M)*t2
dLdM = np.sum((dLdNZ * dNZdM),axis = 0)
dLdZ = dLdNZ * ((self.V+self.eps)**(-0.5))+ (dLdV * (2/self.N)*(self.Z - self.M) + ((1/self.N) * dLdM))
return dLdZ