diff --git a/src/callbacks.jl b/src/callbacks.jl index 154f1ca..bfe6942 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -10,19 +10,21 @@ function find_first_continuous_callback(integrator, callback::AbstractContinuous find_first_continuous_callback(integrator,find_callback_time(integrator,callback)...,1,1,args...) end -function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Float64,idx::Int,counter::Int,callback2) +function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Float64, + event_occured::Bool,idx::Int,counter::Int, + callback2) counter += 1 # counter is idx for callback2. - tmin2,upcrossing2 = find_callback_time(integrator,callback2) + tmin2,upcrossing2,event_occurred2 = find_callback_time(integrator,callback2) - if (tmin2 < tmin && tmin2 != zero(typeof(tmin))) || tmin == zero(typeof(tmin)) - return tmin2,upcrossing2,counter,counter + if event_occurred2 && (tmin2 < tmin || !event_occured) + return tmin2,upcrossing2,true,counter,counter else - return tmin,upcrossing,idx,counter + return tmin,upcrossing,event_occured,idx,counter end end -function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Float64,idx::Int,counter::Int,callback2,args...) - find_first_continuous_callback(integrator,find_first_continuous_callback(integrator,tmin,upcrossing,idx,counter,callback2)...,args...) +function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Float64,event_occured::Bool,idx::Int,counter::Int,callback2,args...) + find_first_continuous_callback(integrator,find_first_continuous_callback(integrator,tmin,upcrossing,event_occured,idx,counter,callback2)...,args...) end @inline function determine_event_occurance(integrator,callback) @@ -38,7 +40,7 @@ end else previous_condition = callback.condition(@view(integrator.uprev[callback.idxs]),integrator.tprev,integrator) end - if isapprox(previous_condition,0,rtol=callback.reltol,atol=callback.abstol) + if integrator.event_last_time prev_sign = 0.0 else prev_sign = sign(previous_condition) @@ -64,7 +66,7 @@ end end new_sign = callback.condition(_tmp,integrator.tprev+dt*Θs[i],integrator) if prev_sign == 0 - prev_sign = new_sign + prev_sign = sign(new_sign) prev_sign_index = i end if ((prev_sign<0 && !(typeof(callback.affect!)<:Void)) || (prev_sign>0 && !(typeof(callback.affect_neg!)<:Void))) && prev_sign*new_sign<0 @@ -122,7 +124,7 @@ function find_callback_time(integrator,callback) else new_t = zero(typeof(integrator.t)) end - new_t,prev_sign + new_t,prev_sign,event_occurred end function apply_callback!(integrator,callback::ContinuousCallback,cb_time,prev_sign) diff --git a/src/integrator_types.jl b/src/integrator_types.jl index 5e92806..583af79 100644 --- a/src/integrator_types.jl +++ b/src/integrator_types.jl @@ -15,6 +15,7 @@ mutable struct ODEInterfaceIntegrator{uType,uPrevType,oType,SType,solType} <: Ab sizeu::SType sol::solType eval_sol_fcn::InterpFunction + event_last_time::Bool end (integrator::ODEInterfaceIntegrator)(t) = integrator.eval_sol_fcn(t) diff --git a/src/integrator_utils.jl b/src/integrator_utils.jl index bff40c8..24eaf4f 100644 --- a/src/integrator_utils.jl +++ b/src/integrator_utils.jl @@ -8,9 +8,13 @@ function handle_callbacks!(integrator,eval_sol_fcn) discrete_modified = false saved_in_cb = false if !(typeof(continuous_callbacks)<:Tuple{}) - time,upcrossing,idx,counter = find_first_continuous_callback(integrator,continuous_callbacks...) - if time != zero(typeof(integrator.t)) && upcrossing != 0 # if not, then no events + time,upcrossing,event_occured,idx,counter = + find_first_continuous_callback(integrator,continuous_callbacks...) + if event_occured + integrator.event_last_time = true continuous_modified,saved_in_cb = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing) + else + integrator.event_last_time = false end end if !(typeof(discrete_callbacks)<:Tuple{}) diff --git a/src/solve.jl b/src/solve.jl index 0fe1610..db53aff 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -69,7 +69,8 @@ function solve{uType,tType,isinplace,AlgType<:ODEInterfaceAlgorithm}( opts = DEOptions(saveat_internal,save_everystep,callbacks_internal) integrator = ODEInterfaceIntegrator(u,uprev,tspan[1],tspan[1],opts, - false,tdir,sizeu,sol,InterpFunction((t)->[t])) + false,tdir,sizeu,sol, + InterpFunction((t)->[t]),false) outputfcn = OutputFunction(integrator) o[:OUTPUTFCN] = outputfcn