Skip to content

Commit

Permalink
Merge pull request #2399 from SciML/staticarrays
Browse files Browse the repository at this point in the history
Preserve staticarrays in the problem construction
  • Loading branch information
ChrisRackauckas authored Dec 28, 2023
2 parents 7e0917f + a0adfe9 commit 23323bc
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 3 deletions.
5 changes: 5 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ function split_parameters_by_type(ps)
end
tighten_types = x -> identity.(x)
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))

if ps isa StaticArray
parrs = map(x-> SArray{Tuple{size(x)...}}(x), split_ps)
split_ps = SArray{Tuple{size(parrs)...}}(parrs)
end
if length(split_ps) == 1 #Tuple not needed, only 1 type
return split_ps[1], split_idxs
else
Expand Down
4 changes: 4 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,10 @@ function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...)
ODEProblem{true}(sys, args...; kwargs...)
end

function DiffEqBase.ODEProblem(sys::AbstractODESystem, u0map::StaticArray, args...; kwargs...)
ODEProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function DiffEqBase.ODEProblem{true}(sys::AbstractODESystem, args...; kwargs...)
ODEProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end
Expand Down
12 changes: 9 additions & 3 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,16 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
end
end

# T = typeof(varmap)
# We respect the input type (feature removed, not needed with Tuple support)
# We respect the input type if it's a static array
# otherwise canonicalize to a normal array
# container_type = T <: Union{Dict,Tuple} ? Array : T
container_type = Array
if varmap isa StaticArray
container_type = typeof(varmap)
else
container_type = Array
end

@show container_type

vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
varmap = todict(varmap)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ end
@safetestset "Reduction Test" include("reduction.jl")
@safetestset "Split Parameters Test" include("split_parameters.jl")
@safetestset "ODAEProblem Test" include("odaeproblem.jl")
@safetestset "StaticArrays Test" include("static_arrays.jl")
@safetestset "Components Test" include("components.jl")
@safetestset "Model Parsing Test" include("model_parsing.jl")
@safetestset "print_tree" include("print_tree.jl")
Expand Down
28 changes: 28 additions & 0 deletions test/static_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using ModelingToolkit, SciMLBase, StaticArrays, Test

@parameters σ ρ β
@variables t x(t) y(t) z(t)
D = Differential(t)

eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x *- z) - y,
D(z) ~ x * y - β * z]

@named sys = ODESystem(eqs)
sys = structural_simplify(sys)

u0 = @SVector [D(x) => 2.0,
x => 1.0,
y => 0.0,
z => 0.0]

p = @SVector=> 28.0,
ρ => 10.0,
β => 8 / 3]

tspan = (0.0, 100.0)
prob_mtk = ODEProblem(sys, u0, tspan, p)

@test !SciMLBase.isinplace(prob_mtk)
@test prob_mtk.u0 isa SArray
@test prob_mtk.p isa SArray

0 comments on commit 23323bc

Please sign in to comment.