diff --git a/src/initialization.jl b/src/initialization.jl index 7ec14b12d..8b45bb6a6 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -60,21 +60,24 @@ end function Base.showerror(io::IO, e::CheckInitFailureError) print(io, - "DAE initialization failed: your u0 did not satisfy the initialization requirements, - normresid = $(e.normresid) > abstol = $(e.abstol)." - ) + """ + DAE initialization failed: your u0 did not satisfy the initialization requirements, \ + normresid = $(e.normresid) > abstol = $(e.abstol). + """) if e.isdae - print(io, " If you wish for the system to - automatically change the algebraic variables to satisfy the algebraic constraints, - please pass `initializealg = BrownBasicInit()` to solve (this option will require - `using OrdinaryDiffEqNonlinearSolve`). If you wish to perform an initialization on the - complete u0, please pass `initializealg = ShampineCollocationInit()` to solve. Note that - initialization can be a very difficult process for DAEs and in many cases can be - numerically intractable without symbolic manipulation of the system. For an automated - system that will generate numerically stable initializations, see ModelingToolkit.jl - structural simplification for more details." - ) + print(io, + """ + If you wish for the system to automatically change the algebraic variables to \ + satisfy the algebraic constraints, please pass `initializealg = BrownBasicInit()` \ + to solve (this option will require `using OrdinaryDiffEqNonlinearSolve`). If you \ + wish to perform an initialization on the complete u0, please pass \ + `initializealg = ShampineCollocationInit()` to `solve`. Note that initialization \ + can be a very difficult process for DAEs and in many cases can be numerically \ + intractable without symbolic manipulation of the system. For an automated \ + system that will generate numerically stable initializations, see \ + ModelingToolkit.jl structural simplification for more details. + """) end end @@ -188,6 +191,9 @@ Keyword arguments: provided to the `OverrideInit` constructor takes priority over this keyword argument. If the former is `nothing`, this keyword argument will be used. If it is also not provided, an error will be thrown. + +In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are +not required. """ function get_initial_values(prob, valp, f, alg::OverrideInit, iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...) @@ -201,35 +207,55 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata::OverrideInitData = f.initialization_data initprob = initdata.initializeprob - nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) - if nlsolve_alg === nothing && state_values(initprob) !== nothing - throw(OverrideInitMissingAlgorithm()) - end - if initdata.update_initializeprob! !== nothing initdata.update_initializeprob!(initprob, valp) end - if alg.abstol !== nothing - _abstol = alg.abstol - elseif abstol !== nothing - _abstol = abstol + if is_trivial_initialization(initdata) + nlsol = initprob + success = true else - throw(OverrideInitNoTolerance(:abstol)) + nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) + if nlsolve_alg === nothing && state_values(initprob) !== nothing + throw(OverrideInitMissingAlgorithm()) + end + if alg.abstol !== nothing + _abstol = alg.abstol + elseif abstol !== nothing + _abstol = abstol + else + throw(OverrideInitNoTolerance(:abstol)) + end + if alg.reltol !== nothing + _reltol = alg.reltol + elseif reltol !== nothing + _reltol = reltol + else + throw(OverrideInitNoTolerance(:reltol)) + end + nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) + success = SciMLBase.successful_retcode(nlsol) end - if alg.reltol !== nothing - _reltol = alg.reltol - elseif reltol !== nothing - _reltol = reltol - else - throw(OverrideInitNoTolerance(:reltol)) - end - nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) u0 = initdata.initializeprobmap(nlsol) if initdata.initializeprobpmap !== nothing p = initdata.initializeprobpmap(valp, nlsol) end - return u0, p, SciMLBase.successful_retcode(nlsol) + return u0, p, success +end + +is_trivial_initialization(::Nothing) = true + +function is_trivial_initialization(initdata::OverrideInitData) + !(initdata.initializeprob isa NonlinearLeastSquaresProblem) && + state_values(initdata.initializeprob) === nothing +end + +function is_trivial_initialization(f::AbstractSciMLFunction) + has_initialization_data(f) && is_trivial_initialization(f.initialization_data) +end + +function is_trivial_initialization(prob::AbstractSciMLProblem) + is_trivial_initialization(prob.f) end diff --git a/src/remake.jl b/src/remake.jl index 026a9d2a4..b16476473 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -114,6 +114,7 @@ function remake(prob::ODEProblem; f = missing, interpret_symbolicmap = true, build_initializeprob = true, use_defaults = false, + lazy_initialization = nothing, _kwargs...) if tspan === missing tspan = prob.tspan @@ -123,6 +124,8 @@ function remake(prob::ODEProblem; f = missing, iip = isinplace(prob) + initialization_data = prob.f.initialization_data + if f === missing if build_initializeprob initialization_data = remake_initialization_data_compat_wrapper( @@ -170,13 +173,28 @@ function remake(prob::ODEProblem; f = missing, _f = ODEFunction{isinplace(prob), specialization(prob.f)}(f) end - if kwargs === missing + prob = if kwargs === missing ODEProblem{isinplace(prob)}( _f, newu0, tspan, newp, prob.problem_type; prob.kwargs..., _kwargs...) else ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...) end + + if lazy_initialization === nothing + lazy_initialization = !is_trivial_initialization(initialization_data) + end + if !lazy_initialization + u0, p, _ = get_initial_values( + prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + if u0 !== nothing && eltype(u0) == Any && isempty(u0) + u0 = nothing + end + @reset prob.u0 = u0 + @reset prob.p = p + end + + return prob end """ diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index ef0bca75f..7b69e7746 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -336,3 +336,12 @@ end @test sccprob4.p !== sccprob4.probs[1].p @test sccprob4.p !== sccprob4.probs[2].p end + +@testset "Lazy initialization" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p=missing [guess = 1.0] + @mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t) + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) + prob2 = remake(prob; u0 = [x => 2.0]) + @test prob2.ps[p] ≈ 3.0 +end diff --git a/test/initialization.jl b/test/initialization.jl index 1ea1da694..7d5dfb01d 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -244,4 +244,30 @@ end @test p ≈ 0.0 @test success end + + @testset "Trivial initialization" begin + initprob = NonlinearProblem(Returns(nothing), nothing, [1.0]) + update_initializeprob! = function (iprob, integ) + iprob.p[1] = integ.u[1] + end + initprobmap = function (nlsol) + u1 = parameter_values(nlsol)[1] + return [u1, u1] + end + initprobpmap = function (_, nlsol) + return 0.0 + end + initialization_data = SciMLBase.OverrideInitData( + initprob, update_initializeprob!, initprobmap, initprobpmap) + fn = ODEFunction(rhs2; initialization_data) + prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0) + integ = init(prob; initializealg = NoInit()) + + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), Val(false) + ) + @test u0 ≈ [2.0, 2.0] + @test p ≈ 0.0 + @test success + end end