Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounds erro for Flux.reset! in loss function #2057

Closed
FedeClaudi opened this issue Aug 31, 2022 · 5 comments
Closed

Bounds erro for Flux.reset! in loss function #2057

FedeClaudi opened this issue Aug 31, 2022 · 5 comments

Comments

@FedeClaudi
Copy link

Hey,

I'm tring to implement an RNN in Flux but I'm having some problems. Here's the MWE.

Starting by generating some data

import Flux.Losses: mse

function make_trial()::Tuple{Vector, Vector}
    T = (1:100) .+ rand() |> collect
    x::Vector{Vector{Float32}} = [Float32[sin(t)] for t in T]
    y::Vector{Vector{Float32}} = [Float32[cos(t)] for t in T]
    return x,y
end

# generate trials
trials = [make_trial() for _ in 1:100]

# split inputs/outputs
xs, ys = [x for (x,y) in trials], [y for (x,y) in trials]
# xs[1] is a Vector{Vector{Float32}} with size(xs[1][1]) = (1,)

then we define the model

# create a network
rnn = Chain(
    Dense(1 => 64),
    RNN(64 => 64),
    Dense(64 => 1)
)

# set optimizer
opt = Adam()

and the loss function

function (xs, ys)
    Flux.reset!(rnn)
    mse.([rnn(x) for x in xs], ys) |> sum
end

and finally we train:

evalcb() = @show Float64((xs[1], ys[1]))
Flux.@epochs 10 Flux.train!(ℓ, params(rnn), zip(xs, ys), opt, cb = throttle(evalcb, .5))

This gives an error:
image
image

As far as I can tell, Flux.reset!(rnn) turns rnn.layers[2].state ( the RNN state) from a Vector to a Tuple, but ONLY when called within the loss function. If I do Flux.reset! not while training nothing happens.
If I replace the loss function with

function (xs, ys)
    rnn.layers[2].state *= Float32(0.0)
    mse.([rnn(x) for x in xs], ys) |> sum
end

I don't get any errors, but it feels like this should be unnecessary. It also seems like the network is not learning at all, but that's for another discussion.


Now, I'm just getting started with Flux so maybe I'm missing something obvious, but I've based the example code on examples/tutorials found online, including in the docs, so I'm not sure what's going on here.

Thank you,
Fede

@mcabbott
Copy link
Member

mcabbott commented Aug 31, 2022

Looks like the same bug as [Edit] FluxML/Zygote.jl#1297 . Can you try with ChainRulesCore.@non_differentiable foreach(f, ::Tuple{}) as suggested here?

@FedeClaudi
Copy link
Author

Thanks for replying so quickly.

As far as I can tell that's a PR on docs for CUDA, not sure how it relates?

@mcabbott
Copy link
Member

Oh sorry, I mean FluxML/Zygote.jl#1297 on Zygote not here.

@FedeClaudi
Copy link
Author

Beautiful, that works perfectly. Thank you

@FedeClaudi
Copy link
Author

For future reference, "contradict" on Slack suggested this alternative solution:

function (xs, ys)
    ChainRules.ignore_derivatives() do
        Flux.reset!(rnn)
    end
    mse.([rnn(x) for x in xs], ys) |> sum
end

which has the nice feature of being very explicit about what gets fixed where and why.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants