Skip to content

Commit

Permalink
Reuse Klement Code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 30, 2023
1 parent a5c6195 commit f147663
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 136 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
Expand Down Expand Up @@ -75,7 +75,6 @@ SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.14"
StaticArrays = "1"
StaticArraysCore = "1.4"
Symbolics = "5"
Test = "1"
UnPack = "1.0"
Expand All @@ -99,7 +98,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
9 changes: 3 additions & 6 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ import Reexport: @reexport
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload

@recompile_invalidations begin
using DiffEqBase,
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
SparseDiffTools
using ADTypes, DiffEqBase, LazyArrays, LineSearches, LinearAlgebra, LinearSolve, Printf,
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, StaticArrays

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix, restructure, can_setindex,
Expand All @@ -26,10 +25,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
AbstractVectorOfArray, recursivecopy!, recursivefill!
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
import SciMLOperators: FunctionOperator
import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
import StaticArrays: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
import UnPack: @unpack

using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
end

@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
Expand Down
173 changes: 52 additions & 121 deletions src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,38 +74,43 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip, F}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
u = __maybe_unaliased(u0, alias_u0)
fu = evaluate_f(prob, u)
J = __init_identity_jacobian(u, fu)
du = _mutable_zero(u)
@bb du = similar(u)

if u isa Number
linsolve = nothing
linsolve = FakeLinearSolveJLCache(J, fu)
alg = alg_
else
# For General Julia Arrays default to LU Factorization
linsolve_alg = alg_.linsolve === nothing && u isa Array ? LUFactorization() :
nothing
linsolve_alg = (alg_.linsolve === nothing && (u isa Array || u isa StaticArray)) ?
LUFactorization() : nothing
alg = set_linsolve(alg_, linsolve_alg)
linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg)
linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg; linsolve_kwargs)
end

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, J, du; kwargs...)

return GeneralKlementCache{iip}(f, alg, u, zero(u), fu, zero(fu), du, p, linsolve,
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
NLStats(1, 0, 0, 0, 0),
@bb u_prev = copy(u)
@bb fu2 = similar(fu)
@bb J_cache = similar(J)
@bb J_cache2 = similar(J)
@bb Jᵀ²du = similar(fu)
@bb Jdu = similar(fu)

return GeneralKlementCache{iip}(f, alg, u, u_prev, fu, fu2, du, p, linsolve, J, J_cache,
J_cache2, Jᵀ²du, Jdu, 0, false, maxiters, internalnorm, ReturnCode.Default, abstol,
reltol, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
end

function perform_step!(cache::GeneralKlementCache{true})
@unpack u, u_prev, fu, f, p, alg, J, linsolve, du = cache
T = eltype(J)

singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)
function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
@unpack linsolve, alg = cache
T = eltype(cache.J)
singular, fact_done = __try_factorize_and_check_singular!(linsolve, cache.J)

if singular
if cache.resets == alg.max_resets
Expand All @@ -114,135 +119,61 @@ function perform_step!(cache::GeneralKlementCache{true})
return nothing
end
fact_done = false
fill!(J, zero(T))
J[diagind(J)] .= T(1)
cache.J = __reinit_identity_jacobian!!(cache.J)
cache.resets += 1
end

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),
b = _vec(fu), linu = _vec(du), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu),
linu = _vec(cache.du), cache.p, reltol = cache.abstol)
cache.linsolve = linres.cache

# Line Search
α = perform_linesearch!(cache.ls_cache, u, du)
_axpy!(-α, du, u)
f(cache.fu2, u, p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, J,
cache.du, α)

check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1

cache.force_stop && return nothing

