-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issues with backpropagation and adjoint construction #31
Comments
Thanks @ChrisRackauckas for reporting. Using ParallelStencil in inversion frameworks based on adjoint rules is definitively something that has potential (here we derived the adjoint for multi-physics 3D problem - "not automatically"). |
Yeah so I was playing around with a few things this week that were relevant. The key here is:
So one thing you could do is define https://github.com/JuliaDiff/ChainRules.jl rules over the stencil computation for Zygote. An example of doing this can be found in Tullio https://github.com/mcabbott/Tullio.jl. That doesn't solve the mutation problem. So I think an interesting solution here could be to get Enzyme compatibility (@wsmoses) https://github.com/wsmoses/Enzyme.jl. For the GPU code here, it might already have the ability to differentiate the kernel. And this is compatible with mutation. The key here is that if you get the derivative of that stencil computation working, then the adjoint functionality of DifferentialEquations.jl only needs that the u'=f(u,p,t) that right hand side |
Thanks @ChrisRackauckas for your suggestions and further insights. I am still wondering which approach is generally followed under the hood by the workflows you describe to retrieve automatically the objects needed to perform inversions using adjoint to compute gradients. Independently of the solution method, one would need an expression for the transposed Jacobian (or a function for it to do it in a matrix-free fashion). After discussing with @greuber, deriving the expression for the transpose Jacobian for the adjoint is the challenge. One can either do that analytically, what we did here (see Appendix 6), ending up with an infinite dimensional system of equations that one can discretize and solve with our preferred solver. One question is would AD deliver something similar (transposed Jacobian), and if so, in what form ? Maybe we could use the nonlinear diffusion example (from Appendix 6) as MWE to try out those things and see if your tools combined to ParallelStencil could retrieve the equations needed to solve the matrix-free system in a similar fashion we do it in the nonlinear 1D diffusion code here or maybe even better. (If so, I can quickly rewrite the 1D Matlab example in Julia). |
Reverse-mode AD is matrix-free transposed Jacobian calculations. I would recommend taking a look at https://github.com/mitmath/18337 if you're curious about that, specifically https://mitmath.github.io/18337/lecture10/estimation_identification . With Zygote, the vector-transposed Jacobian product of y,back = Zygote.pullback(f,x)
back(v) The issue is that fails on your stencil functions. I have been playing around with Enzyme on other projects this last week, and I think that might be the right one for your case. Since you're generating GPU code, it should be statically compliable to allow it to run its AD passes on the LLVM IR. The vector-transposed Jacobian product in that case is done via: Enzyme.autodiff(Duplicated(y, v),
Duplicated(x, λ)) do _y,_x
f!(_y,_x)
nothing
end for a non-allocating mutating function
The final solution is a mixture. Discrete adjoint sensitivity analysis does the entire adjoint via reverse-mode AD, but that is costly memory-wise. Continuous adjoint sensitivity analysis uses the infinite dimensional system so you can solve forward and reverse with a preferred solver, but then the right-hand side of the adjoint equation is automatically generated via these reverse-mode AD tools to get fast vector-transposed Jacobian products. That's what the sensitivity methods are doing (https://diffeq.sciml.ai/stable/analysis/sensitivity/), and if you watch the video on the SciML adjoint system you'll see the trade-offs between all of the different choices of vector-transposed Jacobian product (vjp) and their mixtures with generated adjoints (https://www.youtube.com/watch?v=XRJ-rtP2fVE). All of this is done automatically though. In the code above, when I did Hopefully that explains what all is going on and how it's pulling in two different levels of reverse-mode AD to chain together gradient calculations. |
@wsmoses on: const USE_GPU = false
using ParallelStencil, OrdinaryDiffEq
using ParallelStencil.FiniteDifferences3D
@static if USE_GPU
@init_parallel_stencil(CUDA, Float64, 3);
else
@init_parallel_stencil(Threads, Float64, 3);
end
@parallel function diffusion3D_step!(T2, T, Ci, lam, dx, dy, dz)
@inn(T2) = lam*@inn(Ci)*(@d2_xi(T)/dx^2 + @d2_yi(T)/dy^2 + @d2_zi(T)/dz^2);
return
end
function diffusion3D(lam,alg)
# Physics
cp_min = 1.0; # Minimal heat capacity
lx, ly, lz = 10.0, 10.0, 10.0; # Length of domain in dimensions x, y and z.
# Numerics
nx, ny, nz = 16, 16, 16; # Number of gridpoints dimensions x, y and z.
nt = 100; # Number of time steps
dx = lx/(nx-1); # Space step in x-dimension
dy = ly/(ny-1); # Space step in y-dimension
dz = lz/(nz-1); # Space step in z-dimension
# Array initializations
T = @zeros(nx, ny, nz);
T2 = @zeros(nx, ny, nz);
Ci = @zeros(nx, ny, nz);
# Initial conditions (heat capacity and temperature with two Gaussian anomalies each)
Ci .= 1.0./( cp_min .+ Data.Array([5*exp(-(((ix-1)*dx-lx/1.5))^2-(((iy-1)*dy-ly/2))^2-(((iz-1)*dz-lz/1.5))^2) +
5*exp(-(((ix-1)*dx-lx/3.0))^2-(((iy-1)*dy-ly/2))^2-(((iz-1)*dz-lz/1.5))^2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)]) )
T .= Data.Array([100*exp(-(((ix-1)*dx-lx/2)/2)^2-(((iy-1)*dy-ly/2)/2)^2-(((iz-1)*dz-lz/3.0)/2)^2) +
50*exp(-(((ix-1)*dx-lx/2)/2)^2-(((iy-1)*dy-ly/2)/2)^2-(((iz-1)*dz-lz/1.5)/2)^2) for ix=1:size(T,1), iy=1:size(T,2), iz=1:size(T,3)])
T2 .= T; # Assign also T2 to get correct boundary conditions.
dt = min(dx^2,dy^2,dz^2)*cp_min/8.1; # Time step for the 3D Heat diffusion
function f(du,u,p,t)
@show t
@parallel diffusion3D_step!(du, u, Ci, p[1], dx, dy, dz);
end
prob = ODEProblem(f, T, (0.0,nt*dt), lam)
sol = solve(prob, alg, save_everystep = false, save_start = false, sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP()))
end
sol = diffusion3D([1.0],ROCK2())
using ForwardDiff, Zygote, DiffEqSensitivity
function loss(p)
sum(diffusion3D(p,ROCK2()))
end
ForwardDiff.gradient(loss,[1.0])
Zygote.gradient(loss,[1.0]) I'm getting: TypeError: in Type, in parameter, expected Type, got a value of type DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}
Val(x::Function) at essentials.jl:693
autodiff at Enzyme.jl:60 [inlined]
_vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Array{Float64, 3}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float64}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Tuple{Array{Float64, 3}, Vector{Float64}, Array{Float64, 3}, Array{Float64, 3}}, Nothing, Nothing, Array{Float64, 3}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, EnzymeVJP, Bool}, Array{Float64, 3}, ODESolution{Float64, 4, Vector{Array{Float64, 3}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ROCK2{Nothing}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Array{Float64, 3}}, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, OrdinaryDiffEq.ROCK2Cache{Array{Float64, 3}, Array{Float64, 3}, Array{Float64, 3}, OrdinaryDiffEq.ROCK2ConstantCache{Float64, Float64, Array{Float64, 3}}}}, DiffEqBase.DEStats}, DiffEqSensitivity.CheckpointSolution{ODESolution{Float64, 4, Vector{Array{Float64, 3}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ROCK2{Nothing}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Array{Float64, 3}}, Vector{Float64}, Vector{Vector{Array{Float64, 3}}}, OrdinaryDiffEq.ROCK2Cache{Array{Float64, 3}, Array{Float64, 3}, Array{Float64, 3}, OrdinaryDiffEq.ROCK2ConstantCache{Float64, Float64, Array{Float64, 3}}}}, DiffEqBase.DEStats}, Vector{Tuple{Float64, Float64}}, NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, Nothing}, ODEProblem{Array{Float64, 3}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::EnzymeVJP, dgrad::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, dy::Nothing, W::Nothing) at derivative_wrappers.jl:471
vecjacobian!(dλ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, y::Array{Float64, 3}, λ::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, p::Vector{Float64}, t::Float64, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, DiffEqSensitivity.var"#109#124"{var"#f#17"{Array{Float64, 3}, Float64, Float64, Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Tuple{Array{Float64, 3}, Vector{Float64}, Array{Float64, 3}, Array{Float64, 3}}, Nothing, Nothing, Array{Float64, 3}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, Enzym... Can Enzyme not handle captured variables or is that fixed on some branch? |
That should be fixed nowadays. (Except that this doesn't work on the GPU). |
What would be needed for it to run on GPU? |
I couldn't get either methods for adjoint generation working over the
@parallel
stencil. For pure Zygote-based VJP calculations,will fail because of the mutation. I was wondering if you were planning to support adjoint rules on
@parallel
constructors for this. That would be required for GPU usage.Trying to avoid that issue, using ReverseDiffVJPs had an issue highlighted by the MWE:
while forward-mode runs, it looks like the tracing required for adjoints causes issues with the threading constructs.
The text was updated successfully, but these errors were encountered: