Skip to content

Commit

Permalink
Merge pull request #2820 from AayushSabharwal/as/defaults-indepvar
Browse files Browse the repository at this point in the history
fix: fix initialization with defaults dependent on indepvar
  • Loading branch information
ChrisRackauckas authored Aug 8, 2024
2 parents ba343fa + 64d2063 commit 1482786
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
18 changes: 15 additions & 3 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -675,13 +675,17 @@ Take dictionaries with initial conditions and parameters and convert them to num
function get_u0_p(sys,
u0map,
parammap = nothing;
t0 = nothing,
use_union = true,
tofloat = true,
symbolic_u0 = false)
dvs = unknowns(sys)
ps = parameters(sys)

defs = defaults(sys)
if t0 !== nothing
defs[get_iv(sys)] = t0
end
if parammap !== nothing
defs = mergedefaults(defs, parammap, ps)
end
Expand Down Expand Up @@ -717,14 +721,19 @@ function get_u0_p(sys,
end
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
p = p === nothing ? SciMLBase.NullParameters() : p
t0 !== nothing && delete!(defs, get_iv(sys))
u0, p, defs
end

function get_u0(
sys, u0map, parammap = nothing; symbolic_u0 = false, toterm = default_toterm)
sys, u0map, parammap = nothing; symbolic_u0 = false,
toterm = default_toterm, t0 = nothing)
dvs = unknowns(sys)
ps = parameters(sys)
defs = defaults(sys)
if t0 !== nothing
defs[get_iv(sys)] = t0
end
if parammap !== nothing
defs = mergedefaults(defs, parammap, ps)
end
Expand All @@ -745,6 +754,7 @@ function get_u0(
else
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, toterm)
end
t0 !== nothing && delete!(defs, get_iv(sys))
return u0, defs
end

Expand Down Expand Up @@ -819,20 +829,22 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
end

if has_index_cache(sys) && get_index_cache(sys) !== nothing
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0)
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0,
t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t)
check_eqs_u0(eqs, dvs, u0; kwargs...)
p = if parammap === nothing ||
parammap == SciMLBase.NullParameters() && isempty(defs)
nothing
else
MTKParameters(sys, parammap, trueinit; eval_expression, eval_module)
MTKParameters(sys, parammap, trueinit; t0 = t, eval_expression, eval_module)
end
else
u0, p, defs = get_u0_p(sys,
trueinit,
parammap;
tofloat,
use_union,
t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t,
symbolic_u0)
p, split_idxs = split_parameters_by_type(p)
if p isa Tuple
Expand Down
7 changes: 5 additions & 2 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ end

function MTKParameters(
sys::AbstractSystem, p, u0 = Dict(); tofloat = false, use_union = false,
eval_expression = false, eval_module = @__MODULE__)
ic::IndexCache = if has_index_cache(sys) && get_index_cache(sys) !== nothing
t0 = nothing, eval_expression = false, eval_module = @__MODULE__)
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
get_index_cache(sys)
else
error("Cannot create MTKParameters if system does not have index_cache")
Expand All @@ -43,6 +43,9 @@ function MTKParameters(
defs = merge(defs, u0)
defs = merge(Dict(eq.lhs => eq.rhs for eq in observed(sys)), defs)
bigdefs = merge(defs, p)
if t0 !== nothing
bigdefs[get_iv(sys)] = t0
end
p = Dict()
missing_params = Set()
pdeps = has_parameter_dependencies(sys) ? parameter_dependencies(sys) : nothing
Expand Down
9 changes: 9 additions & 0 deletions test/initial_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,12 @@ prob = ODEProblem(sys, [], (0.0, 1.0), [A1 => 0.3])
@test isempty(ModelingToolkit.defaults(sys))
end
end

# Using indepvar in initialization
# Issue#2799
@variables x(t)
@parameters p
@mtkbuild sys = ODESystem([D(x) ~ p], t; defaults = [x => t, p => 2t])
prob = ODEProblem(structural_simplify(sys), [], (1.0, 2.0), [])
@test prob[x] == 1.0
@test prob.ps[p] == 2.0

0 comments on commit 1482786

Please sign in to comment.