Skip to content

Commit

Permalink
Refactor ImperativeAffect into its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
BenChung committed Dec 9, 2024
1 parent 5fcf864 commit 16d3f5c
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 220 deletions.
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ include("systems/parameter_buffer.jl")
include("systems/abstractsystem.jl")
include("systems/model_parsing.jl")
include("systems/connectors.jl")
include("systems/imperative_affect.jl")
include("systems/callbacks.jl")
include("systems/problem_utils.jl")

Expand Down
221 changes: 3 additions & 218 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,111 +73,6 @@ function namespace_affect(affect::FunctionalAffect, s)
context(affect))
end

"""
ImperativeAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx)
`ImperativeAffect` is a helper for writing affect functions that will compute observed values and
ensure that modified values are correctly written back into the system. The affect function `f` needs to have
the signature
```
f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple
```
The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions.
Each declaration`NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x`
so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` currently restricts symbolic expressions to only bare variables, so only tuples of the form
`(; x = y)` or `(; x)` (which aliases `x` as itself) are allowed.
The argument NamedTuples (for instance `(;x=y)`) will be populated with the declared values on function entry; if we require `(;x=y)` in `observed` and `y=2`, for example,
then the NamedTuple `(;x=2)` will be passed as `observed` to the affect function `f`.
The NamedTuple returned from `f` includes the values to be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write
ImperativeAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o
@set! m.x = o.x_plus_y
end
Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in
`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it
in the returned tuple, in which case the associated field will not be updated.
"""
@kwdef struct ImperativeAffect
f::Any
obs::Vector
obs_syms::Vector{Symbol}
modified::Vector
mod_syms::Vector{Symbol}
ctx::Any
skip_checks::Bool
end

function ImperativeAffect(f::Function;
observed::NamedTuple = NamedTuple{()}(()),
modified::NamedTuple = NamedTuple{()}(()),
ctx = nothing,
skip_checks = false)
ImperativeAffect(f,
collect(values(observed)), collect(keys(observed)),
collect(values(modified)), collect(keys(modified)),
ctx, skip_checks)
end
function ImperativeAffect(f::Function, modified::NamedTuple;
observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false)
ImperativeAffect(
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
end
function ImperativeAffect(
f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks = false)
ImperativeAffect(
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
end
function ImperativeAffect(
f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks = false)
ImperativeAffect(
f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks)
end

function Base.show(io::IO, mfa::ImperativeAffect)
obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ")
mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ")
affect = mfa.f
print(io,
"ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)")
end
func(f::ImperativeAffect) = f.f
context(a::ImperativeAffect) = a.ctx
observed(a::ImperativeAffect) = a.obs
observed_syms(a::ImperativeAffect) = a.obs_syms
discretes(a::ImperativeAffect) = filter(ModelingToolkit.isparameter, a.modified)
modified(a::ImperativeAffect) = a.modified
modified_syms(a::ImperativeAffect) = a.mod_syms

function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect)
isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) &&
isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) &&
isequal(a1.ctx, a2.ctx)
end

function Base.hash(a::ImperativeAffect, s::UInt)
s = hash(a.f, s)
s = hash(a.obs, s)
s = hash(a.obs_syms, s)
s = hash(a.modified, s)
s = hash(a.mod_syms, s)
hash(a.ctx, s)
end

function namespace_affect(affect::ImperativeAffect, s)
ImperativeAffect(func(affect),
namespace_expr.(observed(affect), (s,)),
observed_syms(affect),
renamespace.((s,), modified(affect)),
modified_syms(affect),
context(affect),
affect.skip_checks)
end

