Skip to content

Commit

Permalink
NLLS Poly Algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 6, 2023
1 parent f2d3285 commit 549c857
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 15 deletions.
13 changes: 7 additions & 6 deletions ext/NonlinearSolveFastLevenbergMarquardtExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,24 @@ end
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
alg::FastLevenbergMarquardtJL, args...; abstol = 1e-8, reltol = 1e-8,
verbose = false, maxiters = 1000, kwargs...)
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = 1e-8,
reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)
u0 = alias_u0 ? prob.u0 : deepcopy(prob.u0)

Check warning on line 38 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L38

Added line #L38 was not covered by tests

@assert prob.f.jac!==nothing "FastLevenbergMarquardt requires a Jacobian!"

f! = InplaceFunction{iip}(prob.f)
J! = InplaceFunction{iip}(prob.f.jac)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
(!iip ? prob.f(u0, prob.p) : zeros(u0)) :
prob.f.resid_prototype

J = similar(prob.u0, length(resid_prototype), length(prob.u0))
J = similar(u0, length(resid_prototype), length(u0))

Check warning on line 49 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L49

Added line #L49 was not covered by tests

solver = _fast_lm_solver(alg, prob.u0)
LM = FastLM.LMWorkspace(prob.u0, resid_prototype, J)
solver = _fast_lm_solver(alg, u0)
LM = FastLM.LMWorkspace(u0, resid_prototype, J)

Check warning on line 52 in ext/NonlinearSolveFastLevenbergMarquardtExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveFastLevenbergMarquardtExt.jl#L51-L52

Added lines #L51 - L52 were not covered by tests

return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
Expand Down
8 changes: 5 additions & 3 deletions ext/NonlinearSolveLeastSquaresOptimExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,19 @@ end
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresOptimJL,
args...; abstol = 1e-8, reltol = 1e-8, verbose = false, maxiters = 1000, kwargs...)
args...; alias_u0 = false, abstol = 1e-8, reltol = 1e-8, verbose = false,
maxiters = 1000, kwargs...)
iip = SciMLBase.isinplace(prob)
u = alias_u0 ? prob.u0 : deepcopy(prob.u0)

Check warning on line 39 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L39

Added line #L39 was not covered by tests

f! = FunctionWrapper{iip}(prob.f, prob.p)
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)

resid_prototype = prob.f.resid_prototype === nothing ?
(!iip ? prob.f(prob.u0, prob.p) : zeros(prob.u0)) :
(!iip ? prob.f(u, prob.p) : zeros(u)) :
prob.f.resid_prototype

lsoprob = LSO.LeastSquaresProblem(; x = prob.u0, f!, y = resid_prototype, g!,
lsoprob = LSO.LeastSquaresProblem(; x = u, f!, y = resid_prototype, g!,

Check warning on line 48 in ext/NonlinearSolveLeastSquaresOptimExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/NonlinearSolveLeastSquaresOptimExt.jl#L48

Added line #L48 was not covered by tests
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))

Expand Down
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ export RadiusUpdateSchemes
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient,
GeneralBroyden, GeneralKlement, LimitedMemoryBroyden
export LeastSquaresOptimJL, FastLevenbergMarquardtJL
export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg
export NonlinearSolvePolyAlgorithm,
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg

export LineSearch, LiFukushimaLineSearch

Expand Down
40 changes: 38 additions & 2 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,42 @@ function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothi
return NonlinearSolvePolyAlgorithm(algs, Val(:NLS))
end

"""
FastShortcutNLLSPolyalg(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
for more performance and then tries more robust techniques if the faster ones fail.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used,
then the Jacobian will not be constructed and instead direct Jacobian-vector products
`J*v` are computed using forward-mode automatic differentiation or finite differencing
tricks (without ever constructing the Jacobian). However, if the Jacobian is still needed,
for example for a preconditioner, `concrete_jac = true` can be passed in order to force
the construction of the Jacobian.
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
LinearSolve.jl default algorithm choice. For more information on available algorithm
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
"""
function FastShortcutNLLSPolyalg(; concrete_jac = nothing, linsolve = nothing,

Check warning on line 280 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L280

Added line #L280 was not covered by tests
precs = DEFAULT_PRECS, adkwargs...)
algs = (GaussNewton(; concrete_jac, linsolve, precs, adkwargs...),

Check warning on line 282 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L282

Added line #L282 was not covered by tests
GaussNewton(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
adkwargs...),
LevenbergMarquardt(; concrete_jac, linsolve, precs, adkwargs...))
return NonlinearSolvePolyAlgorithm(algs, Val(:NLLS))

Check warning on line 286 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L286

Added line #L286 was not covered by tests
end

## Defaults

function SciMLBase.__init(prob::NonlinearProblem, ::Nothing, args...; kwargs...)
Expand All @@ -263,10 +299,10 @@ end
# FIXME: We default to using LM currently. But once we have line searches for GN implemented
# we should default to a polyalgorithm.
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, ::Nothing, args...; kwargs...)
return SciMLBase.__init(prob, LevenbergMarquardt(), args...; kwargs...)
return SciMLBase.__init(prob, FastShortcutNLLSPolyalg(), args...; kwargs...)

Check warning on line 302 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L301-L302

Added lines #L301 - L302 were not covered by tests
end

function SciMLBase.__solve(prob::NonlinearLeastSquaresProblem, ::Nothing, args...;

Check warning on line 305 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L305

Added line #L305 was not covered by tests
kwargs...)
return SciMLBase.__solve(prob, LevenbergMarquardt(), args...; kwargs...)
return SciMLBase.__solve(prob, FastShortcutNLLSPolyalg(), args...; kwargs...)

Check warning on line 307 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L307

Added line #L307 was not covered by tests
end
6 changes: 3 additions & 3 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ append!(solvers,
LevenbergMarquardt(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
nothing,
nothing,
])

for prob in nlls_problems, solver in solvers
Expand All @@ -46,7 +46,8 @@ for prob in nlls_problems, solver in solvers
end

function jac!(J, θ, p)
ForwardDiff.jacobian!(J, resid -> loss_function(resid, θ, p), θ)
resid = zeros(length(p))
ForwardDiff.jacobian!(J, (resid, θ) -> loss_function(resid, θ, p), resid, θ)
return J
end

Expand All @@ -57,6 +58,5 @@ solvers = [FastLevenbergMarquardtJL(:cholesky), FastLevenbergMarquardtJL(:qr)]

for solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6
end

0 comments on commit 549c857

Please sign in to comment.