Skip to content

Commit

Permalink
Merge pull request #2973 from AayushSabharwal/as/fix-everything
Browse files Browse the repository at this point in the history
fix: fix several bugs, get MTK to precompile
  • Loading branch information
ChrisRackauckas authored Aug 19, 2024
2 parents dc98d54 + df80490 commit 24ca6e9
Show file tree
Hide file tree
Showing 10 changed files with 19 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.28"
SymbolicUtils = "3.1.2"
SymbolicUtils = "3.2"
Symbolics = "6"
URIs = "1"
UnPack = "0.1, 1.0"
Expand Down
5 changes: 2 additions & 3 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2266,14 +2266,13 @@ function linearization_function(sys::AbstractSystem, inputs,
end
x0 = merge(defaults_and_guesses(sys), op)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
sys_ps = MTKParameters(sys, p, x0; eval_expression, eval_module)
sys_ps = MTKParameters(sys, p, x0)
else
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
end
p[get_iv(sys)] = NaN
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op);
eval_expression, eval_module)
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
initsys_ps = parameters(initsys)
p_getter = build_explicit_observed_function(
sys, initsys_ps; eval_expression, eval_module)
Expand Down
10 changes: 5 additions & 5 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ function get_u0_p(sys,
if symbolic_u0
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false, use_union = false)
else
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union)
end
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
p = p === nothing ? SciMLBase.NullParameters() : p
Expand All @@ -732,7 +732,7 @@ end

function get_u0(
sys, u0map, parammap = nothing; symbolic_u0 = false,
toterm = default_toterm, t0 = nothing)
toterm = default_toterm, t0 = nothing, use_union = true)
dvs = unknowns(sys)
ps = parameters(sys)
defs = defaults(sys)
Expand All @@ -757,7 +757,7 @@ function get_u0(
u0 = varmap_to_vars(
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)
else
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, toterm)
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true, use_union, toterm)
end
t0 !== nothing && delete!(defs, get_iv(sys))
return u0, defs
Expand Down Expand Up @@ -836,13 +836,13 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;

if has_index_cache(sys) && get_index_cache(sys) !== nothing
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0,
t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t)
t0 = constructor <: Union{DDEFunction, SDDEFunction} ? nothing : t, use_union)
check_eqs_u0(eqs, dvs, u0; kwargs...)
p = if parammap === nothing ||
parammap == SciMLBase.NullParameters() && isempty(defs)
nothing
else
MTKParameters(sys, parammap, trueinit; t0 = t, eval_expression, eval_module)
MTKParameters(sys, parammap, trueinit; t0 = t)
end
else
u0, p, defs = get_u0_p(sys,
Expand Down
2 changes: 1 addition & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, paramm
end
if has_index_cache(sys) && get_index_cache(sys) !== nothing
u0, defs = get_u0(sys, trueu0map, parammap)
p = MTKParameters(sys, parammap, trueu0map; eval_expression, eval_module)
p = MTKParameters(sys, parammap, trueu0map)
else
u0, p, defs = get_u0_p(sys, trueu0map, parammap; tofloat, use_union)
end
Expand Down
4 changes: 2 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end
Expand Down Expand Up @@ -458,7 +458,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end
Expand Down
2 changes: 1 addition & 1 deletion src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
if has_index_cache(sys) && get_index_cache(sys) !== nothing
u0, defs = get_u0(sys, u0map, parammap)
check_eqs_u0(eqs, dvs, u0; kwargs...)
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
p = MTKParameters(sys, parammap, u0map)
else
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
check_eqs_u0(eqs, dvs, u0; kwargs...)
Expand Down
4 changes: 2 additions & 2 deletions src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
if parammap isa MTKParameters
p = parammap
elseif has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end
Expand Down Expand Up @@ -524,7 +524,7 @@ function OptimizationProblemExpr{iip}(sys::OptimizationSystem, u0map,

u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
p = MTKParameters(sys, parammap, u0map; eval_expression, eval_module)
p = MTKParameters(sys, parammap, u0map)
else
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
end
Expand Down
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,9 @@ function promote_to_concrete(vs; tofloat = true, use_union = true)
if use_union
C = Union{C, E}
else
@assert C==E "`promote_to_concrete` can't make type $E uniform with $C"
C = E
C2 = promote_type(C, E)
@assert C2==E || C2==C "`promote_to_concrete` can't make type $E uniform with $C"
C = C2
end
end

Expand Down
3 changes: 1 addition & 2 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,11 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
val = unwrap(fixpoint_sub(var, varmap; operator = Symbolics.Operator))
if !isequal(val, var)
values[var] = val
T = promote_type(T, typeof(val))
end
end
missingvars = setdiff(varlist, collect(keys(values)))
check && (isempty(missingvars) || throw(MissingVariablesError(missingvars)))
return [T(values[unwrap(var)]) for var in varlist]
return [values[unwrap(var)] for var in varlist]
end

function varmap_with_toterm(varmap; toterm = Symbolics.diff2term)
Expand Down
2 changes: 1 addition & 1 deletion test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ end
sys = structural_simplify(ODESystem([D(x) ~ P], t, [x], [P]; name = :sys))

function x_at_1(P)
prob = ODEProblem(sys, [x => P], (0.0, 1.0), [sys.P => P])
prob = ODEProblem(sys, [x => P], (0.0, 1.0), [sys.P => P], use_union = false)
return solve(prob, Tsit5())(1.0)
end

Expand Down

0 comments on commit 24ca6e9

Please sign in to comment.