diff --git a/src/solve.jl b/src/solve.jl index 13a80cf..e657105 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -49,19 +49,18 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, sol = solve(_prob,alg.alg,args...;kwargs..., callback=TerminateSteadyState(alg.abstol,alg.reltol), save_everystep=save_everystep,save_start=save_start) - if sol.t[end] == _prob.tspan[end] - sol = DiffEqBase.solution_new_retcode(sol, :Failure) - elseif sol.retcode == :Terminated - if isinplace(prob) - du = similar(sol.u[end]) - prob.f(du, sol.u[end], prob.p, sol.t[end]) - else - du = prob.f(sol.u[end], prob.p, sol.t[end]) - end - if all(abs(d) <= abstol || abs(d) <= reltol*abs(u) for (d,abstol, reltol, u) = + if isinplace(prob) + du = similar(sol.u[end]) + prob.f(du, sol.u[end], prob.p, sol.t[end]) + else + du = prob.f(sol.u[end], prob.p, sol.t[end]) + end + if sol.retcode == :Terminated && all(abs(d) <= abstol || + abs(d) <= reltol*abs(u) for (d,abstol, reltol, u) in zip(du, Iterators.cycle(alg.abstol), Iterators.cycle(alg.reltol), sol.u[end])) - sol = DiffEqBase.solution_new_retcode(sol, :Success) - end + _sol = DiffEqBase.build_solution(prob,alg,sol.u[end],du;retcode = :Success) + else + _sol = DiffEqBase.build_solution(prob,alg,sol.u[end],du;retcode = :Failure) end - sol + _sol end diff --git a/test/runtests.jl b/test/runtests.jl index d53339b..33dd0bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,7 +34,7 @@ using OrdinaryDiffEq sol = solve(prob,DynamicSS(Rodas5())) @test sol.retcode == :Success -f(du,sol.u[end],p,0) +f(du,sol.u,p,0) @test du ≈ [0,0] atol = 1e-7 sol = solve(prob,DynamicSS(Rodas5(),tspan=1e-3)) @@ -43,5 +43,5 @@ sol = solve(prob,DynamicSS(Rodas5(),tspan=1e-3)) sol = solve(prob,DynamicSS(CVODE_BDF()),dt=1.0) @test sol.retcode == :Success -f(du,sol.u[end],p,0) +f(du,sol.u,p,0) @test du ≈ [0,0] atol = 1e-6