Skip to content

Commit

Permalink
feat: add remake for SCCNonlinearProblem
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 2, 2024
1 parent 7d4a687 commit 844ebfd
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 2 deletions.
51 changes: 49 additions & 2 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,52 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
end
end

"""
remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
parameters_alias = prob.parameters_alias, sys = missing, explicitfuns! = missing)
Remake the given `SCCNonlinearProblem`. `u0` is the state vector for the entire problem,
which will be chunked appropriately and used to `remake` the individual subproblems. `p`
is the parameter object for `prob`. If `parameters_alias`, the same parameter object will be
used to `remake` the individual subproblems. Otherwise if `p !== missing`, this function will
error and require that `probs` be specified. `probs` is the collection of subproblems. Even if
`probs` is explicitly specified, the value of `u0` provided to `remake` will be used to
override the values in `probs`. `sys` is the index provider for the full system.
"""
function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
parameters_alias = prob.parameters_alias, sys = missing,
interpret_symbolicmap = true, use_defaults = false, explicitfuns! = missing)
if p !== missing && !parameters_alias && probs === missing
throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each."))
end
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults,
indp = sys === missing ? prob.full_index_provider : sys)
if probs === missing
probs = prob.probs
end
offset = 0
if u0 !== missing || p !== missing && parameters_alias
probs = map(probs) do subprob
subprob = if parameters_alias
remake(subprob;
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))],
p = newp)
else
remake(subprob;
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))])
end
offset += length(state_values(subprob))
return subprob
end
end
if sys === missing
sys = prob.full_index_provider
end
return SCCNonlinearProblem{
typeof(probs), typeof(explicitfuns!), typeof(sys), typeof(newp)}(
probs, explicitfuns!, sys, newp, parameters_alias)
end

function varmap_has_var(varmap, var)
haskey(varmap, var) || hasname(var) && haskey(varmap, getname(var))
end
Expand Down Expand Up @@ -737,11 +783,12 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
end

function updated_u0_p(
prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
prob, u0, p, t0 = nothing; interpret_symbolicmap = true,
use_defaults = false, indp = has_sys(prob.f) ? prob.f.sys : nothing)
if u0 === missing && p === missing
return state_values(prob), parameter_values(prob)
end
if !has_sys(prob.f)
if indp === nothing
if interpret_symbolicmap && eltype(p) !== Union{} && eltype(p) <: Pair
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
Expand Down
62 changes: 62 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ discu0 = Dict([u0..., x(k - 1) => 0.0, y(k - 1) => 0.0, z(k - 1) => 0.0])
push!(syss, discsys)
push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps))

# TODO: Rewrite this example when the MTK codegen is merged
@named sys1 = NonlinearSystem(
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ])
sys1 = complete(sys1)
@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], [])
sys2 = complete(sys2)
@named fullsys = NonlinearSystem(
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4],
[x, y, z], [σ, β, ρ])
fullsys = complete(fullsys)

prob1 = NonlinearProblem(sys1, u0, p)
prob2 = NonlinearProblem(sys2, u0, prob1.p)
sccprob = SCCNonlinearProblem(
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
push!(syss, fullsys)
push!(probs, sccprob)

for (sys, prob) in zip(syss, probs)
@test parameter_values(prob) isa ModelingToolkit.MTKParameters

Expand Down Expand Up @@ -274,3 +292,47 @@ end
@test_throws SciMLBase.CyclicDependencyError remake(
prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3])
end

@testset "SCCNonlinearProblem" begin
@named sys1 = NonlinearSystem(
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ])
sys1 = complete(sys1)
@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], [])
sys2 = complete(sys2)
@named fullsys = NonlinearSystem(
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4],
[x, y, z], [σ, β, ρ])
fullsys = complete(fullsys)

u0 = [x => 1.0,
y => 0.0,
z => 0.0]

p ==> 28.0,
ρ => 10.0,
β => 8 / 3]

prob1 = NonlinearProblem(sys1, u0, p)
prob2 = NonlinearProblem(sys2, u0, prob1.p)
sccprob = SCCNonlinearProblem(
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)

sccprob2 = remake(sccprob; u0 = 2ones(3))
@test state_values(sccprob2) 2ones(3)
@test sccprob2.probs[1].u0 2ones(2)
@test sccprob2.probs[2].u0 2ones(1)

sccprob3 = remake(sccprob; p ==> 2.0])
@test sccprob3.parameter_object === sccprob3.probs[1].p
@test sccprob3.parameter_object === sccprob3.probs[2].p

@test_throws ["parameters_alias", "SCCNonlinearProblem"] remake(
sccprob; parameters_alias = false, p ==> 2.0])

newp = remake_buffer(sys1, prob1.p, [σ], [3.0])
sccprob4 = remake(sccprob; parameters_alias = false, p = newp,
probs = [remake(prob1; p ==> 3.0]), prob2])
@test !sccprob4.parameters_alias
@test sccprob4.parameter_object !== sccprob4.probs[1].p
@test sccprob4.parameter_object !== sccprob4.probs[2].p
end

0 comments on commit 844ebfd

Please sign in to comment.