Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix up initializesystem for hierarchical models #2403

Merged
merged 48 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
5ca99eb
Fix up initializesystem for hierarchical models
ChrisRackauckas Dec 29, 2023
4717aee
Handle overdetermined systems gracefully when `fully_determined = false`
YingboMa Jan 2, 2024
6d6d3fa
Setup guesses(sys) and passing override dictionaries
ChrisRackauckas Feb 24, 2024
1440304
Add NonlinearLeastSquaresProblem building
ChrisRackauckas Feb 24, 2024
ee77fb2
Build the initialization system generation into structural simplifica…
ChrisRackauckas Feb 24, 2024
6f960db
InitializationProblem works
ChrisRackauckas Feb 24, 2024
0f3ba28
Test the initialization problem
ChrisRackauckas Feb 24, 2024
aafddc9
format
ChrisRackauckas Feb 24, 2024
4702426
Automate tagging of the initialization system to the ODEProblem
ChrisRackauckas Feb 26, 2024
c443ac7
Fix guess and initial condition length checking
ChrisRackauckas Feb 26, 2024
67f6f24
Bump ordinarydiffeq
ChrisRackauckas Feb 26, 2024
7c0c423
format
ChrisRackauckas Feb 26, 2024
fb4b987
stop initialize on clocks
ChrisRackauckas Feb 26, 2024
16ce722
format
ChrisRackauckas Feb 26, 2024
7a84444
fix direct structural simplification for initialization systems
ChrisRackauckas Feb 26, 2024
ae4be0f
format
ChrisRackauckas Feb 26, 2024
dcffe8f
check for io first
ChrisRackauckas Feb 26, 2024
78b5832
format
ChrisRackauckas Feb 26, 2024
6e336ba
Properly drop from defaults
ChrisRackauckas Feb 26, 2024
dfee0ef
handle the Vector{Float} case
ChrisRackauckas Feb 27, 2024
f62163a
handle nothing equations case
ChrisRackauckas Feb 27, 2024
80ece19
handle no structural simplify
ChrisRackauckas Feb 27, 2024
97c1d9d
new error
ChrisRackauckas Feb 27, 2024
8521ddb
Remove scalarization in NonlinearSystem
ChrisRackauckas Feb 27, 2024
2efc30a
Try other tests
ChrisRackauckas Feb 27, 2024
9df0b4b
remove old initialize
ChrisRackauckas Feb 27, 2024
3d41e04
see if the non array tests all pass
ChrisRackauckas Feb 27, 2024
0763874
fix reduction test
ChrisRackauckas Feb 27, 2024
deaa824
better fix
ChrisRackauckas Feb 27, 2024
456478f
only build initialization if simplified
ChrisRackauckas Feb 27, 2024
a5e275d
reenable mass matrix test
ChrisRackauckas Feb 27, 2024
f85b219
format
ChrisRackauckas Feb 27, 2024
f2559ab
Update test/odesystem.jl
ChrisRackauckas Feb 27, 2024
7970cc1
let all tests run
ChrisRackauckas Feb 27, 2024
33c361e
Merge branch 'initializesystem' of https://github.com/SciML/ModelingT…
ChrisRackauckas Feb 27, 2024
3ea493a
handle dds and t
ChrisRackauckas Feb 27, 2024
734e195
format
ChrisRackauckas Feb 27, 2024
90ceeda
update initialization problem usage for t0
ChrisRackauckas Feb 27, 2024
920fcc2
Remove early caching initialization system
ChrisRackauckas Feb 27, 2024
0b6edee
guesses
ChrisRackauckas Feb 27, 2024
d78c680
handle empty parammap
ChrisRackauckas Feb 27, 2024
79eca95
format
ChrisRackauckas Feb 27, 2024
ab935d7
Fix schedule setting
ChrisRackauckas Feb 27, 2024
2ccf8fa
a few fixes
ChrisRackauckas Feb 28, 2024
9cf1768
Fix test cases
ChrisRackauckas Feb 28, 2024
537fb7c
format
ChrisRackauckas Feb 28, 2024
56421fb
return drop_expr
ChrisRackauckas Feb 28, 2024
c78321f
Fix warning
ChrisRackauckas Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,12 @@ Libdl = "1"
LinearAlgebra = "1"
MLStyle = "0.4.17"
NaNMath = "0.3, 1"
OrdinaryDiffEq = "6"
OrdinaryDiffEq = "6.72.0"
PrecompileTools = "1"
RecursiveArrayTools = "2.3, 3"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.0.1"
SciMLBase = "2.28.0"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down
9 changes: 4 additions & 5 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,18 @@ include("systems/model_parsing.jl")
include("systems/connectors.jl")
include("systems/callbacks.jl")

