Skip to content

Commit

Permalink
fix: copy MTKParameters over setp_oop for initializeprobpmap
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 6, 2024
1 parent 083240d commit a427126
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -768,12 +768,14 @@ end
struct GetUpdatedMTKParameters{G, S}
# `getu` functor which gets parameters that are unknowns during initialization
getpunknowns::G
# `setu_oop` functor which returns a modified MTKParameters using those parameters
# `setu` functor which returns a modified MTKParameters using those parameters
setpunknowns::S
end

function (f::GetUpdatedMTKParameters)(prob, initializesol)
f.setpunknowns(prob, f.getpunknowns(initializesol))
mtkp = copy(parameter_values(prob))
f.setpunknowns(mtkp, f.getpunknowns(initializesol))
mtkp
end

function get_temporary_value(p)
Expand Down Expand Up @@ -836,16 +838,13 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
end
end
defs = defaults(sys)
missingpars = [p
guesses = merge(ModelingToolkit.guesses(sys), isempty(guesses) ? Dict() : todict(guesses))
solvablepars = [p
for p in parameters(sys)
if (parammap !== SciMLBase.NullParameters() &&
get(parammap, p, nothing) === missing) ||
((parammap isa SciMLBase.NullParameters ||
get(parammap, p, nothing) !== missing) &&
get(defs, p, nothing) === missing)]
if is_parameter_solvable(p, parammap, defs, guesses)]
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
if sys isa ODESystem && build_initializeprob &&
(((implicit_dae || !isempty(missingvars) || !isempty(missingpars)) &&
(((implicit_dae || !isempty(missingvars) || !isempty(solvablepars)) &&
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
!isempty(initialization_equations(sys))) && t !== nothing
if eltype(u0map) <: Number
Expand All @@ -861,9 +860,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
punknowns = [p
for p in all_variable_symbols(initializeprob) if is_parameter(sys, p)]
getpunknowns = getu(initializeprob, punknowns)
setpunknowns = setp_oop(sys, punknowns)
setpunknowns = setp(sys, punknowns)
initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns)
# TODO: Initializeprobpmap when setp_oop is a thing

zerovars = Dict(setdiff(unknowns(sys), keys(defaults(sys))) .=> 0.0)
if parammap isa SciMLBase.NullParameters
Expand Down

0 comments on commit a427126

Please sign in to comment.