Skip to content

Commit

Permalink
Use symmetric linear solve if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 15, 2023
1 parent 8f68ef1 commit f151a0a
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 21 deletions.
16 changes: 8 additions & 8 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
adkwargs...)
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
An advanced GaussNewton implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
Expand Down Expand Up @@ -41,8 +41,8 @@ for large-scale and numerically-difficult nonlinear least squares problems.
precs
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
adkwargs...)
function GaussNewton(; concrete_jac = nothing, linsolve = CholeskyFactorization(),
precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
end
Expand Down Expand Up @@ -97,8 +97,8 @@ function perform_step!(cache::GaussNewtonCache{true})
__matmul!(Jᵀf, J', fu1)

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = JᵀJ, b = _vec(Jᵀf), linu = _vec(du),
p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(JᵀJ), b = _vec(Jᵀf),
linu = _vec(du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du
f(cache.fu_new, u, p)
Expand All @@ -125,8 +125,8 @@ function perform_step!(cache::GaussNewtonCache{false})
if linsolve === nothing
cache.du = fu1 / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = cache.JᵀJ, b = _vec(cache.Jᵀf),
linu = _vec(cache.du), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.JᵀJ),
b = _vec(cache.Jᵀf), linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
Expand Down
14 changes: 10 additions & 4 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
Jᵀfu = J' * fu
end

linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
u0 = _vec(du))
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))

weight = similar(u)
recursivefill!(weight, true)

Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
nothing, nothing, nothing, nothing, nothing)..., weight)
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)

Expand All @@ -119,6 +119,12 @@ __init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = J' * J
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x

Check warning on line 123 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L123

Added line #L123 was not covered by tests
# LinearSolve with `nothing` doesn't dispatch correctly here
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x

Check warning on line 126 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L126

Added line #L126 was not covered by tests

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
Expand Down
14 changes: 7 additions & 7 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ numerically-difficult nonlinear systems.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
`AutoForwardDiff()`. Valid choices are types from ADTypes.jl.
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is used,
then the Jacobian will not be constructed and instead direct Jacobian-vector products
`J*v` are computed using forward-mode automatic differentiation or finite differencing
Expand Down Expand Up @@ -203,8 +203,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
mul!(cache.u_tmp, J', fu1)
@. cache.mat_tmp = JᵀJ + λ * DᵀD
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _vec(cache.u_tmp), linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.v = -cache.du

Expand Down Expand Up @@ -280,8 +280,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)
else
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp, b = _vec(J' * fu1),
linu = _vec(cache.v), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
b = _vec(J' * fu1), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end

Expand All @@ -291,7 +291,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.a = -cache.mat_tmp \
_vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))
else
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp,
linres = dolinsolve(alg.precs, linsolve;
b = _mutable(_vec(J' *
((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
linu = _vec(cache.a), p, reltol = cache.abstol)
Expand Down
9 changes: 7 additions & 2 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]
solvers = [GaussNewton(), LevenbergMarquardt(), LSOptimSolver(:lm), LSOptimSolver(:dogleg)]
solvers = [
GaussNewton(),
LevenbergMarquardt(),
LSOptimSolver(:lm),
LSOptimSolver(:dogleg),
]

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 1000, abstol = 1e-8)
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test SciMLBase.successful_retcode(sol)
@test norm(sol.resid) < 1e-6
end

0 comments on commit f151a0a

Please sign in to comment.