Skip to content

Commit

Permalink
Use needs_square_A from LinearSolve
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 25, 2023
1 parent 1f93c36 commit 4f3c047
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 52 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearSolve = "2"
LinearSolve = "2.12"
NonlinearProblemLibrary = "0.1"
PrecompileTools = "1"
RecursiveArrayTools = "2"
Expand Down
6 changes: 1 addition & 5 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob

if !needs_square_A(alg.linsolve) && !(u0 isa Number) && !(u0 isa StaticArray)
linsolve_with_JᵀJ = Val(false)
else
linsolve_with_JᵀJ = Val(true)
end
linsolve_with_JᵀJ = Val(_needs_square_A(alg, u0))

Check warning on line 85 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L85

Added line #L85 was not covered by tests

u = alias_u0 ? u0 : deepcopy(u0)
if iip
Expand Down
24 changes: 10 additions & 14 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end

@concrete mutable struct LevenbergMarquardtCache{iip, fastqr} <:
@concrete mutable struct LevenbergMarquardtCache{iip, fastls} <:
AbstractNonlinearSolveCache{iip}
f
alg
Expand Down Expand Up @@ -164,11 +164,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
u = alias_u0 ? u0 : deepcopy(u0)
fu1 = evaluate_f(prob, u)

if !needs_square_A(alg.linsolve) && !(u isa Number) && !(u isa StaticArray)
linsolve_with_JᵀJ = Val(false)
else
linsolve_with_JᵀJ = Val(true)
end
linsolve_with_JᵀJ = Val(_needs_square_A(alg, u0))

if _unwrap_val(linsolve_with_JᵀJ)
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p,
Expand Down Expand Up @@ -227,7 +223,7 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
zero(u), zero(fu1), mat_tmp, rhs_tmp, J², NLStats(1, 0, 0, 0, 0))
end

function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fastqr}
function perform_step!(cache::LevenbergMarquardtCache{true, fastls}) where {fastls}
@unpack fu1, f, make_new_J = cache
if iszero(fu1)
cache.force_stop = true
Expand All @@ -236,7 +232,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast

if make_new_J
jacobian!!(cache.J, cache)
if fastqr
if fastls
cache.J² .= cache.J .^ 2
sum!(cache.JᵀJ', cache.J²)
cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)
Expand All @@ -251,7 +247,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast

# Usual Levenberg-Marquardt step ("velocity").
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
if fastqr
if fastls
cache.mat_tmp[1:length(fu1), :] .= cache.J
cache.mat_tmp[(length(fu1) + 1):end, :] .= λ .* cache.DᵀD
cache.rhs_tmp[1:length(fu1)] .= _vec(fu1)
Expand All @@ -276,7 +272,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
# NOTE: Don't pass `A` in again, since we want to reuse the previous solve
mul!(_vec(cache.Jv), J, _vec(v))
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
if fastqr
if fastls
cache.rhs_tmp[1:length(fu1)] .= _vec(cache.fu_tmp)
linres = dolinsolve(alg.precs, linsolve; b = cache.rhs_tmp, linu = _vec(cache.du),
p = p, reltol = cache.abstol)
Expand Down Expand Up @@ -321,7 +317,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true, fastqr}) where {fast
return nothing
end

function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fastqr}
function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fastls}
@unpack fu1, f, make_new_J = cache
if iszero(fu1)
cache.force_stop = true
Expand All @@ -330,7 +326,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas

if make_new_J
cache.J = jacobian!!(cache.J, cache)
if fastqr
if fastls
cache.JᵀJ = _vec(sum(cache.J .^ 2; dims = 1))
cache.DᵀD.diag .= max.(cache.DᵀD.diag, cache.JᵀJ)
else
Expand All @@ -347,7 +343,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache

# Usual Levenberg-Marquardt step ("velocity").
if fastqr
if fastls
cache.mat_tmp = vcat(J, λ .* cache.DᵀD)
cache.rhs_tmp[1:length(fu1)] .= -_vec(fu1)
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp,
Expand All @@ -367,7 +363,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastqr}) where {fas
# Geodesic acceleration (step_size = v + a / 2).
rhs_term = _vec(((2 / h) .* ((_vec(f(u .+ h .* _restructure(u, v), p)) .-
_vec(fu1)) ./ h .- J * _vec(v))))
if fastqr
if fastls
cache.rhs_tmp[1:length(fu1)] .= -_vec(rhs_term)
linres = dolinsolve(alg.precs, linsolve;
b = cache.rhs_tmp, linu = _vec(cache.a), p = p, reltol = cache.abstol)
Expand Down
30 changes: 3 additions & 27 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,30 +265,6 @@ _reshape(x::Number, args...) = x
return :(@. y += α * x)

Check warning on line 265 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L265

Added line #L265 was not covered by tests
end

# Needs Square Matrix
# FIXME: Remove once https://github.com/SciML/LinearSolve.jl/pull/400 is merged and tagged
"""
needs_square_A(alg)
Returns `true` if the algorithm requires a square matrix.
"""
needs_square_A(::Nothing) = false
function needs_square_A(alg)
try
A = [1.0 2.0;
3.0 4.0;
5.0 6.0]
b = ones(Float64, 3)
solve(LinearProblem(A, b), alg)
return false
catch err
return true
end
end
for alg in (:QRFactorization, :FastQRFactorization, NormalCholeskyFactorization,
NormalBunchKaufmanFactorization)
@eval needs_square_A(::$(alg)) = false
end
for kralg in (LinearSolve.Krylov.lsmr!, LinearSolve.Krylov.craigmr!)
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
end
_needs_square_A(_, ::Number) = true
_needs_square_A(_, ::StaticArray) = true
_needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)
6 changes: 2 additions & 4 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@ nlls_problems = [prob_oop, prob_iip]
solvers = [
GaussNewton(),
GaussNewton(; linsolve = LUFactorization()),
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
]

# Compile time on v"1.9" is too high!
VERSION v"1.10-" && append!(solvers,
[LevenbergMarquardt(), LevenbergMarquardt(; linsolve = LUFactorization())])

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ end
@time @safetestset "Sparsity Tests" include("sparse.jl")
@time @safetestset "Polyalgs" include("polyalgs.jl")
@time @safetestset "Matrix Resizing" include("matrix_resizing.jl")
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
if VERSION v"1.10-"
# Takes too long to compile on older versions
@time @safetestset "Nonlinear Least Squares" include("nonlinear_least_squares.jl")
end
end

if GROUP == "All" || GROUP == "23TestProblems"
Expand Down

0 comments on commit 4f3c047

Please sign in to comment.