Skip to content

Commit

Permalink
Merge branch 'master' into compathelper/new_version/2024-11-22-03-28-…
Browse files Browse the repository at this point in the history
…35-271-01723128443
  • Loading branch information
ChrisRackauckas authored Nov 26, 2024
2 parents b0b00af + c39a4d4 commit 1650ddf
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.63.1"
version = "2.64.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
61 changes: 59 additions & 2 deletions src/problems/nonlinear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,64 @@ Note that this example aliases the parameters together for a memory-reduced repr
* `probs`: the collection of problems to solve
* `explictfuns!`: the explicit functions for mutating the parameter set
"""
mutable struct SCCNonlinearProblem{P, E}
mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <:
AbstractNonlinearProblem{uType, iip}
probs::P
explictfuns!::E
explicitfuns!::E
full_index_provider::I
parameter_object::Par
parameters_alias::Bool

function SCCNonlinearProblem{P, E, I, Par}(
probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par}
u0 = mapreduce(state_values, vcat, probs)
uType = typeof(u0)
new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias)
end
end

function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing,
parameter_object = nothing, parameters_alias = false)
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
typeof(full_index_provider), typeof(parameter_object)}(
probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias)
end

function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
if name == :explictfuns!
return getfield(prob, :explicitfuns!)
elseif name == :ps
return ParameterIndexingProxy(prob)
end
return getfield(prob, name)
end

function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem)
prob.full_index_provider
end
function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
prob.parameter_object
end
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
mapreduce(state_values, vcat, prob.probs)
end

function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, idx)
for scc in prob.probs
svals = state_values(scc)
checkbounds(Bool, svals, idx) && return set_state!(scc, val, idx)
idx -= length(svals)
end
throw(BoundsError(state_values(prob), idx))
end

function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
if prob.parameter_object !== nothing
set_parameter!(prob.parameter_object, val, idx)
prob.parameters_alias && return
end
for scc in prob.probs
is_parameter(scc, idx) || continue
set_parameter!(scc, val, idx)
end
end
2 changes: 1 addition & 1 deletion src/problems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function Base.show(io::IO, mime::MIME"text/plain", A::AbstractNonlinearProblem)
summary(io, A)
println(io)
print(io, "u0: ")
show(io, mime, A.u0)
show(io, mime, state_values(A))
end

function Base.show(io::IO, mime::MIME"text/plain", A::IntervalNonlinearProblem)
Expand Down
30 changes: 23 additions & 7 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ function remake(prob::ODEProblem; f = missing,

if f === missing
if build_initializeprob
initialization_data = remake_initialization_data(
prob.f.sys, prob.f, u0, tspan[1], p)
initialization_data = remake_initialization_data_compat_wrapper(
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
else
initialization_data = nothing
end
Expand Down Expand Up @@ -203,16 +203,32 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p)
end

"""
remake_initialization_data(sys, scimlfn, u0, t0, p)
$(TYPEDSIGNATURES)
Wrapper around `remake_initialization_data` for backward compatibility when `newu0` and
`newp` were not arguments.
"""
function remake_initialization_data_compat_wrapper(sys, scimlfn, u0, t0, p, newu0, newp)
if hasmethod(remake_initialization_data,
Tuple{typeof(sys), typeof(scimlfn), typeof(u0), typeof(t0), typeof(p)})
remake_initialization_data(sys, scimlfn, u0, t0, p)
else
remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
end
end

"""
remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
Re-create the initialization data present in the function `scimlfn`, using the
associated system `sys` and the user provided new values of `u0`, initial time `t0` and
`p`. By default, this calls `remake_initializeprob` for backward compatibility and
attempts to construct an `OverrideInitData` from the result.
associated system `sys`, the user provided new values of `u0`, initial time `t0`,
user-provided `p`, new u0 vector `newu0` and new parameter object `newp`. By default,
this calls `remake_initializeprob` for backward compatibility and attempts to construct
an `OverrideInitData` from the result.
Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
"""
function remake_initialization_data(sys, scimlfn, u0, t0, p)
function remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
return reconstruct_initialization_data(
nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...)
end
Expand Down
2 changes: 2 additions & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ ModelingToolkitStandardLibrary = "2.7"
NonlinearSolve = "2, 3, 4"
Optimization = "4"
OptimizationOptimJL = "0.4"
OptimizationMOI = "0.5"
OrdinaryDiffEq = "6.33"
PartialFunctions = "1"
Plots = "1.40"
RecursiveArrayTools = "3"
SciMLBase = "2"
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ gs_ts, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
end

@test_broken all(map(x -> x == true_grad_vecsym, gs_ts))
@test all(map(x -> x == true_grad_vecsym, gs_ts))

# BatchedInterface AD
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0
Expand Down
97 changes: 97 additions & 0 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,100 @@ prob = SteadyStateProblem(osys, u0, ps)
getsym(prob, [:X, :X2])(prob) == [0.1, 0.2]
@test getsym(prob, (X, X2))(prob) == getsym(prob, (osys.X, osys.X2))(prob) ==
getsym(prob, (:X, :X2))(prob) == (0.1, 0.2)

@testset "SCCNonlinearProblem" begin
# TODO: Rewrite this example when the MTK codegen is merged

function fullf!(du, u, p)
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
du[3] = 2u[4] + u[3] + p[1]
du[4] = u[5]^2 + u[4]
du[5] = u[3]^2 + u[5]
du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8]
du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8]
du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8]
end
@variables u[1:8]=zeros(8) [irreducible = true]
u2 = collect(u)
@parameters p = 1.0
eqs = Any[0 for _ in 1:8]
fullf!(eqs, u, [p])
@named model = NonlinearSystem(0 .~ eqs, [u...], [p])
model = complete(model; split = false)

cache = zeros(4)
cache[1] = 1.0

function f1!(du, u, p)
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
end
explicitfun1(cache, sols) = nothing

f1!(eqs, u2[1:2], [p])
@named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p])
subsys1 = complete(subsys1; split = false)
prob1 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1),
zeros(2), copy(cache))

function f2!(du, u, p)
du[1] = 2u[2] + u[1] + p[1]
du[2] = u[3]^2 + u[2]
du[3] = u[1]^2 + u[3]
end
explicitfun2(cache, sols) = nothing

f2!(eqs, u2[3:5], [p])
@named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p])
subsys2 = complete(subsys2; split = false)
prob2 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2),
zeros(3), copy(cache))

function f3!(du, u, p)
du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3]
du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3]
du[3] = p[4] + +u[1] - u[2] - u[3]
end
function explicitfun3(cache, sols)
cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3]
cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3]
cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
6.0sols[2][3]
end

@parameters tmpvar[1:3]
f3!(eqs, u2[6:8], [p, tmpvar...])
@named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...])
subsys3 = complete(subsys3; split = false)
prob3 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3),
zeros(3), copy(cache))

prob = NonlinearProblem(model, [])
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
model, copy(cache))

for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
@test prob[sym] sccprob[sym]
end

for sym in [p, 2p + 1]
@test prob.ps[sym] sccprob.ps[sym]
end

for (i, sym) in enumerate([u[1], u[3], u[6]])
sccprob[sym] = 0.5i
@test sccprob[sym] 0.5i
@test sccprob.probs[i].u0[1] 0.5i
end
sccprob.ps[p] = 2.5
@test sccprob.ps[p] 2.5
@test sccprob.parameter_object[1] 2.5
for scc in sccprob.probs
@test parameter_values(scc)[1] 2.5
end
end

0 comments on commit 1650ddf

Please sign in to comment.