From 3b3d217e5054dd49a85c2d86fd868d0704df38f3 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 27 Dec 2023 16:37:09 -0500 Subject: [PATCH] Preserve staticarrays in the problem construction Fixes https://github.com/SciML/ModelingToolkit.jl/issues/2398. This is pretty crucial for giving users a way to target GPUs. --- src/parameters.jl | 6 ++++++ src/variables.jl | 12 +++++++++--- test/runtests.jl | 1 + test/static_arrays.jl | 27 +++++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 test/static_arrays.jl diff --git a/src/parameters.jl b/src/parameters.jl index 4339cb7acf..0d348398d8 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -100,6 +100,12 @@ function split_parameters_by_type(ps) end tighten_types = x -> identity.(x) split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs)) + + if ps isa StaticArray + parrs = map(x-> SArray{Tuple{size(x)...}}(x), split_ps) + split_ps = SArray{Tuple{size(parrs)...}}(parrs) + end + @show typeof(split_ps) if length(split_ps) == 1 #Tuple not needed, only 1 type return split_ps[1], split_idxs else diff --git a/src/variables.jl b/src/variables.jl index df359354a5..8e30521ec9 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -75,10 +75,16 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, end end - # T = typeof(varmap) - # We respect the input type (feature removed, not needed with Tuple support) + # We respect the input type if it's a static array + # otherwise canonicalize to a normal array # container_type = T <: Union{Dict,Tuple} ? Array : T - container_type = Array + if varmap isa StaticArray + container_type = typeof(varmap) + else + container_type = Array + end + + @show container_type vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs varmap = todict(varmap) diff --git a/test/runtests.jl b/test/runtests.jl index cf0a9652ff..ea179d609b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ end @safetestset "Reduction Test" include("reduction.jl") @safetestset "Split Parameters Test" include("split_parameters.jl") @safetestset "ODAEProblem Test" include("odaeproblem.jl") + @safetestset "StaticArrays Test" include("static_arrays.jl") @safetestset "Components Test" include("components.jl") @safetestset "Model Parsing Test" include("model_parsing.jl") @safetestset "print_tree" include("print_tree.jl") diff --git a/test/static_arrays.jl b/test/static_arrays.jl new file mode 100644 index 0000000000..2bc403fb5d --- /dev/null +++ b/test/static_arrays.jl @@ -0,0 +1,27 @@ +using Catalyst, StaticArrays, Test + +@parameters σ ρ β +@variables t x(t) y(t) z(t) +D = Differential(t) + +eqs = [D(D(x)) ~ σ * (y - x), + D(y) ~ x * (ρ - z) - y, + D(z) ~ x * y - β * z] + +@named sys = ODESystem(eqs) +sys = structural_simplify(sys) + +u0 = @SVector [D(x) => 2.0, + x => 1.0, + y => 0.0, + z => 0.0] + +p = @SVector [σ => 28.0, + ρ => 10.0, + β => 8 / 3] + +tspan = (0.0, 100.0) +prob_mtk = ODEProblem{false, SciMLBase.FullSpecialize}(sys, u0, tspan, p) + +@test prob_mtk.u0 isa SArray +@test prob_mtk.p isa SArray