-
Notifications
You must be signed in to change notification settings - Fork 0
/
Function.py
54 lines (42 loc) · 1.14 KB
/
Function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from Variable import Variable
import Utils as u
import numpy as np
class Function:
def __call__(self, *inputs):
Xs = [x.data for x in inputs]
ys = self.forward(*Xs)
if not isinstance(ys, tuple):
ys = (ys,)
outputs = [Variable(u.as_array(y)) for y in ys]
for output in outputs:
output.set_creator(self)
self.inputs = inputs
self.outputs = outputs
return outputs if len(outputs) > 1 else outputs[0]
def forward(self, xs):
raise NotImplementedError()
def backward(self, gy):
raise NotImplementedError()
class Add(Function):
def forward(self, x1, x2):
return (x1+x2,)
class Square(Function):
def forward(self, x):
return x ** 2
def backward(self, gy):
X = self.input.data
grad = gy * 2 * X
return grad
class Exp(Function):
def forward(self, x):
return np.exp(x)
def backward(self, gy):
X = self.input.data
grad = gy * np.exp(X)
return grad
def square(x):
return Square()(x)
def exp(x):
return Exp()(x)
def add(x1, x2):
return Add()(x1, x2)