From d6c5ea498a2c3bf795596b84b634d54f63e70999 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Nov 2023 22:01:55 -0500 Subject: [PATCH] Share reinit code --- src/NonlinearSolve.jl | 36 +++++++++++++++++++++++++++++++++++- src/broyden.jl | 29 +++-------------------------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index f1782b8c1..9096525ee 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,7 +8,8 @@ import Reexport: @reexport import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload @recompile_invalidations begin - using DiffEqBase, LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays, + using DiffEqBase, + LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays, SparseDiffTools import ADTypes: AbstractFiniteDifferencesMode @@ -51,6 +52,39 @@ abstract type AbstractNonlinearSolveCache{iip} end isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip +function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(cache); + p = cache.p, abstol = cache.abstol, reltol = cache.reltol, + maxiters = cache.maxiters, alias_u0 = false, + termination_condition = get_termination_mode(cache.tc_cache)) where {iip} + cache.p = p + if iip + recursivecopy!(get_u(cache), u0) + cache.f(cache.fu1, get_u(cache), p) + else + cache.u = __maybe_unaliased(u0, alias_u0) + set_fu!(cache, cache.f(cache.u, p)) + end + + reset!(cache.trace) + abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, get_fu(cache), + get_u(cache), termination_condition) + + cache.abstol = abstol + cache.reltol = reltol + cache.tc_cache = tc_cache + cache.maxiters = maxiters + cache.stats.nf = 1 + cache.stats.nsteps = 1 + cache.force_stop = false + cache.retcode = ReturnCode.Default + + __reinit_internal!(cache) + + return cache +end + +__reinit_internal!(cache::AbstractNonlinearSolveCache) = nothing + function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm) str = "$(nameof(typeof(alg)))(" modifiers = String[] diff --git a/src/broyden.jl b/src/broyden.jl index e0b69f19c..dbc4f5131 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -137,31 +137,8 @@ function perform_step!(cache::GeneralBroydenCache{iip}) where {iip} return nothing end -function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters, - termination_condition = get_termination_mode(cache.tc_cache)) where {iip} - cache.p = p - if iip - recursivecopy!(cache.u, u0) - cache.f(cache.fu, cache.u, p) - else - # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter - cache.u = u0 - cache.fu = cache.f(cache.u, p) - end - - reset!(cache.trace) - abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u, - termination_condition) - - cache.abstol = abstol - cache.reltol = reltol - cache.tc_cache = tc_cache - cache.maxiters = maxiters - cache.stats.nf = 1 - cache.stats.nsteps = 1 +function __reinit_internal!(cache::GeneralBroydenCache) + cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹) cache.resets = 0 - cache.force_stop = false - cache.retcode = ReturnCode.Default - return cache + return nothing end