Skip to content

Commit

Permalink
Switch MutatingFunctionalAffect from using ComponentArrays to using N…
Browse files Browse the repository at this point in the history
…amedTuples for heterotyped operation support.
  • Loading branch information
BenChung committed Sep 24, 2024
1 parent eb2966e commit 3e1637d
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 75 deletions.
78 changes: 39 additions & 39 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,29 @@ end
`MutatingFunctionalAffect` is a helper for writing affect functions that will compute observed values and
ensure that modified values are correctly written back into the system. The affect function `f` needs to have
one of four signatures:
* `f(modified::ComponentArray)` if the function only writes values (unknowns or parameters) to the system,
* `f(modified::ComponentArray, observed::ComponentArray)` if the function also reads observed values from the system,
* `f(modified::ComponentArray, observed::ComponentArray, ctx)` if the function needs the user-defined context,
* `f(modified::ComponentArray, observed::ComponentArray, ctx, integrator)` if the function needs the low-level integrator.
* `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system,
* `f(modified::NamedTuple, observed::NamedTuple)::NamedTuple` if the function also reads observed values from the system,
* `f(modified::NamedTuple, observed::NamedTuple, ctx)::NamedTuple` if the function needs the user-defined context,
* `f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple` if the function needs the low-level integrator.
These will be checked in reverse order (that is, the four-argument version first, than the 3, etc).
The function `f` will be called with `observed` and `modified` `ComponentArray`s that are derived from their respective `NamedTuple` definitions.
Each `NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x`
The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions.
Each declaration`NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x`
so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` currently restricts symbolic expressions to only bare variables, so only tuples of the form
`(; x = y)` or `(; x)` (which aliases `x` as itself) are allowed.
Both `observed` and `modified` will be automatically populated with the current values of their corresponding expressions on function entry.
The values in `modified` will be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write
The argument NamedTuples (for instance `(;x=y)`) will be populated with the declared values on function entry; if we require `(;x=y)` in `observed` and `y=2`, for example,
then the NamedTuple `(;x=2)` will be passed as `observed` to the affect function `f`.
The NamedTuple returned from `f` includes the values to be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write
MutatingFunctionalAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o
m.x = o.x_plus_y
@set! m.x = o.x_plus_y
end
The affect function updates the value at `x` in `modified` to be the result of evaluating `x + y` as passed in the observed values.
Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in
`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it
in the returned tuple, in which case the associated field will not be updated.
"""
@kwdef struct MutatingFunctionalAffect
f::Any
Expand Down Expand Up @@ -983,6 +987,18 @@ function unassignable_variables(sys, expr)
x -> !any(isequal(x), assignable_syms), written)
end

@generated function _generated_writeback(integ, setters::NamedTuple{NS1,<:Tuple}, values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2}
setter_exprs = []
for name in NS2
if !(name in NS1)
missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to."
error(missing_name)
end
push!(setter_exprs, :(setters.$name(integ, values.$name)))
end
return :(begin $(setter_exprs...) end)
end

function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; kwargs...)
#=
Implementation sketch:
Expand Down Expand Up @@ -1016,7 +1032,6 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
end
obs_syms = observed_syms(affect)
obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs)
obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size

mod_exprs = modified(affect)
for mexpr in mod_exprs
Expand All @@ -1033,8 +1048,6 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
end
mod_syms = modified_syms(affect)
mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs)
_, mod_og_val_fun = build_explicit_observed_function(
sys, mod_exprs; return_inplace = true)

overlapping_syms = intersect(mod_syms, obs_syms)
if length(overlapping_syms) > 0 && !affect.skip_checks
Expand All @@ -1048,31 +1061,20 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
else
zeros(sz)
end
_, obs_fun = build_explicit_observed_function(
obs_fun = build_explicit_observed_function(
sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []);
return_inplace = true)
obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms...,)}(mkzero.(obs_size)))
array_type = :tuple)
obs_sym_tuple = (obs_syms...,)

# okay so now to generate the stuff to assign it back into the system
# note that we reorder the componentarray to make the views coherent wrt the base array
mod_pairs = mod_exprs .=> mod_syms
mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs)
mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs)
_, mod_og_val_fun = build_explicit_observed_function(
sys, reduce(vcat, Symbolics.scalarize.([first.(mod_param_pairs); first.(mod_unk_pairs)]); init = []);
return_inplace = true)
upd_params_fun = setu(
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = []))
upd_unk_fun = setu(
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = []))

upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs);
last.(mod_unk_pairs)]...,)}(
[collect(mkzero(size(e)) for e in first.(mod_param_pairs));
collect(mkzero(size(e)) for e in first.(mod_unk_pairs))]))
upd_params_view = view(upd_component_array, last.(mod_param_pairs))
upd_unks_view = view(upd_component_array, last.(mod_unk_pairs))
mod_names = (mod_syms..., )
mod_og_val_fun = build_explicit_observed_function(
sys, reduce(vcat, Symbolics.scalarize.(first.(mod_pairs)); init = []);
array_type = :tuple)

upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))

if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
save_idxs = get(ic.callback_to_clocks, cb, Int[])
else
Expand All @@ -1082,13 +1084,13 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
let user_affect = func(affect), ctx = context(affect)
function (integ)
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
mod_og_val_fun(upd_component_array, integ.u, integ.p..., integ.t)
upd_component_array = NamedTuple{mod_names}(mod_og_val_fun(integ.u, integ.p..., integ.t))

# update the observed values
obs_fun(obs_component_array, integ.u, integ.p..., integ.t)
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p..., integ.t))

# let the user do their thing
if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ)
modvals = if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ)
user_affect(upd_component_array, obs_component_array, ctx, integ)
elseif applicable(user_affect, upd_component_array, obs_component_array, ctx)
user_affect(upd_component_array, obs_component_array, ctx)
Expand All @@ -1102,9 +1104,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps;
end

# write the new values back to the integrator
upd_params_fun(integ, upd_params_view)
upd_unk_fun(integ, upd_unks_view)

_generated_writeback(integ, upd_funs, modvals)

for idx in save_idxs
SciMLBase.save_discretes!(integ, idx)
Expand Down
11 changes: 6 additions & 5 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ Options not otherwise specified are:
* `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail.
* `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist
* `drop_expr` is deprecated.
* `array_type`; only used if the output is an array (that is, `!isscalar(ts)`). If `:array`, then it will generate an array, if `:tuple` then it will generate a tuple.
"""
function build_explicit_observed_function(sys, ts;
inputs = nothing,
Expand All @@ -421,7 +422,8 @@ function build_explicit_observed_function(sys, ts;
return_inplace = false,
param_only = false,
op = Operator,
throw = true)
throw = true,
array_type=:array)
if (isscalar = symbolic_type(ts) !== NotSymbolic())
ts = [ts]
end
Expand Down Expand Up @@ -536,12 +538,11 @@ function build_explicit_observed_function(sys, ts;
wrap_array_vars(sys, ts; ps = _ps, inputs) .∘
wrap_parameter_dependencies(sys, isscalar)
end

output_expr = isscalar ? ts[1] : (array_type == :array ? MakeArray(ts, output_type) : MakeTuple(ts))
# Need to keep old method of building the function since it uses `output_type`,
# which can't be provided to `build_function`
oop_fn = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
false))) |> array_wrapper[1] |> toexpr
oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |> toexpr
oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module)

if !isscalar
Expand Down
81 changes: 50 additions & 31 deletions test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ModelingToolkit: SymbolicContinuousCallback,
using StableRNGs
import SciMLBase
using SymbolicIndexingInterface
using Setfield
rng = StableRNG(12345)

@variables x(t) = 0
Expand Down Expand Up @@ -1010,12 +1011,12 @@ end
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c
x.furnace_on = false
@set! x.furnace_on = false
end)
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_on_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c
x.furnace_on = true
@set! x.furnace_on = true
end)
@named sys = ODESystem(
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
Expand All @@ -1027,12 +1028,12 @@ end
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i
x.furnace_on = false
@set! x.furnace_on = false
end)
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_on_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i
x.furnace_on = true
@set! x.furnace_on = true
end)
@named sys = ODESystem(
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
Expand All @@ -1044,12 +1045,12 @@ end
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o
x.furnace_on = false
@set! x.furnace_on = false
end)
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_on_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o
x.furnace_on = true
@set! x.furnace_on = true
end)
@named sys = ODESystem(
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
Expand All @@ -1061,12 +1062,12 @@ end
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x
x.furnace_on = false
@set! x.furnace_on = false
end)
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_on_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x
x.furnace_on = true
@set! x.furnace_on = true
end)
@named sys = ODESystem(
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
Expand All @@ -1078,14 +1079,14 @@ end
furnace_off = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x
x.furnace_on = false
@set! x.furnace_on = false
end; initialize = ModelingToolkit.MutatingFunctionalAffect(modified = (; temp)) do x
x.temp = 0.2
@set! x.temp = 0.2
end)
furnace_enable = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_on_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, c, i
x.furnace_on = true
@set! x.furnace_on = true
end)
@named sys = ODESystem(
eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable])
Expand All @@ -1107,7 +1108,7 @@ end
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(
modified = (; furnace_on), observed = (; furnace_on)) do x, o, c, i
x.furnace_on = false
@set! x.furnace_on = false
end)
@named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off])
ss = structural_simplify(sys)
Expand All @@ -1123,7 +1124,7 @@ end
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(
modified = (; furnace_on, tempsq), observed = (; furnace_on)) do x, o, c, i
x.furnace_on = false
@set! x.furnace_on = false
end)
@named sys = ODESystem(
eqs, t, [temp, tempsq], params; continuous_events = [furnace_off])
Expand All @@ -1136,18 +1137,32 @@ end
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on),
observed = (; furnace_on, not_actually_here)) do x, o, c, i
x.furnace_on = false
@set! x.furnace_on = false
end)
@named sys = ODESystem(
eqs, t, [temp, tempsq], params; continuous_events = [furnace_off])
ss = structural_simplify(sys)
@test_throws "refers to missing variable(s)" prob=ODEProblem(
ss, [temp => 0.0, furnace_on => true], (0.0, 100.0))


furnace_off = ModelingToolkit.SymbolicContinuousCallback(
[temp ~ furnace_off_threshold],
ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on),
observed = (; furnace_on)) do x, o, c, i
return (;fictional2 = false)
end)
@named sys = ODESystem(
eqs, t, [temp, tempsq], params; continuous_events = [furnace_off])
ss = structural_simplify(sys)
prob=ODEProblem(
ss, [temp => 0.0, furnace_on => true], (0.0, 100.0))
@test_throws "Tried to write back to" solve(prob, Tsit5())
end

@testset "Quadrature" begin
@variables theta(t) omega(t)
params = @parameters qA=0 qB=0 hA=0 hB=0 cnt=0
params = @parameters qA=0 qB=0 hA=0 hB=0 cnt::Int=0
eqs = [D(theta) ~ omega
omega ~ 1.0]
function decoder(oldA, oldB, newA, newB)
Expand All @@ -1167,31 +1182,35 @@ end
end
qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0],
ModelingToolkit.MutatingFunctionalAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c
x.hA = x.qA
x.hB = o.qB
x.qA = 1
x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
@set! x.hA = x.qA
@set! x.hB = o.qB
@set! x.qA = 1
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
x
end,
affect_neg = ModelingToolkit.MutatingFunctionalAffect(
(; qA, hA, hB, cnt), (; qB)) do x, o, c, i
x.hA = x.qA
x.hB = o.qB
x.qA = 0
x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
@set! x.hA = x.qA
@set! x.hB = o.qB
@set! x.qA = 0
@set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB)
x
end; rootfind = SciMLBase.RightRootFind)
qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0],
ModelingToolkit.MutatingFunctionalAffect((; qB, hA, hB, cnt), (; qA)) do x, o, i, c
x.hA = o.qA
x.hB = x.qB
x.qB = 1
x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
@set! x.hA = o.qA
@set! x.hB = x.qB
@set! x.qB = 1
@set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
x
end,
affect_neg = ModelingToolkit.MutatingFunctionalAffect(
(; qB, hA, hB, cnt), (; qA)) do x, o, c, i
x.hA = o.qA
x.hB = x.qB
x.qB = 0
x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
@set! x.hA = o.qA
@set! x.hB = x.qB
@set! x.qB = 0
@set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB)
x
end; rootfind = SciMLBase.RightRootFind)
@named sys = ODESystem(
eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt])
Expand Down

0 comments on commit 3e1637d

Please sign in to comment.