Skip to content

Commit

Permalink
Merge pull request #27 from kanav99/vectorcb
Browse files Browse the repository at this point in the history
Changes for VectorContinuousCallback
  • Loading branch information
ChrisRackauckas authored Jun 6, 2019
2 parents 548dcee + be7e693 commit ed6025e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/integrator_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mutable struct DEOptions{SType,CType}
callback::CType
end

mutable struct ODEInterfaceIntegrator{algType,uType,uPrevType,oType,SType,solType,P} <: DiffEqBase.AbstractODEIntegrator{algType, true, uType, Float64}
mutable struct ODEInterfaceIntegrator{algType,uType,uPrevType,oType,SType,solType,P,CallbackCacheType} <: DiffEqBase.AbstractODEIntegrator{algType, true, uType, Float64}
u::uType
uprev::uPrevType
t::Float64
Expand All @@ -18,6 +18,8 @@ mutable struct ODEInterfaceIntegrator{algType,uType,uPrevType,oType,SType,solTyp
sol::solType
eval_sol_fcn
event_last_time::Int
vector_event_last_time::Int
callback_cache::CallbackCacheType
alg::algType
last_event_error::Float64
end
Expand Down
6 changes: 4 additions & 2 deletions src/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ function handle_callbacks!(integrator,eval_sol_fcn)
discrete_modified = false
saved_in_cb = false
if !(typeof(continuous_callbacks)<:Tuple{})
time,upcrossing,event_occured,idx,counter =
time,upcrossing,event_occured,event_idx,idx,counter =
DiffEqBase.find_first_continuous_callback(integrator,continuous_callbacks...)
if event_occured
integrator.event_last_time = idx
continuous_modified,saved_in_cb = DiffEqBase.apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing)
integrator.vector_event_last_time = event_idx
continuous_modified,saved_in_cb = DiffEqBase.apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing,event_idx)
else
integrator.event_last_time = 0
integrator.vector_event_last_time = 1
end
end
if !(typeof(discrete_callbacks)<:Tuple{})
Expand Down
9 changes: 8 additions & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ function DiffEqBase.__solve(

callbacks_internal = CallbackSet(callback,prob.callback)

max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal)
if max_len_cb isa VectorContinuousCallback
callback_cache = DiffEqBase.CallbackCache(max_len_cb.len,uBottomEltype,uBottomEltype)
else
callback_cache = nothing
end

tspan = prob.tspan

o = KW(kwargs)
Expand Down Expand Up @@ -69,7 +76,7 @@ function DiffEqBase.__solve(
opts = DEOptions(saveat_internal,save_on,save_everystep,callbacks_internal)
integrator = ODEInterfaceIntegrator(u,uprev,tspan[1],tspan[1],prob.p,opts,
false,tdir,sizeu,sol,
(t)->[t],0,alg,0.)
(t)->[t],0,1,callback_cache,alg,0.)
initialize_callbacks!(integrator)

if !isinplace && typeof(u)<:AbstractArray
Expand Down

0 comments on commit ed6025e

Please sign in to comment.