include("systems/nonlinear/nonlinearsystem.jl")
include("systems/diffeqs/odesystem.jl")
include("systems/diffeqs/sdesystem.jl")
include("systems/diffeqs/abstractodesystem.jl")
include("systems/nonlinear/modelingtoolkitize.jl")
include("systems/nonlinear/initializesystem.jl")
include("systems/diffeqs/first_order_transform.jl")
include("systems/diffeqs/modelingtoolkitize.jl")
include("systems/diffeqs/basic_transformations.jl")

include("systems/jumps/jumpsystem.jl")

include("systems/nonlinear/nonlinearsystem.jl")
include("systems/nonlinear/modelingtoolkitize.jl")
include("systems/nonlinear/initializesystem.jl")

include("systems/optimization/constraints_system.jl")
include("systems/optimization/optimizationsystem.jl")
include("systems/optimization/modelingtoolkitize.jl")
Expand Down Expand Up @@ -253,7 +252,7 @@ export toexpr, get_variables
export simplify, substitute
export build_function
export modelingtoolkitize
export initializesystem
export initializesystem, generate_initializesystem

export @variables, @parameters, @constants, @brownian
export @named, @nonamespace, @namespace, extend, compose, complete
Expand Down
2 changes: 1 addition & 1 deletion src/bipartite_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ return `false` may not be matched.
"""
function maximal_matching(g::BipartiteGraph, srcfilter = vsrc -> true,
dstfilter = vdst -> true, ::Type{U} = Unassigned) where {U}
matching = Matching{U}(ndsts(g))
matching = Matching{U}(max(nsrcs(g), ndsts(g)))
foreach(Iterators.filter(srcfilter, 𝑠vertices(g))) do vsrc
construct_augmenting_path!(matching, g, vsrc, dstfilter)
end
Expand Down
3 changes: 2 additions & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
IncrementalCycleTracker, add_edge_checked!, topological_sort,
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
fast_substitute, get_fullvars, has_equations, observed
fast_substitute, get_fullvars, has_equations, observed,
Schedule

using ModelingToolkit.BipartiteGraphs
import .BipartiteGraphs: invview, complete
Expand Down
6 changes: 6 additions & 0 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,12 @@ function tearing_reassemble(state::TearingState, var_eq_matching;
# TODO: compute the dependency correctly so that we don't have to do this
obs = [fast_substitute(observed(sys), obs_sub); subeqs]
@set! sys.observed = obs

# Only makes sense for time-dependent
# TODO: generalize to SDE
if sys isa ODESystem
@set! sys.schedule = Schedule(var_eq_matching, dummy_sub)
end
@set! state.sys = sys
@set! sys.tearing_state = state
return invalidate_cache!(sys)
Expand Down
17 changes: 14 additions & 3 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,9 @@ function complete(sys::AbstractSystem; split = true)
if split && has_index_cache(sys)
@set! sys.index_cache = IndexCache(sys)
end
if isdefined(sys, :initializesystem) && get_initializesystem(sys) !== nothing
@set! sys.initializesystem = complete(get_initializesystem(sys); split)
end
isdefined(sys, :complete) ? (@set! sys.complete = true) : sys
end

Expand All @@ -551,6 +554,7 @@ for prop in [:eqs
:var_to_name
:ctrls
:defaults
:guesses
:observed
:tgrad
:jac
Expand All @@ -571,6 +575,8 @@ for prop in [:eqs
:connections
:preface
:torn_matching
:initializesystem
:schedule
:tearing_state
:substitutions
:metadata
Expand Down Expand Up @@ -933,6 +939,10 @@ function full_parameters(sys::AbstractSystem)
vcat(parameters(sys), dependent_parameters(sys))
end

function guesses(sys::AbstractSystem)
get_guesses(sys)
end

# required in `src/connectors.jl:437`
parameters(_) = []

Expand Down Expand Up @@ -2259,14 +2269,15 @@ function UnPack.unpack(sys::ModelingToolkit.AbstractSystem, ::Val{p}) where {p}
end

"""
missing_variable_defaults(sys::AbstractSystem, default = 0.0)
missing_variable_defaults(sys::AbstractSystem, default = 0.0; subset = unknowns(sys))

