Skip to content

Commit

Permalink
Preserve staticarrays in the problem construction
Browse files Browse the repository at this point in the history
Fixes #2398. This is pretty crucial for giving users a way to target GPUs.
  • Loading branch information
ChrisRackauckas committed Dec 27, 2023
1 parent f61f83f commit 3b3d217
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ function split_parameters_by_type(ps)
end
tighten_types = x -> identity.(x)
split_ps = tighten_types.(Base.Fix1(getindex, ps).(split_idxs))

if ps isa StaticArray
parrs = map(x-> SArray{Tuple{size(x)...}}(x), split_ps)
split_ps = SArray{Tuple{size(parrs)...}}(parrs)
end
@show typeof(split_ps)
if length(split_ps) == 1 #Tuple not needed, only 1 type
return split_ps[1], split_idxs
else
Expand Down
12 changes: 9 additions & 3 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,16 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
end
end

# T = typeof(varmap)
# We respect the input type (feature removed, not needed with Tuple support)
# We respect the input type if it's a static array
# otherwise canonicalize to a normal array
# container_type = T <: Union{Dict,Tuple} ? Array : T
container_type = Array
if varmap isa StaticArray
container_type = typeof(varmap)
else
container_type = Array
end

@show container_type

vals = if eltype(varmap) <: Pair # `varmap` is a dict or an array of pairs
varmap = todict(varmap)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ end
@safetestset "Reduction Test" include("reduction.jl")
@safetestset "Split Parameters Test" include("split_parameters.jl")
@safetestset "ODAEProblem Test" include("odaeproblem.jl")
@safetestset "StaticArrays Test" include("static_arrays.jl")
@safetestset "Components Test" include("components.jl")
@safetestset "Model Parsing Test" include("model_parsing.jl")
@safetestset "print_tree" include("print_tree.jl")
Expand Down
27 changes: 27 additions & 0 deletions test/static_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Catalyst, StaticArrays, Test

@parameters σ ρ β
@variables t x(t) y(t) z(t)
D = Differential(t)

eqs = [D(D(x)) ~ σ * (y - x),
D(y) ~ x *- z) - y,
D(z) ~ x * y - β * z]

@named sys = ODESystem(eqs)
sys = structural_simplify(sys)

u0 = @SVector [D(x) => 2.0,
x => 1.0,
y => 0.0,
z => 0.0]

p = @SVector=> 28.0,
ρ => 10.0,
β => 8 / 3]

tspan = (0.0, 100.0)
prob_mtk = ODEProblem{false, SciMLBase.FullSpecialize}(sys, u0, tspan, p)

@test prob_mtk.u0 isa SArray
@test prob_mtk.p isa SArray

0 comments on commit 3b3d217

Please sign in to comment.