From 62266d97292373791810a3129dfd187b0568ba9e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 30 Nov 2023 01:14:46 -0500 Subject: [PATCH] Reuse Klement Code --- Project.toml | 4 +- src/NonlinearSolve.jl | 9 +-- src/klement.jl | 174 +++++++++++++----------------------------- src/utils.jl | 10 +-- 4 files changed, 61 insertions(+), 136 deletions(-) diff --git a/Project.toml b/Project.toml index 8a42f9d21..9385b14a2 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] @@ -75,7 +75,6 @@ SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.14" StaticArrays = "1" -StaticArraysCore = "1.4" Symbolics = "5" Test = "1" UnPack = "1.0" @@ -99,7 +98,6 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 9096525ee..63987898b 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,9 +8,8 @@ import Reexport: @reexport import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload @recompile_invalidations begin - using DiffEqBase, - LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays, - SparseDiffTools + using ADTypes, DiffEqBase, LazyArrays, LineSearches, LinearAlgebra, LinearSolve, Printf, + SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, StaticArrays import ADTypes: AbstractFiniteDifferencesMode import ArrayInterface: undefmatrix, restructure, can_setindex, @@ -26,10 +25,8 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work AbstractVectorOfArray, recursivecopy!, recursivefill! import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace import SciMLOperators: FunctionOperator - import StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix + import StaticArrays: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix import UnPack: @unpack - - using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve end @reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve diff --git a/src/klement.jl b/src/klement.jl index 8a9640fd4..261348bf6 100644 --- a/src/klement.jl +++ b/src/klement.jl @@ -74,38 +74,44 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme termination_condition = nothing, internalnorm::F = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip, F} @unpack f, u0, p = prob - u = alias_u0 ? u0 : deepcopy(u0) + u = __maybe_unaliased(u0, alias_u0) fu = evaluate_f(prob, u) J = __init_identity_jacobian(u, fu) - du = _mutable_zero(u) + @bb du = similar(u) if u isa Number - linsolve = nothing + linsolve = FakeLinearSolveJLCache(J, fu) alg = alg_ else # For General Julia Arrays default to LU Factorization - linsolve_alg = alg_.linsolve === nothing && u isa Array ? LUFactorization() : - nothing + linsolve_alg = (alg_.linsolve === nothing && (u isa Array || u isa StaticArray)) ? + LUFactorization() : nothing alg = set_linsolve(alg_, linsolve_alg) - linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg) + linsolve = linsolve_caches(J, _vec(fu), _vec(du), p, alg; linsolve_kwargs) end abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u, termination_condition) trace = init_nonlinearsolve_trace(alg, u, fu, J, du; kwargs...) - return GeneralKlementCache{iip}(f, alg, u, zero(u), fu, zero(fu), du, p, linsolve, - J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false, - maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, - NLStats(1, 0, 0, 0, 0), + @bb u_prev = copy(u) + @bb fu2 = similar(fu) + @bb J_cache = similar(J) + @bb J_cache2 = similar(J) + @bb Jᵀ²du = similar(fu) + @bb Jdu = similar(fu) + + return GeneralKlementCache{iip}(f, alg, u, u_prev, fu, fu2, du, p, linsolve, J, J_cache, + J_cache2, Jᵀ²du, Jdu, 0, false, maxiters, internalnorm, ReturnCode.Default, abstol, + reltol, prob, NLStats(1, 0, 0, 0, 0), init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace) end -function perform_step!(cache::GeneralKlementCache{true}) - @unpack u, u_prev, fu, f, p, alg, J, linsolve, du = cache - T = eltype(J) - - singular, fact_done = _try_factorize_and_check_singular!(linsolve, J) +function perform_step!(cache::GeneralKlementCache{iip}) where {iip} + @unpack linsolve, alg = cache + # @unpack fu, f, p, alg, J, linsolve = cache + T = eltype(cache.J) + singular, fact_done = __try_factorize_and_check_singular!(linsolve, cache.J) if singular if cache.resets == alg.max_resets @@ -114,88 +120,33 @@ function perform_step!(cache::GeneralKlementCache{true}) return nothing end fact_done = false - fill!(J, zero(T)) - J[diagind(J)] .= T(1) + cache.J = __reinit_identity_jacobian!!(cache.J) cache.resets += 1 end # u = u - J \ fu - linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J), - b = _vec(fu), linu = _vec(du), p, reltol = cache.abstol) + linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu), + linu = _vec(cache.du), cache.p, reltol = cache.abstol) cache.linsolve = linres.cache - # Line Search - α = perform_linesearch!(cache.ls_cache, u, du) - _axpy!(-α, du, u) - f(cache.fu2, u, p) - - update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, J, - cache.du, α) - - check_and_update!(cache, cache.fu2, cache.u, cache.u_prev) - cache.stats.nf += 1 - cache.stats.nsolve += 1 - cache.stats.nfactors += 1 - - cache.force_stop && return nothing - - # Update the Jacobian - cache.du .*= -1 - cache.J_cache .= cache.J' .^ 2 - cache.Jdu .= _vec(du) .^ 2 - mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu) - mul!(cache.Jdu, J, _vec(du)) - cache.fu .= cache.fu2 .- cache.fu - cache.fu .= _restructure(cache.fu, - (_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(real(T)))) - mul!(cache.J_cache, _vec(cache.fu), _vec(du)') - cache.J_cache .*= J - mul!(cache.J_cache2, cache.J_cache, J) - J .+= cache.J_cache2 - - @. u_prev = u - cache.fu .= cache.fu2 - - return nothing -end - -function perform_step!(cache::GeneralKlementCache{false}) - @unpack fu, f, p, alg, J, linsolve = cache + !iip && (cache.du = linres.u) - T = eltype(J) - - singular, fact_done = _try_factorize_and_check_singular!(linsolve, J) - - if singular - if cache.resets == alg.max_resets - cache.force_stop = true - cache.retcode = ReturnCode.ConvergenceFailure - return nothing - end - fact_done = false - cache.J = __init_identity_jacobian(cache.u, fu) - cache.resets += 1 - end + # Line Search + α = perform_linesearch!(cache.ls_cache, cache.u, cache.du) + @bb axpy!(-α, cache.du, cache.u) - # u = u - J \ fu - if linsolve === nothing - cache.du = fu / cache.J + if iip + cache.f(cache.fu2, cache.u, cache.p) else - linres = dolinsolve(alg.precs, linsolve; A = ifelse(fact_done, nothing, J), - b = _vec(fu), linu = _vec(cache.du), p, reltol = cache.abstol) - cache.linsolve = linres.cache + cache.fu2 = cache.f(cache.u, cache.p) end - # Line Search - α = perform_linesearch!(cache.ls_cache, cache.u, cache.du) - cache.u = @. cache.u - α * cache.du # `u` might not support mutation - cache.fu2 = f(cache.u, p) - - update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, J, + update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, cache.J, cache.du, α) check_and_update!(cache, cache.fu2, cache.u, cache.u_prev) - cache.u_prev = cache.u + @bb copyto!(cache.u_prev, cache.u) + cache.stats.nf += 1 cache.stats.nsolve += 1 cache.stats.nfactors += 1 @@ -203,46 +154,27 @@ function perform_step!(cache::GeneralKlementCache{false}) cache.force_stop && return nothing # Update the Jacobian - cache.du = -cache.du - cache.J_cache = cache.J' .^ 2 - cache.Jdu = _vec(cache.du) .^ 2 - cache.Jᵀ²du = cache.J_cache * cache.Jdu - cache.Jdu = J * _vec(cache.du) - cache.fu = cache.fu2 .- cache.fu - cache.fu = _restructure(cache.fu, - (_vec(cache.fu) .- cache.Jdu) ./ max.(cache.Jᵀ²du, eps(real(T)))) - cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J - cache.J = J .+ cache.J_cache - - cache.fu = cache.fu2 + @bb cache.du .*= -1 + @bb cache.J_cache .= cache.J' .^ 2 + @bb @. cache.Jdu = cache.du ^ 2 + @bb cache.Jᵀ²du = cache.J_cache × vec(cache.Jdu) + @bb cache.Jdu = cache.J × vec(cache.du) + @bb @. cache.fu = cache.fu2 - cache.fu - return nothing -end + @bb @. cache.fu = (cache.fu - cache.Jdu) / max(cache.Jᵀ²du, eps(real(T))) -function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters, - termination_condition = get_termination_mode(cache.tc_cache)) where {iip} - cache.p = p - if iip - recursivecopy!(cache.u, u0) - cache.f(cache.fu, cache.u, p) - else - # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter - cache.u = u0 - cache.fu = cache.f(cache.u, p) - end + @bb cache.J_cache = vec(cache.fu) × transpose(_vec(cache.du)) + @bb @. cache.J_cache *= cache.J + @bb cache.J_cache2 = cache.J_cache × cache.J + @bb cache.J .+= cache.J_cache2 - reset!(cache.trace) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u, - termination_condition) + @bb copyto!(cache.fu, cache.fu2) - cache.abstol = abstol - cache.reltol = reltol - cache.tc_cache = tc_cache - cache.maxiters = maxiters - cache.stats.nf = 1 - cache.stats.nsteps = 1 - cache.force_stop = false - cache.retcode = ReturnCode.Default - return cache + return nothing +end + +function __reinit_internal!(cache::GeneralKlementCache) + cache.J = __reinit_identity_jacobian!!(cache.J) + cache.resets = 0 + return nothing end diff --git a/src/utils.jl b/src/utils.jl index ab2db093f..bc38d9257 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -356,8 +356,8 @@ end # If factorization is LU then perform that and update the linsolve cache # else check if the matrix is singular -function _try_factorize_and_check_singular!(linsolve, X) - if linsolve.cacheval isa LU +function __try_factorize_and_check_singular!(linsolve, X) + if linsolve.cacheval isa LU || linsolve.cacheval isa StaticArrays.LU # LU Factorization was used linsolve.A = X linsolve.cacheval = LinearSolve.do_factorization(linsolve.alg, X, linsolve.b, @@ -368,11 +368,9 @@ function _try_factorize_and_check_singular!(linsolve, X) end return _issingular(X), false end -_try_factorize_and_check_singular!(::Nothing, x) = _issingular(x), false - -@inline _reshape(x, args...) = reshape(x, args...) -@inline _reshape(x::Number, args...) = x +__try_factorize_and_check_singular!(::FakeLinearSolveJLCache, x) = _issingular(x), false +# TODO: Remove. handled in MaybeInplace.jl @generated function _axpy!(α, x, y) hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y)) return :(@. y += α * x)