Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove parameter dependencies from MTKParameters #2934

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ SimpleNonlinearSolve = "0.1.0, 1"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.3.26"
SymbolicIndexingInterface = "0.3.28"
SymbolicUtils = "2.1"
Symbolics = "5.32"
URIs = "1"
Expand Down
4 changes: 3 additions & 1 deletion src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
end
process = get_postprocess_fbody(sys)
f = build_function(rhss, args...; postprocess_fbody = process,
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps), kwargs...)
expression = Val{true}, wrap_code = wrap_array_vars(sys, rhss; dvs, ps) .∘
wrap_parameter_dependencies(sys, false),
kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
(; f, dvs, ps, io_sys = sys)
end
Expand Down
112 changes: 58 additions & 54 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function calculate_hessian end

"""
```julia
generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_tgrad(sys::AbstractTimeDependentSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; kwargs...)
```

Expand All @@ -93,7 +93,7 @@ function generate_tgrad end

"""
```julia
generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_gradient(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; kwargs...)
```

Expand All @@ -104,7 +104,7 @@ function generate_gradient end

"""
```julia
generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_jacobian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; sparse = false, kwargs...)
```

Expand All @@ -115,7 +115,7 @@ function generate_jacobian end

"""
```julia
generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_factorized_W(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; sparse = false, kwargs...)
```

Expand All @@ -126,7 +126,7 @@ function generate_factorized_W end

"""
```julia
generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_hessian(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; sparse = false, kwargs...)
```

Expand All @@ -137,7 +137,7 @@ function generate_hessian end

"""
```julia
generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = full_parameters(sys),
generate_function(sys::AbstractSystem, dvs = unknowns(sys), ps = parameters(sys),
expression = Val{true}; kwargs...)
```

