Skip to content

Commit

Permalink
Allow for DAE initialization on ODEs with initializeprob
Browse files Browse the repository at this point in the history
This can come up from cases with ModelingToolkit, see SciML/ModelingToolkit.jl#2512
  • Loading branch information
ChrisRackauckas committed Feb 29, 2024
1 parent 019c186 commit cb6c390
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob)
CompositeAlgorithm(algs, alg.choice_function)
end

has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false
has_autodiff(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm, OrdinaryDiffEqImplicitAlgorithm, CompositeAlgorithm, OrdinaryDiffEqExponentialAlgorithm}) = true

# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
error("This algorithm does not have an autodifferentiation option defined.")
Expand Down
7 changes: 6 additions & 1 deletion src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ end
function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem},
alg::OverrideInit, isinplace::Union{Val{true}, Val{false}})
initializeprob = prob.f.initializeprob
isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff

# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
# Since then it's the case of not a DAE but has initializeprob
# In which case, it should be differentiable
isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff : true

alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
nlsol = solve(initializeprob, alg)
if isinplace === Val{true}()
Expand Down
2 changes: 1 addition & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ function DiffEqBase.__init(
opts, stats, initializealg, differential_vars)

if initialize_integrator
if isdae
if isdae || SciMLBase.has_initializeprob(prob.f)
DiffEqBase.initialize_dae!(integrator)
end

Expand Down

0 comments on commit cb6c390

Please sign in to comment.