Skip to content

Commit

Permalink
Merge pull request #2826 from AayushSabharwal/as/m-e-g-a
Browse files Browse the repository at this point in the history
feat: complete `eval_expression` and `eval_module` support
  • Loading branch information
ChrisRackauckas authored Jun 30, 2024
2 parents b0c4b2b + add87d5 commit a203b86
Show file tree
Hide file tree
Showing 13 changed files with 265 additions and 194 deletions.
9 changes: 6 additions & 3 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
disturbance_inputs = disturbances(sys);
implicit_dae = false,
simplify = false,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...)
isempty(inputs) && @warn("No unbound inputs were found in system.")

Expand Down Expand Up @@ -240,7 +242,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
end
process = get_postprocess_fbody(sys)
f = build_function(rhss, args...; postprocess_fbody = process,
expression = Val{false}, kwargs...)
expression = Val{true}, kwargs...)
f = eval_or_rgf.(f; eval_expression, eval_module)
(; f, dvs, ps, io_sys = sys)
end

Expand Down Expand Up @@ -395,7 +398,7 @@ model_outputs = [model.inertia1.w, model.inertia2.w, model.inertia1.phi, model.i
`f_oop` will have an extra state corresponding to the integrator in the disturbance model. This state will not be affected by any input, but will affect the dynamics from where it enters, in this case it will affect additively from `model.torque.tau.u`.
"""
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing; kwargs...)
t = get_iv(sys)
@variables d(t)=0 [disturbance = true]
@variables u(t)=0 [input = true] # New system input
Expand All @@ -418,6 +421,6 @@ function add_input_disturbance(sys, dist::DisturbanceModel, inputs = nothing)
augmented_sys = extend(augmented_sys, sys)

(f_oop, f_ip), dvs, p = generate_control_function(augmented_sys, all_inputs,
[d])
[d]; kwargs...)
(f_oop, f_ip), augmented_sys, dvs, p
end
126 changes: 63 additions & 63 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ time-independent systems. If `split=true` (the default) was passed to [`complete
object.
"""
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing, kwargs...)
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing,
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
end
Expand All @@ -177,28 +178,38 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
if states === nothing
states = sol_states
end
if is_time_dependent(sys)
return build_function(exprs,
fnexpr = if is_time_dependent(sys)
build_function(exprs,
dvs,
p...,
get_iv(sys);
kwargs...,
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs)
wrap_array_vars(sys, exprs; dvs),
expression = Val{true}
)
else
return build_function(exprs,
build_function(exprs,
dvs,
p...;
kwargs...,
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs)
wrap_array_vars(sys, exprs; dvs),
expression = Val{true}
)
end
if expression == Val{true}
return fnexpr
end
if fnexpr isa Tuple
return eval_or_rgf.(fnexpr; eval_expression, eval_module)
else
return eval_or_rgf(fnexpr; eval_expression, eval_module)
end
end

function wrap_assignments(isscalar, assignments; let_block = false)
Expand Down Expand Up @@ -509,7 +520,8 @@ function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
!is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
end

function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if sym isa Symbol
_sym = get(ic.symbol_to_variable, sym, nothing)
Expand All @@ -531,7 +543,8 @@ function SymbolicIndexingInterface.observed(sys::AbstractSystem, sym)
end
end
end
_fn = build_explicit_observed_function(sys, sym)
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)

if is_time_dependent(sys)
return let _fn = _fn
fn1(u, p, t) = _fn(u, p, t)
Expand Down Expand Up @@ -1210,19 +1223,30 @@ end
struct ObservedFunctionCache{S}
sys::S
dict::Dict{Any, Any}
eval_expression::Bool
eval_module::Module
end

function ObservedFunctionCache(sys)
return ObservedFunctionCache(sys, Dict())
let sys = sys, dict = Dict()
function generated_observed(obsvar, args...)
end
end
function ObservedFunctionCache(sys; eval_expression = false, eval_module = @__MODULE__)
return ObservedFunctionCache(sys, Dict(), eval_expression, eval_module)
end

# This is hit because ensemble problems do a deepcopy
function Base.deepcopy_internal(ofc::ObservedFunctionCache, stackdict::IdDict)
sys = deepcopy(ofc.sys)
dict = deepcopy(ofc.dict)
eval_expression = ofc.eval_expression
eval_module = ofc.eval_module
newofc = ObservedFunctionCache(sys, dict, eval_expression, eval_module)
stackdict[ofc] = newofc
return newofc
end