# Update the Jacobian
cache.du .*= -1
cache.J_cache .= cache.J' .^ 2
cache.Jdu .= _vec(du) .^ 2
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
mul!(cache.Jdu, J, _vec(du))
cache.fu .= cache.fu2 .- cache.fu
cache.fu .= _restructure(cache.fu,
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(real(T))))
mul!(cache.J_cache, _vec(cache.fu), _vec(du)')
cache.J_cache .*= J
mul!(cache.J_cache2, cache.J_cache, J)
J .+= cache.J_cache2

@. u_prev = u
cache.fu .= cache.fu2

return nothing
end

function perform_step!(cache::GeneralKlementCache{false})
@unpack fu, f, p, alg, J, linsolve = cache
!iip && (cache.du = linres.u)

T = eltype(J)

singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)

if singular
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.ConvergenceFailure
return nothing
end
fact_done = false
cache.J = __init_identity_jacobian(cache.u, fu)
cache.resets += 1
end
# Line Search
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
@bb axpy!(-α, cache.du, cache.u)

# u = u - J \ fu
if linsolve === nothing
cache.du = fu / cache.J
if iip
cache.f(cache.fu2, cache.u, cache.p)
else
linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J),
b = _vec(fu), linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
cache.fu2 = cache.f(cache.u, cache.p)
end

# Line Search
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
cache.u = @. cache.u - α * cache.du # `u` might not support mutation
cache.fu2 = f(cache.u, p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, J,
update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, cache.J,
cache.du, α)

check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
cache.u_prev = cache.u
@bb copyto!(cache.u_prev, cache.u)

cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1

cache.force_stop && return nothing

# Update the Jacobian
cache.du = -cache.du
cache.J_cache = cache.J' .^ 2
cache.Jdu = _vec(cache.du) .^ 2
cache.Jᵀ²du = cache.J_cache * cache.Jdu
cache.Jdu = J * _vec(cache.du)
cache.fu = cache.fu2 .- cache.fu
cache.fu = _restructure(cache.fu,
(_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(real(T))))
cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J
cache.J = J .+ cache.J_cache

cache.fu = cache.fu2
@bb cache.du .*= -1
@bb cache.J_cache .= cache.J' .^ 2
@bb @. cache.Jdu = cache.du ^ 2
@bb cache.Jᵀ²du = cache.J_cache × vec(cache.Jdu)
@bb cache.Jdu = cache.J × vec(cache.du)
@bb @. cache.fu = cache.fu2 - cache.fu

return nothing
end
@bb @. cache.fu = (cache.fu - cache.Jdu) / max(cache.Jᵀ²du, eps(real(T)))

function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
@bb cache.J_cache = vec(cache.fu) × transpose(_vec(cache.du))
@bb @. cache.J_cache *= cache.J
@bb cache.J_cache2 = cache.J_cache × cache.J
@bb cache.J .+= cache.J_cache2

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
termination_condition)
@bb copyto!(cache.fu, cache.fu2)

cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
return nothing
end

function __reinit_internal!(cache::GeneralKlementCache)
cache.J = __reinit_identity_jacobian!!(cache.J)
cache.resets = 0
return nothing
end
10 changes: 4 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ end

# If factorization is LU then perform that and update the linsolve cache
# else check if the matrix is singular
function _try_factorize_and_check_singular!(linsolve, X)
if linsolve.cacheval isa LU
function __try_factorize_and_check_singular!(linsolve, X)
if linsolve.cacheval isa LU || linsolve.cacheval isa StaticArrays.LU
# LU Factorization was used
linsolve.A = X
linsolve.cacheval = LinearSolve.do_factorization(linsolve.alg, X, linsolve.b,
Expand All @@ -368,11 +368,9 @@ function _try_factorize_and_check_singular!(linsolve, X)
end
return _issingular(X), false
end
_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false

@inline _reshape(x, args...) = reshape(x, args...)
@inline _reshape(x::Number, args...) = x
__try_factorize_and_check_singular!(::FakeLinearSolveJLCache, x) = _issingular(x), false

# TODO: Remove. handled in MaybeInplace.jl
@generated function _axpy!(α, x, y)
hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y))
return :(@. y += α * x)
Expand Down

0 comments on commit f147663

Please sign in to comment.