Skip to content

Commit

Permalink
fix: better handle reconstructing initializeprob with new types
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 24, 2024
1 parent 96f8d5d commit 0a881e7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 40 deletions.
5 changes: 5 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined)
end

meta = get_metadata(isys)
if meta isa InitializationSystemMetadata
@set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(sys, isys)
end

ts = get_tearing_state(isys)
unassigned_vars = StructuralTransformations.singular_check(ts)
if warn_initialize_determined && !isempty(unassigned_vars)
Expand Down
69 changes: 29 additions & 40 deletions src/systems/nonlinear/initializesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ function generate_initializesystem(sys::AbstractSystem;
defs[k] = substitute(defs[k], paramsubs)
end
meta = InitializationSystemMetadata(
anydict(u0map), anydict(pmap), additional_guesses, additional_initialization_eqs, extra_metadata)
anydict(u0map), anydict(pmap), additional_guesses,
additional_initialization_eqs, extra_metadata, nothing)
return NonlinearSystem(eqs_ics,
vars,
pars;
Expand All @@ -204,12 +205,30 @@ function generate_initializesystem(sys::AbstractSystem;
kwargs...)
end

struct ReconstructInitializeprob
getter::Any
setter::Any
end

function ReconstructInitializeprob(srcsys::AbstractSystem, dstsys::AbstractSystem)
syms = [unknowns(dstsys);
reduce(vcat, reorder_parameters(dstsys, parameters(dstsys)); init = [])]
getter = getu(srcsys, syms)
setter = setsym_oop(dstsys, syms)
return ReconstructInitializeprob(getter, setter)
end

function (rip::ReconstructInitializeprob)(srcvalp, dstvalp)
rip.setter(dstvalp, rip.getter(srcvalp))
end

struct InitializationSystemMetadata
u0map::Dict{Any, Any}
pmap::Dict{Any, Any}
additional_guesses::Dict{Any, Any}
additional_initialization_eqs::Vector{Equation}
extra_metadata::NamedTuple
oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob}
end

function is_parameter_solvable(p, pmap, defs, guesses)
Expand Down Expand Up @@ -239,55 +258,25 @@ function SciMLBase.remake_initialization_data(
if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem)
return oldinitdata
end
pidxs = ParameterIndex[]
pvals = []
u0idxs = Int[]
u0vals = []
for sym in variable_symbols(oldinitprob)
if is_variable(sys, sym) || has_observed_with_lhs(sys, sym)
u0 !== missing || continue
idx = variable_index(oldinitprob, sym)
push!(u0idxs, idx)
push!(u0vals, eltype(u0)(state_values(oldinitprob, idx)))
else
p !== missing || continue
idx = variable_index(oldinitprob, sym)
push!(u0idxs, idx)
push!(u0vals, typeof(getp(sys, sym)(p))(state_values(oldinitprob, idx)))
end
end
if p !== missing
for sym in parameter_symbols(oldinitprob)
push!(pidxs, parameter_index(oldinitprob, sym))
if is_time_dependent(sys) && isequal(sym, get_iv(sys))
push!(pvals, t0)
else
push!(pvals, getp(sys, sym)(p))
end
end
end
if isempty(u0idxs)
newu0 = state_values(oldinitprob)
else
newu0 = remake_buffer(
oldinitprob.f.sys, state_values(oldinitprob), u0idxs, u0vals)
end
if isempty(pidxs)
newp = parameter_values(oldinitprob)
oldinitsys = oldinitprob.f.sys
meta = get_metadata(oldinitsys)
if meta isa InitializationSystemMetadata && meta.oop_reconstruct_u0_p !== nothing
reconstruct_fn = meta.oop_reconstruct_u0_p
else
newp = remake_buffer(
oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
reconstruct_fn = ReconstructInitializeprob(sys, oldinitsys)
end
new_initu0, new_initp = reconstruct_fn(
ProblemState(; u = newu0, p = newp, t = t0), oldinitprob)
if oldinitprob.f.resid_prototype === nothing
newf = oldinitprob.f
else
newf = NonlinearFunction{
SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}(
oldinitprob.f;
resid_prototype = calculate_resid_prototype(
length(oldinitprob.f.resid_prototype), newu0, newp))
length(oldinitprob.f.resid_prototype), new_initu0, new_initp))
end
initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp)
initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp)
return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!,
oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap)
end
Expand Down

0 comments on commit 0a881e7

Please sign in to comment.