Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization on non-DAE models #2512

Merged
merged 27 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
66c29ef
WIP: Initialization on non-DAE models
ChrisRackauckas Feb 29, 2024
e61a0f9
use new ODE solver changes
ChrisRackauckas Feb 29, 2024
fbd3e1f
format
ChrisRackauckas Feb 29, 2024
a905587
Handle empty u0map
ChrisRackauckas Feb 29, 2024
186302f
late binding initialization_eqs
ChrisRackauckas Feb 29, 2024
fad548f
make sure u0map isn't a vector of numbers
ChrisRackauckas Feb 29, 2024
82cb87c
format
ChrisRackauckas Feb 29, 2024
7fb84b3
fix a few tests
ChrisRackauckas Feb 29, 2024
e7511d7
format
ChrisRackauckas Feb 29, 2024
d08fefe
fix condition
ChrisRackauckas Feb 29, 2024
59eae13
don't initialize SDEs
ChrisRackauckas Feb 29, 2024
3938441
fix a typo from tests
ChrisRackauckas Feb 29, 2024
d20095a
handle static arrays
ChrisRackauckas Feb 29, 2024
752f9e3
format
ChrisRackauckas Feb 29, 2024
6d5b6dd
Fix up a few tests / examples
ChrisRackauckas Feb 29, 2024
11d096f
format
ChrisRackauckas Feb 29, 2024
6542bba
Handle dummy derivative u0's and throw custom incomplete init error
ChrisRackauckas Feb 29, 2024
746386c
format
ChrisRackauckas Feb 29, 2024
7002310
don't filter empty u0maps
ChrisRackauckas Feb 29, 2024
4e5e723
handle arrays
ChrisRackauckas Feb 29, 2024
2147dc6
Handle the scalar u0map case
ChrisRackauckas Mar 1, 2024
674b8c4
fix filter
ChrisRackauckas Mar 1, 2024
1274908
fix components test
ChrisRackauckas Mar 1, 2024
49d8119
Handle steady state initializations
ChrisRackauckas Mar 1, 2024
96bfc35
format
ChrisRackauckas Mar 1, 2024
8f2c780
Fix some odd test choices
ChrisRackauckas Mar 1, 2024
a477f43
fix a few last tests
ChrisRackauckas Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Libdl = "1"
LinearAlgebra = "1"
MLStyle = "0.4.17"
NaNMath = "0.3, 1"
OrdinaryDiffEq = "6.72.0"
OrdinaryDiffEq = "6.73.0"
PrecompileTools = "1"
RecursiveArrayTools = "2.3, 3"
Reexport = "0.2, 1"
Expand Down
6 changes: 3 additions & 3 deletions examples/electrical_components.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using ModelingToolkit, OrdinaryDiffEq
using ModelingToolkit: t_nounits as t, D_nounits as D

@connector function Pin(; name)
sts = @variables v(t)=1.0 i(t)=1.0 [connect = Flow]
sts = @variables v(t) [guess = 1.0] i(t) [guess = 1.0, connect = Flow]
ODESystem(Equation[], t, sts, []; name = name)
end

Expand All @@ -16,7 +16,7 @@ end
@component function OnePort(; name)
@named p = Pin()
@named n = Pin()
sts = @variables v(t)=1.0 i(t)=1.0
sts = @variables v(t) [guess = 1.0] i(t) [guess = 1.0]
eqs = [v ~ p.v - n.v
0 ~ p.i + n.i
i ~ p.i]
Expand Down Expand Up @@ -64,7 +64,7 @@ end
end

@connector function HeatPort(; name)
@variables T(t)=293.15 Q_flow(t)=0.0 [connect = Flow]
@variables T(t) [guess = 293.15] Q_flow(t) [guess = 0.0, connect = Flow]
ODESystem(Equation[], t, [T, Q_flow], [], name = name)
end

