Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix initializeprobpmap call in OverrideInit #866

Merged
4 changes: 2 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct OverrideInitData{IProb, UIProb, IProbMap, IProbPmap}
"""
initializeprobmap::IProbMap
"""
A function which takes the solution of `initializeprob` and returns
A function which takes `value_provider` and the solution of `initializeprob` and returns
the parameter object of the original problem. If absent (`nothing`),
this will not be called and the parameters of the problem being
initialized will be returned as-is.
Expand Down Expand Up @@ -210,7 +210,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
p = initdata.initializeprobpmap(nlsol)
p = initdata.initializeprobpmap(valp, nlsol)
end

return u0, p, SciMLBase.successful_retcode(nlsol)
Expand Down
8 changes: 4 additions & 4 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -22,6 +23,7 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -31,12 +33,10 @@ DelayDiffEq = "5"
DiffEqCallbacks = "3, 4"
ForwardDiff = "0.10"
JumpProcesses = "9.10"
ModelingToolkit = "9"
ModelingToolkit = "9.52"
ModelingToolkitStandardLibrary = "2.7"
NonlinearSolve = "2, 3, 4"
Optimization = "3"
OptimizationMOI = "0.4"
OptimizationOptimJL = "0.1, 0.2, 0.3"
Optimization = "4"
OrdinaryDiffEq = "6.33"
Plots = "1.40"
RecursiveArrayTools = "3"
Expand Down
2 changes: 1 addition & 1 deletion test/downstream/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ gs_ts, = Zygote.gradient(sol) do sol
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
end

@test all(map(x -> x == true_grad_vecsym, gs_ts))
@test_broken all(map(x -> x == true_grad_vecsym, gs_ts))

# BatchedInterface AD
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0
Expand Down
53 changes: 26 additions & 27 deletions test/downstream/initialization.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,9 @@
using OrdinaryDiffEq, Sundials, SciMLBase, Test
using ModelingToolkit, NonlinearSolve, OrdinaryDiffEq, Sundials, SciMLBase, Test
using SymbolicIndexingInterface
using ModelingToolkit: t_nounits as t, D_nounits as D

@testset "CheckInit" begin
abstol = 1e-10
@testset "Sundials + ODEProblem" begin
function rhs(u, p, t)
return [u[1] * t, u[1]^2 - u[2]^2]
end
function rhs!(du, u, p, t)
du[1] = u[1] * t
du[2] = u[1]^2 - u[2]^2
end

oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
integ = init(prob, Sundials.ARKODE())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
end
end

@testset "Sundials + DAEProblem" begin
function daerhs(du, u, p, t)
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
Expand Down Expand Up @@ -59,3 +35,26 @@ using OrdinaryDiffEq, Sundials, SciMLBase, Test
end
end
end

@testset "OverrideInit with MTK" begin
abstol = 1e-10
reltol = 1e-8

@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p=missing [guess = 1.0] q=missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ p * y + q * t, D(y) ~ 5x + q], t;
initialization_eqs = [p^2 + q^2 ~ 3, x^3 + y^3 ~ 5])
prob = ODEProblem(
sys, [x => 1.0], (0.0, 1.0), [p => 1.0]; initializealg = SciMLBase.NoInit())

@test prob.f.initialization_data isa SciMLBase.OverrideInitData
integ = init(prob, Tsit5())
u0, pobj, success = SciMLBase.get_initial_values(
prob, integ, prob.f, SciMLBase.OverrideInit(), Val(true);
nlsolve_alg = NewtonRaphson(), abstol, reltol)

@test getu(sys, x)(u0) ≈ 1.0
@test getu(sys, y)(u0) ≈ cbrt(4)
@test getp(sys, p)(pobj) ≈ 1.0
@test getp(sys, q)(pobj) ≈ sqrt(2)
end
8 changes: 4 additions & 4 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,12 @@ eprob = EnsembleProblem(oprob)
@test eprob.ps[osys.p] == 0.1

