diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 46f9df83a..3c7d16a1c 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -31,6 +31,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LineSearch = "87fe0de2-c867-4266-b59a-2f0a94fc965b" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" @@ -39,6 +40,7 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" NonlinearSolveBaseBandedMatricesExt = "BandedMatrices" NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase" NonlinearSolveBaseForwardDiffExt = "ForwardDiff" +NonlinearSolveBaseLineSearchExt = "LineSearch" NonlinearSolveBaseLinearSolveExt = "LinearSolve" NonlinearSolveBaseSparseArraysExt = "SparseArrays" NonlinearSolveBaseSparseMatrixColoringsExt = "SparseMatrixColorings" @@ -60,6 +62,7 @@ FastClosures = "0.3" ForwardDiff = "0.10.36" FunctionProperties = "0.1.2" InteractiveUtils = "<0.0.1, 1" +LineSearch = "0.1.4" LinearAlgebra = "1.10" LinearSolve = "2.36.1" Markdown = "1.10" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl new file mode 100644 index 000000000..d68007dc0 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLineSearchExt.jl @@ -0,0 +1,17 @@ +module NonlinearSolveBaseLineSearchExt + +using LineSearch: LineSearch, AbstractLineSearchCache +using NonlinearSolveBase: NonlinearSolveBase, InternalAPI +using SciMLBase: SciMLBase + +function NonlinearSolveBase.callback_into_cache!( + topcache, cache::AbstractLineSearchCache, args... +) + return LineSearch.callback_into_cache!(cache, NonlinearSolveBase.get_fu(topcache)) +end + +function InternalAPI.reinit!(cache::AbstractLineSearchCache; kwargs...) + return SciMLBase.reinit!(cache; kwargs...) +end + +end diff --git a/lib/NonlinearSolveBase/src/abstract_types.jl b/lib/NonlinearSolveBase/src/abstract_types.jl index bc94dce99..a01aa7afb 100644 --- a/lib/NonlinearSolveBase/src/abstract_types.jl +++ b/lib/NonlinearSolveBase/src/abstract_types.jl @@ -1,10 +1,31 @@ module InternalAPI +using SciMLBase: NLStats + function init end function solve! end -function reinit! end function step! end +function reinit! end +function reinit_self! end + +function reinit!(x::Any; kwargs...) + @debug "`InternalAPI.reinit!` is not implemented for $(typeof(x))." + return +end +function reinit_self!(x::Any; kwargs...) + @debug "`InternalAPI.reinit_self!` is not implemented for $(typeof(x))." + return +end + +function reinit_self!(stats::NLStats) + stats.nf = 0 + stats.nsteps = 0 + stats.nfactors = 0 + stats.njacs = 0 + stats.nsolve = 0 +end + end abstract type AbstractNonlinearSolveBaseAPI end # Mostly used for pretty-printing @@ -512,3 +533,53 @@ accepted then these values should be copied into the toplevel cache. abstract type AbstractTrustRegionMethodCache <: AbstractNonlinearSolveBaseAPI end last_step_accepted(cache::AbstractTrustRegionMethodCache) = cache.last_step_accepted + +# Additional Interface +""" + callback_into_cache!(cache, internalcache, args...) + +Define custom operations on `internalcache` tightly coupled with the calling `cache`. +`args...` contain the sequence of caches calling into `internalcache`. + +This unfortunately makes code very tightly coupled and not modular. It is recommended to not +use this functionality unless it can't be avoided (like in [`LevenbergMarquardt`](@ref)). +""" +callback_into_cache!(cache, internalcache, args...) = nothing # By default do nothing + +# Helper functions to generate cache callbacks and resetting functions +macro internal_caches(cType, internal_cache_names...) + callback_caches = map(internal_cache_names) do name + return quote + $(callback_into_cache!)( + cache, getproperty(internalcache, $(name)), internalcache, args... + ) + end + end + callbacks_self = map(internal_cache_names) do name + return quote + $(callback_into_cache!)(cache, getproperty(cache, $(name))) + end + end + reinit_caches = map(internal_cache_names) do name + return quote + $(InternalAPI.reinit!)(getproperty(cache, $(name)), args...; kwargs...) + end + end + return esc(quote + function NonlinearSolveBase.callback_into_cache!( + cache, internalcache::$(cType), args... + ) + $(callback_caches...) + return + end + function NonlinearSolveBase.callback_into_cache!(cache::$(cType)) + $(callbacks_self...) + return + end + function NonlinearSolveBase.InternalAPI.reinit!(cache::$(cType), args...; kwargs...) + $(reinit_caches...) + $(InternalAPI.reinit_self!)(cache, args...; kwargs...) + return + end + end) +end diff --git a/lib/NonlinearSolveBase/src/descent/damped_newton.jl b/lib/NonlinearSolveBase/src/descent/damped_newton.jl index 3ab507065..a9921ba47 100644 --- a/lib/NonlinearSolveBase/src/descent/damped_newton.jl +++ b/lib/NonlinearSolveBase/src/descent/damped_newton.jl @@ -42,8 +42,7 @@ supports_trust_region(::DampedNewtonDescent) = true mode <: Union{Val{:normal_form}, Val{:least_squares}, Val{:simple}} end -# XXX: Implement -# @internal_caches DampedNewtonDescentCache :lincache :damping_fn_cache +NonlinearSolveBase.@internal_caches DampedNewtonDescentCache :lincache :damping_fn_cache function InternalAPI.init( prob::AbstractNonlinearProblem, alg::DampedNewtonDescent, J, fu, u; stats, diff --git a/lib/NonlinearSolveBase/src/descent/dogleg.jl b/lib/NonlinearSolveBase/src/descent/dogleg.jl index c138adbde..133b5fabb 100644 --- a/lib/NonlinearSolveBase/src/descent/dogleg.jl +++ b/lib/NonlinearSolveBase/src/descent/dogleg.jl @@ -44,8 +44,7 @@ end normal_form <: Union{Val{false}, Val{true}} end -# XXX: Implement -# @internal_caches DoglegCache :newton_cache :cauchy_cache +NonlinearSolveBase.@internal_caches DoglegCache :newton_cache :cauchy_cache function InternalAPI.init( prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u; diff --git a/lib/NonlinearSolveBase/src/descent/geodesic_acceleration.jl b/lib/NonlinearSolveBase/src/descent/geodesic_acceleration.jl index c465cc5cb..f5b686433 100644 --- a/lib/NonlinearSolveBase/src/descent/geodesic_acceleration.jl +++ b/lib/NonlinearSolveBase/src/descent/geodesic_acceleration.jl @@ -49,13 +49,12 @@ get_linear_solver(alg::GeodesicAcceleration) = get_linear_solver(alg.descent) last_step_accepted::Bool end -function InternalAPI.reinit!(cache::GeodesicAccelerationCache; p = cache.p, kwargs...) +function InternalAPI.reinit_self!(cache::GeodesicAccelerationCache; p = cache.p, kwargs...) cache.p = p cache.last_step_accepted = false end -# XXX: Implement -# @internal_caches GeodesicAccelerationCache :descent_cache +NonlinearSolveBase.@internal_caches GeodesicAccelerationCache :descent_cache function get_velocity(cache::GeodesicAccelerationCache) return SciMLBase.get_du(cache.descent_cache, Val(1)) diff --git a/lib/NonlinearSolveBase/src/descent/newton.jl b/lib/NonlinearSolveBase/src/descent/newton.jl index 1a7acf177..e453597a1 100644 --- a/lib/NonlinearSolveBase/src/descent/newton.jl +++ b/lib/NonlinearSolveBase/src/descent/newton.jl @@ -24,8 +24,7 @@ supports_line_search(::NewtonDescent) = true normal_form <: Union{Val{false}, Val{true}} end -# XXX: Implement -# @internal_caches NewtonDescentCache :lincache +NonlinearSolveBase.@internal_caches NewtonDescentCache :lincache function InternalAPI.init( prob::AbstractNonlinearProblem, alg::NewtonDescent, J, fu, u; stats, diff --git a/lib/NonlinearSolveBase/src/descent/steepest.jl b/lib/NonlinearSolveBase/src/descent/steepest.jl index b29045bd5..b93c727c5 100644 --- a/lib/NonlinearSolveBase/src/descent/steepest.jl +++ b/lib/NonlinearSolveBase/src/descent/steepest.jl @@ -21,8 +21,7 @@ supports_line_search(::SteepestDescent) = true preinverted_jacobian <: Union{Val{false}, Val{true}} end -# XXX: Implement -# @internal_caches SteepestDescentCache :lincache +NonlinearSolveBase.@internal_caches SteepestDescentCache :lincache function InternalAPI.init( prob::AbstractNonlinearProblem, alg::SteepestDescent, J, fu, u; diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 999f064e9..7316fd168 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -18,8 +18,7 @@ function CommonSolve.solve!(cache::AbstractNonlinearSolveCache) ) end - # XXX: Implement this - # update_from_termination_cache!(cache.termination_cache, cache) + update_from_termination_cache!(cache.termination_cache, cache) update_trace!( cache.trace, cache.nsteps, get_u(cache), get_fu(cache), nothing, nothing, nothing; diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index c2d768c17..cca9134d1 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -1,7 +1,9 @@ const RelNormModes = Union{ - RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode} + RelNormTerminationMode, RelNormSafeTerminationMode, RelNormSafeBestTerminationMode +} const AbsNormModes = Union{ - AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode} + AbsNormTerminationMode, AbsNormSafeTerminationMode, AbsNormSafeBestTerminationMode +} # Core Implementation @concrete mutable struct NonlinearTerminationModeCache{uType, T} @@ -32,7 +34,8 @@ end function CommonSolve.init( ::AbstractNonlinearProblem, mode::AbstractNonlinearTerminationMode, du, u, - saved_value_prototype...; abstol = nothing, reltol = nothing, kwargs...) + saved_value_prototype...; abstol = nothing, reltol = nothing, kwargs... +) T = promote_type(eltype(du), eltype(u)) abstol = get_tolerance(u, abstol, T) reltol = get_tolerance(u, reltol, T) @@ -77,12 +80,14 @@ function CommonSolve.init( return NonlinearTerminationModeCache( u_unaliased, ReturnCode.Default, abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0, saved_value_prototype, - u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache) + u0_norm, step_norm_trace, max_stalled_steps, u_diff_cache + ) end function SciMLBase.reinit!( cache::NonlinearTerminationModeCache, du, u, saved_value_prototype...; - abstol = cache.abstol, reltol = cache.reltol, kwargs...) + abstol = cache.abstol, reltol = cache.reltol, kwargs... +) T = eltype(cache.abstol) length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype) @@ -113,7 +118,8 @@ end ## This dispatch is needed based on how Terminating Callback works! function (cache::NonlinearTerminationModeCache)( - integrator::AbstractODEIntegrator, abstol::Number, reltol::Number, min_t) + integrator::AbstractODEIntegrator, abstol::Number, reltol::Number, min_t +) if min_t === nothing || integrator.t ≥ min_t return cache(cache.mode, SciMLBase.get_du(integrator), integrator.u, integrator.uprev, abstol, reltol) @@ -125,7 +131,8 @@ function (cache::NonlinearTerminationModeCache)(du, u, uprev, args...) end function (cache::NonlinearTerminationModeCache)( - mode::AbstractNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...) + mode::AbstractNonlinearTerminationMode, du, u, uprev, abstol, reltol, args... +) if check_convergence(mode, du, u, uprev, abstol, reltol) cache.retcode = ReturnCode.Success return true @@ -134,7 +141,8 @@ function (cache::NonlinearTerminationModeCache)( end function (cache::NonlinearTerminationModeCache)( - mode::AbstractSafeNonlinearTerminationMode, du, u, uprev, abstol, reltol, args...) + mode::AbstractSafeNonlinearTerminationMode, du, u, uprev, abstol, reltol, args... +) if mode isa AbsNormSafeTerminationMode || mode isa AbsNormSafeBestTerminationMode objective = Utils.apply_norm(mode.internalnorm, du) criteria = abstol @@ -251,7 +259,8 @@ end # High-Level API with defaults. ## This is mostly for internal usage in NonlinearSolve and SimpleNonlinearSolve function default_termination_mode( - ::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:simple}) + ::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:simple} +) return AbsNormTerminationMode(Base.Fix1(maximum, abs)) end function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:simple}) @@ -259,7 +268,8 @@ function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:simple} end function default_termination_mode( - ::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:regular}) + ::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:regular} +) return AbsNormSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32) end @@ -268,16 +278,53 @@ function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:regular end function init_termination_cache( - prob::AbstractNonlinearProblem, abstol, reltol, du, u, ::Nothing, callee::Val) + prob::AbstractNonlinearProblem, abstol, reltol, du, u, ::Nothing, callee::Val +) return init_termination_cache( prob, abstol, reltol, du, u, default_termination_mode(prob, callee), callee) end function init_termination_cache(prob::AbstractNonlinearProblem, abstol, reltol, du, - u, tc::AbstractNonlinearTerminationMode, ::Val) + u, tc::AbstractNonlinearTerminationMode, ::Val +) T = promote_type(eltype(du), eltype(u)) abstol = get_tolerance(u, abstol, T) reltol = get_tolerance(u, reltol, T) cache = init(prob, tc, du, u; abstol, reltol) return abstol, reltol, cache end + +function check_and_update!(cache, fu, u, uprev) + return check_and_update!( + cache.termination_cache, cache, fu, u, uprev, cache.termination_cache.mode + ) +end + +function check_and_update!(tc_cache, cache, fu, u, uprev, mode) + if tc_cache(fu, u, uprev) + cache.retcode = tc_cache.retcode + update_from_termination_cache!(tc_cache, cache, mode, u) + cache.force_stop = true + end +end + +function update_from_termination_cache!(tc_cache, cache, u = get_u(cache)) + return update_from_termination_cache!(tc_cache, cache, tc_cache.mode, u) +end + +function update_from_termination_cache!( + tc_cache, cache, ::AbstractNonlinearTerminationMode, u = get_u(cache) +) + Utils.evaluate_f!(cache, u, cache.p) +end + +function update_from_termination_cache!( + tc_cache, cache, ::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache) +) + if SciMLBase.isinplace(cache) + copyto!(get_u(cache), tc_cache.u) + else + SciMLBase.set_u!(cache, tc_cache.u) + end + Utils.evaluate_f!(cache, get_u(cache), cache.p) +end diff --git a/lib/NonlinearSolveBase/src/tracing.jl b/lib/NonlinearSolveBase/src/tracing.jl index 2dd88ecee..e818e05e8 100644 --- a/lib/NonlinearSolveBase/src/tracing.jl +++ b/lib/NonlinearSolveBase/src/tracing.jl @@ -103,8 +103,16 @@ function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, norm_type = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf) fnorm = prob isa NonlinearLeastSquaresProblem ? L2_NORM(fu) : Linf_NORM(fu) condJ = J !== missing ? Utils.condition_number(J) : nothing - storage = u === missing ? nothing : - (; u = copy(u), fu = copy(fu), δu = copy(δu), J = copy(J)) + storage = if u === missing + nothing + else + (; + u = ArrayInterface.ismutable(u) ? copy(u) : u, + fu = ArrayInterface.ismutable(fu) ? copy(fu) : fu, + δu = ArrayInterface.ismutable(δu) ? copy(δu) : δu, + J = ArrayInterface.ismutable(J) ? copy(J) : J + ) + end return NonlinearSolveTraceEntry( iteration, fnorm, L2_NORM(δu), condJ, storage, norm_type ) @@ -149,7 +157,8 @@ function init_nonlinearsolve_trace( ) if show_trace isa Val{true} print("\nAlgorithm: ") - Base.printstyled(alg, "\n\n"; color = :green, bold = true) + str = Utils.clean_sprint_struct(alg, 0) + Base.printstyled(str, "\n\n"; color = :green, bold = true) end J = uses_jac_inverse isa Val{true} ? (trace_level.trace_mode isa Val{:minimal} ? J : LinearAlgebra.pinv(J)) : J diff --git a/lib/NonlinearSolveQuasiNewton/src/initialization.jl b/lib/NonlinearSolveQuasiNewton/src/initialization.jl index cd5fbaa2e..23d9a19cd 100644 --- a/lib/NonlinearSolveQuasiNewton/src/initialization.jl +++ b/lib/NonlinearSolveQuasiNewton/src/initialization.jl @@ -37,12 +37,11 @@ is reinitialized. internalnorm end -function InternalAPI.reinit!(cache::InitializedApproximateJacobianCache; kwargs...) +function InternalAPI.reinit_self!(cache::InitializedApproximateJacobianCache; kwargs...) cache.initialized = false end -# XXX: Implement -# @internal_caches InitializedApproximateJacobianCache :cache +NonlinearSolveBase.@internal_caches InitializedApproximateJacobianCache :cache function (cache::InitializedApproximateJacobianCache)(::Nothing) return NonlinearSolveBase.get_full_jacobian(cache, cache.structure, cache.J) diff --git a/lib/NonlinearSolveQuasiNewton/src/solve.jl b/lib/NonlinearSolveQuasiNewton/src/solve.jl index a37f098df..98ef6d9ee 100644 --- a/lib/NonlinearSolveQuasiNewton/src/solve.jl +++ b/lib/NonlinearSolveQuasiNewton/src/solve.jl @@ -122,7 +122,9 @@ end # reset_timer!(cache.timer) # end -# @internal_caches ApproximateJacobianSolveCache :initialization_cache :descent_cache :linesearch_cache :trustregion_cache :update_rule_cache :reinit_rule_cache +NonlinearSolveBase.@internal_caches(ApproximateJacobianSolveCache, + :initialization_cache, :descent_cache, :linesearch_cache, :trustregion_cache, + :update_rule_cache, :reinit_rule_cache) function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; @@ -361,8 +363,8 @@ function InternalAPI.step!( error("Unknown Globalization Strategy: $(cache.globalization). Allowed values \ are (:LineSearch, :TrustRegion, :None)") end - # XXX: Implement - # check_and_update!(cache, cache.fu, cache.u, cache.u_cache) + + NonlinearSolveBase.check_and_update!(cache, cache.fu, cache.u, cache.u_cache) else α = false cache.force_reinit = true @@ -376,8 +378,7 @@ function InternalAPI.step!( if (cache.force_stop || cache.force_reinit || (recompute_jacobian !== nothing && !recompute_jacobian)) - # XXX: Implement - # callback_into_cache!(cache) + NonlinearSolveBase.callback_into_cache!(cache) return nothing end @@ -385,8 +386,7 @@ function InternalAPI.step!( cache.J = InternalAPI.solve!( cache.update_rule_cache, cache.J, cache.fu, cache.u, δu ) - # XXX: Implement - # callback_into_cache!(cache) + NonlinearSolveBase.callback_into_cache!(cache) end return nothing diff --git a/lib/NonlinearSolveSpectralMethods/src/solve.jl b/lib/NonlinearSolveSpectralMethods/src/solve.jl index 9c68c81b4..297b9bf69 100644 --- a/lib/NonlinearSolveSpectralMethods/src/solve.jl +++ b/lib/NonlinearSolveSpectralMethods/src/solve.jl @@ -107,7 +107,7 @@ end # cache.retcode = ReturnCode.Default # end -# @internal_caches GeneralizedDFSaneCache :linesearch_cache +NonlinearSolveBase.@internal_caches GeneralizedDFSaneCache :linesearch_cache function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...; @@ -186,8 +186,8 @@ function InternalAPI.step!( end update_trace!(cache, α) - # XXX: Implement - # check_and_update!(cache, cache.fu, cache.u, cache.u_cache) + + NonlinearSolveBase.check_and_update!(cache, cache.fu, cache.u, cache.u_cache) # Update Spectral Parameter @static_timeit cache.timer "update spectral parameter" begin @@ -209,8 +209,7 @@ function InternalAPI.step!( @bb copyto!(cache.u_cache, cache.u) @bb copyto!(cache.fu_cache, cache.fu) - # XXX: Implement - # callback_into_cache!(cache, cache.linesearch_cache) + NonlinearSolveBase.callback_into_cache!(cache, cache.linesearch_cache) return end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 477f12815..5e1f39b34 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -76,9 +76,6 @@ include("abstract_types.jl") include("timer_outputs.jl") include("internal/helpers.jl") -include("internal/termination.jl") - -include("globalization/line_search.jl") include("globalization/trust_region.jl") include("core/generic.jl") diff --git a/src/abstract_types.jl b/src/abstract_types.jl deleted file mode 100644 index 65bda73c3..000000000 --- a/src/abstract_types.jl +++ /dev/null @@ -1,134 +0,0 @@ -const __internal_init = InternalAPI.init -const __internal_solve! = InternalAPI.solve! - -""" - AbstractNonlinearSolveAlgorithm{name} <: AbstractNonlinearAlgorithm - -Abstract Type for all NonlinearSolve.jl Algorithms. `name` can be used to define custom -dispatches by wrapped solvers. - -### Interface Functions - - - `concrete_jac(alg)`: whether or not the algorithm uses a concrete Jacobian. Defaults - to `nothing`. - - `get_name(alg)`: get the name of the algorithm. -""" -abstract type AbstractNonlinearSolveAlgorithm{name} <: AbstractNonlinearAlgorithm end - -function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm) - __show_algorithm(io, alg, get_name(alg), 0) -end - -get_name(::AbstractNonlinearSolveAlgorithm{name}) where {name} = name - -""" - AbstractNonlinearSolveExtensionAlgorithm <: AbstractNonlinearSolveAlgorithm{:Extension} - -Abstract Type for all NonlinearSolve.jl Extension Algorithms, i.e. wrappers over 3rd party -solvers. -""" -abstract type AbstractNonlinearSolveExtensionAlgorithm <: - AbstractNonlinearSolveAlgorithm{:Extension} end - -""" - AbstractNonlinearSolveCache{iip, timeit} - -Abstract Type for all NonlinearSolve.jl Caches. - -### Interface Functions - - - `get_fu(cache)`: get the residual. - - `get_u(cache)`: get the current state. - - `set_fu!(cache, fu)`: set the residual. - - `set_u!(cache, u)`: set the current state. - - `reinit!(cache, u0; kwargs...)`: reinitialize the cache with the initial state `u0` and - any additional keyword arguments. - - `step!(cache; kwargs...)`: See [`SciMLBase.step!`](@ref) for more details. - - `not_terminated(cache)`: whether or not the solver has terminated. - - `isinplace(cache)`: whether or not the solver is inplace. -""" -abstract type AbstractNonlinearSolveCache{iip, timeit} end - -function SymbolicIndexingInterface.symbolic_container(cache::AbstractNonlinearSolveCache) - return cache.prob -end -function SymbolicIndexingInterface.parameter_values(cache::AbstractNonlinearSolveCache) - return parameter_values(symbolic_container(cache)) -end -function SymbolicIndexingInterface.state_values(cache::AbstractNonlinearSolveCache) - return state_values(symbolic_container(cache)) -end - -function Base.getproperty(cache::AbstractNonlinearSolveCache, sym::Symbol) - sym == :ps && return ParameterIndexingProxy(cache) - return getfield(cache, sym) -end - -function Base.getindex(cache::AbstractNonlinearSolveCache, sym) - return getu(cache, sym)(cache) -end - -function Base.setindex!(cache::AbstractNonlinearSolveCache, val, sym) - return setu(cache, sym)(cache, val) -end - -function Base.show(io::IO, cache::AbstractNonlinearSolveCache) - __show_cache(io, cache, 0) -end - -function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0) - println(io, "$(nameof(typeof(cache)))(") - __show_algorithm(io, cache.alg, - (" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4) - - ustr = sprint(show, get_u(cache); context = (:compact => true, :limit => true)) - println(io, ",\n" * (" "^(indent + 4)) * "u = $(ustr),") - - residstr = sprint(show, get_fu(cache); context = (:compact => true, :limit => true)) - println(io, (" "^(indent + 4)) * "residual = $(residstr),") - - normstr = sprint( - show, norm(get_fu(cache), Inf); context = (:compact => true, :limit => true)) - println(io, (" "^(indent + 4)) * "inf-norm(residual) = $(normstr),") - - println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",") - println(io, " "^(indent + 4) * "retcode = ", cache.retcode) - print(io, " "^(indent) * ")") -end - -SciMLBase.isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip - -get_fu(cache::AbstractNonlinearSolveCache) = cache.fu -get_u(cache::AbstractNonlinearSolveCache) = cache.u -set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu) -SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u) - -function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache; kwargs...) - return reinit_cache!(cache; kwargs...) -end -function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache, u0; kwargs...) - return reinit_cache!(cache; u0, kwargs...) -end - -""" - AbstractTrustRegionMethodCache - -Abstract Type for all Trust Region Method Caches used in NonlinearSolve.jl. - -### Interface Functions - - - `last_step_accepted(cache)`: whether or not the last step was accepted. Defaults to - `cache.last_step_accepted`. Should if overloaded if the field is not present. - -### `__internal_solve!` specification - -```julia -__internal_solve!(cache::AbstractTrustRegionMethodCache, J, fu, u, δu, descent_stats) -``` - -Returns `last_step_accepted`, updated `u_cache` and `fu_cache`. If the last step was -accepted then these values should be copied into the toplevel cache. -""" -abstract type AbstractTrustRegionMethodCache end - -last_step_accepted(cache::AbstractTrustRegionMethodCache) = cache.last_step_accepted diff --git a/src/globalization/line_search.jl b/src/globalization/line_search.jl deleted file mode 100644 index 7549f1f9d..000000000 --- a/src/globalization/line_search.jl +++ /dev/null @@ -1,7 +0,0 @@ -function callback_into_cache!(topcache, cache::AbstractLineSearchCache, args...) - LineSearch.callback_into_cache!(cache, get_fu(topcache)) -end - -function reinit_cache!(cache::AbstractLineSearchCache, args...; kwargs...) - return SciMLBase.reinit!(cache, args...; kwargs...) -end diff --git a/src/internal/helpers.jl b/src/internal/helpers.jl index 735122b8d..a7326bb51 100644 --- a/src/internal/helpers.jl +++ b/src/internal/helpers.jl @@ -1,45 +1,3 @@ -# Evaluate the residual function at a given point -function evaluate_f(prob::AbstractNonlinearProblem{uType, iip}, u) where {uType, iip} - (; f, p) = prob - if iip - fu = f.resid_prototype === nothing ? zero(u) : similar(f.resid_prototype) - f(fu, u, p) - else - fu = f(u, p) - end - return fu -end - -function evaluate_f!(cache, u, p) - cache.stats.nf += 1 - if isinplace(cache) - cache.prob.f(get_fu(cache), u, p) - else - set_fu!(cache, cache.prob.f(u, p)) - end -end - -evaluate_f!!(prob::AbstractNonlinearProblem, fu, u, p) = evaluate_f!!(prob.f, fu, u, p) -function evaluate_f!!(f::NonlinearFunction{iip}, fu, u, p) where {iip} - if iip - f(fu, u, p) - return fu - end - return f(u, p) -end - -# Callbacks -""" - callback_into_cache!(cache, internalcache, args...) - -Define custom operations on `internalcache` tightly coupled with the calling `cache`. -`args...` contain the sequence of caches calling into `internalcache`. - -This unfortunately makes code very tightly coupled and not modular. It is recommended to not -use this functionality unless it can't be avoided (like in [`LevenbergMarquardt`](@ref)). -""" -@inline callback_into_cache!(cache, internalcache, args...) = nothing # By default do nothing - # Extension Algorithm Helpers function __test_termination_condition(termination_condition, alg) !(termination_condition isa AbsNormTerminationMode) && @@ -125,42 +83,3 @@ function __construct_extension_jac(prob, alg, u0, fu; can_handle_oop::Val = Fals return 𝐉, Jₚ(nothing) end - -function reinit_cache! end -reinit_cache!(cache::Nothing, args...; kwargs...) = nothing -reinit_cache!(cache, args...; kwargs...) = nothing - -function __reinit_internal! end -__reinit_internal!(::Nothing, args...; kwargs...) = nothing -__reinit_internal!(cache, args...; kwargs...) = nothing - -# Auto-generate some of the helper functions -macro internal_caches(cType, internal_cache_names...) - return __internal_caches(cType, internal_cache_names) -end - -function __internal_caches(cType, internal_cache_names::Tuple) - callback_caches = map( - name -> :($(callback_into_cache!)( - cache, getproperty(internalcache, $(name)), internalcache, args...)), - internal_cache_names) - callbacks_self = map( - name -> :($(callback_into_cache!)( - internalcache, getproperty(internalcache, $(name)))), - internal_cache_names) - reinit_caches = map( - name -> :($(reinit_cache!)(getproperty(cache, $(name)), args...; kwargs...)), - internal_cache_names) - return esc(quote - function callback_into_cache!(cache, internalcache::$(cType), args...) - $(callback_caches...) - end - function callback_into_cache!(internalcache::$(cType)) - $(callbacks_self...) - end - function reinit_cache!(cache::$(cType), args...; kwargs...) - $(reinit_caches...) - $(__reinit_internal!)(cache, args...; kwargs...) - end - end) -end diff --git a/src/internal/termination.jl b/src/internal/termination.jl deleted file mode 100644 index 7728aea69..000000000 --- a/src/internal/termination.jl +++ /dev/null @@ -1,34 +0,0 @@ -function check_and_update!(cache, fu, u, uprev) - return check_and_update!(cache.termination_cache, cache, fu, u, uprev) -end - -function check_and_update!(tc_cache, cache, fu, u, uprev) - return check_and_update!(tc_cache, cache, fu, u, uprev, tc_cache.mode) -end - -function check_and_update!(tc_cache, cache, fu, u, uprev, mode) - if tc_cache(fu, u, uprev) - cache.retcode = tc_cache.retcode - update_from_termination_cache!(tc_cache, cache, mode, u) - cache.force_stop = true - end -end - -function update_from_termination_cache!(tc_cache, cache, u = get_u(cache)) - return update_from_termination_cache!(tc_cache, cache, tc_cache.mode, u) -end - -function update_from_termination_cache!( - tc_cache, cache, ::AbstractNonlinearTerminationMode, u = get_u(cache)) - evaluate_f!(cache, u, cache.p) -end - -function update_from_termination_cache!( - tc_cache, cache, ::AbstractSafeBestNonlinearTerminationMode, u = get_u(cache)) - if isinplace(cache) - copyto!(get_u(cache), tc_cache.u) - else - set_u!(cache, tc_cache.u) - end - evaluate_f!(cache, get_u(cache), cache.p) -end diff --git a/src/utils.jl b/src/utils.jl index 1582234bc..cb346da0a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -103,18 +103,3 @@ function __build_solution_less_specialize(prob::AbstractNonlinearProblem, alg, u Any, typeof(left), typeof(stats), typeof(trace)}( u, resid, prob, alg, retcode, original, left, right, stats, trace) end - -@inline empty_nlstats() = NLStats(0, 0, 0, 0, 0) -function __reinit_internal!(stats::NLStats) - stats.nf = 0 - stats.nsteps = 0 - stats.nfactors = 0 - stats.njacs = 0 - stats.nsolve = 0 -end - -function __similar(x, args...; kwargs...) - y = similar(x, args...; kwargs...) - fill!(y, false) - return y -end