From 73315e84f1066a9fd560c4fa568ffc5e0146eae7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 23 Oct 2023 16:53:23 -0400 Subject: [PATCH] Add a function to check if square A is needed --- src/NonlinearSolve.jl | 6 ++---- src/gaussnewton.jl | 4 +--- src/levenberg.jl | 4 +--- src/utils.jl | 27 +++++++++++++++++++++++++++ test/23_test_problems.jl | 9 +++++---- test/basictests.jl | 9 +++++---- 6 files changed, 41 insertions(+), 18 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 6e3a6f804..2b26b3721 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -88,10 +88,8 @@ import PrecompileTools for T in (Float32, Float64) prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2)) - # precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), - # PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing) - # DON'T MERGE - precompile_algs = () + precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), + PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing) for alg in precompile_algs solve(prob, alg, abstol = T(1e-2)) diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 307119af3..bce757e3a 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -82,9 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob - # Use QR if the user did not specify a linear solver - if alg.linsolve === nothing || alg.linsolve isa QRFactorization || - alg.linsolve isa FastQRFactorization + if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray) linsolve_with_JᵀJ = Val(false) else linsolve_with_JᵀJ = Val(true) diff --git a/src/levenberg.jl b/src/levenberg.jl index 3480e6c63..047ca16c2 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -164,9 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, u = alias_u0 ? u0 : deepcopy(u0) fu1 = evaluate_f(prob, u) - # Use QR if the user did not specify a linear solver - if (alg.linsolve === nothing || alg.linsolve isa QRFactorization || - alg.linsolve isa FastQRFactorization) && !(u isa Number) + if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray) linsolve_with_JᵀJ = Val(false) else linsolve_with_JᵀJ = Val(true) diff --git a/src/utils.jl b/src/utils.jl index 688322329..b594a83c9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -256,3 +256,30 @@ function _try_factorize_and_check_singular!(linsolve, X) return _issingular(X), false end _try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false + +# Needs Square Matrix +""" + needs_square_A(alg) + +Returns `true` if the algorithm requires a square matrix. +""" +needs_square_A(::Nothing) = false +function needs_square_A(alg) + try + A = [1.0 2.0; + 3.0 4.0; + 5.0 6.0] + b = ones(Float64, 3) + solve(LinearProblem(A, b), alg) + return false + catch err + return true + end +end +for alg in (:QRFactorization, :FastQRFactorization, NormalCholeskyFactorization, + NormalBunchKaufmanFactorization) + @eval needs_square_A(::$(alg)) = false +end +for kralg in (LinearSolve.Krylov.lsmr!, LinearSolve.Krylov.craigmr!) + @eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false +end diff --git a/test/23_test_problems.jl b/test/23_test_problems.jl index 2bdbecffb..d0ab52d7f 100644 --- a/test/23_test_problems.jl +++ b/test/23_test_problems.jl @@ -59,13 +59,14 @@ end end @testset "LevenbergMarquardt 23 Test Problems" begin - alg_ops = (LevenbergMarquardt(; linsolve = NormalCholeskyFactorization()), - LevenbergMarquardt(; α_geodesic = 0.1, linsolve = NormalCholeskyFactorization())) + alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.1), + LevenbergMarquardt(; linsolve = CholeskyFactorization())) # dictionary with indices of test problems where method does not converge to small residual broken_tests = Dict(alg => Int[] for alg in alg_ops) - broken_tests[alg_ops[1]] = [3, 6, 11, 21] - broken_tests[alg_ops[2]] = [3, 6, 11, 21] + broken_tests[alg_ops[1]] = [3, 6, 17, 21] + broken_tests[alg_ops[2]] = [3, 6, 17, 21] + broken_tests[alg_ops[3]] = [6, 11, 21] test_on_library(problems, dicts, alg_ops, broken_tests) end diff --git a/test/basictests.jl b/test/basictests.jl index 681d506de..1606270fc 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -352,7 +352,8 @@ end AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(), AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0) - @test all(solve(probN, LevenbergMarquardt(; autodiff)).u .≈ sqrt(2.0)) + @test all(solve(probN, LevenbergMarquardt(; autodiff); abstol = 1e-9, + reltol = 1e-9).u .≈ sqrt(2.0)) end # Test that `LevenbergMarquardt` passes a test that `NewtonRaphson` fails on. @@ -368,7 +369,7 @@ end @testset "Keyword Arguments" begin damping_initial = [0.5, 2.0, 5.0] damping_increase_factor = [1.5, 3.0, 10.0] - damping_decrease_factor = Float64[2, 5, 10] + damping_decrease_factor = Float64[2, 5, 12] finite_diff_step_geodesic = [0.02, 0.2, 0.3] α_geodesic = [0.6, 0.8, 0.9] b_uphill = Float64[0, 1, 2] @@ -379,14 +380,14 @@ end min_damping_D) for options in list_of_options local probN, sol, alg - alg = LevenbergMarquardt(damping_initial = options[1], + alg = LevenbergMarquardt(; damping_initial = options[1], damping_increase_factor = options[2], damping_decrease_factor = options[3], finite_diff_step_geodesic = options[4], α_geodesic = options[5], b_uphill = options[6], min_damping_D = options[7]) probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0) - sol = solve(probN, alg, abstol = 1e-10) + sol = solve(probN, alg, abstol = 1e-12) @test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10) end end