Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up generate_initializesystem() #3051

Merged
merged 12 commits into from
Oct 5, 2024
132 changes: 53 additions & 79 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,109 +5,83 @@ Generate `NonlinearSystem` which initializes an ODE problem from specified initi
"""
function generate_initializesystem(sys::ODESystem;
u0map = Dict(),
name = nameof(sys),
guesses = Dict(), check_defguess = false,
default_dd_value = 0.0,
algebraic_only = false,
initialization_eqs = [],
check_units = true,
kwargs...)
sts, eqs = unknowns(sys), equations(sys)
guesses = Dict(),
default_dd_guess = 0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the old name documented anywhere?

algebraic_only = false,
check_units = true, check_defguess = false,
name = nameof(sys), kwargs...)
vars = unique([unknowns(sys); getfield.((observed(sys)), :lhs)])
vars_set = Set(vars) # for efficient in-lookup

eqs = equations(sys)
idxs_diff = isdiffeq.(eqs)
idxs_alge = .!idxs_diff
num_alge = sum(idxs_alge)

# Start the equations list with algebraic equations
eqs_ics = eqs[idxs_alge]
u0 = Vector{Pair}(undef, 0)

# prepare map for dummy derivative substitution
eqs_diff = eqs[idxs_diff]
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
full_diffmap = merge(diffmap, observed_diffmap)
D = Differential(get_iv(sys))
diffmap = merge(
Dict(eq.lhs => eq.rhs for eq in eqs_diff),
Dict(D(eq.lhs) => D(eq.rhs) for eq in observed(sys))
)

full_states = unique([sts; getfield.((observed(sys)), :lhs)])
set_full_states = Set(full_states)
# 1) process dummy derivatives and u0map into initialization system
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
guesses = merge(get_guesses(sys), todict(guesses))
schedule = getfield(sys, :schedule)

if schedule !== nothing
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
for x in schedule.dummy_sub]
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
if u0map === nothing || isempty(u0map)
filtered_u0 = u0map
else
filtered_u0 = Pair[]
for x in u0map
y = get(schedule.dummy_sub, x[1], x[1])
y = ModelingToolkit.fixpoint_sub(y, full_diffmap)

if y ∈ set_full_states
# defer initialization until defaults are merged below
push!(filtered_u0, y => x[2])
if !isnothing(schedule)
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
# set dummy derivatives to default_dd_guess unless specified
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
end
if !isnothing(u0map)
for (y, x) in u0map
y = get(schedule.dummy_sub, y, y)
y = fixpoint_sub(y, diffmap)
if y ∈ vars_set
# variables specified in u0 overrides defaults
push!(defs, y => x)
elseif y isa Symbolics.Arr
# scalarize array # TODO: don't scalarize arrays
_y = collect(y)
for i in eachindex(_y)
push!(filtered_u0, _y[i] => x[2][i])
end
# TODO: don't scalarize arrays
push!(defs, (collect(y) .=> x)...)
hersle marked this conversation as resolved.
Show resolved Hide resolved
elseif y isa Symbolics.BasicSymbolic
# y is a derivative expression expanded
# add to the initialization equations
push!(eqs_ics, y ~ x[2])
# y is a derivative expression expanded; add it to the initialization equations
push!(eqs_ics, y ~ x)
else
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
end
end
filtered_u0 = todict(filtered_u0)
end
else
dd_guess = Dict()
filtered_u0 = todict(u0map)
end

defs = merge(defaults(sys), filtered_u0)

for st in full_states
if st ∈ keys(defs)
def = defs[st]

if def isa Equation
st ∉ keys(guesses) && check_defguess &&
error("Invalid setup: unknown $(st) has an initial condition equation with no guess.")
push!(eqs_ics, def)
push!(u0, st => guesses[st])
else
push!(eqs_ics, st ~ def)
push!(u0, st => def)
end
elseif st ∈ keys(guesses)
push!(u0, st => guesses[st])
# 2) process other variables
for var in vars
if var ∈ keys(defs)
push!(eqs_ics, var ~ defs[var])
elseif var ∈ keys(guesses)
push!(defs, var => guesses[var])
elseif check_defguess
error("Invalid setup: unknown $(st) has no default value or initial guess")
error("Invalid setup: variable $(var) has no default value or initial guess")
end
end

# 3) process explicitly provided initialization equations
if !algebraic_only
for eq in [get_initialization_eqs(sys); initialization_eqs]
_eq = ModelingToolkit.fixpoint_sub(eq, full_diffmap)
push!(eqs_ics, _eq)
initialization_eqs = [get_initialization_eqs(sys); initialization_eqs]
for eq in initialization_eqs
eq = fixpoint_sub(eq, diffmap) # expand dummy derivatives
push!(eqs_ics, eq)
end
end

pars = [parameters(sys); get_iv(sys)]
nleqs = [eqs_ics; observed(sys)]

sys_nl = NonlinearSystem(nleqs,
full_states,
pars;
defaults = merge(ModelingToolkit.defaults(sys), todict(u0), dd_guess),
parameter_dependencies = parameter_dependencies(sys),
pars = [parameters(sys); get_iv(sys)] # include independent variable as pseudo-parameter
eqs_ics = [eqs_ics; observed(sys)]
return NonlinearSystem(
eqs_ics, vars, pars;
defaults = defs, parameter_dependencies = parameter_dependencies(sys),
checks = check_units,
name,
kwargs...)

return sys_nl
name, kwargs...
)
end
Loading