Skip to content

Commit

Permalink
Merge pull request #3226 from AayushSabharwal/as/remake-propagate-gue…
Browse files Browse the repository at this point in the history
…sses

feat: propagate `ODEProblem` guesses to `remake`
  • Loading branch information
AayushSabharwal authored Dec 2, 2024
2 parents e9fe9a1 + 10cc9c1 commit 4d4ff85
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 82 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ REPL = "1"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.57.1"
SciMLBase = "2.64"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1310,11 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
elseif isempty(u0map) && get_initializesystem(sys) === nothing
isys = structural_simplify(
generate_initializesystem(
sys; initialization_eqs, check_units, pmap = parammap); fully_determined)
sys; initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
else
isys = structural_simplify(
generate_initializesystem(
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
sys; u0map, initialization_eqs, check_units, pmap = parammap, guesses); fully_determined)
end

ts = get_tearing_state(isys)
Expand Down
136 changes: 64 additions & 72 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ function generate_initializesystem(sys::ODESystem;
# 1) process dummy derivatives and u0map into initialization system
eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations
defs = copy(defaults(sys)) # copy so we don't modify sys.defaults
guesses = merge(get_guesses(sys), todict(guesses))
additional_guesses = anydict(guesses)
guesses = merge(get_guesses(sys), additional_guesses)
schedule = getfield(sys, :schedule)
if !isnothing(schedule)
for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub)
Expand Down Expand Up @@ -178,7 +179,7 @@ function generate_initializesystem(sys::ODESystem;
for k in keys(defs)
defs[k] = substitute(defs[k], paramsubs)
end
meta = InitializationSystemMetadata(Dict{Any, Any}(u0map), Dict{Any, Any}(pmap))
meta = InitializationSystemMetadata(anydict(u0map), anydict(pmap), additional_guesses)
return NonlinearSystem(eqs_ics,
vars,
pars;
Expand All @@ -193,6 +194,7 @@ end
struct InitializationSystemMetadata
u0map::Dict{Any, Any}
pmap::Dict{Any, Any}
additional_guesses::Dict{Any, Any}
end

function is_parameter_solvable(p, pmap, defs, guesses)
Expand All @@ -208,17 +210,16 @@ function is_parameter_solvable(p, pmap, defs, guesses)
_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
end

function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newu0, newp)
if u0 === missing && p === missing
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
odefn.initializeprobpmap
return odefn.initialization_data
end
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
oldinitprob = odefn.initializeprob
if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) ||
!(oldinitprob.f.sys isa NonlinearSystem)
return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap,
odefn.initializeprobpmap
oldinitprob === nothing && return nothing
if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem)
return SciMLBase.OverrideInitData(oldinitprob, odefn.update_initializeprob!,
odefn.initializeprobmap, odefn.initializeprobpmap)
end
pidxs = ParameterIndex[]
pvals = []
Expand Down Expand Up @@ -260,78 +261,69 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
end
initprob = remake(oldinitprob; u0 = newu0, p = newp)
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
odefn.initializeprobpmap
return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!,
odefn.initializeprobmap, odefn.initializeprobpmap)
end
if u0 === missing || isempty(u0)
u0 = Dict()
elseif !(eltype(u0) <: Pair)
u0 = Dict(unknowns(sys) .=> u0)
end
if p === missing
p = Dict()
end
if t0 === nothing
t0 = 0.0
end
u0 = todict(u0)
dvs = unknowns(sys)
ps = parameters(sys)
u0map = to_varmap(u0, dvs)
symbols_to_symbolics!(sys, u0map)
pmap = to_varmap(p, ps)
symbols_to_symbolics!(sys, pmap)
guesses = Dict()
defs = defaults(sys)
varmap = merge(defs, u0)
for k in collect(keys(varmap))
if varmap[k] === nothing
delete!(varmap, k)
if SciMLBase.has_initializeprob(odefn)
oldsys = odefn.initializeprob.f.sys
meta = get_metadata(oldsys)
if meta isa InitializationSystemMetadata
u0map = merge(meta.u0map, u0map)
pmap = merge(meta.pmap, pmap)
merge!(guesses, meta.additional_guesses)
end
end
varmap = canonicalize_varmap(varmap)
missingvars = setdiff(unknowns(sys), collect(keys(varmap)))
setobserved = filter(keys(varmap)) do var
has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var))
end
p = todict(p)
guesses = ModelingToolkit.guesses(sys)
solvablepars = [par
for par in parameters(sys)
if is_parameter_solvable(par, p, defs, guesses)]
pvarmap = merge(defs, p)
setparobserved = filter(keys(pvarmap)) do var
has_parameter_dependency_with_lhs(sys, var)
end
if (((!isempty(missingvars) || !isempty(solvablepars) ||
!isempty(setobserved) || !isempty(setparobserved)) &&
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
!isempty(initialization_equations(sys)))
if SciMLBase.has_initializeprob(odefn)
oldsys = odefn.initializeprob.f.sys
meta = get_metadata(oldsys)
if meta isa InitializationSystemMetadata
u0 = merge(meta.u0map, u0)
p = merge(meta.pmap, p)
else
# there is no initializeprob, so the original problem construction
# had no solvable parameters and had the differential variables
# specified in `u0map`.
if u0 === missing
# the user didn't pass `u0` to `remake`, so they want to retain
# existing values. Fill the differential variables in `u0map`,
# initialization will either be elided or solve for the algebraic
# variables
diff_idxs = isdiffeq.(equations(sys))
for i in eachindex(dvs)
diff_idxs[i] || continue
u0map[dvs[i]] = newu0[i]
end
end
for k in collect(keys(u0))
if u0[k] === nothing
delete!(u0, k)
if p === missing
# the user didn't pass `p` to `remake`, so they want to retain
# existing values. Fill all parameters in `pmap` so that none of
# them are solvable.
for p in ps
pmap[p] = getp(sys, p)(newp)
end
end
for k in collect(keys(p))
if p[k] === nothing
delete!(p, k)
end
# all non-solvable parameters need values regardless
for p in ps
haskey(pmap, p) && continue
is_parameter_solvable(p, pmap, defs, guesses) && continue
pmap[p] = getp(sys, p)(newp)
end

initprob = InitializationProblem(sys, t0, u0, p)
initprobmap = getu(initprob, unknowns(sys))
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
getpunknowns = getu(initprob, punknowns)
setpunknowns = setp(sys, punknowns)
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
reqd_syms = parameter_symbols(initprob)
update_initializeprob! = UpdateInitializeprob(
getu(sys, reqd_syms), setu(initprob, reqd_syms))
return initprob, update_initializeprob!, initprobmap, initprobpmap
else
return nothing, nothing, nothing, nothing
end
if t0 === nothing
t0 = 0.0
end
filter_missing_values!(u0map)
filter_missing_values!(pmap)
f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0)
kws = f.kwargs
initprob = get(kws, :initializeprob, nothing)
if initprob === nothing
return nothing
end
return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing),
get(kws, :initializeprobmap, nothing),
get(kws, :initializeprobpmap, nothing))
end

"""
Expand Down
3 changes: 2 additions & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ function OptimizationSystem(objective; constraints = [], kwargs...)
push!(new_ps, p)
end
end
return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
return OptimizationSystem(
objective, collect(allunknowns), collect(new_ps); constraints, kwargs...)
end

function flatten(sys::OptimizationSystem)
Expand Down
62 changes: 57 additions & 5 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ const AnyDict = Dict{Any, Any}
$(TYPEDSIGNATURES)
If called without arguments, return `Dict{Any, Any}`. Otherwise, interpret the input
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`
and `nothing`.
as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`,
`missing` and `nothing`.
"""
anydict() = AnyDict()
anydict(::SciMLBase.NullParameters) = AnyDict()
anydict(::Nothing) = AnyDict()
anydict(::Missing) = AnyDict()
anydict(x::AnyDict) = x
anydict(x) = AnyDict(x)

Expand Down Expand Up @@ -51,6 +52,42 @@ function add_toterms(varmap::AbstractDict; toterm = default_toterm)
return cp
end

"""
$(TYPEDSIGNATURES)
Turn any `Symbol` keys in `varmap` to the appropriate symbolic variables in `sys`. Any
symbols that cannot be converted are ignored.
"""
function symbols_to_symbolics!(sys::AbstractSystem, varmap::AbstractDict)
if is_split(sys)
ic = get_index_cache(sys)
for k in collect(keys(varmap))
k isa Symbol || continue
newk = get(ic.symbol_to_variable, k, nothing)
newk === nothing && continue
varmap[newk] = varmap[k]
delete!(varmap, k)
end
else
syms = all_symbols(sys)
for k in collect(keys(varmap))
k isa Symbol || continue
idx = findfirst(syms) do sym
hasname(sym) || return false
name = getname(sym)
return name == k
end
idx === nothing && continue
newk = syms[idx]
if iscall(newk) && operation(newk) === getindex
newk = arguments(newk)[1]
end
varmap[newk] = varmap[k]
delete!(varmap, k)
end
end
end

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -388,6 +425,15 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100)
end
end

