Skip to content

Commit

Permalink
get v0.7 working
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jul 5, 2018
1 parent 0d3d189 commit 6dd6716
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 16 deletions.
2 changes: 0 additions & 2 deletions src/ODEInterfaceDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ end

@compat const KW = Dict{Symbol,Any}

const InterpFunction = FunctionWrappers.FunctionWrapper{Vector{Float64},Tuple{Float64}}

include("algorithms.jl")
include("integrator_types.jl")
include("integrator_utils.jl")
Expand Down
14 changes: 7 additions & 7 deletions src/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ end
# of the true value due to it being =0 sans floating point issues.

if !(typeof(callback.idxs) <: Number)
tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev)))
tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev)))::Vector{Float64}
callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs]
else
_tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev)))[callback.idxs]
_tmp = integrator((integrator.tprev+100eps(typeof(integrator.tprev)))::Vector{Float64})[callback.idxs]
end

tmp_condition = callback.condition(_tmp,integrator.tprev +
Expand All @@ -82,10 +82,10 @@ end
elseif callback.interp_points!=0 # Use the interpolants for safety checking
for i in 2:length(Θs)
if !(typeof(callback.idxs) <: Number)
tmp = integrator(integrator.tprev+dt*Θs[i])
tmp = integrator(integrator.tprev+dt*Θs[i])::Vector{Float64}
callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs]
else
_tmp = integrator(integrator.tprev+dt*Θs[i])[callback.idxs]
_tmp = (integrator(integrator.tprev+dt*Θs[i])::Vector{Float64})[callback.idxs]
end
new_sign = callback.condition(_tmp,integrator.tprev+dt*Θs[i],integrator)
if prev_sign == 0
Expand Down Expand Up @@ -119,10 +119,10 @@ function find_callback_time(integrator,callback)
if callback.rootfind
zero_func = (Θ) -> begin
if !(typeof(callback.idxs) <: Number)
tmp = integrator(integrator.tprev+Θ*dt)
tmp = integrator(integrator.tprev+Θ*dt)::Vector{Float64}
callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs]
else
_tmp = integrator(integrator.tprev+Θ*dt)[callback.idxs]
_tmp = (integrator(integrator.tprev+Θ*dt)::Vector{Float64})[callback.idxs]
end
out = callback.condition(tmp,integrator.tprev+Θ*dt,integrator)
out
Expand Down Expand Up @@ -153,7 +153,7 @@ end
function apply_callback!(integrator,callback::ContinuousCallback,cb_time,prev_sign)
if cb_time != zero(typeof(integrator.t))
integrator.t = integrator.tprev+cb_time
tmp = integrator(integrator.t)
tmp = integrator(integrator.t)::Vector{Float64}
if eltype(integrator.sol.u) <: Vector
integrator.u .= tmp
else
Expand Down
2 changes: 1 addition & 1 deletion src/integrator_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mutable struct ODEInterfaceIntegrator{uType,uPrevType,oType,SType,solType} <: Di
tdir::Float64
sizeu::SType
sol::solType
eval_sol_fcn::InterpFunction
eval_sol_fcn
event_last_time::Bool
end

Expand Down
4 changes: 2 additions & 2 deletions src/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function handle_callbacks!(integrator,eval_sol_fcn)
integrator.event_last_time = true
continuous_modified,saved_in_cb = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing)
else
integrator.event_last_time = false
integrator.event_last_time = false
end
end
if !(typeof(discrete_callbacks)<:Tuple{})
Expand All @@ -38,7 +38,7 @@ function DiffEqBase.savevalues!(integrator::ODEInterfaceIntegrator,force_save=fa
while !isempty(integrator.opts.saveat) &&
integrator.tdir*top(integrator.opts.saveat) < integrator.tdir*integrator.t
curt = pop!(integrator.opts.saveat)
tmp = integrator(curt)
tmp = integrator(curt)::Vector{Float64}
push!(integrator.sol.t,curt)
save_value!(integrator.sol.u,tmp,uType,integrator.sizeu)
end
Expand Down
4 changes: 2 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function solve{uType,tuptType,isinplace,AlgType<:ODEInterfaceAlgorithm}(
opts = DEOptions(saveat_internal,save_everystep,callbacks_internal)
integrator = ODEInterfaceIntegrator(u,uprev,tspan[1],tspan[1],opts,
false,tdir,sizeu,sol,
InterpFunction((t)->[t]),false)
(t)->[t],false)

outputfcn = OutputFunction(integrator)
o[:OUTPUTFCN] = outputfcn
Expand Down Expand Up @@ -326,7 +326,7 @@ function (f::OutputFunction)(reason::ODEInterface.OUTPUTFCN_CALL_REASON,
end
integrator.t = t
integrator.tprev = tprev
integrator.eval_sol_fcn = InterpFunction(eval_sol_fcn)
integrator.eval_sol_fcn = eval_sol_fcn

handle_callbacks!(integrator,eval_sol_fcn)

Expand Down
3 changes: 1 addition & 2 deletions test/algorithm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ODEInterfaceDiffEq, DiffEqProblemLibrary, DiffEqBase
using Base.Test
using ODEInterfaceDiffEq, DiffEqProblemLibrary, DiffEqBase, Test

prob = prob_ode_linear
sol =solve(prob,dopri5(),dt=1//2^(4))
Expand Down

0 comments on commit 6dd6716

Please sign in to comment.