From af3e026986f11b835ce4feb67a3d46a77be1365f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 8 Oct 2023 17:29:53 -0400 Subject: [PATCH] Make LM and GN oop versions work with linearSolve.jl --- src/gaussnewton.jl | 7 +- src/jacobian.jl | 18 +++++- src/levenberg.jl | 84 +++++++++++++----------- src/utils.jl | 1 + test/23_test_problems.jl | 13 ++-- test/basictests.jl | 109 ++++++++++++++------------------ test/nonlinear_least_squares.jl | 9 ++- 7 files changed, 126 insertions(+), 115 deletions(-) diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 5c9557516..973be9288 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -82,10 +82,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::G else fu1 = f(u, p) end - uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip)) - - JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2)) - Jᵀf = zero(u) + uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip); + linsolve_with_JᵀJ = Val(true)) return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, @@ -120,6 +118,7 @@ function perform_step!(cache::GaussNewtonCache{false}) @unpack u, fu1, f, p, alg, linsolve = cache cache.J = jacobian!!(cache.J, cache) + cache.JᵀJ = cache.J' * cache.J cache.Jᵀf = cache.J' * fu1 # u = u - J \ fu diff --git a/src/jacobian.jl b/src/jacobian.jl index 82f2ef2bb..bd4575fcc 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -50,7 +50,8 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u)) # Build Jacobian Caches function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip}; - linsolve_kwargs = (;)) where {iip} + linsolve_kwargs = (;), + linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ} uf = JacobianWrapper{iip}(f, p) haslinsolve = hasfield(typeof(alg), :linsolve) @@ -85,7 +86,15 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii end du = _mutable_zero(u) - linprob = LinearProblem(J, _vec(fu); u0 = _vec(du)) + + if needsJᵀJ + JᵀJ = __init_JᵀJ(J) + # FIXME: This needs to be handled better for JacVec Operator + Jᵀfu = J' * fu + end + + linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); + u0 = _vec(du)) weight = similar(u) recursivefill!(weight, true) @@ -95,6 +104,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr, linsolve_kwargs...) + needsJᵀJ && return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu return uf, linsolve, J, fu, jac_cache, du end @@ -103,6 +113,10 @@ __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff() __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote() __get_nonsparse_ad(ad) = ad +__init_JᵀJ(J::Number) = zero(J) +__init_JᵀJ(J::AbstractArray) = zeros(eltype(J), size(J, 2), size(J, 2)) +__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef) + ## Special Handling for Scalars function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p, ::Val{false}; kwargs...) diff --git a/src/levenberg.jl b/src/levenberg.jl index bea4c84a3..f43eff0a1 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -72,11 +72,6 @@ numerically-difficult nonlinear systems. where `J` is the Jacobian. It is suggested by [this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in `DᵀD` to prevent the damping from being too small. Defaults to `1e-8`. - -!!! warning - - `linsolve` and `precs` are used exclusively for the inplace version of the algorithm. - Support for the OOP version is planned! """ @concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD} ad::AD @@ -102,18 +97,17 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing, finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D) end -@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} <: - AbstractNonlinearSolveCache{iip} +@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip} f alg - u::uType + u fu1 fu2 du p uf linsolve - J::jType + J jac_cache force_stop::Bool maxiters::Int @@ -122,27 +116,27 @@ end abstol prob DᵀD - JᵀJ::jType - λ::λType - λ_factor::λType - damping_increase_factor::λType - damping_decrease_factor::λType - h::λType - α_geodesic::λType - b_uphill::λType - min_damping_D::λType - v::uType - a::uType - tmp_vec::uType - v_old::uType - norm_v_old::lossType - δ::uType - loss_old::lossType + JᵀJ + λ + λ_factor + damping_increase_factor + damping_decrease_factor + h + α_geodesic + b_uphill + min_damping_D + v + a + tmp_vec + v_old + norm_v_old + δ + loss_old make_new_J::Bool fu_tmp u_tmp Jv - mat_tmp::jType + mat_tmp stats::NLStats end @@ -153,8 +147,8 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) fu1 = evaluate_f(prob, u) - uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); - linsolve_kwargs) + uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, Val(iip); + linsolve_kwargs, linsolve_with_JᵀJ=Val(true)) λ = convert(eltype(u), alg.damping_initial) λ_factor = convert(eltype(u), alg.damping_increase_factor) @@ -174,12 +168,10 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, end loss = internalnorm(fu1) - JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2)) - v = zero(u) - a = zero(u) - tmp_vec = zero(u) - v_old = zero(u) - δ = zero(u) + a = _mutable_zero(u) + tmp_vec = _mutable_zero(u) + v_old = _mutable_zero(u) + δ = _mutable_zero(u) make_new_J = true fu_tmp = zero(fu1) mat_tmp = zero(JᵀJ) @@ -223,7 +215,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true}) # The following lines do: cache.a = -J \ cache.fu_tmp mul!(cache.Jv, J, v) @. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv) - linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp), + mul!(cache.u_tmp, J', cache.fu_tmp) + linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol) cache.linsolve = linres.cache @. cache.a = -cache.du @@ -279,15 +272,30 @@ function perform_step!(cache::LevenbergMarquardtCache{false}) cache.make_new_J = false cache.stats.njacs += 1 end - @unpack u, p, λ, JᵀJ, DᵀD, J = cache + @unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache cache.mat_tmp = JᵀJ + λ * DᵀD # Usual Levenberg-Marquardt step ("velocity"). - cache.v = -cache.mat_tmp \ (J' * fu1) + if linsolve === nothing + cache.v = -cache.mat_tmp \ (J' * fu1) + else + linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp, b = _vec(J' * fu1), + linu = _vec(cache.v), p, reltol = cache.abstol) + cache.linsolve = linres.cache + end @unpack v, h, α_geodesic = cache # Geodesic acceleration (step_size = v + a / 2). - cache.a = -cache.mat_tmp \ ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)) + if linsolve === nothing + cache.a = -cache.mat_tmp \ + _vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v))) + else + linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp, + b = _mutable(_vec(J' * + ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))), + linu = _vec(cache.a), p, reltol = cache.abstol) + cache.linsolve = linres.cache + end cache.stats.nsolve += 1 cache.stats.nfactors += 1 diff --git a/src/utils.jl b/src/utils.jl index bc3cb9819..9aa2e71bc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -138,6 +138,7 @@ _mutable_zero(x::SArray) = MArray(x) _mutable(x) = x _mutable(x::SArray) = MArray(x) + _maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x) # The shadow allocated for Enzyme needs to be mutable _maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x) diff --git a/test/23_test_problems.jl b/test/23_test_problems.jl index f13182119..20fe388ac 100644 --- a/test/23_test_problems.jl +++ b/test/23_test_problems.jl @@ -1,16 +1,16 @@ -using NonlinearSolve, LinearAlgebra, NonlinearProblemLibrary, Test +using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test problems = NonlinearProblemLibrary.problems dicts = NonlinearProblemLibrary.dicts -function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-5) +function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-4) for (idx, (problem, dict)) in enumerate(zip(problems, dicts)) x = dict["start"] res = similar(x) nlprob = NonlinearProblem(problem, x) @testset "$(dict["title"])" begin for alg in alg_ops - sol = solve(nlprob, alg, abstol = 1e-15, reltol = 1e-15) + sol = solve(nlprob, alg, abstol = 1e-18, reltol = 1e-18) problem(res, sol.u, nothing) broken = idx in broken_tests[alg] ? true : false @test norm(res)≤ϵ broken=broken @@ -43,7 +43,7 @@ end broken_tests[alg_ops[1]] = [6, 11, 21] broken_tests[alg_ops[2]] = [6, 11, 21] broken_tests[alg_ops[3]] = [1, 6, 11, 12, 15, 16, 21] - broken_tests[alg_ops[4]] = [1, 6, 8, 11, 15, 16, 21, 22] + broken_tests[alg_ops[4]] = [1, 6, 8, 11, 16, 21, 22] broken_tests[alg_ops[5]] = [6, 21] broken_tests[alg_ops[6]] = [6, 21] @@ -51,11 +51,12 @@ end end @testset "TrustRegion test problem library" begin - alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.5)) + alg_ops = (LevenbergMarquardt(; linsolve=NormalCholeskyFactorization()), + LevenbergMarquardt(; α_geodesic = 0.1, linsolve=NormalCholeskyFactorization())) # dictionary with indices of test problems where method does not converge to small residual broken_tests = Dict(alg => Int[] for alg in alg_ops) - broken_tests[alg_ops[1]] = [3, 6, 11, 17, 21] + broken_tests[alg_ops[1]] = [3, 6, 11, 21] broken_tests[alg_ops[2]] = [3, 6, 11, 21] test_on_library(problems, dicts, alg_ops, broken_tests) diff --git a/test/basictests.jl b/test/basictests.jl index 819f97b98..acfeff6c8 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -37,7 +37,7 @@ end ad in (AutoFiniteDiff(), AutoZygote()) linesearch = LineSearch(; method = lsmethod, autodiff = ad) - u0s = VERSION ≥ v"1.9" ? ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) : ([1.0, 1.0], 1.0) + u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) @testset "[OOP] u0: $(typeof(u0))" for u0 in u0s sol = benchmark_nlsolve_oop(quadratic_f, u0; linesearch) @@ -49,7 +49,6 @@ end @test (@ballocated solve!($cache)) < 200 end - precs = [ NonlinearSolve.DEFAULT_PRECS, (args...) -> (Diagonal(rand!(similar(u0))), nothing), @@ -72,17 +71,15 @@ end end end - if VERSION ≥ v"1.9" - @testset "[OOP] [Immutable AD]" begin - for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) - res_true = sqrt(p) - all(res.u .≈ res_true) - end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) + @testset "[OOP] [Immutable AD]" begin + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) + res_true = sqrt(p) + all(res.u .≈ res_true) end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) end end @@ -99,14 +96,11 @@ end end end - if VERSION ≥ v"1.9" - t = (p) -> [sqrt(p[2] / p[1])] - p = [0.9, 50.0] - @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) - @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], - p) ≈ - ForwardDiff.jacobian(t, p) - end + t = (p) -> [sqrt(p[2] / p[1])] + p = [0.9, 50.0] + @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) + @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], + p) ≈ ForwardDiff.jacobian(t, p) # Iterator interface function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip} @@ -148,7 +142,7 @@ end radius_update_schemes = [RadiusUpdateSchemes.Simple, RadiusUpdateSchemes.NocedalWright, RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan, RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin] - u0s = VERSION ≥ v"1.9" ? ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) : ([1.0, 1.0], 1.0) + u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) @testset "[OOP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme)" for u0 in u0s, radius_update_scheme in radius_update_schemes @@ -173,18 +167,16 @@ end @test (@ballocated solve!($cache)) ≤ 64 end - if VERSION ≥ v"1.9" - @testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes - for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p; - radius_update_scheme) - res_true = sqrt(p) - all(res.u .≈ res_true) - end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) ≈ 1 / (2 * sqrt(p)) + @testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p; + radius_update_scheme) + res_true = sqrt(p) + all(res.u .≈ res_true) end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) ≈ 1 / (2 * sqrt(p)) end end @@ -202,17 +194,15 @@ end end end - if VERSION ≥ v"1.9" - t = (p) -> [sqrt(p[2] / p[1])] - p = [0.9, 50.0] - @testset "[OOP] [Jacobian] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes - @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p; radius_update_scheme).u ≈ - sqrt(p[2] / p[1]) - @test ForwardDiff.jacobian(p -> [ - benchmark_nlsolve_oop(quadratic_f2, 0.5, p; - radius_update_scheme).u, - ], p) ≈ ForwardDiff.jacobian(t, p) - end + t = (p) -> [sqrt(p[2] / p[1])] + p = [0.9, 50.0] + @testset "[OOP] [Jacobian] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes + @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p; radius_update_scheme).u ≈ + sqrt(p[2] / p[1]) + @test ForwardDiff.jacobian(p -> [ + benchmark_nlsolve_oop(quadratic_f2, 0.5, p; + radius_update_scheme).u, + ], p) ≈ ForwardDiff.jacobian(t, p) end # Iterator interface @@ -307,7 +297,7 @@ end return solve(prob, LevenbergMarquardt(), abstol = 1e-9) end - u0s = VERSION ≥ v"1.9" ? ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) : ([1.0, 1.0], 1.0) + u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) @testset "[OOP] u0: $(typeof(u0))" for u0 in u0s sol = benchmark_nlsolve_oop(quadratic_f, u0) @test SciMLBase.successful_retcode(sol) @@ -328,17 +318,15 @@ end @test (@ballocated solve!($cache)) ≤ 64 end - if VERSION ≥ v"1.9" - @testset "[OOP] [Immutable AD]" begin - for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) - res_true = sqrt(p) - all(res.u .≈ res_true) - end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) + @testset "[OOP] [Immutable AD]" begin + for p in 1.0:0.1:100.0 + @test begin + res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) + res_true = sqrt(p) + all(res.u .≈ res_true) end + @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, + @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) end end @@ -355,14 +343,11 @@ end end end - if VERSION ≥ v"1.9" - t = (p) -> [sqrt(p[2] / p[1])] - p = [0.9, 50.0] - @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) - @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], - p) ≈ - ForwardDiff.jacobian(t, p) - end + t = (p) -> [sqrt(p[2] / p[1])] + p = [0.9, 50.0] + @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) + @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], + p) ≈ ForwardDiff.jacobian(t, p) @testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true, AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(), diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index 5ca621313..27775bc40 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -25,15 +25,18 @@ prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x) prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x) -sol = solve(prob_oop, GaussNewton(); maxiters = 1000, abstol = 1e-8) +sol = solve(prob_oop, GaussNewton(; linsolve = NormalCholeskyFactorization()); + maxiters = 1000, abstol = 1e-8) @test SciMLBase.successful_retcode(sol) @test norm(sol.resid) < 1e-6 -sol = solve(prob_iip, GaussNewton(); maxiters = 1000, abstol = 1e-8) +sol = solve(prob_iip, GaussNewton(; linsolve = NormalCholeskyFactorization()); + maxiters = 1000, abstol = 1e-8) @test SciMLBase.successful_retcode(sol) @test norm(sol.resid) < 1e-6 -sol = solve(prob_oop, LevenbergMarquardt(); maxiters = 1000, abstol = 1e-8) +sol = solve(prob_oop, LevenbergMarquardt(; linsolve = NormalCholeskyFactorization()); + maxiters = 1000, abstol = 1e-8) @test SciMLBase.successful_retcode(sol) @test norm(sol.resid) < 1e-6