function has_functional_affect(cb)
(affects(cb) isa FunctionalAffect || affects(cb) isa ImperativeAffect)
end
Expand All @@ -203,13 +98,13 @@ sharp discontinuity between integrator steps (which in this example would not no
guaranteed to be triggered.
Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used
is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active at tc,
is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active as `tc``,
the value in the integrator after windback will be:
* `u[tc-epsilon], p[tc-epsilon], tc` if `LeftRootFind` is used,
* `u[tc+epsilon], p[tc+epsilon], tc` if `RightRootFind` is used,
* or `u[t], p[t], t` if `NoRootFind` is used.
For example, if we want to detect when an unknown variable `x` satisfies `x > 0` using the condition `x ~ 0` on a positive edge (that is, `D(x) > 0`),
then left root finding will get us `x=-epsilon`, right root finding `x=epsilon` and no root finding whatever the next step of the integrator was after
then left root finding will get us `x=-epsilon`, right root finding `x=epsilon` and no root finding will produce whatever the next step of the integrator was after
it passed through 0.
Multiple callbacks in the same system with different `rootfind` operations will be grouped
Expand Down Expand Up @@ -405,7 +300,6 @@ end

namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s)
namespace_affects(::Nothing, s) = nothing

function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
Expand Down Expand Up @@ -480,7 +374,6 @@ scalarize_affects(affects) = scalarize(affects)
scalarize_affects(affects::Tuple) = FunctionalAffect(affects...)
scalarize_affects(affects::NamedTuple) = FunctionalAffect(; affects...)
scalarize_affects(affects::FunctionalAffect) = affects
scalarize_affects(affects::ImperativeAffect) = affects

SymbolicDiscreteCallback(p::Pair) = SymbolicDiscreteCallback(p[1], p[2])
SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough
Expand Down Expand Up @@ -1099,117 +992,9 @@ function check_assignable(sys, sym)
end
end

function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs...)
#=
Implementation sketch:
generate observed function (oop), should save to a component array under obs_syms
do the same stuff as the normal FA for pars_syms
call the affect method
unpack and apply the resulting values
=#
function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup)
seen = Set{Symbol}()
syms_dedup = []
exprs_dedup = []
for (sym, exp) in Iterators.zip(syms, exprs)
if !in(sym, seen)
push!(syms_dedup, sym)
push!(exprs_dedup, exp)
push!(seen, sym)
elseif !affect.skip_checks
@warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used."
end
end
return (syms_dedup, exprs_dedup)
end

obs_exprs = observed(affect)
if !affect.skip_checks
for oexpr in obs_exprs
invalid_vars = invalid_variables(sys, oexpr)
if length(invalid_vars) > 0
error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).")
end
end
end
obs_syms = observed_syms(affect)
obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs)

mod_exprs = modified(affect)
if !affect.skip_checks
for mexpr in mod_exprs
if !check_assignable(sys, mexpr)
@warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.")
end
invalid_vars = unassignable_variables(sys, mexpr)
if length(invalid_vars) > 0
error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.")
end
end
end
mod_syms = modified_syms(affect)
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)

overlapping_syms = intersect(mod_syms, obs_syms)
if length(overlapping_syms) > 0 && !affect.skip_checks
@warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
end

# sanity checks done! now build the data and update function for observed values
mkzero(sz) =
if sz === ()
0.0
else
zeros(sz)
end
obs_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(obs_exprs);
array_type = Tuple)
obs_sym_tuple = (obs_syms...,)

# okay so now to generate the stuff to assign it back into the system
mod_pairs = mod_exprs .=> mod_syms
mod_names = (mod_syms...,)
mod_og_val_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(first.(mod_pairs));
array_type = Tuple)

upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))

if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
save_idxs = get(ic.callback_to_clocks, cb, Int[])
else
save_idxs = Int[]
end

let user_affect = func(affect), ctx = context(affect)
function (integ)
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
upd_component_array = NamedTuple{mod_names}(modvals)

# update the observed values
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
integ.u, integ.p, integ.t))

# let the user do their thing
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)

# write the new values back to the integrator
_generated_writeback(integ, upd_funs, upd_vals)

for idx in save_idxs
SciMLBase.save_discretes!(integ, idx)
end
end
end
end

function compile_affect(
affect::Union{FunctionalAffect, ImperativeAffect}, cb, sys, dvs, ps; kwargs...)
function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...)
compile_user_affect(affect, cb, sys, dvs, ps; kwargs...)
end

function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...)
if isnothing(aff) || aff == default
return nothing
Expand Down
2 changes: 0 additions & 2 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,6 @@ function build_explicit_observed_function(sys, ts;
oop_mtkp_wrapper = mtkparams_wrapper
end

output_expr = isscalar ? ts[1] :
(array_type <: Vector ? MakeArray(ts, output_type) : MakeTuple(ts))
# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
return_value = if isscalar
Expand Down
Loading

0 comments on commit 16d3f5c

Please sign in to comment.