Skip to content

Commit

Permalink
Reuse LU Factorization to check for singular matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 22, 2023
1 parent 00852f0 commit 4ce655d
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 29 deletions.
3 changes: 3 additions & 0 deletions docs/src/api/nonlinearsolve.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ These are the native solvers of NonlinearSolve.jl.
NewtonRaphson
TrustRegion
PseudoTransient
DFSane
GeneralBroyden
GeneralKlement
```

## Polyalgorithms
Expand Down
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ArrayInterface: restructure
import ForwardDiff

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable
import ArrayInterface: undefmatrix, matrix_colors, parameterless_type, ismutable, issingular
import ConcreteStructs: @concrete
import EnumX: @enumx
import ForwardDiff: Dual
Expand Down
1 change: 0 additions & 1 deletion src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ See also the implementation in [SimpleNonlinearSolve.jl](https://github.com/SciM
- `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
algorithm. Defaults to `1000`.
"""

struct DFSane{T, F} <: AbstractNonlinearSolveAlgorithm
σ_min::T
σ_max::T
Expand Down
66 changes: 46 additions & 20 deletions src/klement.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
"""
GeneralKlement(; max_resets = 5, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS)
An implementation of `Klement` with line search, preconditioning and customizable linear
solves.
## Keyword Arguments
- `max_resets`: the maximum number of resets to perform. Defaults to `5`.
- `linsolve`: the [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) used for the
linear solves within the Newton method. Defaults to `nothing`, which means it uses the
LinearSolve.jl default algorithm choice. For more information on available algorithm
choices, see the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
"""
@concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing}
max_resets::Int
linsolve
precs
linesearch
singular_tolerance
end

function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,

Check warning on line 30 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L30

Added line #L30 was not covered by tests
linesearch = LineSearch(), precs = DEFAULT_PRECS, singular_tolerance = nothing)
linesearch = LineSearch(), precs = DEFAULT_PRECS)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GeneralKlement(max_resets, linsolve, precs, linesearch, singular_tolerance)
return GeneralKlement(max_resets, linsolve, precs, linesearch)

Check warning on line 33 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L32-L33

Added lines #L32 - L33 were not covered by tests
end

@concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip}
Expand All @@ -27,7 +48,6 @@ end
Jᵀ²du
Jdu
resets
singular_tolerance
force_stop
maxiters::Int
internalnorm
Expand All @@ -51,20 +71,20 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlemen
if u isa Number
linsolve = nothing

Check warning on line 72 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L71-L72

Added lines #L71 - L72 were not covered by tests
else
# For General Julia Arrays default to LU Factorization
linsolve_alg = alg.linsolve === nothing && u isa Array ? LUFactorization() :

Check warning on line 75 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L75

Added line #L75 was not covered by tests
nothing
weight = similar(u)
recursivefill!(weight, true)
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,

Check warning on line 79 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L77-L79

Added lines #L77 - L79 were not covered by tests
nothing)..., weight)
linprob = LinearProblem(J, _vec(fu); u0 = _vec(fu))
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve = init(linprob, linsolve_alg; alias_A = true, alias_b = true, Pl, Pr,

Check warning on line 82 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L81-L82

Added lines #L81 - L82 were not covered by tests
linsolve_kwargs...)
end

singular_tolerance = alg.singular_tolerance === nothing ? inv(sqrt(eps(eltype(u)))) :
eltype(u)(alg.singular_tolerance)

return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,

Check warning on line 86 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L86

Added line #L86 was not covered by tests
J, zero(J), zero(J), zero(fu), zero(fu), 0, singular_tolerance, false,
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
end
Expand All @@ -73,21 +93,23 @@ function perform_step!(cache::GeneralKlementCache{true})
@unpack u, fu, f, p, alg, J, linsolve, du = cache
T = eltype(J)

Check warning on line 94 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L92-L94

Added lines #L92 - L94 were not covered by tests

# FIXME: How can we do this faster?
if cond(J) > cache.singular_tolerance
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)

Check warning on line 96 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L96

Added line #L96 was not covered by tests

if singular
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
return nothing

Check warning on line 102 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L98-L102

Added lines #L98 - L102 were not covered by tests
end
fact_done = false
fill!(J, zero(T))
J[diagind(J)] .= T(1)
cache.resets += 1

Check warning on line 107 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L104-L107

Added lines #L104 - L107 were not covered by tests
end

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu), linu = _vec(du),
p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),

Check warning on line 111 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L111

Added line #L111 was not covered by tests
b = -_vec(fu), linu = _vec(du), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 113 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L113

Added line #L113 was not covered by tests

# Line Search
Expand All @@ -108,7 +130,8 @@ function perform_step!(cache::GeneralKlementCache{true})
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
mul!(cache.Jdu, J, _vec(du))
cache.fu .= cache.fu2 .- cache.fu
cache.fu .= (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
cache.fu .= _restructure(cache.fu,

Check warning on line 133 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L128-L133

Added lines #L128 - L133 were not covered by tests
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(T)))
mul!(cache.J_cache, _vec(cache.fu), _vec(du)')
cache.J_cache .*= J
mul!(cache.J_cache2, cache.J_cache, J)
Expand All @@ -123,23 +146,25 @@ function perform_step!(cache::GeneralKlementCache{false})
@unpack fu, f, p, alg, J, linsolve = cache
T = eltype(J)

Check warning on line 147 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L145-L147

Added lines #L145 - L147 were not covered by tests

# FIXME: How can we do this faster?
if cond(J) > cache.singular_tolerance
singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)

Check warning on line 149 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L149

Added line #L149 was not covered by tests

if singular
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
return nothing

Check warning on line 155 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L151-L155

Added lines #L151 - L155 were not covered by tests
end
cache.J = __init_identity_jacobian(u, fu)
fact_done = false
cache.J = __init_identity_jacobian(cache.u, fu)
cache.resets += 1

Check warning on line 159 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L157-L159

Added lines #L157 - L159 were not covered by tests
end

# u = u - J \ fu
if linsolve === nothing
cache.du = -fu / cache.J

Check warning on line 164 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L163-L164

Added lines #L163 - L164 were not covered by tests
else
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu),
linu = _vec(cache.du), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),

Check warning on line 166 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L166

Added line #L166 was not covered by tests
b = -_vec(fu), linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache

Check warning on line 168 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L168

Added line #L168 was not covered by tests
end

Expand All @@ -161,7 +186,8 @@ function perform_step!(cache::GeneralKlementCache{false})
cache.Jᵀ²du = cache.J_cache * cache.Jdu
cache.Jdu = J * _vec(cache.du)
cache.fu = cache.fu2 .- cache.fu
cache.fu = (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
cache.fu = _restructure(cache.fu,

Check warning on line 189 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L184-L189

Added lines #L184 - L189 were not covered by tests
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(T)))
cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J
cache.J = J .+ cache.J_cache

Check warning on line 192 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L191-L192

Added lines #L191 - L192 were not covered by tests

Expand Down
1 change: 0 additions & 1 deletion src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ states as `RadiusUpdateSchemes.T`. Simply put the desired scheme as follows:
end

"""
```julia
TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
radius_update_scheme::RadiusUpdateSchemes.T = RadiusUpdateSchemes.Simple,
max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1,
Expand Down
25 changes: 25 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,28 @@ function __init_identity_jacobian(u::StaticArray, fu)
return convert(MArray{Tuple{length(fu), length(u)}},

Check warning on line 221 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L220-L221

Added lines #L220 - L221 were not covered by tests
Matrix{eltype(u)}(I, length(fu), length(u)))
end

# Check Singular Matrix
_issingular(x::Number) = iszero(x)
@generated function _issingular(x::T) where {T}
hasmethod(issingular, Tuple{T}) && return :(issingular(x))
return :(__issingular(x))

Check warning on line 229 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L226-L229

Added lines #L226 - L229 were not covered by tests
end
__issingular(x::AbstractMatrix{T}) where {T} = cond(x) > inv(sqrt(eps(T)))
__issingular(x) = false ## If SciMLOperator and such

Check warning on line 232 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L231-L232

Added lines #L231 - L232 were not covered by tests

# If factorization is LU then perform that and update the linsolve cache
# else check if the matrix is singular
function _try_factorize_and_check_singular!(linsolve, X)
if linsolve.cacheval isa LU

Check warning on line 237 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L236-L237

Added lines #L236 - L237 were not covered by tests
# LU Factorization was used
linsolve.A = X
linsolve.cacheval = LinearSolve.do_factorization(linsolve.alg, X, linsolve.b,

Check warning on line 240 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L239-L240

Added lines #L239 - L240 were not covered by tests
linsolve.u)
linsolve.isfresh = false

Check warning on line 242 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L242

Added line #L242 was not covered by tests

return !issuccess(linsolve.cacheval), true

Check warning on line 244 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L244

Added line #L244 was not covered by tests
end
return _issingular(X), false

Check warning on line 246 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L246

Added line #L246 was not covered by tests
end
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false

Check warning on line 248 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L248

Added line #L248 was not covered by tests
10 changes: 6 additions & 4 deletions test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,21 @@ end

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 3, 4, 5, 6, 8, 11, 12, 13, 14, 21]
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 9, 11, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 3, 4, 5, 6, 9, 11, 13, 16, 21, 22]
broken_tests[alg_ops[3]] = [1, 2, 4, 5, 6, 8, 11, 12, 13, 14, 21]

test_on_library(problems, dicts, alg_ops, broken_tests)
end

@testset "GeneralKlement 23 Test Problems" begin
alg_ops = (GeneralKlement(),
GeneralKlement(; linesearch = BackTracking()))
GeneralKlement(; linesearch = BackTracking()),
GeneralKlement(; linesearch = HagerZhang()))

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 2, 3, 4, 5, 6, 7, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 12, 22]
broken_tests[alg_ops[1]] = [1, 2, 4, 5, 6, 7, 11, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 13, 22]
broken_tests[alg_ops[3]] = [1, 2, 5, 6, 11, 12, 13, 22]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
4 changes: 2 additions & 2 deletions test/matrix_resizing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ vecprob = NonlinearProblem(ff, vec(u0), p)
prob = NonlinearProblem(ff, u0, p)

for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden())
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden(), GeneralKlement())
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
end

Expand All @@ -18,6 +18,6 @@ vecprob = NonlinearProblem(fiip, vec(u0), p)
prob = NonlinearProblem(fiip, u0, p)

for alg in (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(), PseudoTransient(),
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden())
RobustMultiNewton(), FastShortcutNonlinearPolyalg(), GeneralBroyden(), GeneralKlement())
@test vec(solve(prob, alg).u) == solve(vecprob, alg).u
end

0 comments on commit 4ce655d

Please sign in to comment.