diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 2f5404dc7..cc60c308e 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -68,7 +68,7 @@ MaybeInplace = "0.1.4" Preferences = "1.4" Printf = "1.10" RecursiveArrayTools = "3" -SciMLBase = "2.58" +SciMLBase = "2.68.1" SciMLJacobianOperators = "0.1.1" SciMLOperators = "0.3.10" SparseArrays = "1.10" diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index f45ba9242..9087a5c98 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -55,6 +55,7 @@ include("descent/damped_newton.jl") include("descent/dogleg.jl") include("descent/geodesic_acceleration.jl") +include("initialization.jl") include("solve.jl") include("forward_diff.jl") diff --git a/lib/NonlinearSolveBase/src/abstract_types.jl b/lib/NonlinearSolveBase/src/abstract_types.jl index 6829e19c3..f43d59012 100644 --- a/lib/NonlinearSolveBase/src/abstract_types.jl +++ b/lib/NonlinearSolveBase/src/abstract_types.jl @@ -259,6 +259,8 @@ Abstract Type for all NonlinearSolveBase Caches. `u0` and any additional keyword arguments. - `SciMLBase.isinplace(cache)`: whether or not the solver is inplace. - `CommonSolve.step!(cache; kwargs...)`: See [`CommonSolve.step!`](@ref) for more details. + - `get_abstol(cache)`: get the `abstol` provided to the cache. + - `get_reltol(cache)`: get the `reltol` provided to the cache. Additionally implements `SymbolicIndexingInterface` interface Functions. @@ -304,9 +306,16 @@ end SciMLBase.isinplace(cache::AbstractNonlinearSolveCache) = SciMLBase.isinplace(cache.prob) +function get_abstol(cache::AbstractNonlinearSolveCache) + get_abstol(cache.termination_cache) +end +function get_reltol(cache::AbstractNonlinearSolveCache) + get_reltol(cache.termination_cache) +end + ## SII Interface SII.symbolic_container(cache::AbstractNonlinearSolveCache) = cache.prob -SII.parameter_values(cache::AbstractNonlinearSolveCache) = SII.parameter_values(cache.prob) +SII.parameter_values(cache::AbstractNonlinearSolveCache) = cache.p SII.state_values(cache::AbstractNonlinearSolveCache) = get_u(cache) function Base.getproperty(cache::AbstractNonlinearSolveCache, sym::Symbol) diff --git a/lib/NonlinearSolveBase/src/forward_diff.jl b/lib/NonlinearSolveBase/src/forward_diff.jl index a588aa52d..e780bf554 100644 --- a/lib/NonlinearSolveBase/src/forward_diff.jl +++ b/lib/NonlinearSolveBase/src/forward_diff.jl @@ -6,3 +6,10 @@ values_p partials_p end + +function NonlinearSolveBase.get_abstol(cache::NonlinearSolveForwardDiffCache) + NonlinearSolveBase.get_abstol(cache.cache) +end +function NonlinearSolveBase.get_reltol(cache::NonlinearSolveForwardDiffCache) + NonlinearSolveBase.get_reltol(cache.cache) +end diff --git a/lib/NonlinearSolveBase/src/initialization.jl b/lib/NonlinearSolveBase/src/initialization.jl new file mode 100644 index 000000000..e3612e2cb --- /dev/null +++ b/lib/NonlinearSolveBase/src/initialization.jl @@ -0,0 +1,60 @@ +struct NonlinearSolveDefaultInit <: SciMLBase.DAEInitializationAlgorithm end + +function run_initialization!(cache, initializealg = cache.initializealg, prob = cache.prob) + _run_initialization!(cache, initializealg, prob, Val(SciMLBase.isinplace(cache))) +end + +function _run_initialization!( + cache, ::NonlinearSolveDefaultInit, prob, isinplace::Union{Val{true}, Val{false}}) + if SciMLBase.has_initialization_data(prob.f) && + prob.f.initialization_data isa SciMLBase.OverrideInitData + return _run_initialization!(cache, SciMLBase.OverrideInit(), prob, isinplace) + end + return cache, true +end + +function _run_initialization!(cache, initalg::SciMLBase.OverrideInit, prob, + isinplace::Union{Val{true}, Val{false}}) + if cache isa AbstractNonlinearSolveCache && isdefined(cache.alg, :autodiff) + autodiff = cache.alg.autodiff + else + autodiff = ADTypes.AutoForwardDiff() + end + alg = initialization_alg(prob.f.initialization_data.initializeprob, autodiff) + if alg === nothing && cache isa AbstractNonlinearSolveCache + alg = cache.alg + end + u0, p, success = SciMLBase.get_initial_values( + prob, cache, prob.f, initalg, isinplace; nlsolve_alg = alg, + abstol = get_abstol(cache), reltol = get_reltol(cache)) + cache = update_initial_values!(cache, u0, p) + if cache isa AbstractNonlinearSolveCache && isdefined(cache, :retcode) && !success + cache.retcode = ReturnCode.InitialFailure + end + + return cache, success +end + +function get_abstol(prob::AbstractNonlinearProblem) + get_tolerance(get(prob.kwargs, :abstol, nothing), eltype(SII.state_values(prob))) +end +function get_reltol(prob::AbstractNonlinearProblem) + get_tolerance(get(prob.kwargs, :reltol, nothing), eltype(SII.state_values(prob))) +end + +initialization_alg(initprob, autodiff) = nothing + +function update_initial_values!(cache::AbstractNonlinearSolveCache, u0, p) + InternalAPI.reinit!(cache; u0, p) + cache.prob = SciMLBase.remake(cache.prob; u0, p) + return cache +end + +function update_initial_values!(prob::AbstractNonlinearProblem, u0, p) + return SciMLBase.remake(prob; u0, p) +end + +function _run_initialization!( + cache::AbstractNonlinearSolveCache, ::SciMLBase.NoInit, prob, isinplace) + return cache, true +end diff --git a/lib/NonlinearSolveBase/src/polyalg.jl b/lib/NonlinearSolveBase/src/polyalg.jl index c2101af0e..935019e97 100644 --- a/lib/NonlinearSolveBase/src/polyalg.jl +++ b/lib/NonlinearSolveBase/src/polyalg.jl @@ -59,6 +59,23 @@ end u0 u0_aliased alias_u0::Bool + + initializealg +end + +function update_initial_values!(cache::NonlinearSolvePolyAlgorithmCache, u0, p) + foreach(cache.caches) do subcache + update_initial_values!(subcache, u0, p) + end + cache.prob = SciMLBase.remake(cache.prob; u0, p) + return cache +end + +function NonlinearSolveBase.get_abstol(cache::NonlinearSolvePolyAlgorithmCache) + NonlinearSolveBase.get_abstol(cache.caches[cache.current]) +end +function NonlinearSolveBase.get_reltol(cache::NonlinearSolvePolyAlgorithmCache) + NonlinearSolveBase.get_reltol(cache.caches[cache.current]) end function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache) @@ -67,6 +84,9 @@ end function SII.state_values(cache::NonlinearSolvePolyAlgorithmCache) SII.state_values(SII.symbolic_container(cache)) end +function SII.parameter_values(cache::NonlinearSolvePolyAlgorithmCache) + SII.parameter_values(SII.symbolic_container(cache)) +end function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache) println(io, "NonlinearSolvePolyAlgorithmCache with \ @@ -97,7 +117,8 @@ end function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...; stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000, - internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs... + internalnorm = L2_NORM, alias_u0 = false, verbose = true, + initializealg = NonlinearSolveDefaultInit(), kwargs... ) if alias_u0 && !ArrayInterface.ismutable(prob.u0) verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \ @@ -109,18 +130,21 @@ function SciMLBase.__init( u0_aliased = alias_u0 ? copy(u0) : u0 alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased)) - return NonlinearSolvePolyAlgorithmCache( + cache = NonlinearSolvePolyAlgorithmCache( alg.static_length, prob, map(alg.algs) do solver SciMLBase.__init( prob, solver, args...; - stats, maxtime, internalnorm, alias_u0, verbose, kwargs... + stats, maxtime, internalnorm, alias_u0, verbose, + initializealg = SciMLBase.NoInit(), kwargs... ) end, alg, -1, alg.start_index, 0, stats, 0.0, maxtime, ReturnCode.Default, false, maxiters, internalnorm, - u0, u0_aliased, alias_u0 + u0, u0_aliased, alias_u0, initializealg ) + run_initialization!(cache) + return cache end @generated function InternalAPI.step!( diff --git a/lib/NonlinearSolveBase/src/solve.jl b/lib/NonlinearSolveBase/src/solve.jl index 66ad6a0e4..41dc1d5fb 100644 --- a/lib/NonlinearSolveBase/src/solve.jl +++ b/lib/NonlinearSolveBase/src/solve.jl @@ -7,6 +7,13 @@ function SciMLBase.__solve( end function CommonSolve.solve!(cache::AbstractNonlinearSolveCache) + if cache.retcode == ReturnCode.InitialFailure + return SciMLBase.build_solution( + cache.prob, cache.alg, get_u(cache), get_fu(cache); + cache.retcode, cache.stats, cache.trace + ) + end + while not_terminated(cache) CommonSolve.step!(cache) end @@ -40,6 +47,17 @@ end sol_syms = [gensym("sol") for i in 1:N] u_result_syms = [gensym("u_result") for i in 1:N] + push!(calls, + quote + if cache.retcode == ReturnCode.InitialFailure + u = $(SII.state_values)(cache) + return build_solution_less_specialize( + cache.prob, cache.alg, u, $(Utils.evaluate_f)(cache.prob, u); + retcode = cache.retcode + ) + end + end) + for i in 1:N push!(calls, quote @@ -111,7 +129,8 @@ end @generated function __generated_polysolve( prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...; - stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs... + stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, + initializealg = NonlinearSolveDefaultInit(), kwargs... ) where {N} sol_syms = [gensym("sol") for _ in 1:N] prob_syms = [gensym("prob") for _ in 1:N] @@ -123,9 +142,23 @@ end immutable (checked using `ArrayInterface.ismutable`)." alias_u0 = false # If immutable don't care about aliasing end + end] + + push!(calls, + quote + prob, success = $(run_initialization!)(prob, initializealg, prob) + if !success + u = $(SII.state_values)(prob) + return build_solution_less_specialize( + prob, alg, u, $(Utils.evaluate_f)(prob, u); + retcode = $(ReturnCode.InitialFailure)) + end + end) + + push!(calls, quote u0 = prob.u0 u0_aliased = alias_u0 ? zero(u0) : u0 - end] + end) for i in 1:N cur_sol = sol_syms[i] push!(calls, @@ -246,8 +279,21 @@ end alg args kwargs::Any + initializealg + + retcode::ReturnCode.T end +function get_abstol(cache::NonlinearSolveNoInitCache) + get(cache.kwargs, :abstol, get_tolerance(nothing, eltype(cache.prob.u0))) +end +function get_reltol(cache::NonlinearSolveNoInitCache) + get(cache.kwargs, :reltol, get_tolerance(nothing, eltype(cache.prob.u0))) +end + +SII.parameter_values(cache::NonlinearSolveNoInitCache) = SII.parameter_values(cache.prob) +SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) + get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob) function SciMLBase.reinit!( @@ -264,11 +310,20 @@ end function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; + initializealg = NonlinearSolveDefaultInit(), kwargs... ) - return NonlinearSolveNoInitCache(prob, alg, args, kwargs) + cache = NonlinearSolveNoInitCache( + prob, alg, args, kwargs, initializealg, ReturnCode.Default) + run_initialization!(cache) + return cache end function CommonSolve.solve!(cache::NonlinearSolveNoInitCache) + if cache.retcode == ReturnCode.InitialFailure + u = SII.state_values(cache) + return SciMLBase.build_solution( + cache.prob, cache.alg, u, Utils.evaluate_f(cache.prob, u); cache.retcode) + end return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...) end diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index cca9134d1..e6ab4a579 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -23,6 +23,9 @@ const AbsNormModes = Union{ u_diff_cache::uType end +get_abstol(cache::NonlinearTerminationModeCache) = cache.abstol +get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol + function update_u!!(cache::NonlinearTerminationModeCache, u) cache.u === nothing && return if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u) diff --git a/lib/NonlinearSolveFirstOrder/src/solve.jl b/lib/NonlinearSolveFirstOrder/src/solve.jl index c9c8c77a8..ec0d54da6 100644 --- a/lib/NonlinearSolveFirstOrder/src/solve.jl +++ b/lib/NonlinearSolveFirstOrder/src/solve.jl @@ -87,6 +87,8 @@ end retcode::ReturnCode.T force_stop::Bool kwargs + + initializealg end function InternalAPI.reinit_self!( @@ -121,7 +123,7 @@ function SciMLBase.__init( stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, maxtime = nothing, termination_condition = nothing, internalnorm = L2_NORM, - linsolve_kwargs = (;), kwargs... + linsolve_kwargs = (;), initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs... ) @set! alg.autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) provided_jvp_autodiff = alg.jvp_autodiff !== nothing @@ -206,13 +208,17 @@ function SciMLBase.__init( prob, alg, u, fu, J, du; kwargs... ) - return GeneralizedFirstOrderAlgorithmCache( + cache = GeneralizedFirstOrderAlgorithmCache( fu, u, u_cache, prob.p, du, J, alg, prob, globalization, jac_cache, descent_cache, linesearch_cache, trustregion_cache, stats, 0, maxiters, maxtime, alg.max_shrink_times, timer, - 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs + 0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs, + initializealg ) + NonlinearSolveBase.run_initialization!(cache) end + + return cache end function InternalAPI.step!( diff --git a/lib/NonlinearSolveQuasiNewton/src/solve.jl b/lib/NonlinearSolveQuasiNewton/src/solve.jl index c52a425ae..53289c117 100644 --- a/lib/NonlinearSolveQuasiNewton/src/solve.jl +++ b/lib/NonlinearSolveQuasiNewton/src/solve.jl @@ -93,6 +93,16 @@ end force_stop::Bool force_reinit::Bool kwargs + + # Initialization + initializealg +end + +function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache) + NonlinearSolveBase.get_abstol(cache.termination_cache) +end +function NonlinearSolveBase.get_reltol(cache::QuasiNewtonCache) + NonlinearSolveBase.get_reltol(cache.termination_cache) end function InternalAPI.reinit_self!( @@ -130,7 +140,8 @@ function SciMLBase.__init( stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxtime = nothing, maxiters = 1000, abstol = nothing, reltol = nothing, linsolve_kwargs = (;), termination_condition = nothing, - internalnorm::F = L2_NORM, kwargs... + internalnorm::F = L2_NORM, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), + kwargs... ) where {F} timer = get_timer_output() @static_timeit timer "cache construction" begin @@ -204,15 +215,18 @@ function SciMLBase.__init( uses_jacobian_inverse = inverted_jac, kwargs... ) - return QuasiNewtonCache( + cache = QuasiNewtonCache( fu, u, u_cache, prob.p, du, J, alg, prob, globalization, initialization_cache, descent_cache, linesearch_cache, trustregion_cache, update_rule_cache, reinit_rule_cache, inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime, alg.max_shrink_times, 0, timer, 0.0, termination_cache, trace, - ReturnCode.Default, false, false, kwargs + ReturnCode.Default, false, false, kwargs, initializealg ) + NonlinearSolveBase.run_initialization!(cache) end + + return cache end function InternalAPI.step!( diff --git a/lib/NonlinearSolveSpectralMethods/src/solve.jl b/lib/NonlinearSolveSpectralMethods/src/solve.jl index b3a7d216e..fd71527c0 100644 --- a/lib/NonlinearSolveSpectralMethods/src/solve.jl +++ b/lib/NonlinearSolveSpectralMethods/src/solve.jl @@ -68,6 +68,8 @@ end retcode::ReturnCode.T force_stop::Bool kwargs + + initializealg end function InternalAPI.reinit_self!( @@ -75,6 +77,7 @@ function InternalAPI.reinit_self!( alias_u0::Bool = false, maxiters = 1000, maxtime = nothing, kwargs... ) Utils.reinit_common!(cache, u0, p, alias_u0) + T = eltype(u0) if cache.alg.σ_1 === nothing σ_n = Utils.safe_dot(cache.u, cache.u) / Utils.safe_dot(cache.u, cache.fu) @@ -112,7 +115,7 @@ function SciMLBase.__init( prob::AbstractNonlinearProblem, alg::GeneralizedDFSane, args...; stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, termination_condition = nothing, - maxtime = nothing, kwargs... + maxtime = nothing, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs... ) timer = get_timer_output() @@ -145,13 +148,16 @@ function SciMLBase.__init( σ_n = T(alg.σ_1) end - return GeneralizedDFSaneCache( + cache = GeneralizedDFSaneCache( fu, fu_cache, u, u_cache, prob.p, du, alg, prob, σ_n, T(alg.σ_min), T(alg.σ_max), linesearch_cache, stats, 0, maxiters, maxtime, timer, 0.0, - tc_cache, trace, ReturnCode.Default, false, kwargs + tc_cache, trace, ReturnCode.Default, false, kwargs, initializealg ) + NonlinearSolveBase.run_initialization!(cache) end + + return cache end function InternalAPI.step!( diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 3d8258f7e..2a87eb54e 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -59,6 +59,12 @@ function CommonSolve.solve( prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs... ) + cache = SciMLBase.__init(prob, alg, args...; kwargs...) + prob = cache.prob + if cache.retcode == ReturnCode.InitialFailure + return SciMLBase.build_solution(prob, alg, prob.u0, + NonlinearSolveBase.Utils.evaluate_f(prob, prob.u0); cache.retcode) + end prob = convert(ImmutableNonlinearProblem, prob) return solve(prob, alg, args...; kwargs...) end @@ -97,6 +103,12 @@ function CommonSolve.solve( alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs... ) + cache = SciMLBase.__init(prob, alg, args...; kwargs...) + prob = cache.prob + if cache.retcode == ReturnCode.InitialFailure + return SciMLBase.build_solution(prob, alg, prob.u0, + NonlinearSolveBase.Utils.evaluate_f(prob, prob.u0); cache.retcode) + end if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end diff --git a/src/default.jl b/src/default.jl index 6021a98b1..734305ef6 100644 --- a/src/default.jl +++ b/src/default.jl @@ -50,3 +50,10 @@ function SciMLBase.__solve( prob, FastShortcutNLLSPolyalg(eltype(prob.u0)), args...; kwargs... ) end + +function NonlinearSolveBase.initialization_alg(::AbstractNonlinearProblem, autodiff) + FastShortcutNonlinearPolyalg(; autodiff) +end +function NonlinearSolveBase.initialization_alg(::NonlinearLeastSquaresProblem, autodiff) + FastShortcutNLLSPolyalg(; autodiff) +end