From 1c40b9e6edae54ebc4ffdf09d470edbcd2899f5e Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 21 Jul 2024 23:17:58 +0200 Subject: [PATCH] tell `SDEProblem` that the system contains scalar noise --- Project.toml | 2 ++ ext/MTKDiffEqNoiseProcess.jl | 8 ++++++ src/systems/abstractsystem.jl | 3 +- src/systems/diffeqs/sdesystem.jl | 48 ++++++++++++++++++++++++++------ src/systems/systems.jl | 17 ++++++----- test/sdesystem.jl | 4 +-- 6 files changed, 61 insertions(+), 21 deletions(-) create mode 100644 ext/MTKDiffEqNoiseProcess.jl diff --git a/Project.toml b/Project.toml index 24632bfaf3..86703f5813 100644 --- a/Project.toml +++ b/Project.toml @@ -56,10 +56,12 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [weakdeps] BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665" DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6" +DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503" [extensions] MTKBifurcationKitExt = "BifurcationKit" MTKDeepDiffsExt = "DeepDiffs" +MTKDiffEqNoiseProcess = "DiffEqNoiseProcess" [compat] AbstractTrees = "0.3, 0.4" diff --git a/ext/MTKDiffEqNoiseProcess.jl b/ext/MTKDiffEqNoiseProcess.jl new file mode 100644 index 0000000000..d2f57db835 --- /dev/null +++ b/ext/MTKDiffEqNoiseProcess.jl @@ -0,0 +1,8 @@ +module MTKDiffEqNoiseProcess + +using ModelingToolkit: ModelingToolkit +using DiffEqNoiseProcess: WienerProcess + +ModelingToolkit.scalar_noise() = WienerProcess(0.0, 0.0, 0.0) + +end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index dcff371ecf..82c7fc0195 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -655,7 +655,8 @@ for prop in [:eqs :solved_unknowns :split_idxs :parent - :index_cache] + :index_cache + :is_scalar_noise] fname_get = Symbol(:get_, prop) fname_has = Symbol(:has_, prop) @eval begin diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 86237ac51c..aa1cc055f7 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -128,13 +128,18 @@ struct SDESystem <: AbstractODESystem The hierarchical parent system before simplification. """ parent::Any - + """ + Signal for whether the noise equations should be treated as a scalar process. This should only + be `true` when `noiseeqs isa Vector`. + """ + is_scalar_noise::Bool + function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing, - complete = false, index_cache = nothing, parent = nothing; + complete = false, index_cache = nothing, parent = nothing, is_scalar_noise=false; checks::Union{Bool, Int} = true) if checks == true || (checks & CheckComponents) > 0 check_independent_variables([iv]) @@ -146,6 +151,9 @@ struct SDESystem <: AbstractODESystem throw(ArgumentError("Noise equations ill-formed. Number of rows must match number of drift equations. size(neqs,1) = $(size(neqs,1)) != length(deqs) = $(length(deqs))")) end check_equations(equations(cevents), iv) + if is_scalar_noise && neqs isa AbstractMatrix + throw(ArgumentError("Noise equations ill-formed. Recieved a matrix of noise equations of size $(size(neqs)), but `is_scalar_noise` was set to `true`. Scalar noise is only compatible with an `AbstractVector` of noise equations.")) + end end if checks == true || (checks & CheckUnits) > 0 u = __get_unit_type(dvs, ps, iv) @@ -154,7 +162,7 @@ struct SDESystem <: AbstractODESystem new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, cevents, devents, - parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent) + parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise) end end @@ -173,7 +181,11 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv discrete_events = nothing, parameter_dependencies = nothing, metadata = nothing, - gui_metadata = nothing) + gui_metadata = nothing, + complete = false, + index_cache = nothing, + parent = nothing, + is_scalar_noise=false) name === nothing && throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro")) iv′ = value(iv) @@ -208,9 +220,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv parameter_dependencies, ps′ = process_parameter_dependencies( parameter_dependencies, ps′) SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), - deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, - ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, - cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata; checks = checks) + deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, + ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connector_type, + cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata, + complete, index_cache, parent, is_scalar_noise; checks = checks) end function SDESystem(sys::ODESystem, neqs; kwargs...) @@ -225,6 +238,7 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem) isequal(nameof(sys1), nameof(sys2)) && isequal(get_eqs(sys1), get_eqs(sys2)) && isequal(get_noiseeqs(sys1), get_noiseeqs(sys2)) && + isequal(get_is_scalar_noise(sys1), get_is_scalar_noise(sys2)) && _eq_unordered(get_unknowns(sys1), get_unknowns(sys2)) && _eq_unordered(get_ps(sys1), get_ps(sys2)) && all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2))) @@ -601,6 +615,9 @@ function SDEFunctionExpr(sys::SDESystem, args...; kwargs...) SDEFunctionExpr{true}(sys, args...; kwargs...) end + +function scalar_noise end # defined in ../ext/MTKDiffEqNoiseProcess.jl + function DiffEqBase.SDEProblem{iip, specialize}( sys::SDESystem, u0map = [], tspan = get_tspan(sys), parammap = DiffEqBase.NullParameters(); @@ -616,16 +633,24 @@ function DiffEqBase.SDEProblem{iip, specialize}( sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false)) noiseeqs = get_noiseeqs(sys) + is_scalar_noise = get_is_scalar_noise(sys) if noiseeqs isa AbstractVector noise_rate_prototype = nothing + if is_scalar_noise + noise = scalar_noise() + else + noise = nothing + end elseif sparsenoise I, J, V = findnz(SparseArrays.sparse(noiseeqs)) noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0))) + noise = nothing else noise_rate_prototype = zeros(eltype(u0), size(noiseeqs)) + noise = nothing end - SDEProblem{iip}(f, u0, tspan, p; callback = cbs, + SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise, noise_rate_prototype = noise_rate_prototype, kwargs...) end @@ -693,8 +718,12 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan, sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false)) noiseeqs = get_noiseeqs(sys) + is_scalar_noise = get_is_scalar_noise(sys) if noiseeqs isa AbstractVector noise_rate_prototype = nothing + if is_scalar_noise + noise = scalar_noise() + end elseif sparsenoise I, J, V = findnz(SparseArrays.sparse(noiseeqs)) noise_rate_prototype = SparseArrays.sparse(I, J, zero(eltype(u0))) @@ -708,7 +737,8 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan, tspan = $tspan p = $p noise_rate_prototype = $noise_rate_prototype - SDEProblem(f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype, + noise = $noise + SDEProblem(f, u0, tspan, p; noise_rate_prototype = noise_rate_prototype, noise = noise, $(kwargs...)) end !linenumbers ? Base.remove_linenums!(ex) : ex diff --git a/src/systems/systems.jl b/src/systems/systems.jl index e68e8b8cb6..098bf82749 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -126,21 +126,20 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal @views copyto!(sorted_g_rows[i, :], g[g_row, :]) end # Fix for https://github.com/SciML/ModelingToolkit.jl/issues/2490 - noise_eqs = if isdiag(sorted_g_rows) + if isdiag(sorted_g_rows) # If the noise matrix is diagonal, then we just give solver just takes a vector column of equations # and it interprets that as diagonal noise. - diag(sorted_g_rows) + noise_eqs = diag(sorted_g_rows) + is_scalar_noise = false elseif sorted_g_rows isa AbstractMatrix && size(sorted_g_rows, 2) == 1 - ##------------------------------------------------------------------------------- - ## TODO: re-enable this code once we add a way to signal that the noise is scalar - # sorted_g_rows[:, 1] - ##------------------------------------------------------------------------------- - sorted_g_rows + noise_eqs = sorted_g_rows[:, 1] + is_scalar_noise = true else - sorted_g_rows + noise_eqs = sorted_g_rows + is_scalar_noise = false end return SDESystem(full_equations(ode_sys), noise_eqs, get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys); - name = nameof(ode_sys)) + name = nameof(ode_sys), is_scalar_noise) end end diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 2d8eaff3ca..d961798545 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -681,7 +681,7 @@ let ] prob = SDEProblem(de, u0map, (0.0, 100.0), parammap) # TODO: re-enable this when we support scalar noise - @test_broken solve(prob, SOSRI()).retcode == ReturnCode.Success + @test solve(prob, SOSRI()).retcode == ReturnCode.Success end let # test to make sure that scalar noise always recieve the same kicks @@ -692,7 +692,7 @@ let # test to make sure that scalar noise always recieve the same kicks @mtkbuild de = System(eqs, t) prob = SDEProblem(de, [x => 0, y => 0], (0.0, 10.0), []) - sol = solve(prob, ImplicitEM()) + sol = solve(prob, SOSRI()) @test sol[end][1] == sol[end][2] end