returns a `Vector{Pair}` of variables set to `default` which are missing from `get_defaults(sys)`. The `default` argument can be a single value or vector to set the missing defaults respectively.
"""
function missing_variable_defaults(sys::AbstractSystem, default = 0.0)
function missing_variable_defaults(
sys::AbstractSystem, default = 0.0; subset = unknowns(sys))
varmap = get_defaults(sys)
varmap = Dict(Symbolics.diff2term(value(k)) => value(varmap[k]) for k in keys(varmap))
missingvars = setdiff(unknowns(sys), keys(varmap))
missingvars = setdiff(subset, keys(varmap))
ds = Pair[]

n = length(missingvars)
Expand Down
140 changes: 134 additions & 6 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
struct Schedule
var_eq_matching::Any
dummy_sub::Any
end

function filter_kwargs(kwargs)
kwargs = Dict(kwargs)
for key in keys(kwargs)
Expand Down Expand Up @@ -316,6 +321,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
sparsity = false,
analytic = nothing,
split_idxs = nothing,
initializeprob = nothing,
initializeprobmap = nothing,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`")
Expand Down Expand Up @@ -487,6 +494,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
end

@set! sys.split_idxs = split_idxs

ODEFunction{iip, specialize}(f;
sys = sys,
jac = _jac === nothing ? nothing : _jac,
Expand All @@ -495,7 +503,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
jac_prototype = jac_prototype,
observed = observedfun,
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
analytic = analytic)
analytic = analytic,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap)
end

"""
Expand Down Expand Up @@ -525,6 +535,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
sparse = false, simplify = false,
eval_module = @__MODULE__,
checkbounds = false,
initializeprob = nothing,
initializeprobmap = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
Expand Down Expand Up @@ -596,7 +608,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
sys = sys,
jac = _jac === nothing ? nothing : _jac,
jac_prototype = jac_prototype,
observed = observedfun)
observed = observedfun,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap)
end

function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
Expand Down Expand Up @@ -839,18 +853,46 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
tofloat = true,
symbolic_u0 = false,
u0_constructor = identity,
guesses = Dict(),
t = nothing,
warn_initialize_determined = true,
kwargs...)
eqs = equations(sys)
dvs = unknowns(sys)
ps = full_parameters(sys)
iv = get_iv(sys)

# Append zeros to the variables which are determined by the initialization system
# This essentially bypasses the check for if initial conditions are defined for DAEs
# since they will be checked in the initialization problem's construction
# TODO: make check for if a DAE cheaper than calculating the mass matrix a second time!
ci = infer_clocks!(ClockInference(TearingState(sys)))
# TODO: make it work with clocks
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
if (implicit_dae || calculate_massmatrix(sys) !== I) &&
all(isequal(Continuous()), ci.var_domain) &&
ModelingToolkit.get_tearing_state(sys) !== nothing
if eltype(u0map) <: Number
u0map = unknowns(sys) .=> u0map
end
initializeprob = ModelingToolkit.InitializationProblem(
sys, t, u0map, parammap; guesses, warn_initialize_determined)
initializeprobmap = getu(initializeprob, unknowns(sys))

zerovars = setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0
trueinit = identity.([zerovars; u0map])
else
initializeprob = nothing
initializeprobmap = nothing
trueinit = u0map
end

if has_index_cache(sys) && get_index_cache(sys) !== nothing
u0, defs = get_u0(sys, u0map, parammap; symbolic_u0)
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0)
p = MTKParameters(sys, parammap)
else
u0, p, defs = get_u0_p(sys,
u0map,
trueinit,
parammap;
tofloat,
use_union,
Expand Down Expand Up @@ -881,6 +923,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
checkbounds = checkbounds, p = p,
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
sparse = sparse, eval_expression = eval_expression,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap,
kwargs...)
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
end
Expand Down Expand Up @@ -984,13 +1028,14 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
parammap = DiffEqBase.NullParameters();
callback = nothing,
check_length = true,
warn_initialize_determined = true,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
end
f, u0, p = process_DEProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
check_length, kwargs...)
check_length, warn_initialize_determined, kwargs...)
cbs = process_events(sys; callback, kwargs...)
inits = []
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
Expand Down Expand Up @@ -1055,13 +1100,15 @@ end

function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
parammap = DiffEqBase.NullParameters();
warn_initialize_determined = true,
check_length = true, kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`")
end
f, du0, u0, p = process_DEProblem(DAEFunction{iip}, sys, u0map, parammap;
implicit_dae = true, du0map = du0map, check_length,
kwargs...)
t = tspan !== nothing ? tspan[1] : tspan,
warn_initialize_determined, kwargs...)
diffvars = collect_differential_variables(sys)
sts = unknowns(sys)
differential_vars = map(Base.Fix2(in, diffvars), sts)
Expand Down Expand Up @@ -1237,6 +1284,7 @@ function ODEProblemExpr{iip}(sys::AbstractODESystem, u0map, tspan,
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `ODEProblemExpr`")
end
f, u0, p = process_DEProblem(ODEFunctionExpr{iip}, sys, u0map, parammap; check_length,
t = tspan !== nothing ? tspan[1] : tspan,
kwargs...)
linenumbers = get(kwargs, :linenumbers, true)
kwargs = filter_kwargs(kwargs)
Expand Down Expand Up @@ -1282,6 +1330,7 @@ function DAEProblemExpr{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblemExpr`")
end
f, du0, u0, p = process_DEProblem(DAEFunctionExpr{iip}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
implicit_dae = true, du0map = du0map, check_length,
kwargs...)
linenumbers = get(kwargs, :linenumbers, true)
Expand Down Expand Up @@ -1442,3 +1491,82 @@ function flatten_equations(eqs)
end
end
end

struct InitializationProblem{iip, specialization} end

"""
```julia
InitializationProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
version = nothing, tgrad = false,
jac = false,
checkbounds = false, sparse = false,
simplify = false,
linenumbers = true, parallel = SerialForm(),
kwargs...) where {iip}
```

Generates a NonlinearProblem or NonlinearLeastSquaresProblem from an ODESystem
which represents the initialization, i.e. the calculation of the consistent
initial conditions for the given DAE.
"""
function InitializationProblem(sys::AbstractODESystem, args...; kwargs...)
InitializationProblem{true}(sys, args...; kwargs...)
end

function InitializationProblem(sys::AbstractODESystem, t,
u0map::StaticArray,
args...;
kwargs...)
InitializationProblem{false, SciMLBase.FullSpecialize}(
sys, t, u0map, args...; kwargs...)
end

function InitializationProblem{true}(sys::AbstractODESystem, args...; kwargs...)
InitializationProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...)
InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
t::Number, u0map = [],
parammap = DiffEqBase.NullParameters();
guesses = [],
check_length = true,
warn_initialize_determined = true,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
end

if isempty(u0map) && get_initializesystem(sys) !== nothing
isys = get_initializesystem(sys)
elseif isempty(u0map) && get_initializesystem(sys) === nothing
isys = structural_simplify(generate_initializesystem(sys); fully_determined = false)
else
isys = structural_simplify(
generate_initializesystem(sys; u0map); fully_determined = false)
end

neqs = length(equations(isys))
nunknown = length(unknowns(isys))

if warn_initialize_determined && neqs > nunknown
@warn "Initialization system is overdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
end
if warn_initialize_determined && neqs < nunknown
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
end

parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
[get_iv(sys) => t] :
merge(todict(parammap), Dict(get_iv(sys) => t))

if neqs == nunknown
NonlinearProblem(isys, guesses, parammap)
else
NonlinearLeastSquaresProblem(isys, guesses, parammap)
end
end
Loading
Loading