Skip to content

Commit

Permalink
Merge pull request #2914 from MasonProtter/another_diag_noise_fix
Browse files Browse the repository at this point in the history
Detect cases where there's fewer brownians than equations, but noise still 'diagonal'
  • Loading branch information
ChrisRackauckas authored Aug 2, 2024
2 parents ec870e2 + fe1a3ee commit e64c479
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 17 deletions.
30 changes: 23 additions & 7 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,30 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
end

function __num_isdiag(mat)
for i in axes(mat, 1), j in axes(mat, 2)
i == j || isequal(mat[i, j], 0) || return false
function __num_isdiag_noise(mat)
for i in axes(mat, 1)
nnz = 0
for j in axes(mat, 2)
if !isequal(mat[i, j], 0)
nnz += 1
end
end
if nnz > 1
return (false)
end
end
true
end
function __get_num_diag_noise(mat)
map(axes(mat, 1)) do i
for j in axes(mat, 2)
mij = mat[i, j]
if !isequal(mij, 0)
return mij
end
end
0
end
return true
end

function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
Expand All @@ -258,9 +277,6 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
if isdde
eqs = delay_to_function(sys, eqs)
end
if eqs isa AbstractMatrix && __num_isdiag(eqs)
eqs = diag(eqs)
end
u = map(x -> time_varying_as_func(value(x), sys), dvs)
p = if has_index_cache(sys) && get_index_cache(sys) !== nothing
reorder_parameters(get_index_cache(sys), ps)
Expand Down
9 changes: 5 additions & 4 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,11 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal
# we get a Nx1 matrix of noise equations, which is a special case known as scalar noise
noise_eqs = sorted_g_rows[:, 1]
is_scalar_noise = true
elseif isdiag(sorted_g_rows)
# If the noise matrix is diagonal, then the solver just takes a vector column of equations
# and it interprets that as diagonal noise.
noise_eqs = diag(sorted_g_rows)
elseif __num_isdiag_noise(sorted_g_rows)
# If each column of the noise matrix has either 0 or 1 non-zero entry, then this is "diagonal noise".
# In this case, the solver just takes a vector column of equations and it interprets that to
# mean that each noise process is independent
noise_eqs = __get_num_diag_noise(sorted_g_rows)
is_scalar_noise = false
else
noise_eqs = sorted_g_rows
Expand Down
40 changes: 34 additions & 6 deletions test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -725,12 +725,12 @@ end

@testset "Non-diagonal noise check" begin
@parameters σ ρ β
@variables x(t) y(t) z(t)
@brownian a b c
eqs = [D(x) ~ σ * (y - x) + 0.1a * x + 0.1b * y,
D(y) ~ x *- z) - y + 0.1b * y,
D(z) ~ x * y - β * z + 0.1c * z]
@mtkbuild de = System(eqs, t)
@variables x(tt) y(tt) z(tt)
@brownian a b c d e f
eqs = [D(x) ~ σ * (y - x) + 0.1a * x + d,
D(y) ~ x *- z) - y + 0.1b * y + e,
D(z) ~ x * y - β * z + 0.1c * z + f]
@mtkbuild de = System(eqs, tt)

u0map = [
x => 1.0,
Expand All @@ -746,5 +746,33 @@ end

prob = SDEProblem(de, u0map, (0.0, 100.0), parammap)
# SOSRI only works for diagonal and scalar noise
@test_throws ErrorException solve(prob, SOSRI()).retcode==ReturnCode.Success
# ImplicitEM does work for non-diagonal noise
@test solve(prob, ImplicitEM()).retcode == ReturnCode.Success
@test size(ModelingToolkit.get_noiseeqs(de)) == (3, 6)
end

@testset "Diagonal noise, less brownians than equations" begin
@parameters σ ρ β
@variables x(tt) y(tt) z(tt)
@brownian a b
eqs = [D(x) ~ σ * (y - x) + 0.1a * x, # One brownian
D(y) ~ x *- z) - y + 0.1b * y, # Another brownian
D(z) ~ x * y - β * z] # no brownians -- still diagonal
@mtkbuild de = System(eqs, tt)

u0map = [
x => 1.0,
y => 0.0,
z => 0.0
]

parammap = [
σ => 10.0,
β => 26.0,
ρ => 2.33
]

prob = SDEProblem(de, u0map, (0.0, 100.0), parammap)
@test solve(prob, SOSRI()).retcode == ReturnCode.Success
end

0 comments on commit e64c479

Please sign in to comment.