Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add lazy initialization to remake #881

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 58 additions & 32 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,24 @@ end

function Base.showerror(io::IO, e::CheckInitFailureError)
print(io,
"DAE initialization failed: your u0 did not satisfy the initialization requirements,
normresid = $(e.normresid) > abstol = $(e.abstol)."
)
"""
DAE initialization failed: your u0 did not satisfy the initialization requirements, \
normresid = $(e.normresid) > abstol = $(e.abstol).
""")

if e.isdae
print(io, " If you wish for the system to
automatically change the algebraic variables to satisfy the algebraic constraints,
please pass `initializealg = BrownBasicInit()` to solve (this option will require
`using OrdinaryDiffEqNonlinearSolve`). If you wish to perform an initialization on the
complete u0, please pass `initializealg = ShampineCollocationInit()` to solve. Note that
initialization can be a very difficult process for DAEs and in many cases can be
numerically intractable without symbolic manipulation of the system. For an automated
system that will generate numerically stable initializations, see ModelingToolkit.jl
structural simplification for more details."
)
print(io,
"""
If you wish for the system to automatically change the algebraic variables to \
satisfy the algebraic constraints, please pass `initializealg = BrownBasicInit()` \
to solve (this option will require `using OrdinaryDiffEqNonlinearSolve`). If you \
wish to perform an initialization on the complete u0, please pass \
`initializealg = ShampineCollocationInit()` to `solve`. Note that initialization \
can be a very difficult process for DAEs and in many cases can be numerically \
intractable without symbolic manipulation of the system. For an automated \
system that will generate numerically stable initializations, see \
ModelingToolkit.jl structural simplification for more details.
""")
end
end

Expand Down Expand Up @@ -188,6 +191,9 @@ Keyword arguments:
provided to the `OverrideInit` constructor takes priority over this keyword argument.
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
an error will be thrown.

In case the initialization problem is trivial, `nlsolve_alg`, `abstol` and `reltol` are
not required.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
Expand All @@ -201,35 +207,55 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
if nlsolve_alg === nothing && state_values(initprob) !== nothing
throw(OverrideInitMissingAlgorithm())
end

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
end

if alg.abstol !== nothing
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
if is_trivial_initialization(initdata)
nlsol = initprob
success = true
else
throw(OverrideInitNoTolerance(:abstol))
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
if nlsolve_alg === nothing && state_values(initprob) !== nothing
throw(OverrideInitMissingAlgorithm())
end
if alg.abstol !== nothing
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
else
throw(OverrideInitNoTolerance(:abstol))
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
else
throw(OverrideInitNoTolerance(:reltol))
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)
success = SciMLBase.successful_retcode(nlsol)
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
else
throw(OverrideInitNoTolerance(:reltol))
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(valp, nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
return u0, p, success
end

is_trivial_initialization(::Nothing) = true

function is_trivial_initialization(initdata::OverrideInitData)
!(initdata.initializeprob isa NonlinearLeastSquaresProblem) &&
state_values(initdata.initializeprob) === nothing
end

function is_trivial_initialization(f::AbstractSciMLFunction)
has_initialization_data(f) && is_trivial_initialization(f.initialization_data)
end

function is_trivial_initialization(prob::AbstractSciMLProblem)
is_trivial_initialization(prob.f)
end
20 changes: 19 additions & 1 deletion src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ function remake(prob::ODEProblem; f = missing,
interpret_symbolicmap = true,
build_initializeprob = true,
use_defaults = false,
lazy_initialization = nothing,
_kwargs...)
if tspan === missing
tspan = prob.tspan
Expand All @@ -123,6 +124,8 @@ function remake(prob::ODEProblem; f = missing,

iip = isinplace(prob)

initialization_data = prob.f.initialization_data

if f === missing
if build_initializeprob
initialization_data = remake_initialization_data_compat_wrapper(
Expand Down Expand Up @@ -170,13 +173,28 @@ function remake(prob::ODEProblem; f = missing,
_f = ODEFunction{isinplace(prob), specialization(prob.f)}(f)
end

if kwargs === missing
prob = if kwargs === missing
ODEProblem{isinplace(prob)}(
_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
_kwargs...)
else
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...)
end

if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if !lazy_initialization
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
u0 = nothing
end
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end

"""
Expand Down
9 changes: 9 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,12 @@ end
@test sccprob4.p !== sccprob4.probs[1].p
@test sccprob4.p !== sccprob4.probs[2].p
end

@testset "Lazy initialization" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p=missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t)
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
prob2 = remake(prob; u0 = [x => 2.0])
@test prob2.ps[p] ≈ 3.0
end
26 changes: 26 additions & 0 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,30 @@ end
@test p ≈ 0.0
@test success
end

@testset "Trivial initialization" begin
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
end
initprobmap = function (nlsol)
u1 = parameter_values(nlsol)[1]
return [u1, u1]
end
initprobpmap = function (_, nlsol)
return 0.0
end
initialization_data = SciMLBase.OverrideInitData(
initprob, update_initializeprob!, initprobmap, initprobpmap)
fn = ODEFunction(rhs2; initialization_data)
prob = ODEProblem(fn, [2.0, 0.0], (0.0, 1.0), 0.0)
integ = init(prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(), Val(false)
)
@test u0 ≈ [2.0, 2.0]
@test p ≈ 0.0
@test success
end
end
Loading