Skip to content

Commit

Permalink
Proper handling of complex numbers and failures
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 27, 2023
1 parent 3565824 commit 9c50e5f
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ function perform_step!(cache::GeneralBroydenCache{true})

if all(cache.reset_check, du) || all(cache.reset_check, dfu)
if cache.resets cache.max_resets
cache.retcode = ReturnCode.Unstable
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
Expand Down Expand Up @@ -153,7 +153,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.dfu = cache.fu2 .- cache.fu
if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
if cache.resets cache.max_resets
cache.retcode = ReturnCode.Unstable
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
Expand Down
10 changes: 5 additions & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip},
@unpack adkwargs, linsolve, precs = alg

algs = (
# Klement(),
# Broyden(),
GeneralKlement(; linsolve, precs),
GeneralBroyden(),
NewtonRaphson(; linsolve, precs, adkwargs...),
NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...),
TrustRegion(; linsolve, precs, adkwargs...),
Expand Down Expand Up @@ -159,7 +159,7 @@ end
]
else
[
:(GeneralKlement()),
:(GeneralKlement(; linsolve, precs)),
:(GeneralBroyden()),
:(NewtonRaphson(; linsolve, precs, adkwargs...)),
:(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)),
Expand Down Expand Up @@ -191,7 +191,7 @@ end
push!(calls,
quote
resids = tuple($(Tuple(resids)...))
minfu, idx = findmin(DEFAULT_NORM, resids)
minfu, idx = __findmin(DEFAULT_NORM, resids)
end)

for i in 1:length(algs)
Expand Down Expand Up @@ -249,7 +249,7 @@ end
retcode = ReturnCode.MaxIters

fus = tuple($(Tuple(resids)...))
minfu, idx = findmin(cache.caches[1].internalnorm, fus)
minfu, idx = __findmin(cache.caches[1].internalnorm, fus)
stats = cache.caches[idx].stats
u = cache.caches[idx].u

Expand Down
4 changes: 2 additions & 2 deletions src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ function perform_step!(cache::GeneralKlementCache{true})
if singular
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
cache.retcode = ReturnCode.ConvergenceFailure
return nothing
end
fact_done = false
Expand Down Expand Up @@ -176,7 +176,7 @@ function perform_step!(cache::GeneralKlementCache{false})
if singular
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
cache.retcode = ReturnCode.ConvergenceFailure
return nothing
end
fact_done = false
Expand Down
4 changes: 2 additions & 2 deletions src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
if cache.iterations_since_reset > size(cache.U, 1) &&
(all(cache.reset_check, du) || all(cache.reset_check, cache.dfu))
if cache.resets cache.max_resets
cache.retcode = ReturnCode.Unstable
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
Expand Down Expand Up @@ -188,7 +188,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
if cache.iterations_since_reset > size(cache.U, 1) &&
(all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu))
if cache.resets cache.max_resets
cache.retcode = ReturnCode.Unstable
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
Expand Down
7 changes: 2 additions & 5 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing,
internalnorm = DEFAULT_NORM,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
Expand All @@ -91,9 +90,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
linsolve_kwargs)

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u))
reltol, termination_condition, eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)

Expand Down
35 changes: 20 additions & 15 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,6 @@ for large-scale and numerically-difficult nonlinear systems.
`expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`.
- `max_shrink_times`: the maximum number of times to shrink the trust region radius in a
row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`.
!!! warning
`linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
Support for the OOP version is planned!
"""
@concrete struct TrustRegion{CJ, AD, MTR} <:
AbstractNewtonAlgorithm{CJ, AD}
Expand Down Expand Up @@ -250,7 +245,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
linsolve_kwargs)
u_tmp = zero(u)
u_cauchy = zero(u)
u_gauss_newton = zero(u)
u_gauss_newton = _mutable_zero(u)

loss_new = loss
H = zero(J' * J)
Expand Down Expand Up @@ -338,10 +333,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
initial_trust_radius = convert(trustType, 1.0)
end

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u))
abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
termination_condition, eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)

Expand All @@ -368,8 +361,7 @@ function perform_step!(cache::TrustRegionCache{true})
# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu),
linu = _vec(u_gauss_newton),
p = p, reltol = cache.abstol)
linu = _vec(u_gauss_newton), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.u_gauss_newton = -1 * u_gauss_newton
end
Expand All @@ -395,7 +387,12 @@ function perform_step!(cache::TrustRegionCache{false})
cache.H = J' * J
cache.g = _restructure(fu, J' * _vec(fu))
cache.stats.njacs += 1
cache.u_gauss_newton = -1 .* _restructure(cache.g, cache.H \ _vec(cache.g))

# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
# to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular
linres = dolinsolve(cache.alg.precs, cache.linsolve, A = cache.J, b = -_vec(fu),
linu = _vec(cache.u_gauss_newton), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
end

# Compute the Newton step.
Expand Down Expand Up @@ -718,8 +715,16 @@ function jvp!(cache::TrustRegionCache{true})
end

function not_terminated(cache::TrustRegionCache)
return !cache.force_stop && cache.stats.nsteps < cache.maxiters &&
cache.shrink_counter < cache.alg.max_shrink_times
non_shrink_terminated = cache.force_stop || cache.stats.nsteps cache.maxiters
# Terminated due to convergence or maxiters
non_shrink_terminated && return false
# Terminated due to too many shrink
shrink_terminated = cache.shrink_counter cache.alg.max_shrink_times
if shrink_terminated
cache.retcode = ReturnCode.ConvergenceFailure
return false
end
return true
end
get_fu(cache::TrustRegionCache) = cache.fu

Expand Down
14 changes: 13 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ end
@inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2, u)) / length(u))
@inline DEFAULT_NORM(u) = norm(u)

# Ignores NaN
function __findmin(f, x)
return findmin(x) do xᵢ
fx = f(xᵢ)
return isnan(fx) ? Inf : fx
end
end

"""
default_adargs_to_adtype(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), diff_type = Val{:forward})
Expand Down Expand Up @@ -210,9 +218,13 @@ function __get_concrete_algorithm(alg, prob)
return set_ad(alg, ad)
end

__cvt_real(::Type{T}, ::Nothing) where {T} = nothing
__cvt_real(::Type{T}, x) where {T} = real(T(x))

function _get_tolerance(η, tc_η, ::Type{T}) where {T}
fallback_η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)
return T(ifelse!== nothing, η, ifelse(tc_η !== nothing, tc_η, fallback_η)))
return ifelse!== nothing, __cvt_real(T, η),
ifelse(tc_η !== nothing, __cvt_real(T, tc_η), fallback_η))
end

function _init_termination_elements(abstol, reltol, termination_condition,
Expand Down

0 comments on commit 9c50e5f

Please sign in to comment.