-
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #236 from avik-pal/ap/lsoptim
Impoving NLS Solvers
- Loading branch information
Showing
10 changed files
with
305 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
module NonlinearSolveFastLevenbergMarquardtExt | ||
|
||
using ArrayInterface, NonlinearSolve, SciMLBase | ||
import ConcreteStructs: @concrete | ||
import FastLevenbergMarquardt as FastLM | ||
|
||
NonlinearSolve.extension_loaded(::Val{:FastLevenbergMarquardt}) = true | ||
|
||
function _fast_lm_solver(::FastLevenbergMarquardtSolver{linsolve}, x) where {linsolve} | ||
if linsolve == :cholesky | ||
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x)) | ||
elseif linsolve == :qr | ||
return FastLM.QRSolver(eltype(x), length(x)) | ||
else | ||
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve")) | ||
end | ||
end | ||
|
||
@concrete struct FastLMCache | ||
f! | ||
J! | ||
prob | ||
alg | ||
lmworkspace | ||
solver | ||
kwargs | ||
end | ||
|
||
@concrete struct InplaceFunction{iip} <: Function | ||
f | ||
end | ||
|
||
(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p) | ||
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p)) | ||
|
||
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, | ||
alg::FastLevenbergMarquardtSolver, args...; abstol = 1e-8, reltol = 1e-8, | ||
verbose = false, maxiters = 1000, kwargs...) | ||
iip = SciMLBase.isinplace(prob) | ||
|
||
@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!" | ||
|
||
f! = InplaceFunction{iip}(prob.f) | ||
J! = InplaceFunction{iip}(prob.f.jac) | ||
|
||
resid_prototype = prob.f.resid_prototype === nothing ? | ||
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) : | ||
prob.f.resid_prototype | ||
|
||
J = similar(prob.u0, length(resid_prototype), length(prob.u0)) | ||
|
||
solver = _fast_lm_solver(alg, prob.u0) | ||
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J) | ||
|
||
return FastLMCache(f!, J!, prob, alg, LM, solver, | ||
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept, | ||
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor, | ||
alg.maxfactor, kwargs...)) | ||
end | ||
|
||
function SciMLBase.solve!(cache::FastLMCache) | ||
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!, | ||
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...) | ||
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter) | ||
retcode = info == 1 ? ReturnCode.Success : | ||
(info == -1 ? ReturnCode.MaxIters : ReturnCode.Default) | ||
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx; | ||
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
module NonlinearSolveLeastSquaresOptimExt | ||
|
||
using NonlinearSolve, SciMLBase | ||
import ConcreteStructs: @concrete | ||
import LeastSquaresOptim as LSO | ||
|
||
NonlinearSolve.extension_loaded(::Val{:LeastSquaresOptim}) = true | ||
|
||
function _lso_solver(::LSOptimSolver{alg, linsolve}) where {alg, linsolve} | ||
ls = linsolve == :qr ? LSO.QR() : | ||
(linsolve == :cholesky ? LSO.Cholesky() : | ||
(linsolve == :lsmr ? LSO.LSMR() : nothing)) | ||
if alg == :lm | ||
return LSO.LevenbergMarquardt(ls) | ||
elseif alg == :dogleg | ||
return LSO.Dogleg(ls) | ||
else | ||
throw(ArgumentError("Unknown LeastSquaresOptim Algorithm: $alg")) | ||
end | ||
end | ||
|
||
@concrete struct LeastSquaresOptimCache | ||
prob | ||
alg | ||
allocated_prob | ||
kwargs | ||
end | ||
|
||
@concrete struct FunctionWrapper{iip} | ||
f | ||
p | ||
end | ||
|
||
(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p) | ||
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p)) | ||
|
||
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LSOptimSolver, args...; | ||
abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...) | ||
iip = SciMLBase.isinplace(prob) | ||
|
||
f! = FunctionWrapper{iip}(prob.f, prob.p) | ||
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p) | ||
|
||
resid_prototype = prob.f.resid_prototype === nothing ? | ||
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) : | ||
prob.f.resid_prototype | ||
|
||
lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = resid_prototype, g!, | ||
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype)) | ||
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg)) | ||
|
||
return LeastSquaresOptimCache(prob, alg, allocated_prob, | ||
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose, | ||
kwargs...)) | ||
end | ||
|
||
function SciMLBase.solve!(cache::LeastSquaresOptimCache) | ||
res = LSO.optimize!(cache.allocated_prob; cache.kwargs...) | ||
maxiters = cache.kwargs[:iterations] | ||
retcode = res.x_converged || res.f_converged || res.g_converged ? ReturnCode.Success : | ||
(res.iterations ≥ maxiters ? ReturnCode.MaxIters : | ||
ReturnCode.ConvergenceFailure) | ||
stats = SciMLBase.NLStats(res.f_calls, res.g_calls, -1, -1, res.iterations) | ||
return SciMLBase.build_solution(cache.prob, cache.alg, res.minimizer, res.ssr / 2; | ||
retcode, original = res, stats) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Define Algorithms extended via extensions | ||
""" | ||
LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) | ||
Wrapper over [LeastSquaresOptim.jl](https://github.com/matthieugomez/LeastSquaresOptim.jl) for solving | ||
`NonlinearLeastSquaresProblem`. | ||
## Arguments: | ||
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`. | ||
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If | ||
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based | ||
on the Jacobian structure. | ||
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`. | ||
!!! note | ||
This algorithm is only available if `LeastSquaresOptim.jl` is installed. | ||
""" | ||
struct LSOptimSolver{alg, linsolve} <: AbstractNonlinearSolveAlgorithm | ||
autodiff::Symbol | ||
end | ||
|
||
function LSOptimSolver(alg = :lm; linsolve = nothing, autodiff::Symbol = :central) | ||
@assert alg in (:lm, :dogleg) | ||
@assert linsolve === nothing || linsolve in (:qr, :cholesky, :lsmr) | ||
@assert autodiff in (:central, :forward) | ||
|
||
if !extension_loaded(Val(:LeastSquaresOptim)) | ||
@warn "LeastSquaresOptim.jl is not loaded! It needs to be explicitly loaded \ | ||
before `solve(prob, LSOptimSolver())` is called." | ||
end | ||
|
||
return LSOptimSolver{alg, linsolve}(autodiff) | ||
end | ||
|
||
""" | ||
FastLevenbergMarquardtSolver(linsolve = :cholesky) | ||
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving | ||
`NonlinearLeastSquaresProblem`. | ||
!!! warning | ||
This is not really the fastest solver. It is called that since the original package | ||
is called "Fast". `LevenbergMarquardt()` is almost always a better choice. | ||
!!! warning | ||
This algorithm requires the jacobian function to be provided! | ||
## Arguments: | ||
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`. | ||
!!! note | ||
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. | ||
""" | ||
@concrete struct FastLevenbergMarquardtSolver{linsolve} <: AbstractNonlinearSolveAlgorithm | ||
factor | ||
factoraccept | ||
factorreject | ||
factorupdate::Symbol | ||
minscale | ||
maxscale | ||
minfactor | ||
maxfactor | ||
end | ||
|
||
function FastLevenbergMarquardtSolver(linsolve::Symbol = :cholesky; factor = 1e-6, | ||
factoraccept = 13.0, factorreject = 3.0, factorupdate = :marquardt, | ||
minscale = 1e-12, maxscale = 1e16, minfactor = 1e-28, maxfactor = 1e32) | ||
@assert linsolve in (:qr, :cholesky) | ||
@assert factorupdate in (:marquardt, :nielson) | ||
|
||
if !extension_loaded(Val(:FastLevenbergMarquardt)) | ||
@warn "FastLevenbergMarquardt.jl is not loaded! It needs to be explicitly loaded \ | ||
before `solve(prob, FastLevenbergMarquardtSolver())` is called." | ||
end | ||
|
||
return FastLevenbergMarquardtSolver{linsolve}(factor, factoraccept, factorreject, | ||
factorupdate, minscale, maxscale, minfactor, maxfactor) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.