Skip to content

Commit

Permalink
Add support for line search in Newton Raphson
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 20, 2023
1 parent de8086c commit 7e26d18
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 60 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand All @@ -33,6 +34,7 @@ Enzyme = "0.11"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LinearSolve = "2"
LineSearches = "7"
PrecompileTools = "1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
Expand Down
5 changes: 4 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isi
import StaticArraysCore: StaticArray, SVector, SArray, MArray
import UnPack: @unpack

@reexport using ADTypes, SciMLBase, SimpleNonlinearSolve
@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve

const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
Expand All @@ -35,6 +35,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAl
end

include("utils.jl")
include("linesearch.jl")
include("raphson.jl")
include("trustRegion.jl")
include("levenberg.jl")
Expand Down Expand Up @@ -69,4 +70,6 @@ export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt

export LineSearch

end # module
11 changes: 6 additions & 5 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

sparsity_detection_alg(f, ad) = NoSparsityDetection()
sparsity_detection_alg(_, _) = NoSparsityDetection()
function sparsity_detection_alg(f, ad::AbstractSparseADType)
if f.sparsity === nothing
if f.jac_prototype === nothing
Expand Down Expand Up @@ -49,8 +49,8 @@ end
jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))

# Build Jacobian Caches
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
::Val{iip}) where {iip}
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip};
linsolve_kwargs=(;)) where {iip}
uf = JacobianWrapper{iip}(f, p)

haslinsolve = hasfield(typeof(alg), :linsolve)
Expand Down Expand Up @@ -92,14 +92,15 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,

Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)

return uf, linsolve, J, fu, jac_cache, du
end

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
::Val{false})
::Val{false}; kwargs...)
# NOTE: Scalar `u` assumes scalar output from `f`
uf = JacobianWrapper{false}(f, p)
return uf, nothing, u, nothing, nothing, u
Expand Down
12 changes: 4 additions & 8 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,12 @@ isinplace(::LevenbergMarquardtCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
linsolve_kwargs=(;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
if iip
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu1, u, p)
else
fu1 = f(u, p)
end
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
fu1 = evaluate_f(prob, u)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)

