Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rrule for Recur #2077

Closed
wants to merge 1 commit into from
Closed

rrule for Recur #2077

wants to merge 1 commit into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Oct 5, 2022

Before, Diffractor does not like mutation, thus fails on RNNs:

julia> using Flux, Zygote, Diffractor

julia> Zygote.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
((cell == nothing, Wi = Float32[0.0027779222 0.008236016; 0.0027779222 0.008236016; 0.0027779222 0.008236016], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], b = Float32[0.0027290469, 0.0027290469, 0.0027290469], state0 = nothing), state = Float32[0.00818714; 0.00818714; 0.00818714;;]),)

julia> Diffractor.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
ERROR: MethodError: no method matching copy(::Nothing)
...
Stacktrace:
  [1] perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{1}}, args::Any)
    @ Diffractor ~/.julia/packages/Diffractor/XDXfC/src/stage1/generated.jl:22
...
  [4] setproperty!
    @ ./Base.jl:38 [inlined]

After:

julia> Diffractor.gradient(m -> sum(abs2, m([1 2; 3 4f0])), RNN(2 => 3; init=Flux.ones32))
(Tangent{Flux.Recur}(cell = Tangent{Flux.RNNCell{typeof(tanh), Matrix{Float32}, Vector{Float32}, Matrix{Float32}}}= ChainRulesCore.NoTangent(), b = Float32[0.0027290469, 0.0027290469, 0.0027290469], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], Wi = Float32[0.0027779222 0.008236016; 0.0027779222 0.008236016; 0.0027779222 0.008236016]), state = Float32[0.00818714; 0.00818714; 0.00818714;;]),)

# And with Array{T,3}

julia> Zygote.gradient(m -> sum(abs2, m(reshape(1:24, 2,3,4).+0f0)), RNN(2 => 3; init=Flux.ones32))
((cell == nothing, Wi = Float32[0.019653967 0.03929506; 0.019653967 0.03929506; 0.019653967 0.03929506], Wh = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], b = Float32[0.019641092, 0.019641092, 0.019641092], state0 = nothing), state = Float32[0.058923274 0.058923274 0.058923274; 0.058923274 0.058923274 0.058923274; 0.058923274 0.058923274 0.058923274]),)

The @opt_out is needed to keep the Array{T,3} case working on Zygote; it does not yet work on Diffractor for the lack of a rule.

This rrule may mean Zygote is storing all the state in repeated applications of the rule, instead of in its configuration IdDict, which may matter for performance. On one very crude test it seems to be an improvement, perhaps others have more serious tests?

julia> @btime Zygote.gradient(m -> sum(abs2, m($(randn(Float32, 2, 100)))), $(RNN(2 => 3)));
  min 23.833 μs, mean 94.316 μs (97 allocations, 23.12 KiB)  # before
  min 23.458 μs, mean 27.116 μs (97 allocations, 23.12 KiB)  # after
  
julia> @btime Zygote.gradient(m -> sum(abs2, m($(randn(Float32, 20, 100)))), $(LSTM(20 => 30)));
  min 138.375 μs, mean 1.508 ms (112 allocations, 505.38 KiB)  # before
  min 127.416 μs, mean 190.538 μs (62 allocations, 496.33 KiB)  # after

More serious correctness tests might also be a good idea. I haven't looked at the test file for this.

Edit: Marked draft as many tests fail, I presume this is giving wrong gradients.

@mcabbott mcabbott marked this pull request as draft October 5, 2022 04:29
(m.state, y), back = rrule_via_ad(cfg, m.cell, m.state, x)
function Recur_pullback(dy)
cell, state, dx = back((NoTangent(), dy))
Tangent{Recur}(; cell, state), dx
Copy link
Member

@ToucheSir ToucheSir Oct 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know of a way around storing some running accumulated tangent outside of the rule like Zygote does currently for mutable structs. As-is I believe this will generate a tangent per-timestep and they will never be accumulated.

@mcabbott mcabbott closed this Mar 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants