Skip to content

Commit

Permalink
Add Gauss Newton and make LM work for NLS Problems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 13, 2023
1 parent 5b46c2d commit 22a8fbe
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 52 deletions.
23 changes: 22 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,37 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end

abstract type AbstractNonlinearSolveCache{iip} end

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
args...; kwargs...)
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters

Check warning on line 48 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L48

Added line #L48 was not covered by tests
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end

include("utils.jl")
include("raphson.jl")
include("trustRegion.jl")
include("levenberg.jl")
include("gaussnewton.jl")
include("jacobian.jl")
include("ad.jl")

Expand Down Expand Up @@ -67,6 +88,6 @@ end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt
export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton

end # module
160 changes: 160 additions & 0 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
An advanced GaussNewton implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
for large-scale and numerically-difficult nonlinear least squares problems.
!!! note
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
is a more general extension of `Gauss-Newton` Method.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used,
then the Jacobian will not be constructed and instead direct Jacobian-vector products
`J*v` are computed using forward-mode automatic differentiation or finite differencing
tricks (without ever constructing the Jacobian). However, if the Jacobian is still needed,
for example for a preconditioner, `concrete_jac = true` can be passed in order to force
the construction of the Jacobian.
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
LinearSolve.jl default algorithm choice. For more information on available algorithm
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
"""
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
end

function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(),
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
fu1
fu2
fu_new
du
p
uf
linsolve
J
JᵀJ
Jᵀf
jac_cache
force_stop
maxiters::Int
internalnorm
retcode::ReturnCode.T
abstol
prob
stats::NLStats
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GaussNewton,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
if iip
fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype
f(fu1, u, p)
else
fu1 = f(u, p)
end
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))

JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
Jᵀf = zero(u)

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
prob, NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::GaussNewtonCache{true})
@unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache
jacobian!!(J, cache)
mul!(JᵀJ, J', J)
mul!(Jᵀf, J', fu1)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(cache.force_stop = true)
cache.fu1 .= cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function perform_step!(cache::GaussNewtonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache

cache.J = jacobian!!(cache.J, cache)
cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1
# u = u - J \ fu
if linsolve === nothing
cache.du = fu1 / cache.J

Check warning on line 122 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L122

Added line #L122 was not covered by tests
else
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

(cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol ||
cache.internalnorm(cache.fu_new) < cache.abstol) &&
(cache.force_stop = true)
cache.fu1 = cache.fu_new
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
return nothing
end

function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p,

Check warning on line 142 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L142

Added line #L142 was not covered by tests
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu1, cache.u, p)

Check warning on line 147 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L144-L147

Added lines #L144 - L147 were not covered by tests
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu1 = cache.f(cache.u, p)

Check warning on line 151 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L150-L151

Added lines #L150 - L151 were not covered by tests
end
cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache

Check warning on line 159 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L153-L159

Added lines #L153 - L159 were not covered by tests
end
39 changes: 12 additions & 27 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end

@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType}
@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} <:
AbstractNonlinearSolveCache{iip}
f
alg
u::uType
Expand Down Expand Up @@ -134,12 +135,12 @@ end
loss_old::lossType
make_new_J::Bool
fu_tmp
u_tmp
Jv
mat_tmp::jType
stats::NLStats
end

isinplace(::LevenbergMarquardtCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
Expand Down Expand Up @@ -171,21 +172,21 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq
end

loss = internalnorm(fu1)
JᵀJ = zero(J)
JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
v = zero(u)
a = zero(u)
tmp_vec = zero(u)
v_old = zero(u)
δ = zero(u)
make_new_J = true
fu_tmp = zero(fu1)
mat_tmp = zero(J)
mat_tmp = zero(JᵀJ)

return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
mat_tmp, NLStats(1, 0, 0, 0, 0))
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::LevenbergMarquardtCache{true})
Expand All @@ -205,10 +206,10 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache

# Usual Levenberg-Marquardt step ("velocity").
# The following lines do: cache.v = -cache.mat_tmp \ cache.fu_tmp
mul!(cache.fu_tmp, J', fu1)
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
mul!(cache.u_tmp, J', fu1)
@. cache.mat_tmp = JᵀJ + λ * DᵀD
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp),
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.v = -cache.du
Expand All @@ -218,8 +219,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
f(cache.fu_tmp, u .+ h .* v, p)

# The following lines do: cache.a = -J \ cache.fu_tmp
mul!(cache.du, J, v)
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.du)
mul!(cache.Jv, J, v)
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(cache.fu_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
Expand Down Expand Up @@ -317,19 +318,3 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.λ_factor = cache.damping_increase_factor
return nothing
end

function SciMLBase.solve!(cache::LevenbergMarquardtCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end
22 changes: 1 addition & 21 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ for large-scale and numerically-difficult nonlinear systems.
precs
end

concrete_jac(::NewtonRaphson{CJ}) where {CJ} = CJ

function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
end

@concrete mutable struct NewtonRaphsonCache{iip}
@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
Expand All @@ -61,8 +59,6 @@ end
stats::NLStats
end

isinplace(::NewtonRaphsonCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
Expand Down Expand Up @@ -123,22 +119,6 @@ function perform_step!(cache::NewtonRaphsonCache{false})
return nothing
end

function SciMLBase.solve!(cache::NewtonRaphsonCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
perform_step!(cache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1;
cache.retcode, cache.stats)
end

function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
Expand Down
5 changes: 2 additions & 3 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU
expand_threshold, shrink_factor, expand_factor, max_shrink_times)
end

@concrete mutable struct TrustRegionCache{iip, trustType, floatType}
@concrete mutable struct TrustRegionCache{iip, trustType, floatType} <:
AbstractNonlinearSolveCache{iip}
f
alg
u_prev
Expand Down Expand Up @@ -303,8 +304,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
NLStats(1, 0, 0, 0, 0))
end

isinplace(::TrustRegionCache{iip}) where {iip} = iip

function perform_step!(cache::TrustRegionCache{true})
@unpack make_new_J, J, fu, f, u, p, u_tmp, alg, linsolve = cache
if cache.make_new_J
Expand Down
Loading

0 comments on commit 22a8fbe

Please sign in to comment.