Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for the initializealg argument in SciMLBase callbacks
Browse files Browse the repository at this point in the history
BenChung committed Sep 24, 2024
1 parent 725079e commit eb2966e
Showing 2 changed files with 34 additions and 13 deletions.
41 changes: 31 additions & 10 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
@@ -216,6 +216,11 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either:
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
* A [`MutatingFunctionalAffect`](@ref); refer to its documentation for details.
Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`).
This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate
that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of
initialization.
"""
struct SymbolicContinuousCallback
eqs::Vector{Equation}
@@ -224,14 +229,16 @@ struct SymbolicContinuousCallback
affect::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect}
affect_neg::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing}
rootfind::SciMLBase.RootfindOpt
reinitializealg::SciMLBase.DAEInitializationAlgorithm
function SymbolicContinuousCallback(;
eqs::Vector{Equation},
affect = NULL_AFFECT,
affect_neg = affect,
rootfind = SciMLBase.LeftRootFind,
initialize=NULL_AFFECT,
finalize=NULL_AFFECT)
new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind)
finalize=NULL_AFFECT,
reinitializealg=SciMLBase.CheckInit())
new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg)
end # Default affect to nothing
end
make_affect(affect) = affect
@@ -373,6 +380,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback})
mapreduce(finalize_affects, vcat, cbs, init = Equation[])
end

reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg
reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) =
mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])

namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
namespace_affects(af::MutatingFunctionalAffect, s) = namespace_affect(af, s)
@@ -419,11 +430,12 @@ struct SymbolicDiscreteCallback
# TODO: Iterative
condition::Any
affects::Any
reinitializealg::SciMLBase.DAEInitializationAlgorithm

function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT)
function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT, reinitializealg=SciMLBase.CheckInit())
c = scalarize_condition(condition)
a = scalarize_affects(affects)
new(c, a)
new(c, a, reinitializealg)
end # Default affect to nothing
end

@@ -481,6 +493,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback})
reduce(vcat, affects(cb) for cb in cbs; init = [])
end

reinitialization_alg(cb::SymbolicDiscreteCallback) = cb.reinitializealg
reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) =
mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[])

function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback
af = affects(cb)
af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s)
@@ -776,12 +792,13 @@ function generate_single_rootfinding_callback(
return ContinuousCallback(
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind,
initialize = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i),
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i))
finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i),
initializealg = reinitialization_alg(cb))
end

function generate_vector_rootfinding_callback(
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
ps = parameters(sys); rootfind = SciMLBase.RightRootFind, reinitialization = SciMLBase.CheckInit(), kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
# fuse equations to create VectorContinuousCallback
@@ -847,7 +864,7 @@ function generate_vector_rootfinding_callback(
initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT)
finalize = handle_optional_setup_fn(map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT)
return VectorContinuousCallback(
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize)
cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization)
end

"""
@@ -893,18 +910,22 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
# group the cbs by what rootfind op they use
# groupby would be very useful here, but alas
cb_classes = Dict{
@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
@NamedTuple{
rootfind::SciMLBase.RootfindOpt,
reinitialization::SciMLBase.DAEInitializationAlgorithm}, Vector{SymbolicContinuousCallback}}()
for cb in cbs
push!(
get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)),
get!(() -> SymbolicContinuousCallback[], cb_classes, (
rootfind = cb.rootfind,
reinitialization = reinitialization_alg(cb))),
cb)
end

# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
compiled_callbacks = map(collect(pairs(sort!(
OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class)
return generate_vector_rootfinding_callback(
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...)
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, reinitialization=equiv_class.reinitialization, kwargs...)
end
if length(compiled_callbacks) == 1
return compiled_callbacks[]
6 changes: 3 additions & 3 deletions test/symbolic_events.jl
Original file line number Diff line number Diff line change
@@ -996,8 +996,8 @@ end
@test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0]
sol = solve(prob, Tsit5())

@test sol[a] == [1.0, -1.0]
@test sol[b] == [2.0, 5.0, 5.0]
@test sol[a] == [-1.0]
@test sol[b] == [5.0, 5.0]
@test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
end
@testset "Heater" begin
@@ -1198,5 +1198,5 @@ end
ss = structural_simplify(sys)
prob = ODEProblem(ss, [theta => 0.0], (0.0, pi))
sol = solve(prob, Tsit5(); dtmax = 0.01)
@test sol[cnt] == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state
@test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state
end

0 comments on commit eb2966e

Please sign in to comment.