Skip to content

Commit

Permalink
refactor: implement internal caches function
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2024
1 parent 712ce57 commit 5ed4fe4
Show file tree
Hide file tree
Showing 20 changed files with 183 additions and 318 deletions.
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Expand All @@ -39,6 +40,7 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
NonlinearSolveBaseBandedMatricesExt = "BandedMatrices"
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
NonlinearSolveBaseLineSearchExt = "LineSearch"
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
NonlinearSolveBaseSparseArraysExt = "SparseArrays"
NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings"
Expand All @@ -60,6 +62,7 @@ FastClosures = "0.3"
ForwardDiff = "0.10.36"
FunctionProperties = "0.1.2"
InteractiveUtils = "<0.0.1, 1"
LineSearch = "0.1.4"
LinearAlgebra = "1.10"
LinearSolve = "2.36.1"
Markdown = "1.10"
Expand Down
17 changes: 17 additions & 0 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module NonlinearSolveBaseLineSearchExt

using LineSearch: LineSearch, AbstractLineSearchCache
using NonlinearSolveBase: NonlinearSolveBase, InternalAPI
using SciMLBase: SciMLBase

function NonlinearSolveBase.callback_into_cache!(
topcache, cache::AbstractLineSearchCache, args...
)
return LineSearch.callback_into_cache!(cache, NonlinearSolveBase.get_fu(topcache))
end

function InternalAPI.reinit!(cache::AbstractLineSearchCache; kwargs...)
return SciMLBase.reinit!(cache; kwargs...)
end

end
73 changes: 72 additions & 1 deletion lib/NonlinearSolveBase/src/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
module InternalAPI

using SciMLBase: NLStats

function init end
function solve! end
function reinit! end
function step! end

function reinit! end
function reinit_self! end

function reinit!(x::Any; kwargs...)
@debug "`InternalAPI.reinit!` is not implemented for $(typeof(x))."
return
end
function reinit_self!(x::Any; kwargs...)
@debug "`InternalAPI.reinit_self!` is not implemented for $(typeof(x))."
return
end

function reinit_self!(stats::NLStats)
stats.nf = 0
stats.nsteps = 0
stats.nfactors = 0
stats.njacs = 0
stats.nsolve = 0
end

end

abstract type AbstractNonlinearSolveBaseAPI end # Mostly used for pretty-printing
Expand Down Expand Up @@ -512,3 +533,53 @@ accepted then these values should be copied into the toplevel cache.
abstract type AbstractTrustRegionMethodCache <: AbstractNonlinearSolveBaseAPI end

last_step_accepted(cache::AbstractTrustRegionMethodCache) = cache.last_step_accepted

# Additional Interface
"""
callback_into_cache!(cache, internalcache, args...)
Define custom operations on `internalcache` tightly coupled with the calling `cache`.
`args...` contain the sequence of caches calling into `internalcache`.
This unfortunately makes code very tightly coupled and not modular. It is recommended to not
use this functionality unless it can't be avoided (like in [`LevenbergMarquardt`](@ref)).
"""
callback_into_cache!(cache, internalcache, args...) = nothing # By default do nothing

# Helper functions to generate cache callbacks and resetting functions
macro internal_caches(cType, internal_cache_names...)
callback_caches = map(internal_cache_names) do name
return quote
$(callback_into_cache!)(
cache, getproperty(internalcache, $(name)), internalcache, args...
)
end
end
callbacks_self = map(internal_cache_names) do name
return quote
$(callback_into_cache!)(cache, getproperty(cache, $(name)))
end
end
reinit_caches = map(internal_cache_names) do name
return quote
$(InternalAPI.reinit!)(getproperty(cache, $(name)), args...; kwargs...)
end
end
return esc(quote
function NonlinearSolveBase.callback_into_cache!(
cache, internalcache::$(cType), args...
)
$(callback_caches...)
return
end
function NonlinearSolveBase.callback_into_cache!(cache::$(cType))
$(callbacks_self...)
return
end
function NonlinearSolveBase.InternalAPI.reinit!(cache::$(cType), args...; kwargs...)
$(reinit_caches...)
$(InternalAPI.reinit_self!)(cache, args...; kwargs...)
return
end
end)
end
3 changes: 1 addition & 2 deletions lib/NonlinearSolveBase/src/descent/damped_newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ supports_trust_region(::DampedNewtonDescent) = true
mode <: Union{Val{:normal_form}, Val{:least_squares}, Val{:simple}}
end

# XXX: Implement
# @internal_caches DampedNewtonDescentCache :lincache :damping_fn_cache
NonlinearSolveBase.@internal_caches DampedNewtonDescentCache :lincache :damping_fn_cache

