From 0a881e71d5f1b5fe77e170713e78b5a5ab3ec023 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Dec 2024 16:22:15 +0530 Subject: [PATCH] fix: better handle reconstructing initializeprob with new types --- src/systems/diffeqs/abstractodesystem.jl | 5 ++ src/systems/nonlinear/initializesystem.jl | 69 ++++++++++------------- 2 files changed, 34 insertions(+), 40 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index a6607e0c55..dd4165cb57 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1310,6 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined) end + meta = get_metadata(isys) + if meta isa InitializationSystemMetadata + @set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(sys, isys) + end + ts = get_tearing_state(isys) unassigned_vars = StructuralTransformations.singular_check(ts) if warn_initialize_determined && !isempty(unassigned_vars) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index cf6c6d4d42..602fc7cac9 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -192,7 +192,8 @@ function generate_initializesystem(sys::AbstractSystem; defs[k] = substitute(defs[k], paramsubs) end meta = InitializationSystemMetadata( - anydict(u0map), anydict(pmap), additional_guesses, additional_initialization_eqs, extra_metadata) + anydict(u0map), anydict(pmap), additional_guesses, + additional_initialization_eqs, extra_metadata, nothing) return NonlinearSystem(eqs_ics, vars, pars; @@ -204,12 +205,30 @@ function generate_initializesystem(sys::AbstractSystem; kwargs...) end +struct ReconstructInitializeprob + getter::Any + setter::Any +end + +function ReconstructInitializeprob(srcsys::AbstractSystem, dstsys::AbstractSystem) + syms = [unknowns(dstsys); + reduce(vcat, reorder_parameters(dstsys, parameters(dstsys)); init = [])] + getter = getu(srcsys, syms) + setter = setsym_oop(dstsys, syms) + return ReconstructInitializeprob(getter, setter) +end + +function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) + rip.setter(dstvalp, rip.getter(srcvalp)) +end + struct InitializationSystemMetadata u0map::Dict{Any, Any} pmap::Dict{Any, Any} additional_guesses::Dict{Any, Any} additional_initialization_eqs::Vector{Equation} extra_metadata::NamedTuple + oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob} end function is_parameter_solvable(p, pmap, defs, guesses) @@ -239,45 +258,15 @@ function SciMLBase.remake_initialization_data( if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem) return oldinitdata end - pidxs = ParameterIndex[] - pvals = [] - u0idxs = Int[] - u0vals = [] - for sym in variable_symbols(oldinitprob) - if is_variable(sys, sym) || has_observed_with_lhs(sys, sym) - u0 !== missing || continue - idx = variable_index(oldinitprob, sym) - push!(u0idxs, idx) - push!(u0vals, eltype(u0)(state_values(oldinitprob, idx))) - else - p !== missing || continue - idx = variable_index(oldinitprob, sym) - push!(u0idxs, idx) - push!(u0vals, typeof(getp(sys, sym)(p))(state_values(oldinitprob, idx))) - end - end - if p !== missing - for sym in parameter_symbols(oldinitprob) - push!(pidxs, parameter_index(oldinitprob, sym)) - if is_time_dependent(sys) && isequal(sym, get_iv(sys)) - push!(pvals, t0) - else - push!(pvals, getp(sys, sym)(p)) - end - end - end - if isempty(u0idxs) - newu0 = state_values(oldinitprob) - else - newu0 = remake_buffer( - oldinitprob.f.sys, state_values(oldinitprob), u0idxs, u0vals) - end - if isempty(pidxs) - newp = parameter_values(oldinitprob) + oldinitsys = oldinitprob.f.sys + meta = get_metadata(oldinitsys) + if meta isa InitializationSystemMetadata && meta.oop_reconstruct_u0_p !== nothing + reconstruct_fn = meta.oop_reconstruct_u0_p else - newp = remake_buffer( - oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals) + reconstruct_fn = ReconstructInitializeprob(sys, oldinitsys) end + new_initu0, new_initp = reconstruct_fn( + ProblemState(; u = newu0, p = newp, t = t0), oldinitprob) if oldinitprob.f.resid_prototype === nothing newf = oldinitprob.f else @@ -285,9 +274,9 @@ function SciMLBase.remake_initialization_data( SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}( oldinitprob.f; resid_prototype = calculate_resid_prototype( - length(oldinitprob.f.resid_prototype), newu0, newp)) + length(oldinitprob.f.resid_prototype), new_initu0, new_initp)) end - initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp) + initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp) return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!, oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap) end