Skip to content

Commit

Permalink
Make LM and GN oop versions work with linearSolve.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 8, 2023
1 parent 57238ac commit af3e026
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 115 deletions.
7 changes: 3 additions & 4 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg::G
else
fu1 = f(u, p)

Check warning on line 83 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L83

Added line #L83 was not covered by tests
end
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip))

JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
Jᵀf = zero(u)
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip);

Check warning on line 85 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L85

Added line #L85 was not covered by tests
linsolve_with_JᵀJ = Val(true))

return GaussNewtonCache{iip}(f, alg, u, fu1, fu2, zero(fu1), du, p, uf, linsolve, J,

Check warning on line 88 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L88

Added line #L88 was not covered by tests
JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol,
Expand Down Expand Up @@ -120,6 +118,7 @@ function perform_step!(cache::GaussNewtonCache{false})
@unpack u, fu1, f, p, alg, linsolve = cache

Check warning on line 118 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L117-L118

Added lines #L117 - L118 were not covered by tests

cache.J = jacobian!!(cache.J, cache)

Check warning on line 120 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L120

Added line #L120 was not covered by tests

cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1

Check warning on line 123 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L122-L123

Added lines #L122 - L123 were not covered by tests
# u = u - J \ fu
Expand Down
18 changes: 16 additions & 2 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))

# Build Jacobian Caches
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip};
linsolve_kwargs = (;)) where {iip}
linsolve_kwargs = (;),
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ}
uf = JacobianWrapper{iip}(f, p)

haslinsolve = hasfield(typeof(alg), :linsolve)
Expand Down Expand Up @@ -85,7 +86,15 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
end

du = _mutable_zero(u)
linprob = LinearProblem(J, _vec(fu); u0 = _vec(du))

if needsJᵀJ
JᵀJ = __init_JᵀJ(J)
# FIXME: This needs to be handled better for JacVec Operator
Jᵀfu = J' * fu
end

linprob = LinearProblem(needsJᵀJ ? JᵀJ : J, needsJᵀJ ? _vec(Jᵀfu) : _vec(fu);
u0 = _vec(du))

weight = similar(u)
recursivefill!(weight, true)
Expand All @@ -95,6 +104,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)

needsJᵀJ && return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu
return uf, linsolve, J, fu, jac_cache, du
end

Expand All @@ -103,6 +113,10 @@ __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
__get_nonsparse_ad(ad) = ad

__init_JᵀJ(J::Number) = zero(J)

Check warning on line 116 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L116

Added line #L116 was not covered by tests
__init_JᵀJ(J::AbstractArray) = zeros(eltype(J), size(J, 2), size(J, 2))
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)

Check warning on line 118 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L118