@test state_values(remake(eprob; u0 = [X => 0.1])) == [0.1]
@test state_values(remake(eprob; u0 = [:X => 0.2])) == [0.2]
@test_broken state_values(remake(eprob; u0 = [:X => 0.2])) == [0.2]
@test state_values(remake(eprob; u0 = [osys.X => 0.3])) == [0.3]

@test remake(eprob; p = [d => 0.4]).ps[d] == 0.4
@test remake(eprob; p = [:d => 0.5]).ps[d] == 0.5
@test remake(eprob; p = [osys.d => 0.6]).ps[d] == 0.6
@test_broken remake(eprob; p = [d => 0.4]).ps[d] == 0.4
@test_broken remake(eprob; p = [:d => 0.5]).ps[d] == 0.5
@test_broken remake(eprob; p = [osys.d => 0.6]).ps[d] == 0.6

# SteadyStateProblem Indexing
# Issue#660
Expand Down
28 changes: 5 additions & 23 deletions test/downstream/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,13 @@ end
ss1.state_map == ss2.state_map
end

ode_sol = solve(prob, Tsit5(); save_idxs = xidx)
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]

# FIXME: hack for save_idxs
SciMLBase.@reset ode_sol.saved_subsystem = subsys
ode_sol = solve(prob, Tsit5(); save_idxs = [x])

@mtkbuild sys = ODESystem([D(x) ~ x + p * y, 1 ~ sin(y) + cos(x)], t)
xidx = variable_index(sys, x)
prob = DAEProblem(sys, [D(x) => x + p * y, D(y) => 1 / sqrt(1 - (1 - cos(x))^2)],
[x => 1.0, y => asin(1 - cos(x))], (0.0, 1.0), [p => 2.0])
dae_sol = solve(prob, DFBDF(); save_idxs = xidx)
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [xidx])
# FIXME: hack for save_idxs
SciMLBase.@reset dae_sol.saved_subsystem = subsys
[x => 1.0, y => asin(1 - cos(x))], (0.0, 1.0), [p => 2.0]; build_initializeprob = false)
dae_sol = solve(prob, DFBDF(); save_idxs = [x])

@brownian a b
@mtkbuild sys = System([D(x) ~ x + p * y + x * a, D(y) ~ 2p + x^2 + y * b], t)
Expand Down Expand Up @@ -256,21 +248,11 @@ end

@test SciMLBase.SavedSubsystem(sys, prob.p, [x, y, q, r, s, u]) === nothing

sol = solve(prob; save_idxs = xidx)
sol = solve(prob; save_idxs = [x, q, r])
xvals = sol[x]
subsys = SciMLBase.SavedSubsystem(sys, prob.p, [x, q, r])
@test SciMLBase.get_saved_state_idxs(subsys) == [xidx]
@test SciMLBase.get_saved_state_idxs(sol.saved_subsystem) == [xidx]
qvals = sol.ps[q]
rvals = sol.ps[r]
# FIXME: hack for save_idxs
SciMLBase.@reset sol.saved_subsystem = subsys
discq = DiffEqArray(SciMLBase.TupleOfArraysWrapper.(tuple.(Base.vect.(qvals))),
sol.discretes[qpidx.timeseries_idx].t, (1, 1))
discr = DiffEqArray(SciMLBase.TupleOfArraysWrapper.(tuple.(Base.vect.(rvals))),
sol.discretes[rpidx.timeseries_idx].t, (1, 1))
SciMLBase.@reset sol.discretes.collection[qpidx.timeseries_idx] = discq
SciMLBase.@reset sol.discretes.collection[rpidx.timeseries_idx] = discr

@test sol[x] == xvals

@test all(Base.Fix1(is_parameter, sol), [p, q, r, s, u])
Expand Down
2 changes: 1 addition & 1 deletion test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ end
initprobmap = function (nlsol)
return [parameter_values(nlsol)[1], nlsol.u[1]]
end
initprobpmap = function (nlsol)
initprobpmap = function (_, nlsol)
return nlsol.u[2]
end
initialization_data = SciMLBase.OverrideInitData(
Expand Down
Loading