From cb6c390280de4f7f142dc42e97856d9696797307 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 29 Feb 2024 05:23:51 -0600 Subject: [PATCH] Allow for DAE initialization on ODEs with initializeprob This can come up from cases with ModelingToolkit, see https://github.com/SciML/ModelingToolkit.jl/pull/2512 --- src/alg_utils.jl | 3 +++ src/initialize_dae.jl | 7 ++++++- src/solve.jl | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index cbd7695517..6b6bde2be9 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -284,6 +284,9 @@ function DiffEqBase.prepare_alg(alg::CompositeAlgorithm, u0, p, prob) CompositeAlgorithm(algs, alg.choice_function) end +has_autodiff(alg::OrdinaryDiffEqAlgorithm) = false +has_autodiff(alg::Union{OrdinaryDiffEqAdaptiveImplicitAlgorithm, OrdinaryDiffEqImplicitAlgorithm, CompositeAlgorithm, OrdinaryDiffEqExponentialAlgorithm}) = true + # Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options. function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm) error("This algorithm does not have an autodifferentiation option defined.") diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 3ee8556ac0..5070b2f9ce 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -134,7 +134,12 @@ end function _initialize_dae!(integrator, prob::Union{ODEProblem, DAEProblem}, alg::OverrideInit, isinplace::Union{Val{true}, Val{false}}) initializeprob = prob.f.initializeprob - isAD = alg_autodiff(integrator.alg) isa AutoForwardDiff + + # If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit + # Since then it's the case of not a DAE but has initializeprob + # In which case, it should be differentiable + isAD = has_autodiff(integrator.alg) ? alg_autodiff(integrator.alg) isa AutoForwardDiff : true + alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD) nlsol = solve(initializeprob, alg) if isinplace === Val{true}() diff --git a/src/solve.jl b/src/solve.jl index 04196c4995..e99e636988 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -493,7 +493,7 @@ function DiffEqBase.__init( opts, stats, initializealg, differential_vars) if initialize_integrator - if isdae + if isdae || SciMLBase.has_initializeprob(prob.f) DiffEqBase.initialize_dae!(integrator) end