λ = convert(eltype(u), alg.damping_initial)
λ_factor = convert(eltype(u), alg.damping_increase_factor)
Expand Down
146 changes: 146 additions & 0 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
LineSearch(method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
Wrapper over algorithms from
[LineSeaches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl/). Allows automatic
construction of the objective functions for the line search algorithms utilizing automatic
differentiation for fast Vector Jacobian Products.
### Arguments
- `method`: the line search algorithm to use. Defaults to `Static()`, which means that the
step size is fixed to the value of `alpha`.
- `autodiff`: the automatic differentiation backend to use for the line search. Defaults to
`AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP.
`AutoZygote()` will be faster in most cases, but it requires `Zygote.jl` to be manually
installed and loaded
- `alpha`: the initial step size to use. Defaults to `true` (which is equivalent to `1`).
"""
@concrete struct LineSearch
method
autodiff
α
end

function LineSearch(; method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
return LineSearch(method, autodiff, alpha)
end

@concrete mutable struct LineSearchCache
f
ϕ
ϕdϕ
α
ls
end

function LineSearchCache(ls::LineSearch, f, u::Number, p, _, ::Val{false})
eval_f(u, du, α) = eval_f(u - α * du)
eval_f(u) = f(u, p)

ls.method isa Static && return LineSearchCache(eval_f, nothing, nothing, nothing,
convert(typeof(u), ls.α), ls)

g(u, fu) = last(value_derivative(Base.Fix2(f, p), u)) * fu

function ϕ(u, du)
function ϕ_internal(α)
u_ = u - α * du
_fu = eval_f(u_)
return dot(_fu, _fu) / 2
end
return ϕ_internal
end

function (u, du)
function dϕ_internal(α)
u_ = u - α * du
_fu = eval_f(u_)
g₀ = g(u_, _fu)
return dot(g₀, -du)
end
return dϕ_internal
end

function ϕdϕ(u, du)
function ϕdϕ_internal(α)
u_ = u - α * du
_fu = eval_f(u_)
g₀ = g(u_, _fu)
return dot(_fu, _fu) / 2, dot(g₀, -du)
end
return ϕdϕ_internal
end

return LineSearchCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
end

function LineSearchCache(ls::LineSearch, f, u, p, fu1, IIP::Val{iip}) where {iip}
fu = iip ? fu1 : nothing
u_ = _mutable_zero(u)

function eval_f(u, du, α)
@. u_ = u - α * du
return eval_f(u_)
end
eval_f(u) = evaluate_f(f, u, p, IIP; fu)

ls.method isa Static && return LineSearchCache(eval_f, nothing, nothing, nothing,
convert(eltype(u), ls.α), ls)

g₀ = _mutable_zero(u)

function g!(u, fu)
op = VecJac((args...) -> f(args..., p), u)
if iip
mul!(g₀, op, fu)
return g₀
else
return op * fu
end
end

function ϕ(u, du)
function ϕ_internal(α)
@. u_ = u - α * du
_fu = eval_f(u_)
return dot(_fu, _fu) / 2
end
return ϕ_internal
end

function (u, du)
function dϕ_internal(α)
@. u_ = u - α * du
_fu = eval_f(u_)
g₀ = g!(u_, _fu)
return dot(g₀, -du)
end
return dϕ_internal
end

function ϕdϕ(u, du)
function ϕdϕ_internal(α)
@. u_ = u - α * du
_fu = eval_f(u_)
g₀ = g!(u_, _fu)
return dot(_fu, _fu) / 2, dot(g₀, -du)
end
return ϕdϕ_internal
end

return LineSearchCache(eval_f, ϕ, dϕ, ϕdϕ, convert(eltype(u), ls.α), ls)
end

function perform_linesearch!(cache::LineSearchCache, u, du)
cache.ls.method isa Static && return (cache.α, cache.f(u, du, cache.α))

ϕ = cache.ϕ(u, du)
= cache.(u, du)
ϕdϕ = cache.ϕdϕ(u, du)

ϕ₀, dϕ₀ = ϕdϕ(zero(eltype(u)))

return cache.ls.method(ϕ, cache.(u, du), cache.ϕdϕ(u, du), cache.α, ϕ₀, dϕ₀)
end
35 changes: 21 additions & 14 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,24 @@ for large-scale and numerically-difficult nonlinear systems.
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
"""
@concrete struct NewtonRaphson{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
linesearch
end

concrete_jac(::NewtonRaphson{CJ}) where {CJ} = CJ

function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method=linesearch)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
end

@concrete mutable struct NewtonRaphsonCache{iip}
Expand All @@ -59,26 +64,23 @@ end
abstol
prob
stats::NLStats
lscache
end

isinplace(::NewtonRaphsonCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
linsolve_kwargs=(;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
if iip
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu1, u, p)
else
fu1 = _mutable(f(u, p))
end
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
fu1 = evaluate_f(prob, u)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)

return NewtonRaphsonCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob,
NLStats(1, 0, 0, 0, 0))
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)))
end

function perform_step!(cache::NewtonRaphsonCache{true})
Expand All @@ -89,8 +91,10 @@ function perform_step!(cache::NewtonRaphsonCache{true})
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du),
p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(fu1, u, p)

# Line Search
α, _ = perform_linesearch!(cache.lscache, u, du)
@. u = u - α * du

cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
Expand All @@ -112,7 +116,10 @@ function perform_step!(cache::NewtonRaphsonCache{false})
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation

# Line Search
α, _fu = perform_linesearch!(cache.lscache, u, cache.du)
cache.u = @. u - α * cache.du # `u` might not support mutation
cache.fu1 = f(cache.u, p)

cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true)
Expand Down
11 changes: 3 additions & 8 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,15 @@ end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-8, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
linsolve_kwargs=(;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
u_prev = zero(u)
if iip
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu1, u, p)
else
fu1 = f(u, p)
end
fu1 = evaluate_f(prob, u)
fu_prev = zero(fu1)

loss = get_loss(fu1)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs)

radius_update_scheme = alg.radius_update_scheme
max_trust_radius = convert(eltype(u), alg.max_trust_radius)
Expand Down
23 changes: 23 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,26 @@ _maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x)
# The shadow allocated for Enzyme needs to be mutable
_maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x)
_maybe_mutable(x, _) = x

# Helper function to get value of `f(u, p)`
function evaluate_f(prob::NonlinearProblem{uType, iip}, u) where {uType, iip}
@unpack f, u0, p = prob
if iip
fu = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu, u, p)
else
fu = _mutable(f(u, p))
end
return fu
end

evaluate_f(cache, u; fu = nothing) = evaluate_f(cache.f, u, cache.p, Val(cache.iip); fu)

function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip}
if iip
f(fu, u, p)
return fu
else
return f(u, p)
end
end
Loading

0 comments on commit 7e26d18

Please sign in to comment.