Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support initialize_dae! for SDEIntegrator #594

Merged
merged 10 commits into from
Dec 11, 2024
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>"]
version = "6.71.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -20,6 +21,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -32,6 +34,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
ADTypes = "1"
Adapt = "3, 4"
ArrayInterface = "6, 7"
DataStructures = "0.18"
Expand All @@ -48,11 +51,12 @@ Logging = "1.6"
MuladdMacro = "0.2.1"
NLsolve = "4"
OrdinaryDiffEq = "6.87"
OrdinaryDiffEqCore = "1.12.1"
Random = "1.6"
RandomNumbers = "1.5.3"
RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
SciMLBase = "2.59.2"
SciMLBase = "2.65"
SciMLOperators = "0.2.9, 0.3"
SparseArrays = "1.6"
SparseDiffTools = "2"
Expand Down
7 changes: 6 additions & 1 deletion src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using DocStringExtensions
using Reexport
@reexport using DiffEqBase

import ADTypes

import OrdinaryDiffEq
import OrdinaryDiffEq: default_controller, isstandard, ispredictive,
beta2_default, beta1_default, gamma_default,
Expand Down Expand Up @@ -41,7 +43,7 @@ using DocStringExtensions
import DiffEqBase: step!, initialize!, DEAlgorithm,
AbstractSDEAlgorithm, AbstractRODEAlgorithm, DEIntegrator, AbstractDiffEqInterpolation,
DECache, AbstractSDEIntegrator, AbstractRODEIntegrator, AbstractContinuousCallback,
Tableau
Tableau, AbstractSDDEIntegrator

# Integrator Interface
import DiffEqBase: resize!,deleteat!,addat!,full_cache,user_cache,u_cache,du_cache,
Expand All @@ -58,6 +60,8 @@ using OrdinaryDiffEq: nlsolvefail, isnewton, set_new_W!, get_W, _vec, _reshape

using OrdinaryDiffEq: NLSolver

import OrdinaryDiffEqCore

if isdefined(OrdinaryDiffEq,:FastConvergence)
using OrdinaryDiffEq:
FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence
Expand Down Expand Up @@ -119,6 +123,7 @@ end
include("cache_utils.jl")
include("integrators/integrator_interface.jl")
include("iterator_interface.jl")
include("initialize_dae.jl")
include("solve.jl")
include("initdt.jl")
include("perform_step/low_order.jl")
Expand Down
25 changes: 25 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@ SciMLBase.forwarddiffs_model(alg::Union{StochasticDiffEqNewtonAlgorithm,
StochasticDiffEqNewtonAdaptiveAlgorithm,StochasticDiffEqJumpNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm}) = OrdinaryDiffEq.alg_autodiff(alg)

# Required for initialization, because ODECore._initialize_dae! calls it during
# OverrideInit
OrdinaryDiffEqCore.has_autodiff(::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm,StochasticDiffEqJumpAlgorithm}) = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't true though, some of the implicit methods have the same autodiff args as the ode solver

for T in [
StochasticDiffEqNewtonAlgorithm, StochasticDiffEqNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonAdaptiveAlgorithm,
StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm]
@eval OrdinaryDiffEqCore.has_autodiff(::$T) = true
end

_alg_autodiff(::StochasticDiffEqNewtonAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()
_alg_autodiff(::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{T, AD}) where {T, AD} = Val{AD}()

function OrdinaryDiffEqCore.alg_autodiff(alg)
ad = _alg_autodiff(alg)
if ad == Val(false)
return ADTypes.AutoFiniteDiff()
elseif ad == Val(true)
return ADTypes.AutoForwardDiff()
else
return SciMLBase._unwrap_val(ad)
end
end

isadaptive(alg::Union{StochasticDiffEqAlgorithm,StochasticDiffEqRODEAlgorithm}) = false
isadaptive(alg::Union{StochasticDiffEqAdaptiveAlgorithm,StochasticDiffEqRODEAdaptiveAlgorithm,StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpDiffusionAdaptiveAlgorithm}) = true
Expand Down
13 changes: 13 additions & 0 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
struct SDEDefaultInit <: DiffEqBase.DAEInitializationAlgorithm end

function DiffEqBase.initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, initializealg = integrator.initializealg)
OrdinaryDiffEqCore._initialize_dae!(integrator, integrator.sol.prob, initializealg, Val(DiffEqBase.isinplace(integrator.sol.prob)))
end

function OrdinaryDiffEqCore._initialize_dae!(integrator::Union{AbstractSDEIntegrator, AbstractSDDEIntegrator}, prob, ::SDEDefaultInit, isinplace)
if SciMLBase.has_initializeprob(prob.f)
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.OverrideInit(), isinplace)
elseif SciMLBase.__has_mass_matrix(prob.f)
OrdinaryDiffEqCore._initialize_dae!(integrator, prob, SciMLBase.CheckInit(), isinplace)
end
end
3 changes: 2 additions & 1 deletion src/integrators/type.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs} <: AbstractSDEIntegrator{algType,IIP,uType,tType}
mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenType,tTypeNoUnits,uEltypeNoUnits,randType,randType2,rateType,solType,cacheType,F4,F5,F6,OType,noiseType,EventErrorType,CallbackCacheType,RCs,IA} <: AbstractSDEIntegrator{algType,IIP,uType,tType}
f::F4
g::F5
c::F6
Expand Down Expand Up @@ -43,4 +43,5 @@ mutable struct SDEIntegrator{algType,IIP,uType,uEltype,tType,tdirType,P2,eigenTy
qold::tTypeNoUnits
q11::tTypeNoUnits
stats::DiffEqBase.Stats
initializealg::IA
end
7 changes: 5 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ function DiffEqBase.__init(
userdata=nothing,
initialize_integrator=true,
seed = UInt64(0), alias_u0=false, alias_jumps = Threads.threadid()==1,
initializealg = SDEDefaultInit(),
kwargs...) where recompile_flag

prob = concrete_prob(_prob)
Expand Down Expand Up @@ -587,7 +588,8 @@ function DiffEqBase.__init(
uBottomEltype,tType,typeof(tdir),typeof(p),
typeof(eigen_est),QT,
uEltypeNoUnits,typeof(W),typeof(P),rateType,typeof(sol),typeof(cache),
FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants)}(
FType,GType,CType,typeof(opts),typeof(noise),typeof(last_event_error),typeof(callback_cache),typeof(rate_constants),
typeof(initializealg)}(
f,g,c,noise,uprev,tprev,t,u,p,tType(dt),tType(dt),tType(dt),dtcache,tspan[2],tdir,
just_hit_tstop,do_error_check,isout,event_last_time,
vector_event_last_time,last_event_error,accept_step,
Expand All @@ -597,9 +599,10 @@ function DiffEqBase.__init(
alg,sol,
cache,callback_cache,tType(dt),W,P,rate_constants,
opts,iter,success_iter,eigen_est,EEst,q,
QT(qoldinit),q11,stats)
QT(qoldinit),q11,stats,initializealg)

if initialize_integrator
DiffEqBase.initialize_dae!(integrator)
initialize_callbacks!(integrator, initialize_save)
initialize!(integrator,integrator.cache)
save_start && alg isa Union{StochasticDiffEqCompositeAlgorithm,
Expand Down
Loading