Skip to content

Commit

Permalink
fixup! feat: allow parameters to be unknowns in the initialization sy…
Browse files Browse the repository at this point in the history
…stem
  • Loading branch information
AayushSabharwal committed Sep 6, 2024
1 parent a427126 commit a3f26ec
Showing 1 changed file with 54 additions and 19 deletions.
73 changes: 54 additions & 19 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,36 +96,59 @@ function generate_initializesystem(sys::ODESystem;
if pmap isa SciMLBase.NullParameters
pmap = Dict()
end
pmap = todict(pmap)
for p in parameters(sys)
# If either of them are `missing` the parameter is an unknown
# But if the parameter is passed a value, use that as an additional
# equation in the system
if (_val1 = get(pmap, p, nothing)) === missing || get(defs, p, nothing) === missing
if is_parameter_solvable(p, pmap, defs, guesses)
# If either of them are `missing` the parameter is an unknown
# But if the parameter is passed a value, use that as an additional
# equation in the system
_val1 = get(pmap, p, nothing)
_val2 = get(defs, p, nothing)
_val3 = get(guesses, p, nothing)
varp = tovar(p)
paramsubs[p] = varp
if _val1 !== nothing && _val1 !== missing
push!(eqs_ics, varp ~ _val1)
end
if !haskey(guesses, p)
error("Invalid setup: parameter $(p) has no default value or initial guess")
# Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess)
if _val2 === missing
if _val1 !== nothing && _val1 !== missing
push!(eqs_ics, varp ~ _val1)
push!(u0, varp => _val1)
elseif _val3 !== nothing
# assuming an equation exists (either via algebraic equations or initialization_eqs)
push!(u0, varp => _val1)
elseif check_defguess
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
end
# `missing` passed to `ODEProblem`, and (either an equation using default or a guess)
elseif _val1 === missing
if _val2 !== nothing && _val2 !== missing
push!(eqs_ics, varp ~ _val2)
push!(u0, varp => _val2)
elseif _val3 !== nothing
push!(u0, varp => _val1)
elseif check_defguess
error("Invalid setup: parameter $(p) has no default value, initial value, or guess")
end
# No value passed to `ODEProblem`, but a default and a guess are present
# _val2 !== missing is implied by it falling this far in the elseif chain
elseif _val1 === nothing && _val2 !== nothing && _val3 !== nothing
push!(eqs_ics, varp ~ _val2)
push!(u0, varp => _val3)
else
# _val1 !== missing and _val1 !== nothing, so a value was provided to ODEProblem
# This would mean `is_parameter_solvable` returned `false`, so we never end up
# here
error("This should never be reached")
end
push!(u0, varp => guesses[p])
end
end
pars = vcat(
[get_iv(sys)],
[p for p in parameters(sys) if !haskey(paramsubs, p)]
)
pdeps = parameter_dependencies(sys)
if !isempty(pdeps)
pdep_eqs = [k ~ v for (k, v) in pdeps]
else
pdep_eqs = Equation[]
end
nleqs = if algebraic_only
[eqs_ics; observed(sys); pdep_eqs]
[eqs_ics; observed(sys)]
else
[eqs_ics; get_initialization_eqs(sys); initialization_eqs; observed(sys); pdep_eqs]
[eqs_ics; get_initialization_eqs(sys); initialization_eqs; observed(sys)]
end
nleqs = Symbolics.substitute.(nleqs, (paramsubs,))
unks = [full_states; collect(values(paramsubs))]
Expand All @@ -142,6 +165,15 @@ function generate_initializesystem(sys::ODESystem;
return sys_nl
end

function is_parameter_solvable(p, pmap, defs, guesses)
_val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing
_val2 = get(defs, p, nothing)
_val3 = get(guesses, p, nothing)
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
# the ODEProblem and it has a default and a guess)
return (_val1 === missing || _val2 === missing) || (_val1 === nothing && _val2 !== nothing && _val3 !== nothing)
end

function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
if (u0 === missing || !(eltype(u0) <: Pair) || isempty(u0)) &&
(p === missing || !(eltype(p) <: Pair) || isempty(p))
Expand All @@ -153,13 +185,16 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
if p === missing
p = Dict()
end
if t0 === nothing
t0 = 0.0
end
u0 = todict(u0)
p = todict(p)
initprob = InitializationProblem(sys, t0, u0, p)
initprobmap = getu(initprob, unknowns(sys))
punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)]
getpunknowns = getu(initprob, punknowns)
setpunknowns = setp_oop(sys, punknowns)
setpunknowns = setp(sys, punknowns)
initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
return initprob, initprobmap, initprobpmap
end

0 comments on commit a3f26ec

Please sign in to comment.