From 72fd69170a591fd95f9cf2dd46ee0e52bc8152d9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 25 Oct 2023 09:36:24 -0400 Subject: [PATCH] Use needs_square_A from LinearSolve --- Project.toml | 2 +- src/gaussnewton.jl | 6 +----- src/levenberg.jl | 24 ++++++++++-------------- src/utils.jl | 30 +++--------------------------- test/nonlinear_least_squares.jl | 6 ++---- test/runtests.jl | 5 ++++- 6 files changed, 21 insertions(+), 52 deletions(-) diff --git a/Project.toml b/Project.toml index 89d584fb5..57698c7bf 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ FiniteDiff = "2" ForwardDiff = "0.10.3" LeastSquaresOptim = "0.8" LineSearches = "7" -LinearSolve = "2" +LinearSolve = "2.12" NonlinearProblemLibrary = "0.1" PrecompileTools = "1" RecursiveArrayTools = "2" diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 21adaeccd..a6ec1ae9b 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -82,11 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob - if !needs_square_A(alg.linsolve) && !(u0 isa Number) && !(u0 isa StaticArray) - linsolve_with_JᵀJ = Val(false) - else - linsolve_with_JᵀJ = Val(true) - end + linsolve_with_JᵀJ = Val(_needs_square_A(alg, u0)) u = alias_u0 ? u0 : deepcopy(u0) if iip diff --git a/src/levenberg.jl b/src/levenberg.jl index 16867033a..cd0763fe1 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -109,7 +109,7 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing, finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D) end -@concrete mutable struct LevenbergMarquardtCache{iip, fastqr} <: +@concrete mutable struct LevenbergMarquardtCache{iip, fastls} <: AbstractNonlinearSolveCache{iip} f alg @@ -164,11 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, u = alias_u0 ? u0 : deepcopy(u0) fu1 = evaluate_f(prob, u) - if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray) - linsolve_with_JᵀJ = Val(false) - else - linsolve_with_JᵀJ = Val(true) - end + linsolve_with_JᵀJ = Val(_needs_square_A(alg, u0)) if _unwrap_val(linsolve_with_JᵀJ) uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, @@ -227,7 +223,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, zero(u), zero(fu1), mat_tmp, rhs_tmp, J², NLStats(1, 0, 0, 0, 0)) end -function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fastqr} +function perform_step!(cache::LevenbergMarquardtCache{true, fastls}) where {fastls} @unpack fu1, f, make_new_J = cache if iszero(fu1) cache.force_stop = true @@ -236,7 +232,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast if make_new_J jacobian!!(cache.J, cache) - if fastqr + if fastls cache.J² .= cache.J .^ 2 sum!(cache.JᵀJ', cache.J²) cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ) @@ -251,7 +247,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast # Usual Levenberg-Marquardt step ("velocity"). # The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp - if fastqr + if fastls cache.mat_tmp[1:length(fu1), :] .= cache.J cache.mat_tmp[(length(fu1) + 1):end, :] .= λ .* cache.DᵀD cache.rhs_tmp[1:length(fu1)] .= _vec(fu1) @@ -276,7 +272,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast # NOTE: Don't pass `A` in again, since we want to reuse the previous solve mul!(_vec(cache.Jv), J, _vec(v)) @. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv) - if fastqr + if fastls cache.rhs_tmp[1:length(fu1)] .= _vec(cache.fu_tmp) linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.du), p = p, reltol = cache.abstol) @@ -321,7 +317,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast return nothing end -function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fastqr} +function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fastls} @unpack fu1, f, make_new_J = cache if iszero(fu1) cache.force_stop = true @@ -330,7 +326,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas if make_new_J cache.J = jacobian!!(cache.J, cache) - if fastqr + if fastls cache.JᵀJ = _vec(sum(cache.J .^ 2; dims = 1)) cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ) else @@ -347,7 +343,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas @unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache # Usual Levenberg-Marquardt step ("velocity"). - if fastqr + if fastls cache.mat_tmp = vcat(J, λ .* cache.DᵀD) cache.rhs_tmp[1:length(fu1)] .= -_vec(fu1) linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, @@ -367,7 +363,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas # Geodesic acceleration (step_size = v + a / 2). rhs_term = _vec(((2 / h) .* ((_vec(f(u .+ h .* _restructure(u, v), p)) .- _vec(fu1)) ./ h .- J * _vec(v)))) - if fastqr + if fastls cache.rhs_tmp[1:length(fu1)] .= -_vec(rhs_term) linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.a), p = p, reltol = cache.abstol) diff --git a/src/utils.jl b/src/utils.jl index afcb68922..6398b058b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -265,30 +265,6 @@ _reshape(x::Number, args...) = x return :(@. y += α * x) end -# Needs Square Matrix -# FIXME: Remove once https://github.com/SciML/LinearSolve.jl/pull/400 is merged and tagged -""" - needs_square_A(alg) - -Returns `true` if the algorithm requires a square matrix. -""" -needs_square_A(::Nothing) = false -function needs_square_A(alg) - try - A = [1.0 2.0; - 3.0 4.0; - 5.0 6.0] - b = ones(Float64, 3) - solve(LinearProblem(A, b), alg) - return false - catch err - return true - end -end -for alg in (:QRFactorization, :FastQRFactorization, NormalCholeskyFactorization, - NormalBunchKaufmanFactorization) - @eval needs_square_A(::$(alg)) = false -end -for kralg in (LinearSolve.Krylov.lsmr!, LinearSolve.Krylov.craigmr!) - @eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false -end +_needs_square_A(_, ::Number) = true +_needs_square_A(_, ::StaticArray) = true +_needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve) diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index eb7d6966e..c7a02dc58 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -30,14 +30,12 @@ nlls_problems = [prob_oop, prob_iip] solvers = [ GaussNewton(), GaussNewton(; linsolve = LUFactorization()), + LevenbergMarquardt(), + LevenbergMarquardt(; linsolve = LUFactorization()), LeastSquaresOptimJL(:lm), LeastSquaresOptimJL(:dogleg), ] -# Compile time on v"1.9" is too high! -VERSION ≥ v"1.10-" && append!(solvers, - [LevenbergMarquardt(), LevenbergMarquardt(; linsolve = LUFactorization())]) - for prob in nlls_problems, solver in solvers @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8) @test SciMLBase.successful_retcode(sol) diff --git a/test/runtests.jl b/test/runtests.jl index d4f817d0a..248de16b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,7 +17,10 @@ end @time @safetestset "Sparsity Tests" include("sparse.jl") @time @safetestset "Polyalgs" include("polyalgs.jl") @time @safetestset "Matrix Resizing" include("matrix_resizing.jl") - @time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl") + if VERSION ≥ v"1.10-" + # Takes too long to compile on older versions + @time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl") + end end if GROUP == "All" || GROUP == "23TestProblems"