function InternalAPI.init(
prob::AbstractNonlinearProblem, alg::DampedNewtonDescent, J, fu, u; stats,
Expand Down
3 changes: 1 addition & 2 deletions lib/NonlinearSolveBase/src/descent/dogleg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ end
normal_form <: Union{Val{false}, Val{true}}
end

# XXX: Implement
# @internal_caches DoglegCache :newton_cache :cauchy_cache
NonlinearSolveBase.@internal_caches DoglegCache :newton_cache :cauchy_cache

function InternalAPI.init(
prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u;
Expand Down
5 changes: 2 additions & 3 deletions lib/NonlinearSolveBase/src/descent/geodesic_acceleration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,12 @@ get_linear_solver(alg::GeodesicAcceleration) = get_linear_solver(alg.descent)
last_step_accepted::Bool
end

function InternalAPI.reinit!(cache::GeodesicAccelerationCache; p = cache.p, kwargs...)
function InternalAPI.reinit_self!(cache::GeodesicAccelerationCache; p = cache.p, kwargs...)
cache.p = p
cache.last_step_accepted = false
end

# XXX: Implement
# @internal_caches GeodesicAccelerationCache :descent_cache
NonlinearSolveBase.@internal_caches GeodesicAccelerationCache :descent_cache

function get_velocity(cache::GeodesicAccelerationCache)
return SciMLBase.get_du(cache.descent_cache, Val(1))
Expand Down
3 changes: 1 addition & 2 deletions lib/NonlinearSolveBase/src/descent/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ supports_line_search(::NewtonDescent) = true
normal_form <: Union{Val{false}, Val{true}}
end

# XXX: Implement
# @internal_caches NewtonDescentCache :lincache
NonlinearSolveBase.@internal_caches NewtonDescentCache :lincache

function InternalAPI.init(
prob::AbstractNonlinearProblem, alg::NewtonDescent, J, fu, u; stats,
Expand Down
3 changes: 1 addition & 2 deletions lib/NonlinearSolveBase/src/descent/steepest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ supports_line_search(::SteepestDescent) = true
preinverted_jacobian <: Union{Val{false}, Val{true}}
end

# XXX: Implement
# @internal_caches SteepestDescentCache :lincache
NonlinearSolveBase.@internal_caches SteepestDescentCache :lincache

function InternalAPI.init(
prob::AbstractNonlinearProblem, alg::SteepestDescent, J, fu, u;
Expand Down
3 changes: 1 addition & 2 deletions lib/NonlinearSolveBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
)
end

# XXX: Implement this
# update_from_termination_cache!(cache.termination_cache, cache)
update_from_termination_cache!(cache.termination_cache, cache)

update_trace!(
cache.trace, cache.nsteps, get_u(cache), get_fu(cache), nothing, nothing, nothing;
Expand Down
71 changes: 59 additions & 12 deletions lib/NonlinearSolveBase/src/termination_conditions.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
const RelNormModes = Union{
RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode}
RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode
}
const AbsNormModes = Union{
AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode}
AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode
}

# Core Implementation
@concrete mutable struct NonlinearTerminationModeCache{uType, T}
Expand Down Expand Up @@ -32,7 +34,8 @@ end

function CommonSolve.init(
::AbstractNonlinearProblem, mode::AbstractNonlinearTerminationMode, du, u,
saved_value_prototype...; abstol = nothing, reltol = nothing, kwargs...)
saved_value_prototype...; abstol = nothing, reltol = nothing, kwargs...
)
T = promote_type(eltype(du), eltype(u))
abstol = get_tolerance(u, abstol, T)
reltol = get_tolerance(u, reltol, T)
Expand Down Expand Up @@ -77,12 +80,14 @@ function CommonSolve.init(
return NonlinearTerminationModeCache(
u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode,
initial_objective, objectives_trace, 0, saved_value_prototype,
u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache)
u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache
)
end

function SciMLBase.reinit!(
cache::NonlinearTerminationModeCache, du, u, saved_value_prototype...;
abstol = cache.abstol, reltol = cache.reltol, kwargs...)
abstol = cache.abstol, reltol = cache.reltol, kwargs...
)
T = eltype(cache.abstol)
length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)

Expand Down Expand Up @@ -113,7 +118,8 @@ end

## This dispatch is needed based on how Terminating Callback works!
function (cache::NonlinearTerminationModeCache)(
integrator::AbstractODEIntegrator, abstol::Number, reltol::Number, min_t)
integrator::AbstractODEIntegrator, abstol::Number, reltol::Number, min_t
)
if min_t === nothing || integrator.t min_t
return cache(cache.mode, SciMLBase.get_du(integrator),
integrator.u, integrator.uprev, abstol, reltol)
Expand All @@ -125,7 +131,8 @@ function (cache::NonlinearTerminationModeCache)(du, u, uprev, args...)
end

