From 666eceffdb602787517556dfe5c719736382035f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Sep 2024 16:59:35 +0530 Subject: [PATCH] feat: save discrete variables in callback init --- src/systems/callbacks.jl | 65 ++++++++++++++++++++++++++++++++++++---- test/symbolic_events.jl | 6 ++-- 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 4f8d8065d9..86cab57634 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -565,8 +565,22 @@ function generate_single_rootfinding_callback( rf_oop(u, parameter_values(integ), t) end end + + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && + (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing + initfn = let save_idxs = save_idxs + function (cb, u, t, integrator) + for idx in save_idxs + SciMLBase.save_discretes!(integrator, idx) + end + end + end + else + initfn = SciMLBase.INITIALIZE_DEFAULT + end return ContinuousCallback( - cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind) + cond, affect_function.affect, affect_function.affect_neg, + rootfind = cb.rootfind, initialize = initfn) end function generate_vector_rootfinding_callback( @@ -618,8 +632,25 @@ function generate_vector_rootfinding_callback( affect_neg(integ) end end + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + save_idxs = mapreduce( + cb -> get(ic.callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[]) + initfn = if isempty(save_idxs) + SciMLBase.INITIALIZE_DEFAULT + else + let save_idxs = save_idxs + function (cb, u, t, integrator) + for idx in save_idxs + SciMLBase.save_discretes!(integrator, idx) + end + end + end + end + else + initfn = SciMLBase.INITIALIZE_DEFAULT + end return VectorContinuousCallback( - cond, affect, affect_neg, length(eqs), rootfind = rootfind) + cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initfn) end """ @@ -727,12 +758,24 @@ function generate_timed_callback(cb, sys, dvs, ps; postprocess_affect_expr! = no cond = condition(cb) as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false}, postprocess_affect_expr!, kwargs...) + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && + (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing + initfn = let save_idxs = save_idxs + function (cb, u, t, integrator) + for idx in save_idxs + SciMLBase.save_discretes!(integrator, idx) + end + end + end + else + initfn = SciMLBase.INITIALIZE_DEFAULT + end if cond isa AbstractVector # Preset Time - return PresetTimeCallback(cond, as) + return PresetTimeCallback(cond, as; initialize = initfn) else # Periodic - return PeriodicCallback(as, cond) + return PeriodicCallback(as, cond; initialize = initfn) end end @@ -745,7 +788,19 @@ function generate_discrete_callback(cb, sys, dvs, ps; postprocess_affect_expr! = c = compile_condition(cb, sys, dvs, ps; expression = Val{false}, kwargs...) as = compile_affect(affects(cb), cb, sys, dvs, ps; expression = Val{false}, postprocess_affect_expr!, kwargs...) - return DiscreteCallback(c, as) + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && + (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing + initfn = let save_idxs = save_idxs + function (cb, u, t, integrator) + for idx in save_idxs + SciMLBase.save_discretes!(integrator, idx) + end + end + end + else + initfn = SciMLBase.INITIALIZE_DEFAULT + end + return DiscreteCallback(c, as; initialize = initfn) end end diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index ba58a88ad4..c2bac3404d 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -882,7 +882,7 @@ end @test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0] sol = solve(prob, Tsit5()) - @test sol[a] == [-1.0] - @test sol[b] == [5.0, 5.0] - @test sol[c] == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + @test sol[a] == [1.0, -1.0] + @test sol[b] == [2.0, 5.0, 5.0] + @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end