From 22a8fbe7fd81f5721b3d6f48748bbb0fcb16c200 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 13 Sep 2023 11:54:38 -0400 Subject: [PATCH] Add Gauss Newton and make LM work for NLS Problems --- src/NonlinearSolve.jl | 23 +++++- src/gaussnewton.jl | 160 ++++++++++++++++++++++++++++++++++++++++++ src/levenberg.jl | 39 ++++------ src/raphson.jl | 22 +----- src/trustRegion.jl | 5 +- test/nonlinearls.jl | 43 ++++++++++++ test/runtests.jl | 1 + 7 files changed, 241 insertions(+), 52 deletions(-) create mode 100644 src/gaussnewton.jl create mode 100644 test/nonlinearls.jl diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 2f851faa3..943a06378 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -28,16 +28,37 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences, abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end +abstract type AbstractNonlinearSolveCache{iip} end + +isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip + function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...) cache = init(prob, alg, args...; kwargs...) return solve!(cache) end +function SciMLBase.solve!(cache::AbstractNonlinearSolveCache) + while !cache.force_stop && cache.stats.nsteps < cache.maxiters + perform_step!(cache) + cache.stats.nsteps += 1 + end + + if cache.stats.nsteps == cache.maxiters + cache.retcode = ReturnCode.MaxIters + else + cache.retcode = ReturnCode.Success + end + + return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1; + cache.retcode, cache.stats) +end + include("utils.jl") include("raphson.jl") include("trustRegion.jl") include("levenberg.jl") +include("gaussnewton.jl") include("jacobian.jl") include("ad.jl") @@ -67,6 +88,6 @@ end export RadiusUpdateSchemes -export NewtonRaphson, TrustRegion, LevenbergMarquardt +export NewtonRaphson, TrustRegion, LevenbergMarquardt, GaussNewton end # module diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl new file mode 100644 index 000000000..88d0bb2a4 --- /dev/null +++ b/src/gaussnewton.jl @@ -0,0 +1,160 @@ +""" + GaussNewton(; concrete_jac = nothing, linsolve = nothing, + precs = DEFAULT_PRECS, adkwargs...) + +An advanced GaussNewton implementation with support for efficient handling of sparse +matrices via colored automatic differentiation and preconditioned linear solvers. Designed +for large-scale and numerically-difficult nonlinear least squares problems. + +!!! note + In most practical situations, users should prefer using `LevenbergMarquardt` instead! It + is a more general extension of `Gauss-Newton` Method. + +### Keyword Arguments + + - `autodiff`: determines the backend used for the Jacobian. Note that this argument is + ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to + `AutoForwardDiff()`. Valid choices are types from ADTypes.jl. + - `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used, + then the Jacobian will not be constructed and instead direct Jacobian-vector products + `J*v` are computed using forward-mode automatic differentiation or finite differencing + tricks (without ever constructing the Jacobian). However, if the Jacobian is still needed, + for example for a preconditioner, `concrete_jac = true` can be passed in order to force + the construction of the Jacobian. + - `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the + linear solves within the Newton method. Defaults to `nothing`, which means it uses the + LinearSolve.jl default algorithm choice. For more information on available algorithm + choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/). + - `precs`: the choice of preconditioners for the linear solver. Defaults to using no + preconditioners. For more information on specifying preconditioners for LinearSolve + algorithms, consult the + [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/). +""" +@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} + ad::AD + linsolve + precs +end + +function GaussNewton(; concrete_jac = nothing, linsolve = NormalCholeskyFactorization(), + precs = DEFAULT_PRECS, adkwargs...) + ad = default_adargs_to_adtype(; adkwargs...) + return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs) +end + +@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip} + f + alg + u + fu1 + fu2 + fu_new + du + p + uf + linsolve + J + JᵀJ + Jᵀf + jac_cache + force_stop + maxiters::Int + internalnorm + retcode::ReturnCode.T + abstol + prob + stats::NLStats +end + +function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GaussNewton, + args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + kwargs...) where {uType, iip} + @unpack f, u0, p = prob + u = alias_u0 ? u0 : deepcopy(u0) + if iip + fu1 = f.resid_prototype === nothing ? zero(u) : f.resid_prototype + f(fu1, u, p) + else + fu1 = f(u, p) + end + uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip)) + + JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2)) + Jᵀf = zero(u) + + return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J, + JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, + prob, NLStats(1, 0, 0, 0, 0)) +end + +function perform_step!(cache::GaussNewtonCache{true}) + @unpack u, fu1, f, p, alg, J, JᵀJ, Jᵀf, linsolve, du = cache + jacobian!!(J, cache) + mul!(JᵀJ, J', J) + mul!(Jᵀf, J', fu1) + + # u = u - J \ fu + linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du), + p, reltol = cache.abstol) + cache.linsolve = linres.cache + @. u = u - du + f(cache.fu_new, u, p) + + (cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol || + cache.internalnorm(cache.fu_new) < cache.abstol) && + (cache.force_stop = true) + cache.fu1 .= cache.fu_new + cache.stats.nf += 1 + cache.stats.njacs += 1 + cache.stats.nsolve += 1 + cache.stats.nfactors += 1 + return nothing +end + +function perform_step!(cache::GaussNewtonCache{false}) + @unpack u, fu1, f, p, alg, linsolve = cache + + cache.J = jacobian!!(cache.J, cache) + cache.JᵀJ = cache.J' * cache.J + cache.Jᵀf = cache.J' * fu1 + # u = u - J \ fu + if linsolve === nothing + cache.du = fu1 / cache.J + else + linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf), + linu = _vec(cache.du), p, reltol = cache.abstol) + cache.linsolve = linres.cache + end + cache.u = @. u - cache.du # `u` might not support mutation + cache.fu_new = f(cache.u, p) + + (cache.internalnorm(cache.fu_new .- cache.fu1) < cache.abstol || + cache.internalnorm(cache.fu_new) < cache.abstol) && + (cache.force_stop = true) + cache.fu1 = cache.fu_new + cache.stats.nf += 1 + cache.stats.njacs += 1 + cache.stats.nsolve += 1 + cache.stats.nfactors += 1 + return nothing +end + +function SciMLBase.reinit!(cache::GaussNewtonCache{iip}, u0 = cache.u; p = cache.p, + abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + cache.p = p + if iip + recursivecopy!(cache.u, u0) + cache.f(cache.fu1, cache.u, p) + else + # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter + cache.u = u0 + cache.fu1 = cache.f(cache.u, p) + end + cache.abstol = abstol + cache.maxiters = maxiters + cache.stats.nf = 1 + cache.stats.nsteps = 1 + cache.force_stop = false + cache.retcode = ReturnCode.Default + return cache +end diff --git a/src/levenberg.jl b/src/levenberg.jl index 6265eba3f..6845f459c 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -97,7 +97,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing, finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D) end -@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} +@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} <: + AbstractNonlinearSolveCache{iip} f alg u::uType @@ -134,12 +135,12 @@ end loss_old::lossType make_new_J::Bool fu_tmp + u_tmp + Jv mat_tmp::jType stats::NLStats end -isinplace(::LevenbergMarquardtCache{iip}) where {iip} = iip - function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt, args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} @@ -171,7 +172,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq end loss = internalnorm(fu1) - JᵀJ = zero(J) + JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2)) v = zero(u) a = zero(u) tmp_vec = zero(u) @@ -179,13 +180,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq δ = zero(u) make_new_J = true fu_tmp = zero(fu1) - mat_tmp = zero(J) + mat_tmp = zero(JᵀJ) return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD, JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic, b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp, - mat_tmp, NLStats(1, 0, 0, 0, 0)) + zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0)) end function perform_step!(cache::LevenbergMarquardtCache{true}) @@ -205,10 +206,10 @@ function perform_step!(cache::LevenbergMarquardtCache{true}) @unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache # Usual Levenberg-Marquardt step ("velocity"). - # The following lines do: cache.v = -cache.mat_tmp \ cache.fu_tmp - mul!(cache.fu_tmp, J', fu1) + # The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp + mul!(cache.u_tmp, J', fu1) @. cache.mat_tmp = JᵀJ + λ * DᵀD - linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp), + linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol) cache.linsolve = linres.cache @. cache.v = -cache.du @@ -218,8 +219,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true}) f(cache.fu_tmp, u .+ h .* v, p) # The following lines do: cache.a = -J \ cache.fu_tmp - mul!(cache.du, J, v) - @. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.du) + mul!(cache.Jv, J, v) + @. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv) linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(cache.fu_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol) cache.linsolve = linres.cache @@ -317,19 +318,3 @@ function perform_step!(cache::LevenbergMarquardtCache{false}) cache.λ_factor = cache.damping_increase_factor return nothing end - -function SciMLBase.solve!(cache::LevenbergMarquardtCache) - while !cache.force_stop && cache.stats.nsteps < cache.maxiters - perform_step!(cache) - cache.stats.nsteps += 1 - end - - if cache.stats.nsteps == cache.maxiters - cache.retcode = ReturnCode.MaxIters - else - cache.retcode = ReturnCode.Success - end - - return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1; - cache.retcode, cache.stats) -end diff --git a/src/raphson.jl b/src/raphson.jl index 33d12c4ba..266564e59 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -32,15 +32,13 @@ for large-scale and numerically-difficult nonlinear systems. precs end -concrete_jac(::NewtonRaphson{CJ}) where {CJ} = CJ - function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs) end -@concrete mutable struct NewtonRaphsonCache{iip} +@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u @@ -61,8 +59,6 @@ end stats::NLStats end -isinplace(::NewtonRaphsonCache{iip}) where {iip} = iip - function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} @@ -123,22 +119,6 @@ function perform_step!(cache::NewtonRaphsonCache{false}) return nothing end -function SciMLBase.solve!(cache::NewtonRaphsonCache) - while !cache.force_stop && cache.stats.nsteps < cache.maxiters - perform_step!(cache) - cache.stats.nsteps += 1 - end - - if cache.stats.nsteps == cache.maxiters - cache.retcode = ReturnCode.MaxIters - else - cache.retcode = ReturnCode.Success - end - - return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1; - cache.retcode, cache.stats) -end - function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p, abstol = cache.abstol, maxiters = cache.maxiters) where {iip} cache.p = p diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 41ccb994e..6124e1f3b 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -155,7 +155,8 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU expand_threshold, shrink_factor, expand_factor, max_shrink_times) end -@concrete mutable struct TrustRegionCache{iip, trustType, floatType} +@concrete mutable struct TrustRegionCache{iip, trustType, floatType} <: + AbstractNonlinearSolveCache{iip} f alg u_prev @@ -303,8 +304,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, NLStats(1, 0, 0, 0, 0)) end -isinplace(::TrustRegionCache{iip}) where {iip} = iip - function perform_step!(cache::TrustRegionCache{true}) @unpack make_new_J, J, fu, f, u, p, u_tmp, alg, linsolve = cache if cache.make_new_J diff --git a/test/nonlinearls.jl b/test/nonlinearls.jl new file mode 100644 index 000000000..d6f80367a --- /dev/null +++ b/test/nonlinearls.jl @@ -0,0 +1,43 @@ +using NonlinearSolve, LinearSolve, LinearAlgebra, Test, Random + +true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) +true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])) + +θ_true = [1.0, 0.1, 2.0, 0.5] + +x = [-1.0, -0.5, 0.0, 0.5, 1.0] + +y_target = true_function(x, θ_true) + +function loss_function(θ, p) + ŷ = true_function(p, θ) + return abs2.(ŷ .- y_target) +end + +function loss_function(resid, θ, p) + true_function(resid, p, θ) + resid .= abs2.(resid .- y_target) + return resid +end + +θ_init = θ_true .+ randn!(similar(θ_true)) * 0.1 +prob_oop = NonlinearProblem{false}(loss_function, θ_init, x) +prob_iip = NonlinearProblem(NonlinearFunction(loss_function; + resid_prototype = zero(y_target)), θ_init, x) + +sol = solve(prob_oop, GaussNewton(); maxiters = 1000, abstol = 1e-8) +@test SciMLBase.successful_retcode(sol) +@test norm(sol.resid) < 1e-6 + +sol = solve(prob_iip, GaussNewton(); maxiters = 1000, abstol = 1e-8) +@test SciMLBase.successful_retcode(sol) +@test norm(sol.resid) < 1e-6 + +sol = solve(prob_oop, LevenbergMarquardt(); maxiters = 1000, abstol = 1e-8) +@test SciMLBase.successful_retcode(sol) +@test norm(sol.resid) < 1e-6 + +sol = solve(prob_iip, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization()); + maxiters = 1000, abstol = 1e-8) +@test SciMLBase.successful_retcode(sol) +@test norm(sol.resid) < 1e-6 diff --git a/test/runtests.jl b/test/runtests.jl index a84fc3cb1..a3061dcc4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ end if GROUP == "All" || GROUP == "Core" @time @safetestset "Basic Tests + Some AD" include("basictests.jl") @time @safetestset "Sparsity Tests" include("sparse.jl") + @time @safetestset "Nonlinear Least Squares" include("nonlinearls.jl") end if GROUP == "GPU"