diff --git a/src/callbacks.jl b/src/callbacks.jl index 4da36e4..8544e8f 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -2,19 +2,19 @@ # Base Case: Only one callback function find_first_continuous_callback(integrator, callback::DiffEqBase.AbstractContinuousCallback) - (find_callback_time(integrator,callback)...,1,1) + (find_callback_time(integrator,callback,1)...,1,1) end # Starting Case: Compute on the first callback function find_first_continuous_callback(integrator, callback::DiffEqBase.AbstractContinuousCallback, args...) - find_first_continuous_callback(integrator,find_callback_time(integrator,callback)...,1,1,args...) + find_first_continuous_callback(integrator,find_callback_time(integrator,callback,1)...,1,1,args...) end function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Float64, event_occured::Bool,idx::Int,counter::Int, callback2) counter += 1 # counter is idx for callback2. - tmin2,upcrossing2,event_occurred2 = find_callback_time(integrator,callback2) + tmin2,upcrossing2,event_occurred2 = find_callback_time(integrator,callback2,counter) if event_occurred2 && (tmin2 < tmin || !event_occured) return tmin2,upcrossing2,true,counter,counter @@ -27,7 +27,7 @@ function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Floa find_first_continuous_callback(integrator,find_first_continuous_callback(integrator,tmin,upcrossing,event_occured,idx,counter,callback2)...,args...) end -@inline function determine_event_occurance(integrator,callback) +@inline function determine_event_occurance(integrator,callback,counter) event_occurred = false Θs = range(typeof(integrator.t)(0),stop=typeof(integrator.t)(1),length=callback.interp_points) interp_index = 0 @@ -41,7 +41,7 @@ end previous_condition = callback.condition(@view(integrator.uprev[callback.idxs]),integrator.tprev,integrator) end - if integrator.event_last_time && abs(previous_condition) < callback.abstol + if integrator.event_last_time == counter && abs(previous_condition) < 100callback.abstol # abs(previous_condition) < callback.abstol is for multiple events: only # condition this on the correct event @@ -52,15 +52,18 @@ end # treat the value as negative, reguardless of the postiivity/negativity # of the true value due to it being =0 sans floating point issues. + # Only due this if the discontinuity did not move it far away from an event + # Since near even we use direction instead of location to reset + if !(typeof(callback.idxs) <: Number) - tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev)))::Vector{Float64} + tmp = integrator(integrator.tprev+100eps(integrator.tprev))::Vector{Float64} callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs] else - _tmp = integrator((integrator.tprev+100eps(typeof(integrator.tprev)))::Vector{Float64})[callback.idxs] + _tmp = integrator((integrator.tprev+100eps(integrator.tprev))::Vector{Float64})[callback.idxs] end tmp_condition = callback.condition(_tmp,integrator.tprev + - 100eps(typeof(integrator.tprev)), + 100eps(integrator.tprev), integrator) prev_sign = sign((tmp_condition-previous_condition)/integrator.tdir) @@ -102,8 +105,8 @@ end event_occurred,interp_index,Θs,prev_sign,prev_sign_index end -function find_callback_time(integrator,callback) - event_occurred,interp_index,Θs,prev_sign,prev_sign_index = determine_event_occurance(integrator,callback) +function find_callback_time(integrator,callback,counter) + event_occurred,interp_index,Θs,prev_sign,prev_sign_index = determine_event_occurance(integrator,callback,counter) dt = integrator.t - integrator.tprev if event_occurred if typeof(callback.condition) <: Nothing @@ -127,11 +130,25 @@ function find_callback_time(integrator,callback) out = callback.condition(tmp,integrator.tprev+Θ*dt,integrator) out end - + if zero_func(top_Θ) == 0 Θ = top_Θ else - Θ = prevfloat(find_zero(zero_func,(bottom_θ,top_Θ),Roots.AlefeldPotraShi(),atol = callback.abstol/10)) + if integrator.event_last_time == counter && + abs(zero_func(bottom_θ)) < 100callback.abstol && + prev_sign_index == 1 + # Determined that there is an event by derivative + # But floating point error may make the end point negative + sign_top = sign(zero_func(top_Θ)) + bottom_θ += 2eps(typeof(bottom_θ)) + iter = 1 + while sign(zero_func(bottom_θ)) == sign_top && iter < 12 + bottom_θ *= 5 + iter += 1 + end + iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.") + end + Θ = prevfloat(find_zero(zero_func,(bottom_θ,top_Θ),Roots.AlefeldPotraShi(),atol = callback.abstol/100)) end #Θ = prevfloat(...) # prevfloat guerentees that the new time is either 1 floating point diff --git a/src/integrator_types.jl b/src/integrator_types.jl index 71db006..e8cadfc 100644 --- a/src/integrator_types.jl +++ b/src/integrator_types.jl @@ -15,7 +15,7 @@ mutable struct ODEInterfaceIntegrator{uType,uPrevType,oType,SType,solType} <: Di sizeu::SType sol::solType eval_sol_fcn - event_last_time::Bool + event_last_time::Int end (integrator::ODEInterfaceIntegrator)(t) = integrator.eval_sol_fcn(t) diff --git a/src/integrator_utils.jl b/src/integrator_utils.jl index 3d92d19..dd62730 100644 --- a/src/integrator_utils.jl +++ b/src/integrator_utils.jl @@ -11,10 +11,10 @@ function handle_callbacks!(integrator,eval_sol_fcn) time,upcrossing,event_occured,idx,counter = find_first_continuous_callback(integrator,continuous_callbacks...) if event_occured - integrator.event_last_time = true + integrator.event_last_time = idx continuous_modified,saved_in_cb = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing) else - integrator.event_last_time = false + integrator.event_last_time = 0 end end if !(typeof(discrete_callbacks)<:Tuple{}) diff --git a/src/solve.jl b/src/solve.jl index 9436217..b058e60 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -73,7 +73,7 @@ function DiffEqBase.__solve( opts = DEOptions(saveat_internal,save_everystep,callbacks_internal) integrator = ODEInterfaceIntegrator(u,uprev,tspan[1],tspan[1],opts, false,tdir,sizeu,sol, - (t)->[t],false) + (t)->[t],0) outputfcn = OutputFunction(integrator) o[:OUTPUTFCN] = outputfcn diff --git a/test/callbacks.jl b/test/callbacks.jl index fb2732d..c1ee531 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -22,5 +22,5 @@ prob = ODEProblem(callback_f,u0,tspan) sol = solve(prob,dopri5(),callback=callback,dtmax=0.5) @test sol(4.0)[1] > 0 -sol = solve(prob,dopri5(),callback=callback) -@test sol(4.0)[1] > 0 +sol = solve(prob,dopri5(),callback=callback,save_everystep=true) +@test sol(4.0)[1] > -1e-12