From a3f26eccab54ab15aa1b55f7a9005ace812feb44 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Sep 2024 17:03:58 +0530 Subject: [PATCH] fixup! feat: allow parameters to be unknowns in the initialization system --- src/systems/nonlinear/initializesystem.jl | 73 +++++++++++++++++------ 1 file changed, 54 insertions(+), 19 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index e5e8f4a6a9..8ddd51c21f 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -96,36 +96,59 @@ function generate_initializesystem(sys::ODESystem; if pmap isa SciMLBase.NullParameters pmap = Dict() end + pmap = todict(pmap) for p in parameters(sys) - # If either of them are `missing` the parameter is an unknown - # But if the parameter is passed a value, use that as an additional - # equation in the system - if (_val1 = get(pmap, p, nothing)) === missing || get(defs, p, nothing) === missing + if is_parameter_solvable(p, pmap, defs, guesses) + # If either of them are `missing` the parameter is an unknown + # But if the parameter is passed a value, use that as an additional + # equation in the system + _val1 = get(pmap, p, nothing) + _val2 = get(defs, p, nothing) + _val3 = get(guesses, p, nothing) varp = tovar(p) paramsubs[p] = varp - if _val1 !== nothing && _val1 !== missing - push!(eqs_ics, varp ~ _val1) - end - if !haskey(guesses, p) - error("Invalid setup: parameter $(p) has no default value or initial guess") + # Has a default of `missing`, and (either an equation using the value passed to `ODEProblem` or a guess) + if _val2 === missing + if _val1 !== nothing && _val1 !== missing + push!(eqs_ics, varp ~ _val1) + push!(u0, varp => _val1) + elseif _val3 !== nothing + # assuming an equation exists (either via algebraic equations or initialization_eqs) + push!(u0, varp => _val1) + elseif check_defguess + error("Invalid setup: parameter $(p) has no default value, initial value, or guess") + end + # `missing` passed to `ODEProblem`, and (either an equation using default or a guess) + elseif _val1 === missing + if _val2 !== nothing && _val2 !== missing + push!(eqs_ics, varp ~ _val2) + push!(u0, varp => _val2) + elseif _val3 !== nothing + push!(u0, varp => _val1) + elseif check_defguess + error("Invalid setup: parameter $(p) has no default value, initial value, or guess") + end + # No value passed to `ODEProblem`, but a default and a guess are present + # _val2 !== missing is implied by it falling this far in the elseif chain + elseif _val1 === nothing && _val2 !== nothing && _val3 !== nothing + push!(eqs_ics, varp ~ _val2) + push!(u0, varp => _val3) + else + # _val1 !== missing and _val1 !== nothing, so a value was provided to ODEProblem + # This would mean `is_parameter_solvable` returned `false`, so we never end up + # here + error("This should never be reached") end - push!(u0, varp => guesses[p]) end end pars = vcat( [get_iv(sys)], [p for p in parameters(sys) if !haskey(paramsubs, p)] ) - pdeps = parameter_dependencies(sys) - if !isempty(pdeps) - pdep_eqs = [k ~ v for (k, v) in pdeps] - else - pdep_eqs = Equation[] - end nleqs = if algebraic_only - [eqs_ics; observed(sys); pdep_eqs] + [eqs_ics; observed(sys)] else - [eqs_ics; get_initialization_eqs(sys); initialization_eqs; observed(sys); pdep_eqs] + [eqs_ics; get_initialization_eqs(sys); initialization_eqs; observed(sys)] end nleqs = Symbolics.substitute.(nleqs, (paramsubs,)) unks = [full_states; collect(values(paramsubs))] @@ -142,6 +165,15 @@ function generate_initializesystem(sys::ODESystem; return sys_nl end +function is_parameter_solvable(p, pmap, defs, guesses) + _val1 = pmap isa AbstractDict ? get(pmap, p, nothing) : nothing + _val2 = get(defs, p, nothing) + _val3 = get(guesses, p, nothing) + # either (missing is a default or was passed to the ODEProblem) or (nothing was passed to + # the ODEProblem and it has a default and a guess) + return (_val1 === missing || _val2 === missing) || (_val1 === nothing && _val2 !== nothing && _val3 !== nothing) +end + function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) if (u0 === missing || !(eltype(u0) <: Pair) || isempty(u0)) && (p === missing || !(eltype(p) <: Pair) || isempty(p)) @@ -153,13 +185,16 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) if p === missing p = Dict() end + if t0 === nothing + t0 = 0.0 + end u0 = todict(u0) p = todict(p) initprob = InitializationProblem(sys, t0, u0, p) initprobmap = getu(initprob, unknowns(sys)) punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)] getpunknowns = getu(initprob, punknowns) - setpunknowns = setp_oop(sys, punknowns) + setpunknowns = setp(sys, punknowns) initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) return initprob, initprobmap, initprobpmap end