Skip to content

Commit

Permalink
preserve LM linearsolve structure
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2023
1 parent fabd33c commit 9d81e46
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 26 deletions.
2 changes: 2 additions & 0 deletions docs/src/solvers/NonlinearSystemSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ features, but have a bit of overhead on very small problems.
Automatic Jacobian Resetting. This is a fast method but unstable for most problems!
- `GeneralKlement()`: Generalization of Klement's Quasi-Newton Method with Line Search and
Automatic Jacobian Resetting. This is a fast method but unstable for most problems!
- `LimitedMemoryBroyden()`: An advanced version of `LBroyden` which uses a limited memory
Broyden method. This is a fast method but unstable for most problems!

### SimpleNonlinearSolve.jl

Expand Down
35 changes: 19 additions & 16 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))

# Build Jacobian Caches
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip};
linsolve_kwargs = (;),
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ}
linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true),
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init}
uf = JacobianWrapper{iip}(f, p)

haslinsolve = hasfield(typeof(alg), :linsolve)
Expand Down Expand Up @@ -95,25 +95,28 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
Jᵀfu = J' * _vec(fu)
end

linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))

if alg isa PseudoTransient
alpha = convert(eltype(u), alg.alpha_initial)
J_new = J - (1 / alpha) * I
linprob = LinearProblem(J_new, _vec(fu); u0 = _vec(du))
if linsolve_init
linprob_A = alg isa PseudoTransient ?

Check warning on line 99 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L98-L99

Added lines #L98 - L99 were not covered by tests
(J - (1 / (convert(eltype(u), alg.alpha_initial))) * I) :
(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J)
linsolve = __setup_linsolve(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg)

Check warning on line 102 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L102

Added line #L102 was not covered by tests
else
linsolve = nothing

Check warning on line 104 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L104

Added line #L104 was not covered by tests
end

needsJᵀJ && return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu
return uf, linsolve, J, fu, jac_cache, du

Check warning on line 108 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end

function __setup_linsolve(A, b, u, p, alg)
linprob = LinearProblem(A, _vec(b); u0 = _vec(u))

Check warning on line 112 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L111-L112

Added lines #L111 - L112 were not covered by tests

weight = similar(u)
recursivefill!(weight, true)

Pl, Pr = wrapprecs(alg.precs(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J, nothing, u, p,
nothing, nothing, nothing, nothing, nothing)..., weight)
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)

needsJᵀJ && return uf, linsolve, J, fu, jac_cache, du, JᵀJ, Jᵀfu
return uf, linsolve, J, fu, jac_cache, du
Pl, Pr = wrapprecs(alg.precs(A, nothing, u, p, nothing, nothing, nothing, nothing,

Check warning on line 117 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L117

Added line #L117 was not covered by tests
nothing)..., weight)
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)

Check warning on line 119 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L119

Added line #L119 was not covered by tests
end

__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
Expand Down
16 changes: 8 additions & 8 deletions src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ solves.
linesearch
end

function set_linsolve(alg::GeneralKlement, linsolve)
return GeneralKlement(alg.max_resets, linsolve, alg.precs, alg.linesearch)

Check warning on line 31 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L30-L31

Added lines #L30 - L31 were not covered by tests
end

function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,

Check warning on line 34 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L34

Added line #L34 was not covered by tests
linesearch = LineSearch(), precs = DEFAULT_PRECS)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
Expand Down Expand Up @@ -60,7 +64,7 @@ end

get_fu(cache::GeneralKlementCache) = cache.fu

Check warning on line 65 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L65

Added line #L65 was not covered by tests

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlement, args...;
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...;

Check warning on line 67 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L67

Added line #L67 was not covered by tests
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
Expand All @@ -70,17 +74,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlemen

if u isa Number
linsolve = nothing
alg = alg_

Check warning on line 77 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L75-L77

Added lines #L75 - L77 were not covered by tests
else
# For General Julia Arrays default to LU Factorization
linsolve_alg = alg.linsolve === nothing && u isa Array ? LUFactorization() :

Check warning on line 80 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L80

Added line #L80 was not covered by tests
nothing
weight = similar(u)
recursivefill!(weight, true)
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
linprob = LinearProblem(J, _vec(fu); u0 = _vec(fu))
linsolve = init(linprob, linsolve_alg; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)
alg = set_linsolve(alg_, linsolve_alg)
linsolve = __setup_linsolve(J, _vec(fu), _vec(u), p, alg)

Check warning on line 83 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L82-L83

Added lines #L82 - L83 were not covered by tests
end

return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,

Check warning on line 86 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L86

Added line #L86 was not covered by tests
Expand Down
6 changes: 4 additions & 2 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,12 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
mat_tmp = zero(JᵀJ)
rhs_tmp = nothing

Check warning on line 214 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L212-L214

Added lines #L212 - L214 were not covered by tests
else
mat_tmp = similar(JᵀJ, length(fu1) + length(u), length(u))
# Preserve Types
mat_tmp = vcat(J, DᵀD)
fill!(mat_tmp, zero(eltype(u)))
rhs_tmp = similar(mat_tmp, length(fu1) + length(u))
rhs_tmp = vcat(fu1, u)
fill!(rhs_tmp, zero(eltype(u)))
linsolve = __setup_linsolve(mat_tmp, rhs_tmp, u, p, alg)

Check warning on line 221 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L217-L221

Added lines #L217 - L221 were not covered by tests
end

return LevenbergMarquardtCache{iip, !_unwrap_val(linsolve_with_JᵀJ)}(f, alg, u, fu1,

Check warning on line 224 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L224

Added line #L224 was not covered by tests
Expand Down

0 comments on commit 9d81e46

Please sign in to comment.