function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(ofc.sys, obsvar)
SymbolicIndexingInterface.observed(
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
eval_module = ofc.eval_module)
end
if args === ()
return obs
Expand Down Expand Up @@ -1871,6 +1895,7 @@ function linearization_function(sys::AbstractSystem, inputs,
p = DiffEqBase.NullParameters(),
zero_dummy_der = false,
initialization_solver_alg = TrustRegion(),
eval_expression = false, eval_module = @__MODULE__,
kwargs...)
inputs isa AbstractVector || (inputs = [inputs])
outputs isa AbstractVector || (outputs = [outputs])
Expand All @@ -1895,85 +1920,58 @@ function linearization_function(sys::AbstractSystem, inputs,
end
x0 = merge(defaults_and_guesses(sys), op)
if has_index_cache(sys) && get_index_cache(sys) !== nothing
sys_ps = MTKParameters(sys, p, x0)
sys_ps = MTKParameters(sys, p, x0; eval_expression, eval_module)
else
sys_ps = varmap_to_vars(p, parameters(sys); defaults = x0)
end
p[get_iv(sys)] = NaN
if has_index_cache(initsys) && get_index_cache(initsys) !== nothing
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op))
oldps = MTKParameters(initsys, p, merge(guesses(sys), defaults(sys), op);
eval_expression, eval_module)
initsys_ps = parameters(initsys)
initsys_idxs = [parameter_index(initsys, param) for param in initsys_ps]
tunable_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Tunable()]
tunable_getter = isempty(tunable_ps) ? nothing : getu(sys, tunable_ps)
discrete_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Discrete()]
disc_getter = isempty(discrete_ps) ? nothing : getu(sys, discrete_ps)
constant_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == SciMLStructures.Constants()]
const_getter = isempty(constant_ps) ? nothing : getu(sys, constant_ps)
nonnum_ps = [initsys_ps[i]
for i in eachindex(initsys_ps)
if initsys_idxs[i].portion == NONNUMERIC_PORTION]
nonnum_getter = isempty(nonnum_ps) ? nothing : getu(sys, nonnum_ps)
p_getter = build_explicit_observed_function(
sys, initsys_ps; eval_expression, eval_module)

u_getter = isempty(unknowns(initsys)) ? (_...) -> nothing :
getu(sys, unknowns(initsys))
get_initprob_u_p = let tunable_getter = tunable_getter,
disc_getter = disc_getter,
const_getter = const_getter,
nonnum_getter = nonnum_getter,
oldps = oldps,
build_explicit_observed_function(
sys, unknowns(initsys); eval_expression, eval_module)
get_initprob_u_p = let p_getter,
p_setter! = setp(initsys, initsys_ps),
u_getter = u_getter

function (u, p, t)
state = ProblemState(; u, p, t)
if tunable_getter !== nothing
SciMLStructures.replace!(
SciMLStructures.Tunable(), oldps, tunable_getter(state))
end
if disc_getter !== nothing
SciMLStructures.replace!(
SciMLStructures.Discrete(), oldps, disc_getter(state))
end
if const_getter !== nothing
SciMLStructures.replace!(
SciMLStructures.Constants(), oldps, const_getter(state))
end
if nonnum_getter !== nothing
SciMLStructures.replace!(
NONNUMERIC_PORTION, oldps, nonnum_getter(state))
end
p_setter!(oldps, p_getter(state))
newu = u_getter(state)
return newu, oldps
end
end
else
get_initprob_u_p = let p_getter = getu(sys, parameters(initsys)),
u_getter = getu(sys, unknowns(initsys))
u_getter = build_explicit_observed_function(
sys, unknowns(initsys); eval_expression, eval_module)

