Skip to content

Commit

Permalink
Start using termination conditions in newton raphson
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh530 committed Sep 13, 2023
1 parent f1703e6 commit 14ced9e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 18 deletions.
66 changes: 49 additions & 17 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,32 @@ for large-scale and numerically-difficult nonlinear systems.
Currently, the linear solver and chunk size choice only applies to in-place defined
`NonlinearProblem`s. That is expected to change in the future.
"""
struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ} <:
struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ, TC <: NLSolveTerminationCondition} <:
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::L
precs::P
termination_condition::TC
end

function NewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS)
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing))
NewtonRaphson{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
_unwrap_val(concrete_jac)}(linsolve,
precs)
_unwrap_val(concrete_jac), typeof(termination_condition)}(linsolve,
precs, termination_condition)
end

mutable struct NewtonRaphsonCache{iip, fType, algType, uType, duType, resType, pType,
INType, tolType,
probType, ufType, L, jType, JC}
probType, ufType, L, jType, JC, TCType}
f::fType
alg::algType
u::uType
uprev::uType
fu::resType
p::pType
uf::ufType
Expand All @@ -81,26 +86,31 @@ mutable struct NewtonRaphsonCache{iip, fType, algType, uType, duType, resType, p
internalnorm::INType
retcode::SciMLBase.ReturnCode.T
abstol::tolType
reltol::tolType
prob::probType
stats::NLStats
tc_storage::TCType

function NewtonRaphsonCache{iip}(f::fType, alg::algType, u::uType, fu::resType,
function NewtonRaphsonCache{iip}(f::fType, alg::algType, u::uType, uprev::uType,
fu::resType,
p::pType, uf::ufType, linsolve::L, J::jType,
du1::duType,
jac_config::JC, force_stop::Bool, maxiters::Int,
internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
reltol::tolType,
prob::probType,
stats::NLStats) where {
stats::NLStats,
tc_storage::TCType) where {
iip, fType, algType, uType,
duType, resType, pType, INType,
tolType,
probType, ufType, L, jType, JC}
probType, ufType, L, jType, JC, TCType}
new{iip, fType, algType, uType, duType, resType, pType, INType, tolType,
probType, ufType, L, jType, JC}(f, alg, u, fu, p,
probType, ufType, L, jType, JC, TCType}(f, alg, u, uprev, fu, p,
uf, linsolve, J, du1, jac_config,
force_stop, maxiters, internalnorm,
retcode, abstol, prob, stats)
retcode, abstol, reltol, prob, stats, tc_storage)
end
end

Expand All @@ -112,9 +122,11 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson
args...;
alias_u0 = false,
maxiters = 1000,
abstol = 1e-6,
abstol = nothing,
reltol = nothing,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
uT = eltype(prob.u0)
if alias_u0
u = prob.u0
else
Expand All @@ -130,26 +142,41 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson
end
uf, linsolve, J, du1, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

return NewtonRaphsonCache{iip}(f, alg, u, fu, p, uf, linsolve, J, du1, jac_config,
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

atol = _get_tolerance(abstol, tc.abstol, uT)
rtol = _get_tolerance(reltol, tc.reltol, uT)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu, p, uf, linsolve, J, du1,
jac_config,
false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0))
ReturnCode.Default, atol, rtol, prob, NLStats(1, 0, 0, 0, 0), storage)
end

function perform_step!(cache::NewtonRaphsonCache{true})
@unpack u, fu, f, p, alg = cache
@unpack u, uprev, fu, f, p, alg = cache
@unpack J, linsolve, du1 = cache
jacobian!(J, cache)

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu), linu = _vec(du1),
p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du1
f(fu, u, p)

if cache.internalnorm(cache.fu) < cache.abstol
if termination_condition(fu, u, uprev, cache.abstol, cache.reltol)
cache.force_stop = true
end

@. uprev = u
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
Expand All @@ -158,13 +185,18 @@ function perform_step!(cache::NewtonRaphsonCache{true})
end

function perform_step!(cache::NewtonRaphsonCache{false})
@unpack u, fu, f, p = cache
@unpack u, uprev, fu, f, p = cache
J = jacobian(cache, f)
cache.u = u - J \ fu
cache.fu = f(cache.u, p)
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

if termination_condition(cache.fu, cache.u, uprev, cache.abstol, cache.reltol)
cache.force_stop = true
end
cache.uprev = cache.u
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
Expand Down
5 changes: 5 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,8 @@ function rfunc(r::R, c2::R, M::R, γ1::R, γ2::R, β::R) where {R <: Real} # R-f
return (1 - γ1 - β) * (exp(r - c2) + β / (1 - γ1 - β))
end
end

function _get_tolerance(η, tc_η, ::Type{T}) where {T}
@show fallback_η
return ifelse!== nothing, η, ifelse(tc_η !== nothing, tc_η, fallback_η))
end
2 changes: 1 addition & 1 deletion test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ u0 = [1.0, 1.0]

precs = [
NonlinearSolve.DEFAULT_PRECS,
(args...) -> (Diagonal(rand!(similar(u0))), nothing)
(args...) -> (Diagonal(rand!(similar(u0))), nothing),
]

for prec in precs, linsolve in (nothing, KrylovJL_GMRES())
Expand Down

0 comments on commit 14ced9e

Please sign in to comment.