Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add the jacobian free version
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 3, 2024
1 parent b6f5491 commit d191d1b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ using Zygote: Zygote

include("steadystateadjoint.jl")

end
end
42 changes: 13 additions & 29 deletions ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ function SteadyStateAdjointProblem(
the derivative of the solution with respect to the parameters. Your model \
must have parameters to use parameter sensitivity calculations!")

# sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp,
# f, f.colorvec, false) # Dont allocate the Jacobian yet in diffcache
# @show sense.vjp
y = sol.u

if needs_jac
Expand All @@ -57,32 +54,28 @@ function SteadyStateAdjointProblem(
if dgdu !== nothing
dgdu(dgdu_val, y, p, nothing, nothing)
else
# TODO: Implement this part
error("Not implemented yet")
# if g !== nothing
# if dgdp_val !== nothing
# gradient!(vec(dgdu_val), diffcache.g[1], y, sensealg,
# diffcache.g_grad_config[1])
# else
# gradient!(vec(dgdu_val), diffcache.g, y, sensealg, diffcache.g_grad_config)
# end
# end
end

if !needs_jac # Construct an operator and use Jacobian-Free Linear Solve
error("Todo Jacobian Free Linear Solve")
# usize = size(y)
# __f = y -> vec(f(reshape(y, usize), p, nothing))
# operator = VecJac(__f, vec(y);
# autodiff = get_autodiff_from_vjp(sensealg.autojacvec))
# linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ))
# solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ)
linsolve = if sensealg.linsolve === nothing
LinearSolve.SimpleGMRES(; blocksize=size(u0, 1))
else
sensealg.linsolve
end
usize = size(y)
__f = @closure y -> vec(f(reshape(y, usize), p, nothing))
operator = SciMLSensitivity.VecJac(__f, vec(y);
autodiff=SciMLSensitivity.get_autodiff_from_vjp(sensealg.autojacvec))
linear_problem = SciMLBase.LinearProblem(operator, dgdu_val)
linsol = SciMLBase.solve(
linear_problem, linsolve; alias_A=true, sensealg.linsolve_kwargs...)
else
linear_problem = SciMLBase.LinearProblem(J', dgdu_val)
linsol = SciMLBase.solve(
linear_problem, sensealg.linsolve; alias_A=true, sensealg.linsolve_kwargs...)
λ = linsol.u
end
λ = linsol.u

_, pb_f = Zygote.pullback(@closure(p->vec(f(y, p, nothing))), p)
∂p = only(pb_f(λ))
Expand All @@ -92,15 +85,6 @@ function SteadyStateAdjointProblem(

if g !== nothing || dgdp !== nothing
error("Not implemented yet")
# compute del g/del p
# if dgdp !== nothing
# dgdp(dgdp_val, y, p, nothing, nothing)
# else
# @unpack g_grad_config = diffcache
# gradient!(dgdp_val, diffcache.g[2], p, sensealg, g_grad_config[2])
# end
# recursive_sub!(dgdp_val, vjp)
# return dgdp_val
else
SciMLSensitivity.recursive_neg!(∂p)
return ∂p
Expand Down
15 changes: 11 additions & 4 deletions test/nlsolve_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin
using Chairmarks, ForwardDiff, SciMLBase, SciMLSensitivity, SimpleNonlinearSolve,
Statistics, Zygote
using Chairmarks, ForwardDiff, LinearSolve, SciMLBase, SciMLSensitivity,
SimpleNonlinearSolve, Statistics, Zygote

testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p

Expand Down Expand Up @@ -55,8 +55,15 @@
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u)
end)

@test ∂p3 ∂p4
@test ∂p1 ∂p4
∂p5 = only(Zygote.gradient(p) do p
prob = NonlinearProblem(testing_f, u0, p)
sensealg = SteadyStateAdjoint(; linsolve=KrylovJL_GMRES())
return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson(); sensealg).u)
end)

@test ∂p1∂p3 atol=1e-5
@test ∂p3∂p4 atol=1e-5
@test ∂p4∂p5 atol=1e-5

zygote_nlsolve_timing = @be Zygote.gradient($p) do p
prob = NonlinearProblem(testing_f, u0, p)
Expand Down

0 comments on commit d191d1b

Please sign in to comment.