Skip to content

Commit

Permalink
fix: convert Symbol to symbolic variables in remake, fix split_idxs
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 2, 2024
1 parent df4b7e5 commit 88530ce
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
if p isa Tuple
ps = Base.Fix1(getindex, parameters(sys)).(split_idxs)
ps = (ps...,) #if p is Tuple, ps should be Tuple
else
# if there is only one type, we don't need split_idxs
split_idxs = nothing
end

if implicit_dae && du0map !== nothing
Expand Down
13 changes: 12 additions & 1 deletion src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,19 @@ function SciMLBase.process_p_u0_symbolic(prob::Union{SciMLBase.AbstractDEProblem
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :ps) ||
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `p` keyword argument as a vector of values, paying attention to parameter order."))
p = [

Check warning on line 145 in src/variables.jl

View check run for this annotation

Codecov / codecov/patch

src/variables.jl#L145

Added line #L145 was not covered by tests
(sym isa Symbol ? parameter_symbols(prob)[parameter_index(prob, sym)] : sym) => val
for (sym, val) in p
]
end
if eltype(u0) <: Pair
hasproperty(prob.f, :sys) && hasfield(typeof(prob.f.sys), :states) ||
throw(ArgumentError("This problem does not support symbolic maps with `remake`, i.e. it does not have a symbolic origin." *
" Please use `remake` with the `u0` keyword argument as a vector of values, paying attention to state order."))
u0 = [

Check warning on line 154 in src/variables.jl

View check run for this annotation

Codecov / codecov/patch

src/variables.jl#L154

Added line #L154 was not covered by tests
(sym isa Symbol ? variable_symbols(prob)[variable_index(prob, sym)] : sym) => val
for (sym, val) in u0
]
end

sys = prob.f.sys
Expand All @@ -165,7 +173,10 @@ function SciMLBase.process_p_u0_symbolic(prob::Union{SciMLBase.AbstractDEProblem
sts = states(sys)
defs = mergedefaults(defs, prob.u0, sts)
defs = mergedefaults(defs, u0, sts)
u0, p, defs = get_u0_p(sys, defs)

u0_defs = Dict(sym => val for (sym, val) in defs if is_variable(prob, sym))
ps_defs = Dict(sym => val for (sym, val) in defs if is_parameter(prob, sym))
u0, p, defs = get_u0_p(sys, u0_defs, ps_defs)

Check warning on line 179 in src/variables.jl

View check run for this annotation

Codecov / codecov/patch

src/variables.jl#L177-L179

Added lines #L177 - L179 were not covered by tests

return p, u0
end
Expand Down

0 comments on commit 88530ce

Please sign in to comment.