"""
$(TYPEDSIGNATURES)
Remove keys in `varmap` whose values are `nothing`.
"""
function filter_missing_values!(varmap::AbstractDict)
filter!(kvp -> kvp[2] !== nothing, varmap)
end

struct GetUpdatedMTKParameters{G, S}
# `getu` functor which gets parameters that are unknowns during initialization
getpunknowns::G
Expand Down Expand Up @@ -431,12 +477,16 @@ end
$(TYPEDEF)
A simple utility meant to be used as the `constructor` passed to `process_SciMLProblem` in
case constructing a SciMLFunction is not required.
case constructing a SciMLFunction is not required. The arguments passed to it are available
in the `args` field, and the keyword arguments in the `kwargs` field.
"""
struct EmptySciMLFunction end
struct EmptySciMLFunction{A, K}
args::A
kwargs::K
end

function EmptySciMLFunction(args...; kwargs...)
return nothing
return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs)
end

"""
Expand Down Expand Up @@ -516,8 +566,10 @@ function process_SciMLProblem(
pType = typeof(pmap)
_u0map = u0map
u0map = to_varmap(u0map, dvs)
symbols_to_symbolics!(sys, u0map)
_pmap = pmap
pmap = to_varmap(pmap, ps)
symbols_to_symbolics!(sys, pmap)
defs = add_toterms(recursive_unwrap(defaults(sys)))
cmap, cs = get_cmap(sys)
kwargs = NamedTuple(kwargs)
Expand Down
1 change: 0 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ
return nothing
end


function collect_var!(unknowns, parameters, var, iv; depth = 0)
isequal(var, iv) && return nothing
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
Expand Down
57 changes: 57 additions & 0 deletions test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -975,3 +975,60 @@ end
@test integ.ps[p] 1.0
@test integ.ps[q]cbrt(2) rtol=1e-6
end

@testset "Guesses provided to `ODEProblem` are used in `remake`" begin
@variables x(t) y(t)=2x
@parameters p q=3x
@mtkbuild sys = ODESystem([D(x) ~ x * p + q, x^3 + y^3 ~ 3], t)
prob = ODEProblem(
sys, [], (0.0, 1.0), [p => 1.0]; guesses = [x => 1.0, y => 1.0, q => 1.0])
@test prob[x] == 0.0
@test prob[y] == 0.0
@test prob.ps[p] == 1.0
@test prob.ps[q] == 0.0
integ = init(prob)
@test integ[x] 1 / cbrt(3)
@test integ[y] 2 / cbrt(3)
@test integ.ps[p] == 1.0
@test integ.ps[q] 3 / cbrt(3)
prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x])
integ2 = init(prob2)
@test integ2[x] cbrt(3 / 28)
@test integ2[y] 3cbrt(3 / 28)
@test integ2.ps[p] == 1.0
@test integ2.ps[q] 2cbrt(3 / 28)
end

@testset "Remake problem with no initializeprob" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p [guess = 1.0] q [guess = 1.0]
@mtkbuild sys = ODESystem(
[D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p])
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0])
@test prob.f.initialization_data === nothing
prob2 = remake(prob; u0 = [x => 2.0])
@test prob2[x] == 2.0
@test prob2.f.initialization_data === nothing
prob3 = remake(prob; u0 = [y => 2.0])
@test prob3.f.initialization_data !== nothing
@test init(prob3)[x] 1.0
prob4 = remake(prob; p = [p => 1.0])
@test prob4.f.initialization_data === nothing
prob5 = remake(prob; p = [p => missing, q => 2.0])
@test prob5.f.initialization_data !== nothing
@test init(prob5).ps[p] 1.0
end

@testset "Variables provided as symbols" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p [guess = 1.0] q [guess = 1.0]
@mtkbuild sys = ODESystem(
[D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p])
prob = ODEProblem(sys, [:x => 1.0], (0.0, 1.0), [p => 1.0])
@test prob.f.initialization_data === nothing
prob2 = remake(prob; u0 = [:x => 2.0])
@test prob2.f.initialization_data === nothing
prob3 = remake(prob; u0 = [:y => 1.0])
@test prob3.f.initialization_data !== nothing
@test init(prob3)[x] 0.5
end

0 comments on commit 4d4ff85

Please sign in to comment.