Skip to content

Commit

Permalink
A complete and proper preconditioner interface
Browse files Browse the repository at this point in the history
Basically solves #1551 sans any extra performance issues. The interface is simply that the user gives:

```julia
Pr, Pl = precs(W,du,u,p,t,newW,Plprev,Prprev,solverdata)
```

in the associated algorithm definitions. This gives a slight generalization over the Sundials interface so it should be enough, sans the `jcur` thing that is still an ongoing question. The setup phase is just a different dispatch on this function, i.e.

```julia
Pr, Pl = precs(W,du,u,p,t,::Nothing,::Nothing,::Nothing,solverdata)
```

which is rather clean. `solverdata` is for backwards compatibility: it's going to not be documented for a bit but allow for slapping things like `gamma` or `dt` in there, and adding to the struct won't be breaking like changing the call signature would be.

This PR is currently untested and only implements it for one algorithm, and so getting this to merge will essentially just require adding proper tests of using preconditioners which require updated Jacobians, like incomplete LU-factorizations, showing that it improves the convergence of GMRES. And setting up the rest of the algorithms to have this preconditioner interface as well, which is the simple but tedious part.
  • Loading branch information
ChrisRackauckas committed Jan 7, 2022
1 parent d870734 commit fc64c07
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ using DocStringExtensions

struct OrdinaryDiffEqTag end

DEFAULT_PRECS(W,du,u,p,t,newW,Plprev,Prprev,solverdata) = nothing,nothing

include("misc_utils.jl")
include("algorithms.jl")
include("alg_utils.jl")
Expand Down
5 changes: 3 additions & 2 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3638,10 +3638,11 @@ _unwrap_val(B) = B

for Alg in [:Rosenbrock23, :Rosenbrock32, :ROS3P, :Rodas3, :ROS34PW1a, :ROS34PW1b, :ROS34PW2, :ROS34PW3, :RosShamp4, :Veldd4, :Velds4, :GRK4T, :GRK4A, :Ros4LStab, :Rodas4, :Rodas42, :Rodas4P, :Rodas4P2, :Rodas5]
@eval begin
struct $Alg{CS,AD,F,FDT,ST} <: OrdinaryDiffEqRosenbrockAdaptiveAlgorithm{CS,AD,FDT,ST}
struct $Alg{CS,AD,F,P,FDT,ST} <: OrdinaryDiffEqRosenbrockAdaptiveAlgorithm{CS,AD,FDT,ST}
linsolve::F
precs::P
end
$Alg(;chunk_size=Val{0}(),autodiff=Val{true}(), standardtag = Val{true}(),diff_type=Val{:forward},linsolve=nothing) = $Alg{_unwrap_val(chunk_size),_unwrap_val(autodiff),typeof(linsolve),diff_type,_unwrap_val(standardtag)}(linsolve)
$Alg(;chunk_size=Val{0}(),autodiff=Val{true}(), standardtag = Val{true}(),diff_type=Val{:forward},linsolve=nothing,precs = DEFAULT_PRECS) = $Alg{_unwrap_val(chunk_size),_unwrap_val(autodiff),typeof(linsolve),typeof(precs),diff_type,_unwrap_val(standardtag)}(linsolve,precs)
end
end

Expand Down
5 changes: 3 additions & 2 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ function alg_cache(alg::Rosenbrock32,u,rate_prototype,::Type{uEltypeNoUnits},::T
uf = UJacobianWrapper(f,t,p)
linsolve_tmp = zero(rate_prototype)
linprob = LinearProblem(W,_vec(linsolve_tmp); u0=_vec(tmp))

Pl,Pr = wrapprecs(integrator.alg.precs(W,nothing,u,p,t,nothing,nothing,nothing,nothing)...,weight)
linsolve = init(linprob,alg.linsolve,alias_A=true,alias_b=true,
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),
Pr = Diagonal(_vec(weight)))
Pl = Pl, Pr = Pr)
grad_config = build_grad_config(alg,f,tf,du1,t)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,du2,Val(false))
Rosenbrock32Cache(u,uprev,k₁,k₂,k₃,du1,du2,f₁,fsalfirst,fsallast,dT,J,W,tmp,atmp,weight,tab,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config)
Expand Down
34 changes: 29 additions & 5 deletions src/misc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,41 @@ macro threaded(option, ex)
end
end

function dolinsolve(integrator, linsolve; A = nothing, u = nothing, b = nothing,
Pl = nothing, Pr = nothing,
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
du = nothing, u = nothing, p = nothing, t = nothing, solverdata = nothing,
reltol = integrator === nothing ? nothing : integrator.opts.reltol)

A !== nothing && (linsolve = LinearSolve.set_A(linsolve,A))
b !== nothing && (linsolve = LinearSolve.set_b(linsolve,b))
u !== nothing && (linsolve = LinearSolve.set_u(linsolve,u))
(Pl !== nothing || Pr !== nothing) && (linsolve = LinearSolve.set_prec(Pl,Pr))
u !== nothing && (linsolve = LinearSolve.set_u(linsolve,linu))

Plprev = linsolve.Pl isa ComposePreconditioner ? linsolve.Pl.outer : linsolve.Pl
Prprev = linsolve.Pr isa ComposePreconditioner ? linsolve.Pr.outer : linsolve.Pr

_Pl,_Pr = integrator.alg.precs(linsolve.A,du,u,p,t,A !== nothing,Plprev,Prprev,solverdata)
if _Pl !== nothing || _Pr !== nothing
Pl, Pr = wrapprecs(_Pl,_Pr,weight)
linsolve = LinearSolve.set_prec(Pl,Pr)
end

linres = if reltol === nothing
solve(linsolve;reltol)
else
solve(linsolve;reltol)
end
end

function wrapprecs(_Pl,_Pr,weight)
if _Pl !== nothing
Pl = LinearSolve.ComposePreconditioner(LinearSolve.InvPreconditioner(Diagonal(_vec(weight))),_Pl)
else
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
end

if _Pr !== nothing
Pr = LinearSolve.ComposePreconditioner(Diagonal(_vec(weight)),_Pr)
else
Pr = Diagonal(_vec(weight))
end
Pl, Pr
end
4 changes: 3 additions & 1 deletion src/perform_step/rosenbrock_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ end
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t)

linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))
linres = dolinsolve(integrator, cache.linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp),
du = fsalfirst, u = u, p = p, t = t, solverdata = (;γ = γ))
vecu = _vec(linres.u)
veck₁ = _vec(k₁)

Expand Down Expand Up @@ -117,6 +118,7 @@ end
calculate_residuals!(weight, fill!(weight, one(eltype(u))), uprev, uprev,
integrator.opts.abstol, integrator.opts.reltol, integrator.opts.internalnorm, t)
linsolve = cache.linsolve
Pl,Pr = alg.precs(W,du,u,p,t,newW,solverdata)
linres = dolinsolve(integrator, linsolve; A = repeat_step ? nothing : W, b = _vec(linsolve_tmp))

@inbounds @simd ivdep for i in eachindex(u)
Expand Down

0 comments on commit fc64c07

Please sign in to comment.