Expand Down
1 change: 1 addition & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ for prop in [:eqs
:preface
:torn_matching
:initializesystem
:initialization_eqs
:schedule
:tearing_state
:substitutions
Expand Down
35 changes: 34 additions & 1 deletion src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -862,14 +862,22 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
ps = full_parameters(sys)
iv = get_iv(sys)

# TODO: Pass already computed information to varmap_to_vars call
# in process_u0? That would just be a small optimization
varmap = u0map === nothing || isempty(u0map) || eltype(u0map) <: Number ?
defaults(sys) :
merge(defaults(sys), todict(u0map))
varlist = collect(map(unwrap, dvs))
missingvars = setdiff(varlist, collect(keys(varmap)))

# Append zeros to the variables which are determined by the initialization system
# This essentially bypasses the check for if initial conditions are defined for DAEs
# since they will be checked in the initialization problem's construction
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
ci = infer_clocks!(ClockInference(TearingState(sys)))
# TODO: make it work with clocks
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
if (implicit_dae || calculate_massmatrix(sys) !== I) &&
if sys isa ODESystem && (implicit_dae || !isempty(missingvars)) &&
all(isequal(Continuous()), ci.var_domain) &&
ModelingToolkit.get_tearing_state(sys) !== nothing
if eltype(u0map) <: Number
Expand All @@ -881,6 +889,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;

zerovars = setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0
trueinit = identity.([zerovars; u0map])
u0map isa StaticArraysCore.StaticArray &&
(trueinit = SVector{length(trueinit)}(trueinit))
else
initializeprob = nothing
initializeprobmap = nothing
Expand Down Expand Up @@ -1530,6 +1540,21 @@ function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

const INCOMPLETE_INITIALIZATION_MESSAGE = """
Initialization incomplete. Not all of the state variables of the
DAE system can be determined by the initialization. Missing
variables:
"""

struct IncompleteInitializationError <: Exception
uninit::Any
end

function Base.showerror(io::IO, e::IncompleteInitializationError)
println(io, INCOMPLETE_INITIALIZATION_MESSAGE)
println(io, e.uninit)
end

function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
t::Number, u0map = [],
parammap = DiffEqBase.NullParameters();
Expand All @@ -1550,6 +1575,14 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
generate_initializesystem(sys; u0map); fully_determined = false)
end

uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])

# TODO: throw on uninitialized arrays
filter!(x -> !(x isa Symbolics.Arr), uninit)
if !isempty(uninit)
throw(IncompleteInitializationError(uninit))
end

neqs = length(equations(isys))
nunknown = length(unknowns(isys))

