From ac6eabb32c736d58a5dd4d4d1836a3c2aeb210f6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 31 Jan 2025 17:56:07 +0530 Subject: [PATCH] fix: fix `remake(::HomotopyNonlinearFunction)` --- src/remake.jl | 22 ++++++++++++---------- src/scimlfunctions.jl | 3 ++- test/remake_tests.jl | 20 ++++++++++++++++++++ 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/remake.jl b/src/remake.jl index 4c8f20a97..4a0f1726b 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -177,17 +177,19 @@ function remake( if !(f isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)}) f = split_function_f_wrapper(T){iip, spec}(f) end - # For SplitFunction - # we don't do the same thing as `g`, because for SDEs `g` is - # stored in the problem as well, whereas for Split ODEs etc - # f2 is a part of the function. Thus, if the user provides - # a SciMLFunction for `f` which contains `f2` we use that. - f2 = coalesce(f2, get(props, :f2, missing), func.f2) - if !(f2 isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)}) - f2 = split_function_f_wrapper(T){iip, spec}(f2) + if hasproperty(func, :f2) + # For SplitFunction + # we don't do the same thing as `g`, because for SDEs `g` is + # stored in the problem as well, whereas for Split ODEs etc + # f2 is a part of the function. Thus, if the user provides + # a SciMLFunction for `f` which contains `f2` we use that. + f2 = coalesce(f2, get(props, :f2, missing), func.f2) + if !(f2 isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)}) + f2 = split_function_f_wrapper(T){iip, spec}(f2) + end + props = @delete props.f2 + args = (args..., f2) end - props = @delete props.f2 - args = (args..., f2) end if isdefined(func, :g) # For SDEs/SDDEs where `g` is not a keyword diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index 7b247b5bc..222754b3d 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4757,7 +4757,7 @@ is_split_function(x) = is_split_function(typeof(x)) is_split_function(::Type) = false function is_split_function(::Type{T}) where {T <: Union{ SplitFunction, SplitSDEFunction, DynamicalODEFunction, - DynamicalDDEFunction, DynamicalSDEFunction}} + DynamicalDDEFunction, DynamicalSDEFunction, HomotopyNonlinearFunction}} true end @@ -4766,6 +4766,7 @@ split_function_f_wrapper(::Type{<:SplitSDEFunction}) = SDEFunction split_function_f_wrapper(::Type{<:DynamicalODEFunction}) = ODEFunction split_function_f_wrapper(::Type{<:DynamicalDDEFunction}) = DDEFunction split_function_f_wrapper(::Type{<:DynamicalSDEFunction}) = DDEFunction +split_function_f_wrapper(::Type{<:HomotopyNonlinearFunction}) = NonlinearFunction ######### Additional traits diff --git a/test/remake_tests.jl b/test/remake_tests.jl index 307c4d871..4af29623e 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -384,3 +384,23 @@ end prob2 = remake(prob; f = fn2) @test prob2.f.resid_prototype isa Vector{Float32} end + +@testset "`remake(::HomotopyNonlinearFunction)`" begin + f! = function (du, u, p) + du[1] = u[1] * u[1] - p[1] * u[2] + u[2]^3 + 1 + du[2] = u[2]^3 + 2 * p[2] * u[1] * u[2] + u[2] + end + + fjac! = function (j, u, p) + j[1, 1] = 2u[1] + j[1, 2] = -p[1] + 3 * u[2]^2 + j[2, 1] = 2 * p[2] * u[2] + j[2, 2] = 3 * u[2]^2 + 2 * p[2] * u[1] + 1 + end + fn = NonlinearFunction(f!; jac = fjac!) + fn = HomotopyNonlinearFunction(fn) + prob = NonlinearProblem(fn, ones(2), ones(2)) + @test prob.f.f.jac == fjac! + prob2 = remake(prob; u0 = zeros(2)) + @test prob2.f.f.jac == fjac! +end