-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrad.jl
69 lines (56 loc) · 1.23 KB
/
grad.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
mutable struct Variable
data::Union{Array,Nothing}
grad
creator
end
Variable(data) = Variable(data,nothing,nothing)
abstract type Func end
mutable struct Square <:Func
input
output
end
mutable struct Expo <: Func
input
output
end
Expo() = Expo(nothing,nothing)
Square() = Square(nothing,nothing)
function forward(square::Square,x::Variable)
square.input = x
y = Variable(x.data .* x.data)
y.creator = square
square.output = y
return y
end
function backward(square::Square,grad)
return 2 * square.input.data .* grad
end
function forward(expo::Expo,x::Variable)
expo.input = x
y = Variable(exp.(x.data))
y.creator = expo
expo.output = y
return y
end
function backward(expo::Expo,grad)
return exp.(expo.input.data) .* grad
end
function ones_like(data)
return ones(eltype(data),size(data))
end
function backward(x::Variable)
if isnothing(x.grad)
x.grad = ones_like(x.data)
end
func = Any[]
append!(func,[x.creator])
while !isempty(func)
f = pop!(func)
z,y = f.output,f.input
println(z)
y.grad = backward(f,z.grad)
if y.creator != nothing
append!(func,[y.creator])
end
end
end