Skip to content

Commit

Permalink
Merge pull request #236 from avik-pal/ap/lsoptim
Browse files Browse the repository at this point in the history
Impoving NLS Solvers
  • Loading branch information
ChrisRackauckas authored Oct 16, 2023
2 parents a6af39c + 1c19fa7 commit 1a0e5ee
Show file tree
Hide file tree
Showing 10 changed files with 305 additions and 47 deletions.
16 changes: 14 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.2.1"
version = "2.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -24,15 +24,25 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"

[extensions]
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"

[compat]
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
ConcreteStructs = "0.2"
DiffEqBase = "6.130"
EnumX = "1"
Enzyme = "0.11"
FastLevenbergMarquardt = "0.1"
FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearSolve = "2"
NonlinearProblemLibrary = "0.1"
Expand All @@ -50,7 +60,9 @@ julia = "1.9"
[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
Expand All @@ -64,4 +76,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary"]
test = ["Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt"]
71 changes: 71 additions & 0 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
module NonlinearSolveFastLevenbergMarquardtExt

using ArrayInterface, NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import FastLevenbergMarquardt as FastLM

NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true

function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve}
if linsolve == :cholesky
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
elseif linsolve == :qr
return FastLM.QRSolver(eltype(x), length(x))
else
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
end
end

@concrete struct FastLMCache
f!
J!
prob
alg
lmworkspace
solver
kwargs
end

@concrete struct InplaceFunction{iip} <: Function
f
end

(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p)
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8,
verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"

f! = InplaceFunction{iip}(prob.f)
J! = InplaceFunction{iip}(prob.f.jac)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
prob.f.resid_prototype

J = similar(prob.u0, length(resid_prototype), length(prob.u0))

solver = _fast_lm_solver(alg, prob.u0)
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J)

return FastLMCache(f!, J!, prob, alg, LM, solver,
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
alg.maxfactor, kwargs...))
end

function SciMLBase.solve!(cache::FastLMCache)
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
retcode = info == 1 ? ReturnCode.Success :
(info == -1 ? ReturnCode.MaxIters : ReturnCode.Default)
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
end

end
68 changes: 68 additions & 0 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
module NonlinearSolveLeastSquaresOptimExt

using NonlinearSolve, SciMLBase
import ConcreteStructs: @concrete
import LeastSquaresOptim as LSO

NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true

function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve}
ls = linsolve == :qr ? LSO.QR() :
(linsolve == :cholesky ? LSO.Cholesky() :
(linsolve == :lsmr ? LSO.LSMR() : nothing))
if alg == :lm
return LSO.LevenbergMarquardt(ls)
elseif alg == :dogleg
return LSO.Dogleg(ls)
else
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg"))
end
end

@concrete struct LeastSquaresOptimCache
prob
alg
allocated_prob
kwargs
end

@concrete struct FunctionWrapper{iip}
f
p
end

(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...;
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)

f! = FunctionWrapper{iip}(prob.f, prob.p)
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
prob.f.resid_prototype

lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = resid_prototype, g!,
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))

return LeastSquaresOptimCache(prob, alg, allocated_prob,
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
kwargs...))
end

function SciMLBase.solve!(cache::LeastSquaresOptimCache)
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...)
maxiters = cache.kwargs[:iterations]
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success :
(res.iterations maxiters ? ReturnCode.MaxIters :
ReturnCode.ConvergenceFailure)
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations)
return SciMLBase.build_solution(cache.prob, cache.alg, res.minimizer, res.ssr / 2;
retcode, original = res, stats)
end

end
4 changes: 4 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm

abstract type AbstractNonlinearSolveCache{iip} end

extension_loaded(::Val) = false

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
Expand Down Expand Up @@ -60,6 +62,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
end

include("utils.jl")
include("algorithms.jl")
include("linesearch.jl")
include("raphson.jl")
include("trustRegion.jl")
Expand Down Expand Up @@ -93,6 +96,7 @@ end
export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton
export LSOptimSolver, FastLevenbergMarquardtSolver

export LineSearch

Expand Down
80 changes: 80 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Define Algorithms extended via extensions
"""
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving
`NonlinearLeastSquaresProblem`.
## Arguments:
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
on the Jacobian structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
!!! note
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
autodiff::Symbol
end

function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central)
@assert alg in (:lm, :dogleg)
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr)
@assert autodiff in (:central, :forward)

if !extension_loaded(Val(:LeastSquaresOptim))
@warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \
before `solve(prob, LSOptimSolver())` is called."
end

return LSOptimSolver{alg, linsolve}(autodiff)
end

"""
FastLevenbergMarquardtSolver(linsolve = :cholesky)
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
`NonlinearLeastSquaresProblem`.
!!! warning
This is not really the fastest solver. It is called that since the original package
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
!!! warning
This algorithm requires the jacobian function to be provided!
## Arguments:
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
!!! note
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm
factor
factoraccept
factorreject
factorupdate::Symbol
minscale
maxscale
minfactor
maxfactor
end

function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6,
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt,
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32)
@assert linsolve in (:qr, :cholesky)
@assert factorupdate in (:marquardt, :nielson)

if !extension_loaded(Val(:FastLevenbergMarquardt))
@warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \
before `solve(prob, FastLevenbergMarquardtSolver())` is called."
end

return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject,
factorupdate, minscale, maxscale, minfactor, maxfactor)
end
18 changes: 9 additions & 9 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
adkwargs...)
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
An advanced GaussNewton implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
Expand Down Expand Up @@ -41,7 +41,7 @@ for large-scale and numerically-difficult nonlinear least squares problems.
precs
end

function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
Expand Down Expand Up @@ -93,12 +93,12 @@ end
function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
jacobian!!(J, cache)
mul!(JᵀJ, J', J)
mul!(Jᵀf, J', fu1)
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
linu = _vec(du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)
Expand All @@ -125,8 +125,8 @@ function perform_step!(cache::GaussNewtonCache{false})
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
linu = _vec(cache.du), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
Expand Down
22 changes: 15 additions & 7 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
Expand All @@ -74,7 +74,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
jac_cache = nothing
end

J = if !(linsolve_needs_jac || alg_wants_jac)
# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
else
Expand All @@ -93,14 +95,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
Jᵀfu = J' * fu
end

linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
u0 = _vec(du))
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))

weight = similar(u)
recursivefill!(weight, true)

Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
nothing, nothing, nothing, nothing, nothing)..., weight)
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)

Expand All @@ -114,9 +116,15 @@ __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
__get_nonsparse_ad(ad) = ad

__init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = zeros(eltype(J), size(J, 2), size(J, 2))
__init_JᵀJ(J::AbstractArray) = J' * J
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x
# LinearSolve with `nothing` doesn't dispatch correctly here
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
Expand Down
Loading

0 comments on commit 1a0e5ee

Please sign in to comment.