diff --git a/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl b/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl index 3919062..95ace1d 100644 --- a/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl +++ b/ext/BatchedRoutinesSciMLSensitivityExt/BatchedRoutinesSciMLSensitivityExt.jl @@ -10,4 +10,4 @@ using Zygote: Zygote include("steadystateadjoint.jl") -end \ No newline at end of file +end diff --git a/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl b/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl index a97f1ae..6df0ef9 100644 --- a/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl +++ b/ext/BatchedRoutinesSciMLSensitivityExt/steadystateadjoint.jl @@ -29,9 +29,6 @@ function SteadyStateAdjointProblem( the derivative of the solution with respect to the parameters. Your model \ must have parameters to use parameter sensitivity calculations!") - # sense = SteadyStateAdjointSensitivityFunction(g, sensealg, alg, sol, dgdu, dgdp, - # f, f.colorvec, false) # Dont allocate the Jacobian yet in diffcache - # @show sense.vjp y = sol.u if needs_jac @@ -57,32 +54,28 @@ function SteadyStateAdjointProblem( if dgdu !== nothing dgdu(dgdu_val, y, p, nothing, nothing) else - # TODO: Implement this part error("Not implemented yet") - # if g !== nothing - # if dgdp_val !== nothing - # gradient!(vec(dgdu_val), diffcache.g[1], y, sensealg, - # diffcache.g_grad_config[1]) - # else - # gradient!(vec(dgdu_val), diffcache.g, y, sensealg, diffcache.g_grad_config) - # end - # end end if !needs_jac # Construct an operator and use Jacobian-Free Linear Solve - error("Todo Jacobian Free Linear Solve") - # usize = size(y) - # __f = y -> vec(f(reshape(y, usize), p, nothing)) - # operator = VecJac(__f, vec(y); - # autodiff = get_autodiff_from_vjp(sensealg.autojacvec)) - # linear_problem = LinearProblem(operator, vec(dgdu_val); u0 = vec(λ)) - # solve(linear_problem, linsolve; alias_A = true, sensealg.linsolve_kwargs...) # u is vec(λ) + linsolve = if sensealg.linsolve === nothing + LinearSolve.SimpleGMRES(; blocksize=size(u0, 1)) + else + sensealg.linsolve + end + usize = size(y) + __f = @closure y -> vec(f(reshape(y, usize), p, nothing)) + operator = SciMLSensitivity.VecJac(__f, vec(y); + autodiff=SciMLSensitivity.get_autodiff_from_vjp(sensealg.autojacvec)) + linear_problem = SciMLBase.LinearProblem(operator, dgdu_val) + linsol = SciMLBase.solve( + linear_problem, linsolve; alias_A=true, sensealg.linsolve_kwargs...) else linear_problem = SciMLBase.LinearProblem(J', dgdu_val) linsol = SciMLBase.solve( linear_problem, sensealg.linsolve; alias_A=true, sensealg.linsolve_kwargs...) - λ = linsol.u end + λ = linsol.u _, pb_f = Zygote.pullback(@closure(p->vec(f(y, p, nothing))), p) ∂p = only(pb_f(λ)) @@ -92,15 +85,6 @@ function SteadyStateAdjointProblem( if g !== nothing || dgdp !== nothing error("Not implemented yet") - # compute del g/del p - # if dgdp !== nothing - # dgdp(dgdp_val, y, p, nothing, nothing) - # else - # @unpack g_grad_config = diffcache - # gradient!(dgdp_val, diffcache.g[2], p, sensealg, g_grad_config[2]) - # end - # recursive_sub!(dgdp_val, vjp) - # return dgdp_val else SciMLSensitivity.recursive_neg!(∂p) return ∂p diff --git a/test/nlsolve_tests.jl b/test/nlsolve_tests.jl index fd9eaf1..110ab61 100644 --- a/test/nlsolve_tests.jl +++ b/test/nlsolve_tests.jl @@ -1,6 +1,6 @@ @testitem "Batched Nonlinear Solvers" setup=[SharedTestSetup] begin - using Chairmarks, ForwardDiff, SciMLBase, SciMLSensitivity, SimpleNonlinearSolve, - Statistics, Zygote + using Chairmarks, ForwardDiff, LinearSolve, SciMLBase, SciMLSensitivity, + SimpleNonlinearSolve, Statistics, Zygote testing_f(u, p) = u .^ 2 .+ u .^ 3 .- u .- p @@ -55,8 +55,15 @@ return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson()).u) end) - @test ∂p3 ≈ ∂p4 - @test ∂p1 ≈ ∂p4 + ∂p5 = only(Zygote.gradient(p) do p + prob = NonlinearProblem(testing_f, u0, p) + sensealg = SteadyStateAdjoint(; linsolve=KrylovJL_GMRES()) + return sum(abs2, solve(prob, BatchedSimpleNewtonRaphson(); sensealg).u) + end) + + @test ∂p1≈∂p3 atol=1e-5 + @test ∂p3≈∂p4 atol=1e-5 + @test ∂p4≈∂p5 atol=1e-5 zygote_nlsolve_timing = @be Zygote.gradient($p) do p prob = NonlinearProblem(testing_f, u0, p)