function (u, p, t)
state = ProblemState(; u, p, t)
return u_getter(state), p_getter(state)
end
end
end
initfn = NonlinearFunction(initsys)
initprobmap = getu(initsys, unknowns(sys))
initfn = NonlinearFunction(initsys; eval_expression, eval_module)
initprobmap = build_explicit_observed_function(
initsys, unknowns(sys); eval_expression, eval_module)
ps = full_parameters(sys)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
lin_fun = let diff_idxs = diff_idxs,
alge_idxs = alge_idxs,
input_idxs = input_idxs,
sts = unknowns(sys),
get_initprob_u_p = get_initprob_u_p,
fun = ODEFunction{true, SciMLBase.FullSpecialize}(
sys, unknowns(sys), ps),
sys, unknowns(sys), ps; eval_expression, eval_module),
initfn = initfn,
initprobmap = initprobmap,
h = build_explicit_observed_function(sys, outputs),
h = h,
chunk = ForwardDiff.Chunk(input_idxs),
sys_ps = sys_ps,
initialize = initialize,
Expand Down Expand Up @@ -2056,6 +2054,7 @@ where `x` are differential unknown variables, `z` algebraic variables, `u` input
"""
function linearize_symbolic(sys::AbstractSystem, inputs,
outputs; simplify = false, allow_input_derivatives = false,
eval_expression = false, eval_module = @__MODULE__,
kwargs...)
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(
sys, inputs, outputs; simplify,
Expand All @@ -2065,10 +2064,11 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
ps = full_parameters(sys)
p = reorder_parameters(sys, ps)

fun = generate_function(sys, sts, ps; expression = Val{false})[1]
fun_expr = generate_function(sys, sts, ps; expression = Val{true})[1]
fun = eval_or_rgf(fun_expr; eval_expression, eval_module)
dx = fun(sts, p..., t)

h = build_explicit_observed_function(sys, outputs)
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
y = h(sts, p..., t)

fg_xz = Symbolics.jacobian(dx, sts)
Expand Down
22 changes: 14 additions & 8 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ Notes
- `kwargs` are passed through to `Symbolics.build_function`.
"""
function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
expression = Val{true}, kwargs...)
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
u = map(x -> time_varying_as_func(value(x), sys), dvs)
p = map.(x -> time_varying_as_func(value(x), sys), reorder_parameters(sys, ps))
t = get_iv(sys)
Expand All @@ -353,8 +353,13 @@ function compile_condition(cb::SymbolicDiscreteCallback, sys, dvs, ps;
cmap = map(x -> x => getdefault(x), cs)
condit = substitute(condit, cmap)
end
build_function(condit, u, t, p...; expression, wrap_code = condition_header(sys),
expr = build_function(
condit, u, t, p...; expression = Val{true}, wrap_code = condition_header(sys),
kwargs...)
if expression == Val{true}
return expr
end
return eval_or_rgf(expr; eval_expression, eval_module)
end

function compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)
Expand All @@ -379,7 +384,8 @@ Notes
- `kwargs` are passed through to `Symbolics.build_function`.
"""
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothing,
expression = Val{true}, checkvars = true,
expression = Val{true}, checkvars = true, eval_expression = false,
eval_module = @__MODULE__,
postprocess_affect_expr! = nothing, kwargs...)
if isempty(eqs)
if expression == Val{true}
Expand Down Expand Up @@ -432,20 +438,20 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
end
t = get_iv(sys)
integ = gensym(:MTKIntegrator)
getexpr = (postprocess_affect_expr! === nothing) ? expression : Val{true}
pre = get_preprocess_constants(rhss)
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = getexpr,
rf_oop, rf_ip = build_function(rhss, u, p..., t; expression = Val{true},
wrap_code = add_integrator_header(sys, integ, outvar),
outputidxs = update_inds,
postprocess_fbody = pre,
kwargs...)
# applied user-provided function to the generated expression
if postprocess_affect_expr! !== nothing
postprocess_affect_expr!(rf_ip, integ)
(expression == Val{false}) &&
(return drop_expr(@RuntimeGeneratedFunction(rf_ip)))
end
rf_ip
if expression == Val{false}
return eval_or_rgf(rf_ip; eval_expression, eval_module)
end
return rf_ip
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ function Base.merge(csets::AbstractVector{<:ConnectionSet}, allouter = false)
id2set = Dict{Int, Int}()
merged_set = ConnectionSet[]
for (id, ele) in enumerate(idx2ele)
rid = find_root(union_find, id)
rid = find_root!(union_find, id)
set_idx = get!(id2set, rid) do
set = ConnectionSet()
push!(merged_set, set)
Expand Down
Loading

0 comments on commit a203b86

Please sign in to comment.