function (cache::NonlinearTerminationModeCache)(
mode::AbstractNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...)
mode::AbstractNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...
)
if check_convergence(mode, du, u, uprev, abstol, reltol)
cache.retcode = ReturnCode.Success
return true
Expand All @@ -134,7 +141,8 @@ function (cache::NonlinearTerminationModeCache)(
end

function (cache::NonlinearTerminationModeCache)(
mode::AbstractSafeNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...)
mode::AbstractSafeNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...
)
if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode
objective = Utils.apply_norm(mode.internalnorm, du)
criteria = abstol
Expand Down Expand Up @@ -251,15 +259,17 @@ end
# High-Level API with defaults.
## This is mostly for internal usage in NonlinearSolve and SimpleNonlinearSolve
function default_termination_mode(
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:simple})
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:simple}
)
return AbsNormTerminationMode(Base.Fix1(maximum, abs))
end
function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:simple})
return AbsNormTerminationMode(Base.Fix2(norm, 2))
end

function default_termination_mode(
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:regular})
::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:regular}
)
return AbsNormSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32)
end

Expand All @@ -268,16 +278,53 @@ function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:regular
end

function init_termination_cache(
prob::AbstractNonlinearProblem, abstol, reltol, du, u, ::Nothing, callee::Val)
prob::AbstractNonlinearProblem, abstol, reltol, du, u, ::Nothing, callee::Val
)
return init_termination_cache(
prob, abstol, reltol, du, u, default_termination_mode(prob, callee), callee)
end

function init_termination_cache(prob::AbstractNonlinearProblem, abstol, reltol, du,
u, tc::AbstractNonlinearTerminationMode, ::Val)
u, tc::AbstractNonlinearTerminationMode, ::Val
)
T = promote_type(eltype(du), eltype(u))
abstol = get_tolerance(u, abstol, T)
reltol = get_tolerance(u, reltol, T)
cache = init(prob, tc, du, u; abstol, reltol)
return abstol, reltol, cache
end

function check_and_update!(cache, fu, u, uprev)
return check_and_update!(
cache.termination_cache, cache, fu, u, uprev, cache.termination_cache.mode
)
end

function check_and_update!(tc_cache, cache, fu, u, uprev, mode)
if tc_cache(fu, u, uprev)
cache.retcode = tc_cache.retcode
update_from_termination_cache!(tc_cache, cache, mode, u)
cache.force_stop = true
end
end

function update_from_termination_cache!(tc_cache, cache, u = get_u(cache))
return update_from_termination_cache!(tc_cache, cache, tc_cache.mode, u)
end

function update_from_termination_cache!(
tc_cache, cache, ::AbstractNonlinearTerminationMode, u = get_u(cache)
)
Utils.evaluate_f!(cache, u, cache.p)
end

function update_from_termination_cache!(
tc_cache, cache, ::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache)
)
if SciMLBase.isinplace(cache)
copyto!(get_u(cache), tc_cache.u)
else
SciMLBase.set_u!(cache, tc_cache.u)
end
Utils.evaluate_f!(cache, get_u(cache), cache.p)
end
15 changes: 12 additions & 3 deletions lib/NonlinearSolveBase/src/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,16 @@ function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu,
norm_type = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
fnorm = prob isa NonlinearLeastSquaresProblem ? L2_NORM(fu) : Linf_NORM(fu)
condJ = J !== missing ? Utils.condition_number(J) : nothing
storage = u === missing ? nothing :
(; u = copy(u), fu = copy(fu), δu = copy(δu), J = copy(J))
storage = if u === missing
nothing
else
(;
u = ArrayInterface.ismutable(u) ? copy(u) : u,
fu = ArrayInterface.ismutable(fu) ? copy(fu) : fu,
δu = ArrayInterface.ismutable(δu) ? copy(δu) : δu,
J = ArrayInterface.ismutable(J) ? copy(J) : J
)
end
return NonlinearSolveTraceEntry(
iteration, fnorm, L2_NORM(δu), condJ, storage, norm_type
)
Expand Down Expand Up @@ -149,7 +157,8 @@ function init_nonlinearsolve_trace(
)
if show_trace isa Val{true}
print("\nAlgorithm: ")
Base.printstyled(alg, "\n\n"; color = :green, bold = true)
str = Utils.clean_sprint_struct(alg, 0)
Base.printstyled(str, "\n\n"; color = :green, bold = true)
end
J = uses_jac_inverse isa Val{true} ?
(trace_level.trace_mode isa Val{:minimal} ? J : LinearAlgebra.pinv(J)) : J
Expand Down
5 changes: 2 additions & 3 deletions lib/NonlinearSolveQuasiNewton/src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ is reinitialized.
internalnorm
end

function InternalAPI.reinit!(cache::InitializedApproximateJacobianCache; kwargs...)
function InternalAPI.reinit_self!(cache::InitializedApproximateJacobianCache; kwargs...)
cache.initialized = false
end

# XXX: Implement
# @internal_caches InitializedApproximateJacobianCache :cache
NonlinearSolveBase.@internal_caches InitializedApproximateJacobianCache :cache

function (cache::InitializedApproximateJacobianCache)(::Nothing)
return NonlinearSolveBase.get_full_jacobian(cache, cache.structure, cache.J)
Expand Down
Loading

0 comments on commit 5ed4fe4

Please sign in to comment.