diff --git a/src/problems/problem_utils.jl b/src/problems/problem_utils.jl index 3c10f5fdf..f2faaf20b 100644 --- a/src/problems/problem_utils.jl +++ b/src/problems/problem_utils.jl @@ -180,3 +180,6 @@ function Base.summary(io::IO, prob::AbstractPDEProblem) end Base.copy(p::SciMLBase.NullParameters) = p + +SymbolicIndexingInterface.is_time_dependent(::AbstractDEProblem) = true +SymbolicIndexingInterface.is_time_dependent(::AbstractNonlinearProblem) = false diff --git a/src/remake.jl b/src/remake.jl index dac9c3e12..73d7634ef 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -257,18 +257,9 @@ function remake(prob::ODEProblem; f = missing, ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...) end - if lazy_initialization === nothing - lazy_initialization = !is_trivial_initialization(initialization_data) - end - if initialization_data !== nothing && !lazy_initialization - u0, p, _ = get_initial_values( - prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) - if u0 !== nothing && eltype(u0) == Any && isempty(u0) - u0 = nothing - end - @reset prob.u0 = u0 - @reset prob.p = p - end + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p return prob end @@ -453,18 +444,10 @@ function remake(prob::SDEProblem; else SDEProblem{iip}(f, newu0, tspan, newp; noise, noise_rate_prototype, seed, kwargs...) end - if lazy_initialization === nothing - lazy_initialization = !is_trivial_initialization(initialization_data) - end - if initialization_data !== nothing && !lazy_initialization - u0, p, _ = get_initial_values( - prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) - if u0 !== nothing && eltype(u0) == Any && isempty(u0) - u0 = nothing - end - @reset prob.u0 = u0 - @reset prob.p = p - end + + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p return prob end @@ -520,18 +503,10 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, DDEProblem{iip}(f, newu0, h, tspan, newp; constant_lags, dependent_lags, order_discontinuity_t0, neutral, kwargs...) end - if lazy_initialization === nothing - lazy_initialization = !is_trivial_initialization(initialization_data) - end - if initialization_data !== nothing && !lazy_initialization - u0, p, _ = get_initial_values( - prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) - if u0 !== nothing && eltype(u0) == Any && isempty(u0) - u0 = nothing - end - @reset prob.u0 = u0 - @reset prob.p = p - end + + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p return prob end @@ -619,18 +594,9 @@ function remake(prob::SDDEProblem; dependent_lags, order_discontinuity_t0, neutral, kwargs...) end - if lazy_initialization === nothing - lazy_initialization = !is_trivial_initialization(initialization_data) - end - if initialization_data !== nothing && !lazy_initialization - u0, p, _ = get_initial_values( - prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) - if u0 !== nothing && eltype(u0) == Any && isempty(u0) - u0 = nothing - end - @reset prob.u0 = u0 - @reset prob.p = p - end + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p return prob end @@ -741,18 +707,9 @@ function remake(prob::NonlinearProblem; problem_type = problem_type; kwargs...) end - if lazy_initialization === nothing - lazy_initialization = !is_trivial_initialization(initialization_data) - end - if initialization_data !== nothing && !lazy_initialization - u0, p, _ = get_initial_values( - prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) - if u0 !== nothing && eltype(u0) == Any && isempty(u0) - u0 = nothing - end - @reset prob.u0 = u0 - @reset prob.p = p - end + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p return prob end @@ -792,18 +749,9 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p f, u0 = newu0, p = newp, kwargs...) end - if lazy_initialization === nothing - lazy_initialization = !is_trivial_initialization(initialization_data) - end - if initialization_data !== nothing && !lazy_initialization - u0, p, _ = get_initial_values( - prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) - if u0 !== nothing && eltype(u0) == Any && isempty(u0) - u0 = nothing - end - @reset prob.u0 = u0 - @reset prob.p = p - end + u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization) + @reset prob.u0 = u0 + @reset prob.p = p return prob end @@ -1134,6 +1082,23 @@ function process_p_u0_symbolic(prob, p, u0) end end +function maybe_eager_initialize_problem(prob::AbstractSciMLProblem, initialization_data, lazy_initialization::Union{Nothing, Bool}) + if lazy_initialization === nothing + lazy_initialization = !is_trivial_initialization(initialization_data) + end + if initialization_data !== nothing && !lazy_initialization && (!is_time_dependent(prob) || current_time(prob) !== nothing) + u0, p, _ = get_initial_values( + prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + if u0 !== nothing && eltype(u0) == Any && isempty(u0) + u0 = nothing + end + else + u0 = state_values(prob) + p = parameter_values(prob) + end + return u0, p +end + function remake(thing::AbstractJumpProblem; kwargs...) parameterless_type(thing)(remake(thing.prob; kwargs...)) end diff --git a/test/initialization.jl b/test/initialization.jl index 1bf5886f5..e52b3c80c 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -263,7 +263,10 @@ end @testset "Trivial initialization" begin initprob = NonlinearProblem(Returns(nothing), nothing, [1.0]) update_initializeprob! = function (iprob, integ) - iprob.p[1] = integ.u[1] + # just to access the current time and use it as a number, so this errors + # if run on a problem with `current_time(prob) === nothing` + iprob.p[1] = current_time(integ) + 1 + iprob.p[1] = state_values(integ)[1] end initprobmap = function (nlsol) u1 = parameter_values(nlsol)[1] @@ -284,6 +287,11 @@ end @test u0 ≈ [2.0, 2.0] @test p ≈ 0.0 @test success + + @testset "Doesn't run in `remake` if `tspan == (nothing, nothing)`" begin + prob = ODEProblem(fn, [2.0, 0.0], (nothing, nothing), 0.0) + @test_nowarn remake(prob) + end end end diff --git a/test/remake_tests.jl b/test/remake_tests.jl index 3fdf797ff..307c4d871 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -15,6 +15,7 @@ u0 = [1.0; 2.0; 3.0] tspan = (0.0, 100.0) p = [10.0, 20.0, 30.0] sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +indep_sys = SymbolCache([:x, :y, :z], [:a, :b, :c]) fn = ODEFunction(lorenz!; sys) for T in containerTypes push!(probs, ODEProblem(fn, u0, tspan, T(p))) @@ -64,7 +65,7 @@ function loss(x, p) return sum(du) end -fn = OptimizationFunction(loss; sys) +fn = OptimizationFunction(loss; sys = indep_sys) for T in containerTypes push!(probs, OptimizationProblem(fn, u0, T(p))) end @@ -73,7 +74,7 @@ function nllorenz!(du, u, p) lorenz!(du, u, p, 0.0) end -fn = NonlinearFunction(nllorenz!; sys) +fn = NonlinearFunction(nllorenz!; sys = indep_sys) for T in containerTypes push!(probs, NonlinearProblem(fn, u0, T(p))) end