Skip to content

Commit

Permalink
Merge pull request #3266 from AayushSabharwal/as/sde-ss
Browse files Browse the repository at this point in the history
feat: enable `structural_simplify(::SDESystem)`
  • Loading branch information
ChrisRackauckas authored Dec 12, 2024
2 parents d01496f + ea9b6bd commit 51aea4a
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 1 deletion.
41 changes: 41 additions & 0 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,47 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem)
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
end

"""
function ODESystem(sys::SDESystem)
Convert an `SDESystem` to the equivalent `ODESystem` using `@brownian` variables instead
of noise equations. The returned system will not be `iscomplete` and will not have an
index cache, regardless of `iscomplete(sys)`.
"""
function ODESystem(sys::SDESystem)
neqs = get_noiseeqs(sys)
eqs = equations(sys)
is_scalar_noise = get_is_scalar_noise(sys)
nbrownian = if is_scalar_noise
length(neqs)
else
size(neqs, 2)
end
brownvars = map(1:nbrownian) do i
name = gensym(Symbol(:brown_, i))
only(@brownian $name)
end
if is_scalar_noise
brownterms = reduce(+, neqs .* brownvars; init = 0)
neweqs = map(eqs) do eq
eq.lhs ~ eq.rhs + brownterms
end
else
if neqs isa AbstractVector
neqs = reshape(neqs, (length(neqs), 1))
end
brownterms = neqs * brownvars
neweqs = map(eqs, brownterms) do eq, brown
eq.lhs ~ eq.rhs + brown
end
end
newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys);
parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys),
continuous_events = continuous_events(sys), discrete_events = discrete_events(sys),
name = nameof(sys), description = description(sys), metadata = get_metadata(sys))
@set newsys.parent = sys
end

function __num_isdiag_noise(mat)
for i in axes(mat, 1)
nnz = 0
Expand Down
4 changes: 4 additions & 0 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ function __structural_simplify(sys::JumpSystem, args...; kwargs...)
return sys
end

function __structural_simplify(sys::SDESystem, args...; kwargs...)
return __structural_simplify(ODESystem(sys), args...; kwargs...)
end

function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
kwargs...)
sys = expand_connections(sys)
Expand Down
4 changes: 3 additions & 1 deletion src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,9 @@ $(SIGNATURES)
Define one or more Brownian variables.
"""
macro brownian(xs...)
all(x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$, xs) ||
all(
x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$ || Meta.isexpr(x, :$),
xs) ||
error("@brownian only takes scalar expressions!")
Symbolics._parse_vars(:brownian,
Real,
Expand Down
59 changes: 59 additions & 0 deletions test/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -809,3 +809,62 @@ end
prob = SDEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
@test prob[z] 2.0
end

@testset "SDESystem to ODESystem" begin
@variables x(t) y(t) z(t)
@testset "Scalar noise" begin
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, z ~ x + y], [x, y, 3],
t, [x, y, z], [], is_scalar_noise = true)
odesys = ODESystem(sys)
@test odesys isa ODESystem
vs = ModelingToolkit.vars(equations(odesys))
nbrownian = count(
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
@test nbrownian == 3
for eq in equations(odesys)
ModelingToolkit.isdiffeq(eq) || continue
@test length(arguments(eq.rhs)) == 4
end
end

@testset "Non-scalar vector noise" begin
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, z ~ x + y], [x, y, 0],
t, [x, y, z], [], is_scalar_noise = false)
odesys = ODESystem(sys)
@test odesys isa ODESystem
vs = ModelingToolkit.vars(equations(odesys))
nbrownian = count(
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
@test nbrownian == 1
for eq in equations(odesys)
ModelingToolkit.isdiffeq(eq) || continue
@test length(arguments(eq.rhs)) == 2
end
end

@testset "Matrix noise" begin
noiseeqs = [x+y y+z z+x
2y 2z 2x
z+1 x+1 y+1]
@named sys = SDESystem([D(x) ~ x, D(y) ~ y, D(z) ~ z], noiseeqs, t, [x, y, z], [])
odesys = ODESystem(sys)
@test odesys isa ODESystem
vs = ModelingToolkit.vars(equations(odesys))
nbrownian = count(
v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs)
@test nbrownian == 3
for eq in equations(odesys)
@test length(arguments(eq.rhs)) == 4
end
end
end

@testset "`structural_simplify(::SDESystem)`" begin
@variables x(t) y(t)
@mtkbuild sys = SDESystem(
[D(x) ~ x, y ~ 2x], [x, 0], t, [x, y], []; is_scalar_noise = true)
@test sys isa SDESystem
@test length(equations(sys)) == 1
@test length(ModelingToolkit.get_noiseeqs(sys)) == 1
@test length(observed(sys)) == 1
end

0 comments on commit 51aea4a

Please sign in to comment.