diff --git a/src/integrator_types.jl b/src/integrator_types.jl index 2b5ce25..06870bd 100644 --- a/src/integrator_types.jl +++ b/src/integrator_types.jl @@ -5,7 +5,7 @@ mutable struct DEOptions{SType,CType} callback::CType end -mutable struct ODEInterfaceIntegrator{algType,uType,uPrevType,oType,SType,solType,P} <: DiffEqBase.AbstractODEIntegrator{algType, true, uType, Float64} +mutable struct ODEInterfaceIntegrator{algType,uType,uPrevType,oType,SType,solType,P,CallbackCacheType} <: DiffEqBase.AbstractODEIntegrator{algType, true, uType, Float64} u::uType uprev::uPrevType t::Float64 @@ -18,6 +18,8 @@ mutable struct ODEInterfaceIntegrator{algType,uType,uPrevType,oType,SType,solTyp sol::solType eval_sol_fcn event_last_time::Int + vector_event_last_time::Int + callback_cache::CallbackCacheType alg::algType last_event_error::Float64 end diff --git a/src/integrator_utils.jl b/src/integrator_utils.jl index 5ecd7bb..9ab9c86 100644 --- a/src/integrator_utils.jl +++ b/src/integrator_utils.jl @@ -8,13 +8,15 @@ function handle_callbacks!(integrator,eval_sol_fcn) discrete_modified = false saved_in_cb = false if !(typeof(continuous_callbacks)<:Tuple{}) - time,upcrossing,event_occured,idx,counter = + time,upcrossing,event_occured,event_idx,idx,counter = DiffEqBase.find_first_continuous_callback(integrator,continuous_callbacks...) if event_occured integrator.event_last_time = idx - continuous_modified,saved_in_cb = DiffEqBase.apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing) + integrator.vector_event_last_time = event_idx + continuous_modified,saved_in_cb = DiffEqBase.apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing,event_idx) else integrator.event_last_time = 0 + integrator.vector_event_last_time = 1 end end if !(typeof(discrete_callbacks)<:Tuple{}) diff --git a/src/solve.jl b/src/solve.jl index 5606eb3..dfdbe09 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -26,6 +26,13 @@ function DiffEqBase.__solve( callbacks_internal = CallbackSet(callback,prob.callback) + max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal) + if max_len_cb isa VectorContinuousCallback + callback_cache = DiffEqBase.CallbackCache(max_len_cb.len,uBottomEltype,uBottomEltype) + else + callback_cache = nothing + end + tspan = prob.tspan o = KW(kwargs) @@ -69,7 +76,7 @@ function DiffEqBase.__solve( opts = DEOptions(saveat_internal,save_on,save_everystep,callbacks_internal) integrator = ODEInterfaceIntegrator(u,uprev,tspan[1],tspan[1],prob.p,opts, false,tdir,sizeu,sol, - (t)->[t],0,alg,0.) + (t)->[t],0,1,callback_cache,alg,0.) initialize_callbacks!(integrator) if !isinplace && typeof(u)<:AbstractArray