diff --git a/src/ODEInterfaceDiffEq.jl b/src/ODEInterfaceDiffEq.jl index d88ec28..6f3b56b 100644 --- a/src/ODEInterfaceDiffEq.jl +++ b/src/ODEInterfaceDiffEq.jl @@ -19,8 +19,6 @@ end @compat const KW = Dict{Symbol,Any} -const InterpFunction = FunctionWrappers.FunctionWrapper{Vector{Float64},Tuple{Float64}} - include("algorithms.jl") include("integrator_types.jl") include("integrator_utils.jl") diff --git a/src/callbacks.jl b/src/callbacks.jl index 2f0e7c9..bd0f121 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -53,10 +53,10 @@ end # of the true value due to it being =0 sans floating point issues. if !(typeof(callback.idxs) <: Number) - tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev))) + tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev)))::Vector{Float64} callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs] else - _tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev)))[callback.idxs] + _tmp = integrator((integrator.tprev+100eps(typeof(integrator.tprev)))::Vector{Float64})[callback.idxs] end tmp_condition = callback.condition(_tmp,integrator.tprev + @@ -82,10 +82,10 @@ end elseif callback.interp_points!=0 # Use the interpolants for safety checking for i in 2:length(Θs) if !(typeof(callback.idxs) <: Number) - tmp = integrator(integrator.tprev+dt*Θs[i]) + tmp = integrator(integrator.tprev+dt*Θs[i])::Vector{Float64} callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs] else - _tmp = integrator(integrator.tprev+dt*Θs[i])[callback.idxs] + _tmp = (integrator(integrator.tprev+dt*Θs[i])::Vector{Float64})[callback.idxs] end new_sign = callback.condition(_tmp,integrator.tprev+dt*Θs[i],integrator) if prev_sign == 0 @@ -119,10 +119,10 @@ function find_callback_time(integrator,callback) if callback.rootfind zero_func = (Θ) -> begin if !(typeof(callback.idxs) <: Number) - tmp = integrator(integrator.tprev+Θ*dt) + tmp = integrator(integrator.tprev+Θ*dt)::Vector{Float64} callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs] else - _tmp = integrator(integrator.tprev+Θ*dt)[callback.idxs] + _tmp = (integrator(integrator.tprev+Θ*dt)::Vector{Float64})[callback.idxs] end out = callback.condition(tmp,integrator.tprev+Θ*dt,integrator) out @@ -153,7 +153,7 @@ end function apply_callback!(integrator,callback::ContinuousCallback,cb_time,prev_sign) if cb_time != zero(typeof(integrator.t)) integrator.t = integrator.tprev+cb_time - tmp = integrator(integrator.t) + tmp = integrator(integrator.t)::Vector{Float64} if eltype(integrator.sol.u) <: Vector integrator.u .= tmp else diff --git a/src/integrator_types.jl b/src/integrator_types.jl index 351c06b..71db006 100644 --- a/src/integrator_types.jl +++ b/src/integrator_types.jl @@ -14,7 +14,7 @@ mutable struct ODEInterfaceIntegrator{uType,uPrevType,oType,SType,solType} <: Di tdir::Float64 sizeu::SType sol::solType - eval_sol_fcn::InterpFunction + eval_sol_fcn event_last_time::Bool end diff --git a/src/integrator_utils.jl b/src/integrator_utils.jl index 24eaf4f..3d92d19 100644 --- a/src/integrator_utils.jl +++ b/src/integrator_utils.jl @@ -14,7 +14,7 @@ function handle_callbacks!(integrator,eval_sol_fcn) integrator.event_last_time = true continuous_modified,saved_in_cb = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing) else - integrator.event_last_time = false + integrator.event_last_time = false end end if !(typeof(discrete_callbacks)<:Tuple{}) @@ -38,7 +38,7 @@ function DiffEqBase.savevalues!(integrator::ODEInterfaceIntegrator,force_save=fa while !isempty(integrator.opts.saveat) && integrator.tdir*top(integrator.opts.saveat) < integrator.tdir*integrator.t curt = pop!(integrator.opts.saveat) - tmp = integrator(curt) + tmp = integrator(curt)::Vector{Float64} push!(integrator.sol.t,curt) save_value!(integrator.sol.u,tmp,uType,integrator.sizeu) end diff --git a/src/solve.jl b/src/solve.jl index 1ce42da..127d155 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -72,7 +72,7 @@ function solve{uType,tuptType,isinplace,AlgType<:ODEInterfaceAlgorithm}( opts = DEOptions(saveat_internal,save_everystep,callbacks_internal) integrator = ODEInterfaceIntegrator(u,uprev,tspan[1],tspan[1],opts, false,tdir,sizeu,sol, - InterpFunction((t)->[t]),false) + (t)->[t],false) outputfcn = OutputFunction(integrator) o[:OUTPUTFCN] = outputfcn @@ -326,7 +326,7 @@ function (f::OutputFunction)(reason::ODEInterface.OUTPUTFCN_CALL_REASON, end integrator.t = t integrator.tprev = tprev - integrator.eval_sol_fcn = InterpFunction(eval_sol_fcn) + integrator.eval_sol_fcn = eval_sol_fcn handle_callbacks!(integrator,eval_sol_fcn) diff --git a/test/algorithm_tests.jl b/test/algorithm_tests.jl index e4d81dd..a40b86f 100644 --- a/test/algorithm_tests.jl +++ b/test/algorithm_tests.jl @@ -1,5 +1,4 @@ -using ODEInterfaceDiffEq, DiffEqProblemLibrary, DiffEqBase -using Base.Test +using ODEInterfaceDiffEq, DiffEqProblemLibrary, DiffEqBase, Test prob = prob_ode_linear sol =solve(prob,dopri5(),dt=1//2^(4))