Skip to content

Commit

Permalink
Add a function to check if square A is needed
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2023
1 parent b4ea93f commit 73315e8
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 18 deletions.
6 changes: 2 additions & 4 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,8 @@ import PrecompileTools
for T in (Float32, Float64)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

# precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
# PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)
# DON'T MERGE
precompile_algs = ()
precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
PseudoTransient(), GeneralBroyden(), GeneralKlement(), nothing)

for alg in precompile_algs
solve(prob, alg, abstol = T(1e-2))
Expand Down
4 changes: 1 addition & 3 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob

# Use QR if the user did not specify a linear solver
if alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
alg.linsolve isa FastQRFactorization
if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray)
linsolve_with_JᵀJ = Val(false)

Check warning on line 86 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
else
linsolve_with_JᵀJ = Val(true)

Check warning on line 88 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L88

Added line #L88 was not covered by tests
Expand Down
4 changes: 1 addition & 3 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
u = alias_u0 ? u0 : deepcopy(u0)
fu1 = evaluate_f(prob, u)

# Use QR if the user did not specify a linear solver
if (alg.linsolve === nothing || alg.linsolve isa QRFactorization ||
alg.linsolve isa FastQRFactorization) && !(u isa Number)
if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray)
linsolve_with_JᵀJ = Val(false)

Check warning on line 168 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L167-L168

Added lines #L167 - L168 were not covered by tests
else
linsolve_with_JᵀJ = Val(true)

Check warning on line 170 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L170

Added line #L170 was not covered by tests
Expand Down
27 changes: 27 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,30 @@ function _try_factorize_and_check_singular!(linsolve, X)
return _issingular(X), false

Check warning on line 256 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L256

Added line #L256 was not covered by tests
end
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false

Check warning on line 258 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L258

Added line #L258 was not covered by tests

# Needs Square Matrix
"""
needs_square_A(alg)
Returns `true` if the algorithm requires a square matrix.
"""
needs_square_A(::Nothing) = false
function needs_square_A(alg)
try
A = [1.0 2.0;

Check warning on line 269 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L266-L269

Added lines #L266 - L269 were not covered by tests
3.0 4.0;
5.0 6.0]
b = ones(Float64, 3)
solve(LinearProblem(A, b), alg)
return false

Check warning on line 274 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L272-L274

Added lines #L272 - L274 were not covered by tests
catch err
return true

Check warning on line 276 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L276

Added line #L276 was not covered by tests
end
end
for alg in (:QRFactorization, :FastQRFactorization, NormalCholeskyFactorization,
NormalBunchKaufmanFactorization)
@eval needs_square_A(::$(alg)) = false

Check warning on line 281 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L281

Added line #L281 was not covered by tests
end
for kralg in (LinearSolve.Krylov.lsmr!, LinearSolve.Krylov.craigmr!)
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false

Check warning on line 284 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L284

Added line #L284 was not covered by tests
end
9 changes: 5 additions & 4 deletions test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ end
end

@testset "LevenbergMarquardt 23 Test Problems" begin
alg_ops = (LevenbergMarquardt(; linsolve = NormalCholeskyFactorization()),
LevenbergMarquardt(; α_geodesic = 0.1, linsolve = NormalCholeskyFactorization()))
alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.1),
LevenbergMarquardt(; linsolve = CholeskyFactorization()))

# dictionary with indices of test problems where method does not converge to small residual
broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [3, 6, 11, 21]
broken_tests[alg_ops[2]] = [3, 6, 11, 21]
broken_tests[alg_ops[1]] = [3, 6, 17, 21]
broken_tests[alg_ops[2]] = [3, 6, 17, 21]
broken_tests[alg_ops[3]] = [6, 11, 21]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
Expand Down
9 changes: 5 additions & 4 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ end
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(probN, LevenbergMarquardt(; autodiff)).u .≈ sqrt(2.0))
@test all(solve(probN, LevenbergMarquardt(; autodiff); abstol = 1e-9,
reltol = 1e-9).u .≈ sqrt(2.0))
end

# Test that `LevenbergMarquardt` passes a test that `NewtonRaphson` fails on.
Expand All @@ -368,7 +369,7 @@ end
@testset "Keyword Arguments" begin
damping_initial = [0.5, 2.0, 5.0]
damping_increase_factor = [1.5, 3.0, 10.0]
damping_decrease_factor = Float64[2, 5, 10]
damping_decrease_factor = Float64[2, 5, 12]
finite_diff_step_geodesic = [0.02, 0.2, 0.3]
α_geodesic = [0.6, 0.8, 0.9]
b_uphill = Float64[0, 1, 2]
Expand All @@ -379,14 +380,14 @@ end
min_damping_D)
for options in list_of_options
local probN, sol, alg
alg = LevenbergMarquardt(damping_initial = options[1],
alg = LevenbergMarquardt(; damping_initial = options[1],
damping_increase_factor = options[2],
damping_decrease_factor = options[3],
finite_diff_step_geodesic = options[4], α_geodesic = options[5],
b_uphill = options[6], min_damping_D = options[7])

probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
sol = solve(probN, alg, abstol = 1e-10)
sol = solve(probN, alg, abstol = 1e-12)
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
end
end
Expand Down

0 comments on commit 73315e8

Please sign in to comment.