Skip to content

Commit

Permalink
Merge pull request #911 from AayushSabharwal/as/remake-eager-init
Browse files Browse the repository at this point in the history
fix: fix eager initialization in `remake`
  • Loading branch information
ChrisRackauckas authored Jan 23, 2025
2 parents d2d5e6f + a08e4dd commit c01ffa5
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 75 deletions.
3 changes: 3 additions & 0 deletions src/problems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,6 @@ function Base.summary(io::IO, prob::AbstractPDEProblem)
end

Base.copy(p::SciMLBase.NullParameters) = p

SymbolicIndexingInterface.is_time_dependent(::AbstractDEProblem) = true
SymbolicIndexingInterface.is_time_dependent(::AbstractNonlinearProblem) = false
109 changes: 37 additions & 72 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,9 @@ function remake(prob::ODEProblem; f = missing,
ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...)
end

if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if initialization_data !== nothing && !lazy_initialization
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
u0 = nothing
end
@reset prob.u0 = u0
@reset prob.p = p
end
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p

return prob
end
Expand Down Expand Up @@ -453,18 +444,10 @@ function remake(prob::SDEProblem;
else
SDEProblem{iip}(f, newu0, tspan, newp; noise, noise_rate_prototype, seed, kwargs...)
end
if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if initialization_data !== nothing && !lazy_initialization
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
u0 = nothing
end
@reset prob.u0 = u0
@reset prob.p = p
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p

return prob
end
Expand Down Expand Up @@ -520,18 +503,10 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
DDEProblem{iip}(f, newu0, h, tspan, newp; constant_lags, dependent_lags,
order_discontinuity_t0, neutral, kwargs...)
end
if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if initialization_data !== nothing && !lazy_initialization
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
u0 = nothing
end
@reset prob.u0 = u0
@reset prob.p = p
end

u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p

return prob
end
Expand Down Expand Up @@ -619,18 +594,9 @@ function remake(prob::SDDEProblem;
dependent_lags, order_discontinuity_t0, neutral, kwargs...)
end

if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if initialization_data !== nothing && !lazy_initialization
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
u0 = nothing
end
@reset prob.u0 = u0
@reset prob.p = p
end
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p

return prob
end
Expand Down Expand Up @@ -741,18 +707,9 @@ function remake(prob::NonlinearProblem;
problem_type = problem_type; kwargs...)
end

if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if initialization_data !== nothing && !lazy_initialization
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
u0 = nothing
end
@reset prob.u0 = u0
@reset prob.p = p
end
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p

return prob
end
Expand Down Expand Up @@ -792,18 +749,9 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
f, u0 = newu0, p = newp, kwargs...)
end

if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if initialization_data !== nothing && !lazy_initialization
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
u0 = nothing
end
@reset prob.u0 = u0
@reset prob.p = p
end
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
@reset prob.u0 = u0
@reset prob.p = p

return prob
end
Expand Down Expand Up @@ -1134,6 +1082,23 @@ function process_p_u0_symbolic(prob, p, u0)
end
end

function maybe_eager_initialize_problem(prob::AbstractSciMLProblem, initialization_data, lazy_initialization::Union{Nothing, Bool})
if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if initialization_data !== nothing && !lazy_initialization && (!is_time_dependent(prob) || current_time(prob) !== nothing)
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
u0 = nothing
end
else
u0 = state_values(prob)
p = parameter_values(prob)
end
return u0, p
end

function remake(thing::AbstractJumpProblem; kwargs...)
parameterless_type(thing)(remake(thing.prob; kwargs...))
end
Expand Down
10 changes: 9 additions & 1 deletion test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,10 @@ end
@testset "Trivial initialization" begin
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
update_initializeprob! = function (iprob, integ)
iprob.p[1] = integ.u[1]
# just to access the current time and use it as a number, so this errors
# if run on a problem with `current_time(prob) === nothing`
iprob.p[1] = current_time(integ) + 1
iprob.p[1] = state_values(integ)[1]
end
initprobmap = function (nlsol)
u1 = parameter_values(nlsol)[1]
Expand All @@ -284,6 +287,11 @@ end
@test u0 [2.0, 2.0]
@test p 0.0
@test success

@testset "Doesn't run in `remake` if `tspan == (nothing, nothing)`" begin
prob = ODEProblem(fn, [2.0, 0.0], (nothing, nothing), 0.0)
@test_nowarn remake(prob)
end
end
end

Expand Down
5 changes: 3 additions & 2 deletions test/remake_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ u0 = [1.0; 2.0; 3.0]
tspan = (0.0, 100.0)
p = [10.0, 20.0, 30.0]
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
indep_sys = SymbolCache([:x, :y, :z], [:a, :b, :c])
fn = ODEFunction(lorenz!; sys)
for T in containerTypes
push!(probs, ODEProblem(fn, u0, tspan, T(p)))
Expand Down Expand Up @@ -64,7 +65,7 @@ function loss(x, p)
return sum(du)
end

fn = OptimizationFunction(loss; sys)
fn = OptimizationFunction(loss; sys = indep_sys)
for T in containerTypes
push!(probs, OptimizationProblem(fn, u0, T(p)))
end
Expand All @@ -73,7 +74,7 @@ function nllorenz!(du, u, p)
lorenz!(du, u, p, 0.0)
end

fn = NonlinearFunction(nllorenz!; sys)
fn = NonlinearFunction(nllorenz!; sys = indep_sys)
for T in containerTypes
push!(probs, NonlinearProblem(fn, u0, T(p)))
end
Expand Down

0 comments on commit c01ffa5

Please sign in to comment.