From 34fefbc72bb82a4c324a0665ff5b91badc166375 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 17:54:27 +0530 Subject: [PATCH 01/10] feat: add `initializealg` to `SDEIntegrator` --- Project.toml | 2 +- src/integrators/type.jl | 3 ++- src/solve.jl | 6 ++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 1e47a2f8f..c65441f2b 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ Random = "1.6" RandomNumbers = "1.5.3" RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" -SciMLBase = "2.59.2" +SciMLBase = "2.62" SciMLOperators = "0.2.9, 0.3" SparseArrays = "1.6" SparseDiffTools = "2" diff --git a/src/integrators/type.jl b/src/integrators/type.jl index b77805971..979fdf602 100644 --- a/src/integrators/type.jl +++ b/src/integrators/type.jl @@ -1,4 +1,4 @@ -mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs} <: AbstractSDEIntegrator{algType,IIP,uType,tType} +mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs,IA} <: AbstractSDEIntegrator{algType,IIP,uType,tType} f::F4 g::F5 c::F6 @@ -43,4 +43,5 @@ mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenTy qold::tTypeNoUnits q11::tTypeNoUnits stats::DiffEqBase.Stats + initializealg::IA end diff --git a/src/solve.jl b/src/solve.jl index d7bc2f8c8..aa6ead1f5 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -66,6 +66,7 @@ function DiffEqBase.__init( userdata=nothing, initialize_integrator=true, seed = UInt64(0), alias_u0=false, alias_jumps = Threads.threadid()==1, + initializealg = SDEDefaultInit(), kwargs...) where recompile_flag prob = concrete_prob(_prob) @@ -587,7 +588,8 @@ function DiffEqBase.__init( uBottomEltype,tType,typeof(tdir),typeof(p), typeof(eigen_est),QT, uEltypeNoUnits,typeof(W),typeof(P),rateType,typeof(sol),typeof(cache), - FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants)}( + FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants), + typeof(initializealg)}( f,g,c,noise,uprev,tprev,t,u,p,tType(dt),tType(dt),tType(dt),dtcache,tspan[2],tdir, just_hit_tstop,do_error_check,isout,event_last_time, vector_event_last_time,last_event_error,accept_step, @@ -597,7 +599,7 @@ function DiffEqBase.__init( alg,sol, cache,callback_cache,tType(dt),W,P,rate_constants, opts,iter,success_iter,eigen_est,EEst,q, - QT(qoldinit),q11,stats) + QT(qoldinit),q11,stats,initializealg) if initialize_integrator initialize_callbacks!(integrator, initialize_save) From 9a0733be9050dab52038133794063238238fe3f2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 17:54:41 +0530 Subject: [PATCH 02/10] feat: implement `initialize_dae!` for `SDEIntegrator` --- src/StochasticDiffEq.jl | 1 + src/initialize_dae.jl | 14 ++++++++++++++ src/solve.jl | 1 + 3 files changed, 16 insertions(+) create mode 100644 src/initialize_dae.jl diff --git a/src/StochasticDiffEq.jl b/src/StochasticDiffEq.jl index caf8c11d4..b946d2009 100644 --- a/src/StochasticDiffEq.jl +++ b/src/StochasticDiffEq.jl @@ -119,6 +119,7 @@ end include("cache_utils.jl") include("integrators/integrator_interface.jl") include("iterator_interface.jl") + include("initialize_dae.jl") include("solve.jl") include("initdt.jl") include("perform_step/low_order.jl") diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl new file mode 100644 index 000000000..a1cbbc29e --- /dev/null +++ b/src/initialize_dae.jl @@ -0,0 +1,14 @@ +struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end + +function DiffEqBase.initialize_dae!(integrator::SDEIntegrator, initializealg = integrator.initializealg) + OrdinaryDiffEq._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) +end + +function OrdinaryDiffEq._initialize_dae!(integrator::SDEIntegrator, prob, ::SDEDefaultInit, isinplace) + if SciMLBase.has_initializeprob(prob.f) + OrdinaryDiffEq._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) + else + OrdinaryDiffEq._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) + end +end + diff --git a/src/solve.jl b/src/solve.jl index aa6ead1f5..8858694da 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -602,6 +602,7 @@ function DiffEqBase.__init( QT(qoldinit),q11,stats,initializealg) if initialize_integrator + DiffEqBase.initialize_dae!(integrator) initialize_callbacks!(integrator, initialize_save) initialize!(integrator,integrator.cache) save_start && alg isa Union{StochasticDiffEqCompositeAlgorithm, From d8cfa884955f0415aa7ebfc68c7c531bf6aecdf1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 27 Nov 2024 12:58:46 +0530 Subject: [PATCH 03/10] refactor: depend on `OrdinaryDiffEqCore` for `_initialize_dae!` --- Project.toml | 2 ++ src/StochasticDiffEq.jl | 2 ++ src/initialize_dae.jl | 8 ++++---- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index c65441f2b..8184b14e3 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" @@ -48,6 +49,7 @@ Logging = "1.6" MuladdMacro = "0.2.1" NLsolve = "4" OrdinaryDiffEq = "6.87" +OrdinaryDiffEqCore = "1" Random = "1.6" RandomNumbers = "1.5.3" RecursiveArrayTools = "2, 3" diff --git a/src/StochasticDiffEq.jl b/src/StochasticDiffEq.jl index b946d2009..91f2ed296 100644 --- a/src/StochasticDiffEq.jl +++ b/src/StochasticDiffEq.jl @@ -58,6 +58,8 @@ using OrdinaryDiffEq: nlsolvefail, isnewton, set_new_W!, get_W, _vec, _reshape using OrdinaryDiffEq: NLSolver +import OrdinaryDiffEqCore + if isdefined(OrdinaryDiffEq,:FastConvergence) using OrdinaryDiffEq: FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index a1cbbc29e..947e3aa51 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -1,14 +1,14 @@ struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end function DiffEqBase.initialize_dae!(integrator::SDEIntegrator, initializealg = integrator.initializealg) - OrdinaryDiffEq._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) + OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) end -function OrdinaryDiffEq._initialize_dae!(integrator::SDEIntegrator, prob, ::SDEDefaultInit, isinplace) +function OrdinaryDiffEqCore._initialize_dae!(integrator::SDEIntegrator, prob, ::SDEDefaultInit, isinplace) if SciMLBase.has_initializeprob(prob.f) - OrdinaryDiffEq._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) + OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) else - OrdinaryDiffEq._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) + OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) end end From fd2ee9267202df1fc2c0bbd9a072bc986180a877 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 27 Nov 2024 17:42:13 +0530 Subject: [PATCH 04/10] build: bump `OrdinaryDiffEqCore` compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8184b14e3..d18e27dc2 100644 --- a/Project.toml +++ b/Project.toml @@ -49,7 +49,7 @@ Logging = "1.6" MuladdMacro = "0.2.1" NLsolve = "4" OrdinaryDiffEq = "6.87" -OrdinaryDiffEqCore = "1" +OrdinaryDiffEqCore = "1.12.1" Random = "1.6" RandomNumbers = "1.5.3" RecursiveArrayTools = "2, 3" From 0e4ac2328f572fbb1a47bb8adfbc93dd2537e2bc Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 29 Nov 2024 13:30:54 +0530 Subject: [PATCH 05/10] refactor: relax `initialize_dae!` type restrictions --- src/initialize_dae.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 947e3aa51..87a3fd034 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -1,13 +1,13 @@ struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end -function DiffEqBase.initialize_dae!(integrator::SDEIntegrator, initializealg = integrator.initializealg) +function DiffEqBase.initialize_dae!(integrator::AbstractSDEIntegrator, initializealg = integrator.initializealg) OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) end -function OrdinaryDiffEqCore._initialize_dae!(integrator::SDEIntegrator, prob, ::SDEDefaultInit, isinplace) +function OrdinaryDiffEqCore._initialize_dae!(integrator::AbstractSDEIntegrator, prob, ::SDEDefaultInit, isinplace) if SciMLBase.has_initializeprob(prob.f) OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) - else + elseif SciMLBase.__has_mass_matrix(prob.f) OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) end end From e4cc7b3ddcd7c2023486e446b0322bc1eb68c5c1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 29 Nov 2024 16:32:43 +0530 Subject: [PATCH 06/10] build: bump `SciMLBase` compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d18e27dc2..fd958cc09 100644 --- a/Project.toml +++ b/Project.toml @@ -54,7 +54,7 @@ Random = "1.6" RandomNumbers = "1.5.3" RecursiveArrayTools = "2, 3" Reexport = "0.2, 1.0" -SciMLBase = "2.62" +SciMLBase = "2.65" SciMLOperators = "0.2.9, 0.3" SparseArrays = "1.6" SparseDiffTools = "2" From 2fc2ae77446b26f6e0f99cc327309b093a5ccc9b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:06:54 +0530 Subject: [PATCH 07/10] feat: also support `AbstractSDDEIntegrator` in `initialize_dae!` --- src/StochasticDiffEq.jl | 2 +- src/initialize_dae.jl | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/StochasticDiffEq.jl b/src/StochasticDiffEq.jl index 91f2ed296..a188a86ad 100644 --- a/src/StochasticDiffEq.jl +++ b/src/StochasticDiffEq.jl @@ -41,7 +41,7 @@ using DocStringExtensions import DiffEqBase: step!, initialize!, DEAlgorithm, AbstractSDEAlgorithm, AbstractRODEAlgorithm, DEIntegrator, AbstractDiffEqInterpolation, DECache, AbstractSDEIntegrator, AbstractRODEIntegrator, AbstractContinuousCallback, - Tableau + Tableau, AbstractSDDEIntegrator # Integrator Interface import DiffEqBase: resize!,deleteat!,addat!,full_cache,user_cache,u_cache,du_cache, diff --git a/src/initialize_dae.jl b/src/initialize_dae.jl index 87a3fd034..f9bfbfc07 100644 --- a/src/initialize_dae.jl +++ b/src/initialize_dae.jl @@ -1,14 +1,13 @@ struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end -function DiffEqBase.initialize_dae!(integrator::AbstractSDEIntegrator, initializealg = integrator.initializealg) +function DiffEqBase.initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, initializealg = integrator.initializealg) OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob))) end -function OrdinaryDiffEqCore._initialize_dae!(integrator::AbstractSDEIntegrator, prob, ::SDEDefaultInit, isinplace) +function OrdinaryDiffEqCore._initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, prob, ::SDEDefaultInit, isinplace) if SciMLBase.has_initializeprob(prob.f) OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace) elseif SciMLBase.__has_mass_matrix(prob.f) OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace) end end - From bef084d516b30fa8c76e3d91102572759412fb08 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:07:11 +0530 Subject: [PATCH 08/10] feat: implement `OrdinaryDiffEqCore.has_autodiff` --- src/alg_utils.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 6ddf4083a..4b7e4765d 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -9,6 +9,9 @@ SciMLBase.forwarddiffs_model(alg::Union{StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm,StochasticDiffEqJumpNewtonAdaptiveAlgorithm, StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm}) = OrdinaryDiffEq.alg_autodiff(alg) +# Required for initialization, because ODECore._initialize_dae! calls it during +# OverrideInit +OrdinaryDiffEqCore.has_autodiff(::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm,StochasticDiffEqJumpAlgorithm}) = false isadaptive(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = false isadaptive(alg::Union{StochasticDiffEqAdaptiveAlgorithm,StochasticDiffEqRODEAdaptiveAlgorithm,StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpDiffusionAdaptiveAlgorithm}) = true From baa5adfbfa40a5cb0269f76f97fb9f4c0ad8081f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 9 Dec 2024 18:09:41 +0530 Subject: [PATCH 09/10] build: add `ADTypes` --- Project.toml | 2 ++ src/StochasticDiffEq.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index fd958cc09..d5274406c 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas "] version = "6.71.1" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" @@ -33,6 +34,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] +ADTypes = "1" Adapt = "3, 4" ArrayInterface = "6, 7" DataStructures = "0.18" diff --git a/src/StochasticDiffEq.jl b/src/StochasticDiffEq.jl index a188a86ad..b1da7e579 100644 --- a/src/StochasticDiffEq.jl +++ b/src/StochasticDiffEq.jl @@ -9,6 +9,8 @@ using DocStringExtensions using Reexport @reexport using DiffEqBase + import ADTypes + import OrdinaryDiffEq import OrdinaryDiffEq: default_controller, isstandard, ispredictive, beta2_default, beta1_default, gamma_default, From 75087875a5b0c47b7225cf52b8e65b6bb0996ce7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 9 Dec 2024 18:09:57 +0530 Subject: [PATCH 10/10] feat: properly implement `has_autodiff` and `alg_autodiff` --- src/alg_utils.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/alg_utils.jl b/src/alg_utils.jl index 4b7e4765d..4d22b89e9 100644 --- a/src/alg_utils.jl +++ b/src/alg_utils.jl @@ -12,6 +12,28 @@ SciMLBase.forwarddiffs_model(alg::Union{StochasticDiffEqNewtonAlgorithm, # Required for initialization, because ODECore._initialize_dae! calls it during # OverrideInit OrdinaryDiffEqCore.has_autodiff(::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm,StochasticDiffEqJumpAlgorithm}) = false +for T in [ + StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm, + StochasticDiffEqJumpNewtonAdaptiveAlgorithm, + StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm] + @eval OrdinaryDiffEqCore.has_autodiff(::$T) = true +end + +_alg_autodiff(::StochasticDiffEqNewtonAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() +_alg_autodiff(::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}() + +function OrdinaryDiffEqCore.alg_autodiff(alg) + ad = _alg_autodiff(alg) + if ad == Val(false) + return ADTypes.AutoFiniteDiff() + elseif ad == Val(true) + return ADTypes.AutoForwardDiff() + else + return SciMLBase._unwrap_val(ad) + end +end isadaptive(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = false isadaptive(alg::Union{StochasticDiffEqAdaptiveAlgorithm,StochasticDiffEqRODEAdaptiveAlgorithm,StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpDiffusionAdaptiveAlgorithm}) = true