diff --git a/src/parameters.jl b/src/parameters.jl index 4339cb7acf..ec2b8a8845 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -100,6 +100,11 @@ function split_parameters_by_type(ps) end tighten_types = x -> identity.(x) split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs)) + + if ps isa StaticArray + parrs = map(x-> SArray{Tuple{size(x)...}}(x), split_ps) + split_ps = SArray{Tuple{size(parrs)...}}(parrs) + end if length(split_ps) == 1 #Tuple not needed, only 1 type return split_ps[1], split_idxs else diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index f79859da57..d9ca7b6552 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -912,6 +912,10 @@ function DiffEqBase.ODEProblem(sys::AbstractODESystem, args...; kwargs...) ODEProblem{true}(sys, args...; kwargs...) end +function DiffEqBase.ODEProblem(sys::AbstractODESystem, u0map::StaticArray, args...; kwargs...) + ODEProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...) +end + function DiffEqBase.ODEProblem{true}(sys::AbstractODESystem, args...; kwargs...) ODEProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) end diff --git a/src/variables.jl b/src/variables.jl index df359354a5..8e30521ec9 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -75,10 +75,16 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, end end - # T = typeof(varmap) - # We respect the input type (feature removed, not needed with Tuple support) + # We respect the input type if it's a static array + # otherwise canonicalize to a normal array # container_type = T <: Union{Dict,Tuple} ? Array : T - container_type = Array + if varmap isa StaticArray + container_type = typeof(varmap) + else + container_type = Array + end + + @show container_type vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs varmap = todict(varmap) diff --git a/test/runtests.jl b/test/runtests.jl index cf0a9652ff..ea179d609b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ end @safetestset "Reduction Test" include("reduction.jl") @safetestset "Split Parameters Test" include("split_parameters.jl") @safetestset "ODAEProblem Test" include("odaeproblem.jl") + @safetestset "StaticArrays Test" include("static_arrays.jl") @safetestset "Components Test" include("components.jl") @safetestset "Model Parsing Test" include("model_parsing.jl") @safetestset "print_tree" include("print_tree.jl") diff --git a/test/static_arrays.jl b/test/static_arrays.jl new file mode 100644 index 0000000000..5638fff7f8 --- /dev/null +++ b/test/static_arrays.jl @@ -0,0 +1,28 @@ +using ModelingToolkit, SciMLBase, StaticArrays, Test + +@parameters σ ρ β +@variables t x(t) y(t) z(t) +D = Differential(t) + +eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + +@named sys = ODESystem(eqs) +sys = structural_simplify(sys) + +u0 = @SVector [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0] + +p = @SVector [σ => 28.0, + ρ => 10.0, + β => 8 / 3] + +tspan = (0.0, 100.0) +prob_mtk = ODEProblem(sys, u0, tspan, p) + +@test !SciMLBase.isinplace(prob_mtk) +@test prob_mtk.u0 isa SArray +@test prob_mtk.p isa SArray