Added line #L118 was not covered by tests

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
::Val{false}; kwargs...)
Expand Down
84 changes: 46 additions & 38 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ numerically-difficult nonlinear systems.
where `J` is the Jacobian. It is suggested by
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
!!! warning
`linsolve` and `precs` are used exclusively for the inplace version of the algorithm.
Support for the OOP version is planned!
"""
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
Expand All @@ -102,18 +97,17 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end

@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} <:
AbstractNonlinearSolveCache{iip}
@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u::uType
u
fu1
fu2
du
p
uf
linsolve
J::jType
J
jac_cache
force_stop::Bool
maxiters::Int
Expand All @@ -122,27 +116,27 @@ end
abstol
prob
DᵀD
JᵀJ::jType
λ::λType
λ_factor::λType
damping_increase_factor::λType
damping_decrease_factor::λType
h::λType
α_geodesic::λType
b_uphill::λType
min_damping_D::λType
v::uType
a::uType
tmp_vec::uType
v_old::uType
norm_v_old::lossType
δ::uType
loss_old::lossType
JᵀJ
λ
λ_factor
damping_increase_factor
damping_decrease_factor
h
α_geodesic
b_uphill
min_damping_D
v
a
tmp_vec
v_old
norm_v_old
δ
loss_old
make_new_J::Bool
fu_tmp
u_tmp
Jv
mat_tmp::jType
mat_tmp
stats::NLStats
end

Expand All @@ -153,8 +147,8 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
fu1 = evaluate_f(prob, u)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)
uf, linsolve, J, fu2, jac_cache, du, JᵀJ, v = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs, linsolve_with_JᵀJ=Val(true))

λ = convert(eltype(u), alg.damping_initial)
λ_factor = convert(eltype(u), alg.damping_increase_factor)
Expand All @@ -174,12 +168,10 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
end

loss = internalnorm(fu1)
JᵀJ = J isa Number ? zero(J) : similar(J, size(J, 2), size(J, 2))
v = zero(u)
a = zero(u)
tmp_vec = zero(u)
v_old = zero(u)
δ = zero(u)
a = _mutable_zero(u)
tmp_vec = _mutable_zero(u)
v_old = _mutable_zero(u)
δ = _mutable_zero(u)
make_new_J = true
fu_tmp = zero(fu1)
mat_tmp = zero(JᵀJ)
Expand Down Expand Up @@ -223,7 +215,8 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
# The following lines do: cache.a = -J \ cache.fu_tmp
mul!(cache.Jv, J, v)
@. cache.fu_tmp = (2 / h) * ((cache.fu_tmp - fu1) / h - cache.Jv)
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.fu_tmp),
mul!(cache.u_tmp, J', cache.fu_tmp)
linres = dolinsolve(alg.precs, linsolve; A = cache.mat_tmp, b = _vec(cache.u_tmp),
linu = _vec(cache.du), p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. cache.a = -cache.du
Expand Down Expand Up @@ -279,15 +272,30 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J = cache
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache

Check warning on line 275 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L275

Added line #L275 was not covered by tests

cache.mat_tmp = JᵀJ + λ * DᵀD
# Usual Levenberg-Marquardt step ("velocity").
cache.v = -cache.mat_tmp \ (J' * fu1)
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)

Check warning on line 280 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L279-L280

Added lines #L279 - L280 were not covered by tests
else
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp, b = _vec(J' * fu1),

Check warning on line 282 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L282

Added line #L282 was not covered by tests
linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 284 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L284

Added line #L284 was not covered by tests
end

@unpack v, h, α_geodesic = cache
# Geodesic acceleration (step_size = v + a / 2).
cache.a = -cache.mat_tmp \ ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v))
if linsolve === nothing
cache.a = -cache.mat_tmp \

Check warning on line 290 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L289-L290

Added lines #L289 - L290 were not covered by tests
_vec(J' * ((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))
else
linres = dolinsolve(alg.precs, linsolve; A = -cache.mat_tmp,

Check warning on line 293 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L293

Added line #L293 was not covered by tests
b = _mutable(_vec(J' *
((2 / h) .* ((f(u .+ h .* v, p) .- fu1) ./ h .- J * v)))),
linu = _vec(cache.a), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 297 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L297

Added line #L297 was not covered by tests
end
cache.stats.nsolve += 1
cache.stats.nfactors += 1

Expand Down
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ _mutable_zero(x::SArray) = MArray(x)

_mutable(x) = x
_mutable(x::SArray) = MArray(x)

_maybe_mutable(x, ::AbstractFiniteDifferencesMode) = _mutable(x)
# The shadow allocated for Enzyme needs to be mutable
_maybe_mutable(x, ::AutoSparseEnzyme) = _mutable(x)
Expand Down
13 changes: 7 additions & 6 deletions test/23_test_problems.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
using NonlinearSolve, LinearAlgebra, NonlinearProblemLibrary, Test
using NonlinearSolve, LinearAlgebra, LinearSolve, NonlinearProblemLibrary, Test

problems = NonlinearProblemLibrary.problems
dicts = NonlinearProblemLibrary.dicts

function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-5)
function test_on_library(problems, dicts, alg_ops, broken_tests, ϵ = 1e-4)
for (idx, (problem, dict)) in enumerate(zip(problems, dicts))
x = dict["start"]
res = similar(x)
nlprob = NonlinearProblem(problem, x)
@testset "$(dict["title"])" begin
for alg in alg_ops
sol = solve(nlprob, alg, abstol = 1e-15, reltol = 1e-15)
sol = solve(nlprob, alg, abstol = 1e-18, reltol = 1e-18)
problem(res, sol.u, nothing)
broken = idx in broken_tests[alg] ? true : false
@test norm(res)ϵ broken=broken
Expand Down Expand Up @@ -43,19 +43,20 @@ end
broken_tests[alg_ops[1]] = [6, 11, 21]
broken_tests[alg_ops[2]] = [6, 11, 21]
broken_tests[alg_ops[3]] = [1, 6, 11, 12, 15, 16, 21]
broken_tests[alg_ops[4]] = [1, 6, 8, 11, 15, 16, 21, 22]
broken_tests[alg_ops[4]] = [1, 6, 8, 11, 16, 21, 22]
broken_tests[alg_ops[5]] = [6, 21]
broken_tests[alg_ops[6]] = [6, 21]

test_on_library(problems, dicts, alg_ops, broken_tests)
end

@testset "TrustRegion test problem library" begin
alg_ops = (LevenbergMarquardt(), LevenbergMarquardt(; α_geodesic = 0.5))
alg_ops = (LevenbergMarquardt(; linsolve=NormalCholeskyFactorization()),
LevenbergMarquardt(; α_geodesic = 0.1, linsolve=NormalCholeskyFactorization()))

# dictionary with indices of test problems where method does not converge to small residual
broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [3, 6, 11, 17, 21]
broken_tests[alg_ops[1]] = [3, 6, 11, 21]
broken_tests[alg_ops[2]] = [3, 6, 11, 21]

test_on_library(problems, dicts, alg_ops, broken_tests)
Expand Down
Loading

0 comments on commit af3e026

Please sign in to comment.