Expand Down
15 changes: 11 additions & 4 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ struct ODESystem <: AbstractODESystem
"""
initializesystem::Union{Nothing, NonlinearSystem}
"""
Extra equations to be enforced during the initialization sequence.
"""
initialization_eqs::Vector{Equation}
"""
The schedule for the code generation process.
"""
schedule::Any
Expand Down Expand Up @@ -171,7 +175,8 @@ struct ODESystem <: AbstractODESystem

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses,
torn_matching, initializesystem, schedule, connector_type, preface, cevents,
torn_matching, initializesystem, initialization_eqs, schedule,
connector_type, preface, cevents,
devents, parameter_dependencies,
metadata = nothing, gui_metadata = nothing,
tearing_state = nothing,
Expand All @@ -190,8 +195,8 @@ struct ODESystem <: AbstractODESystem
end
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, torn_matching,
initializesystem, schedule, connector_type, preface, cevents, devents, parameter_dependencies,
metadata,
initializesystem, initialization_eqs, schedule, connector_type, preface,
cevents, devents, parameter_dependencies, metadata,
gui_metadata, tearing_state, substitutions, complete, index_cache,
discrete_subsystems, solved_unknowns, split_idxs, parent)
end
Expand All @@ -208,6 +213,7 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
defaults = _merge(Dict(default_u0), Dict(default_p)),
guesses = Dict(),
initializesystem = nothing,
initialization_eqs = Equation[],
schedule = nothing,
connector_type = nothing,
preface = nothing,
Expand Down Expand Up @@ -260,7 +266,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, guesses, nothing, initializesystem,
schedule, connector_type, preface, cont_callbacks, disc_callbacks, parameter_dependencies,
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
disc_callbacks, parameter_dependencies,
metadata, gui_metadata, checks = checks)
end

Expand Down
42 changes: 36 additions & 6 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,51 @@ function generate_initializesystem(sys::ODESystem;
# Start the equations list with algebraic equations
eqs_ics = eqs[idxs_alge]
u0 = Vector{Pair}(undef, 0)
defs = merge(defaults(sys), todict(u0map))

full_states = [sts; getfield.((observed(sys)), :lhs)]
eqs_diff = eqs[idxs_diff]
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))

full_states = unique([sts; getfield.((observed(sys)), :lhs)])
set_full_states = Set(full_states)
guesses = todict(guesses)
schedule = getfield(sys, :schedule)

dd_guess = if schedule !== nothing
if schedule !== nothing
guessmap = [x[2] => get(guesses, x[1], default_dd_value)
for x in schedule.dummy_sub]
Dict(filter(x -> !isnothing(x[1]), guessmap))
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
if u0map === nothing || isempty(u0map)
filtered_u0 = u0map
else
filtered_u0 = Pair[]
for x in u0map
y = get(schedule.dummy_sub, x[1], x[1])
y = get(diffmap, y, y)
if y isa Symbolics.Arr
_y = collect(y)

# TODO: Don't scalarize arrays
for i in 1:length(_y)
push!(filtered_u0, _y[i] => x[2][i])
end
elseif y isa ModelingToolkit.BasicSymbolic
# y is a derivative expression expanded
# add to the initialization equations
push!(eqs_ics, y ~ x[2])
elseif y ∈ set_full_states
push!(filtered_u0, y => x[2])
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
end
filtered_u0 = filtered_u0 isa Pair ? todict([filtered_u0]) : todict(filtered_u0)
end
else
Dict()
dd_guess = Dict()
filtered_u0 = u0map
end

defs = merge(defaults(sys), filtered_u0)
guesses = merge(get_guesses(sys), todict(guesses), dd_guess)

for st in full_states
Expand All @@ -55,7 +85,7 @@ function generate_initializesystem(sys::ODESystem;
end

pars = [parameters(sys); get_iv(sys)]
nleqs = [eqs_ics; observed(sys)]
nleqs = [eqs_ics; get_initialization_eqs(sys); observed(sys)]

sys_nl = NonlinearSystem(nleqs,
full_states,
Expand Down
19 changes: 17 additions & 2 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,23 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
end
end

const MISSING_VARIABLES_MESSAGE = """
Initial condition underdefined. Some are missing from the variable map.
Please provide a default (`u0`), initialization equation, or guess
for the following variables:
"""

struct MissingVariablesError <: Exception
vars::Any
end

function Base.showerror(io::IO, e::MissingVariablesError)
println(io, MISSING_VARIABLES_MESSAGE)
println(io, e.vars)
end

function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false,
toterm = Symbolics.diff2term)
toterm = Symbolics.diff2term, initialization_phase = false)
varmap = merge(defaults, varmap) # prefers the `varmap`
varmap = Dict(toterm(value(k)) => value(varmap[k]) for k in keys(varmap))
# resolve symbolic parameter expressions
Expand All @@ -180,7 +195,7 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
end

missingvars = setdiff(varlist, collect(keys(varmap)))
check && (isempty(missingvars) || throw_missingvars(missingvars))
check && (isempty(missingvars) || throw(MissingVariablesError(missingvars)))

out = [varmap[var] for var in varlist]
end
Expand Down
6 changes: 3 additions & 3 deletions test/components.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ let
@named rc_model2 = compose(_rc_model2,
[resistor, resistor2, capacitor, source, ground])
sys2 = structural_simplify(rc_model2)
prob2 = ODEProblem(sys2, [], (0, 10.0), guesses = u0)
prob2 = ODEProblem(sys2, [source.p.i => 0.0], (0, 10.0), guesses = u0)
sol2 = solve(prob2, Rosenbrock23())
@test sol2[source.p.i] ≈ sol2[rc_model2.source.p.i] ≈ sol2[capacitor.i]
@test sol2[source.p.i] ≈ sol2[rc_model2.source.p.i] ≈ -sol2[capacitor.i]
end

# Outer/inner connections
Expand Down Expand Up @@ -157,7 +157,7 @@ sys = structural_simplify(ll_model)
u0 = unknowns(sys) .=> 0
@test_nowarn ODEProblem(
sys, [], (0, 10.0), guesses = u0, warn_initialize_determined = false)
prob = DAEProblem(sys, D.(unknowns(sys)) .=> 0, u0, (0, 0.5))
prob = DAEProblem(sys, D.(unknowns(sys)) .=> 0, [], (0, 0.5), guesses = u0)
sol = solve(prob, DFBDF())
@test sol.retcode == SciMLBase.ReturnCode.Success

Expand Down
88 changes: 87 additions & 1 deletion test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,90 @@ p = [σ => 28.0,
β => 8 / 3]

tspan = (0.0, 100.0)
@test_throws ArgumentError prob=ODEProblem(sys, u0, tspan, p, jac = true)
@test_throws ModelingToolkit.IncompleteInitializationError prob=ODEProblem(
sys, u0, tspan, p, jac = true)

# DAE Initialization on ODE with nonlinear system for initial conditions
# https://github.com/SciML/ModelingToolkit.jl/issues/2508

using ModelingToolkit, OrdinaryDiffEq, Test
using ModelingToolkit: t_nounits as t, D_nounits as D

function System2(; name)
vars = @variables begin
dx(t), [guess = 0]
ddx(t), [guess = 0]
end
eqs = [D(dx) ~ ddx
0 ~ ddx + dx + 1]
return ODESystem(eqs, t, vars, []; name)
end

@mtkbuild sys = System2()
prob = ODEProblem(sys, [sys.dx => 1], (0, 1)) # OK
prob = ODEProblem(sys, [sys.ddx => -2], (0, 1), guesses = [sys.dx => 1])
sol = solve(prob, Tsit5())
@test SciMLBase.successful_retcode(sol)
@test sol[1] == [1.0]

## Late binding initialization_eqs

function System3(; name)
vars = @variables begin
dx(t), [guess = 0]
ddx(t), [guess = 0]
end
eqs = [D(dx) ~ ddx
0 ~ ddx + dx + 1]
initialization_eqs = [
ddx ~ -2
]
return ODESystem(eqs, t, vars, []; name, initialization_eqs)
end

@mtkbuild sys = System3()
prob = ODEProblem(sys, [], (0, 1), guesses = [sys.dx => 1])
sol = solve(prob, Tsit5())
@test SciMLBase.successful_retcode(sol)
@test sol[1] == [1.0]

# Steady state initialization

@parameters σ ρ β
@variables x(t) y(t) z(t)

eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z]

@named sys = ODESystem(eqs, t)
sys = structural_simplify(sys)

u0 = [D(x) => 2.0,
x => 1.0,
D(y) => 0.0,
z => 0.0]

p = [σ => 28.0,
ρ => 10.0,
β => 8 / 3]

tspan = (0.0, 0.2)
prob_mtk = ODEProblem(sys, u0, tspan, p)
sol = solve(prob_mtk, Tsit5())
@test sol[x * (ρ - z) - y][1] == 0.0

@variables x(t) y(t) z(t)
@parameters α=1.5 β=1.0 γ=3.0 δ=1.0

eqs = [D(x) ~ α * x - β * x * y
D(y) ~ -γ * y + δ * x * y
z ~ x + y]

@named sys = ODESystem(eqs, t)
simpsys = structural_simplify(sys)
tspan = (0.0, 10.0)

prob = ODEProblem(simpsys, [D(x) => 0.0, y => 0.0], tspan, guesses = [x => 0.0])
sol = solve(prob, Tsit5())
@test sol[1] == [0.0, 0.0]
2 changes: 1 addition & 1 deletion test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ model = ODESystem(eqs, t; systems = [torque, inertia1, inertia2, spring, damper]
name = :name)
model_outputs = [inertia1.w, inertia2.w, inertia1.phi, inertia2.phi]
model_inputs = [torque.tau.u]
matrices, ssys = linearize(model, model_inputs, model_outputs)
matrices, ssys = linearize(model, model_inputs, model_outputs);
@test length(ModelingToolkit.outputs(ssys)) == 4

if VERSION >= v"1.8" # :opaque_closure not supported before
Expand Down
Loading
Loading