Expand All @@ -148,7 +148,7 @@ function generate_function end
"""
```julia
generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
```

Generate a function to evaluate `exprs`. `exprs` is a symbolic expression or
Expand Down Expand Up @@ -187,7 +187,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs),
wrap_array_vars(sys, exprs; dvs) .∘
wrap_parameter_dependencies(sys, isscalar),
expression = Val{true}
)
else
Expand All @@ -198,7 +199,8 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs),
wrap_array_vars(sys, exprs; dvs) .∘
wrap_parameter_dependencies(sys, isscalar),
expression = Val{true}
)
end
Expand All @@ -223,6 +225,10 @@ function wrap_assignments(isscalar, assignments; let_block = false)
end
end

function wrap_parameter_dependencies(sys::AbstractSystem, isscalar)
wrap_assignments(isscalar, [eq.lhs ← eq.rhs for eq in parameter_dependencies(sys)])
end

function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), inputs = nothing)
isscalar = !(exprs isa AbstractArray)
Expand Down Expand Up @@ -757,7 +763,7 @@ function SymbolicIndexingInterface.get_all_timeseries_indexes(sys::AbstractSyste
end

function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
return full_parameters(sys)
return parameters(sys)
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
Expand Down Expand Up @@ -1214,11 +1220,6 @@ function namespace_guesses(sys)
Dict(unknowns(sys, k) => namespace_expr(v, sys) for (k, v) in guess)
end

function namespace_parameter_dependencies(sys)
pdeps = parameter_dependencies(sys)
Dict(parameters(sys, k) => namespace_expr(v, sys) for (k, v) in pdeps)
end

function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
eqs = equations(sys)
isempty(eqs) && return Equation[]
Expand Down Expand Up @@ -1325,25 +1326,11 @@ function parameters(sys::AbstractSystem)
ps = first.(ps)
end
systems = get_systems(sys)
result = unique(isempty(systems) ? ps :
[ps; reduce(vcat, namespace_parameters.(systems))])
if has_parameter_dependencies(sys) &&
(pdeps = parameter_dependencies(sys)) !== nothing
filter(result) do sym
!haskey(pdeps, sym)
end
else
result
end
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
end

function dependent_parameters(sys::AbstractSystem)
if has_parameter_dependencies(sys) &&
!isempty(parameter_dependencies(sys))
collect(keys(parameter_dependencies(sys)))
else
[]
end
return map(eq -> eq.lhs, parameter_dependencies(sys))
end

"""
Expand All @@ -1353,17 +1340,19 @@ Get the parameter dependencies of the system `sys` and its subsystems.
See also [`defaults`](@ref) and [`ModelingToolkit.get_parameter_dependencies`](@ref).
"""
function parameter_dependencies(sys::AbstractSystem)
pdeps = get_parameter_dependencies(sys)
if isnothing(pdeps)
pdeps = Dict()
if !has_parameter_dependencies(sys)
return Equation[]
end
pdeps = get_parameter_dependencies(sys)
systems = get_systems(sys)
isempty(systems) && return pdeps
for subsys in systems
pdeps = merge(pdeps, namespace_parameter_dependencies(subsys))
end
# @info pdeps
return pdeps
# put pdeps after those of subsystems to maintain topological sorted order
return vcat(
reduce(vcat,
[map(eq -> namespace_equation(eq, s), parameter_dependencies(s))
for s in systems];
init = Equation[]),
pdeps
)
end

function full_parameters(sys::AbstractSystem)
Expand Down Expand Up @@ -2317,7 +2306,7 @@ function linearization_function(sys::AbstractSystem, inputs,
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
initprobmap = build_explicit_observed_function(
initsys, unknowns(sys); eval_expression, eval_module)
ps = full_parameters(sys)
ps = parameters(sys)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
Expand Down Expand Up @@ -2420,7 +2409,7 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
kwargs...)
sts = unknowns(sys)
t = get_iv(sys)
ps = full_parameters(sys)
ps = parameters(sys)
p = reorder_parameters(sys, ps)

fun_expr = generate_function(sys, sts, ps; expression = Val{true})[1]
Expand Down Expand Up @@ -2852,7 +2841,7 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam
eqs = union(get_eqs(basesys), get_eqs(sys))
sts = union(get_unknowns(basesys), get_unknowns(sys))
ps = union(get_ps(basesys), get_ps(sys))
dep_ps = union_nothing(parameter_dependencies(basesys), parameter_dependencies(sys))
dep_ps = union(parameter_dependencies(basesys), parameter_dependencies(sys))
obs = union(get_observed(basesys), get_observed(sys))
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
Expand Down Expand Up @@ -2956,15 +2945,28 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
end

function process_parameter_dependencies(pdeps, ps)
pdeps === nothing && return pdeps, ps
if pdeps isa Vector && eltype(pdeps) <: Pair
pdeps = Dict(pdeps)
elseif !(pdeps isa Dict)
error("parameter_dependencies must be a `Dict` or `Vector{<:Pair}`")
if pdeps === nothing || isempty(pdeps)
return Equation[], ps
elseif eltype(pdeps) <: Pair
pdeps = [lhs ~ rhs for (lhs, rhs) in pdeps]
end

if !(eltype(pdeps) <: Equation)
error("Parameter dependencies must be a `Dict`, `Vector{Pair}` or `Vector{Equation}`")
end
lhss = BasicSymbolic[]
for p in pdeps
if !isparameter(p.lhs)
error("LHS of parameter dependency must be a single parameter. Found $(p.lhs).")
end
syms = vars(p.rhs)
if !all(isparameter, syms)
error("RHS of parameter dependency must only include parameters. Found $(p.rhs)")
end
push!(lhss, p.lhs)
end
pdeps = topsort_equations(pdeps, union(ps, lhss))
ps = filter(ps) do p
!haskey(pdeps, p)
!any(isequal(p), lhss)
end
return pdeps, ps
end
Expand Down Expand Up @@ -2997,12 +2999,14 @@ function dump_parameters(sys::AbstractSystem)
end
meta
end
pdep_metas = map(collect(keys(pdeps))) do sym
val = pdeps[sym]
pdep_metas = map(pdeps) do eq
sym = eq.lhs
val = eq.rhs
meta = dump_variable_metadata(sym)
defs[eq.lhs] = eq.rhs
meta = merge(meta,
(; dependency = pdeps[sym],
default = symbolic_evaluate(pdeps[sym], merge(defs, pdeps))))
(; dependency = val,
default = symbolic_evaluate(val, defs)))
return meta
end
return vcat(metas, pdep_metas)
Expand Down
25 changes: 16 additions & 9 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
end
expr = build_function(
condit, u, t, p...; expression = Val{true},
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps),
wrap_code = condition_header(sys) .∘ wrap_array_vars(sys, condit; dvs, ps) .∘
wrap_parameter_dependencies(sys, !(condit isa AbstractArray)),
kwargs...)
if expression == Val{true}
return expr
Expand Down Expand Up @@ -497,7 +498,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = add_integrator_header(sys, integ, outvar) .∘
wrap_array_vars(sys, rhss; dvs, ps = _ps),
wrap_array_vars(sys, rhss; dvs, ps = _ps) .∘
wrap_parameter_dependencies(sys, false),
outputidxs = update_inds,
postprocess_fbody = pre,
kwargs...)
Expand All @@ -513,7 +515,7 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
end

function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
cbs = continuous_events(sys)
isempty(cbs) && return nothing
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
Expand All @@ -524,7 +526,7 @@ generate_rootfinding_callback and thus we can produce a ContinuousCallback inste
"""
function generate_single_rootfinding_callback(
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
if !isequal(eq.lhs, 0)
eq = 0 ~ eq.lhs - eq.rhs
end
Expand All @@ -547,7 +549,7 @@ end

function generate_vector_rootfinding_callback(
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
# fuse equations to create VectorContinuousCallback
Expand Down Expand Up @@ -617,7 +619,7 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
total_eqs = sum(num_eqs)
Expand Down Expand Up @@ -660,10 +662,15 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
v_inds = map(sym -> dvs_ind[sym], unknowns(affect))

if has_index_cache(sys) && get_index_cache(sys) !== nothing
p_inds = [parameter_index(sys, sym) for sym in parameters(affect)]
p_inds = [if (pind = parameter_index(sys, sym)) === nothing
sym
else
pind
end
for sym in parameters(affect)]
else
ps_ind = Dict(reverse(en) for en in enumerate(ps))
p_inds = map(sym -> ps_ind[sym], parameters(affect))
p_inds = map(sym -> get(ps_ind, sym, sym), parameters(affect))
end
# HACK: filter out eliminated symbols. Not clear this is the right thing to do
# (MTK should keep these symbols)
Expand Down Expand Up @@ -711,7 +718,7 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! =
end

function generate_discrete_callbacks(sys::AbstractSystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
ps = parameters(sys); kwargs...)
has_discrete_events(sys) || return nothing
symcbs = discrete_events(sys)
isempty(symcbs) && return nothing
Expand Down
Loading
Loading