Skip to content

Commit

Permalink
update to new callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 17, 2018
1 parent 0cfafd3 commit b63ca37
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 18 deletions.
41 changes: 29 additions & 12 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

# Base Case: Only one callback
function find_first_continuous_callback(integrator, callback::DiffEqBase.AbstractContinuousCallback)
(find_callback_time(integrator,callback)...,1,1)
(find_callback_time(integrator,callback,1)...,1,1)
end

# Starting Case: Compute on the first callback
function find_first_continuous_callback(integrator, callback::DiffEqBase.AbstractContinuousCallback, args...)
find_first_continuous_callback(integrator,find_callback_time(integrator,callback)...,1,1,args...)
find_first_continuous_callback(integrator,find_callback_time(integrator,callback,1)...,1,1,args...)
end

function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Float64,
event_occured::Bool,idx::Int,counter::Int,
callback2)
counter += 1 # counter is idx for callback2.
tmin2,upcrossing2,event_occurred2 = find_callback_time(integrator,callback2)
tmin2,upcrossing2,event_occurred2 = find_callback_time(integrator,callback2,counter)

if event_occurred2 && (tmin2 < tmin || !event_occured)
return tmin2,upcrossing2,true,counter,counter
Expand All @@ -27,7 +27,7 @@ function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Floa
find_first_continuous_callback(integrator,find_first_continuous_callback(integrator,tmin,upcrossing,event_occured,idx,counter,callback2)...,args...)
end

@inline function determine_event_occurance(integrator,callback)
@inline function determine_event_occurance(integrator,callback,counter)
event_occurred = false
Θs = range(typeof(integrator.t)(0),stop=typeof(integrator.t)(1),length=callback.interp_points)
interp_index = 0
Expand All @@ -41,7 +41,7 @@ end
previous_condition = callback.condition(@view(integrator.uprev[callback.idxs]),integrator.tprev,integrator)
end

if integrator.event_last_time && abs(previous_condition) < callback.abstol
if integrator.event_last_time == counter && abs(previous_condition) < 100callback.abstol

# abs(previous_condition) < callback.abstol is for multiple events: only
# condition this on the correct event
Expand All @@ -52,15 +52,18 @@ end
# treat the value as negative, reguardless of the postiivity/negativity
# of the true value due to it being =0 sans floating point issues.

# Only due this if the discontinuity did not move it far away from an event
# Since near even we use direction instead of location to reset

if !(typeof(callback.idxs) <: Number)
tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev)))::Vector{Float64}
tmp = integrator(integrator.tprev+100eps(integrator.tprev))::Vector{Float64}
callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs]
else
_tmp = integrator((integrator.tprev+100eps(typeof(integrator.tprev)))::Vector{Float64})[callback.idxs]
_tmp = integrator((integrator.tprev+100eps(integrator.tprev))::Vector{Float64})[callback.idxs]
end

tmp_condition = callback.condition(_tmp,integrator.tprev +
100eps(typeof(integrator.tprev)),
100eps(integrator.tprev),
integrator)

prev_sign = sign((tmp_condition-previous_condition)/integrator.tdir)
Expand Down Expand Up @@ -102,8 +105,8 @@ end
event_occurred,interp_index,Θs,prev_sign,prev_sign_index
end

function find_callback_time(integrator,callback)
event_occurred,interp_index,Θs,prev_sign,prev_sign_index = determine_event_occurance(integrator,callback)
function find_callback_time(integrator,callback,counter)
event_occurred,interp_index,Θs,prev_sign,prev_sign_index = determine_event_occurance(integrator,callback,counter)
dt = integrator.t - integrator.tprev
if event_occurred
if typeof(callback.condition) <: Nothing
Expand All @@ -127,11 +130,25 @@ function find_callback_time(integrator,callback)
out = callback.condition(tmp,integrator.tprev+Θ*dt,integrator)
out
end

if zero_func(top_Θ) == 0
Θ = top_Θ
else
Θ = prevfloat(find_zero(zero_func,(bottom_θ,top_Θ),Roots.AlefeldPotraShi(),atol = callback.abstol/10))
if integrator.event_last_time == counter &&
abs(zero_func(bottom_θ)) < 100callback.abstol &&
prev_sign_index == 1
# Determined that there is an event by derivative
# But floating point error may make the end point negative
sign_top = sign(zero_func(top_Θ))
bottom_θ += 2eps(typeof(bottom_θ))
iter = 1
while sign(zero_func(bottom_θ)) == sign_top && iter < 12
bottom_θ *= 5
iter += 1
end
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.")
end
Θ = prevfloat(find_zero(zero_func,(bottom_θ,top_Θ),Roots.AlefeldPotraShi(),atol = callback.abstol/100))
end
#Θ = prevfloat(...)
# prevfloat guerentees that the new time is either 1 floating point
Expand Down
2 changes: 1 addition & 1 deletion src/integrator_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mutable struct ODEInterfaceIntegrator{uType,uPrevType,oType,SType,solType} <: Di
sizeu::SType
sol::solType
eval_sol_fcn
event_last_time::Bool
event_last_time::Int
end

(integrator::ODEInterfaceIntegrator)(t) = integrator.eval_sol_fcn(t)
4 changes: 2 additions & 2 deletions src/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ function handle_callbacks!(integrator,eval_sol_fcn)
time,upcrossing,event_occured,idx,counter =
find_first_continuous_callback(integrator,continuous_callbacks...)
if event_occured
integrator.event_last_time = true
integrator.event_last_time = idx
continuous_modified,saved_in_cb = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing)
else
integrator.event_last_time = false
integrator.event_last_time = 0
end
end
if !(typeof(discrete_callbacks)<:Tuple{})
Expand Down
2 changes: 1 addition & 1 deletion src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function DiffEqBase.__solve(
opts = DEOptions(saveat_internal,save_everystep,callbacks_internal)
integrator = ODEInterfaceIntegrator(u,uprev,tspan[1],tspan[1],opts,
false,tdir,sizeu,sol,
(t)->[t],false)
(t)->[t],0)

outputfcn = OutputFunction(integrator)
o[:OUTPUTFCN] = outputfcn
Expand Down
4 changes: 2 additions & 2 deletions test/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ prob = ODEProblem(callback_f,u0,tspan)

sol = solve(prob,dopri5(),callback=callback,dtmax=0.5)
@test sol(4.0)[1] > 0
sol = solve(prob,dopri5(),callback=callback)
@test sol(4.0)[1] > 0
sol = solve(prob,dopri5(),callback=callback,save_everystep=true)
@test sol(4.0)[1] > -1e-12

0 comments on commit b63ca37

Please sign in to comment.