Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add hooks for OverrideInit #517

Merged
merged 13 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ MaybeInplace = "0.1.4"
Preferences = "1.4"
Printf = "1.10"
RecursiveArrayTools = "3"
SciMLBase = "2.58"
SciMLBase = "2.68.1"
SciMLJacobianOperators = "0.1.1"
SciMLOperators = "0.3.10"
SparseArrays = "1.10"
Expand Down
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ include("descent/damped_newton.jl")
include("descent/dogleg.jl")
include("descent/geodesic_acceleration.jl")

include("initialization.jl")
include("solve.jl")

include("forward_diff.jl")
Expand Down
11 changes: 10 additions & 1 deletion lib/NonlinearSolveBase/src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ Abstract Type for all NonlinearSolveBase Caches.
`u0` and any additional keyword arguments.
- `SciMLBase.isinplace(cache)`: whether or not the solver is inplace.
- `CommonSolve.step!(cache; kwargs...)`: See [`CommonSolve.step!`](@ref) for more details.
- `get_abstol(cache)`: get the `abstol` provided to the cache.
- `get_reltol(cache)`: get the `reltol` provided to the cache.

Additionally implements `SymbolicIndexingInterface` interface Functions.

Expand Down Expand Up @@ -304,9 +306,16 @@ end

SciMLBase.isinplace(cache::AbstractNonlinearSolveCache) = SciMLBase.isinplace(cache.prob)

function get_abstol(cache::AbstractNonlinearSolveCache)
get_abstol(cache.termination_cache)
end
function get_reltol(cache::AbstractNonlinearSolveCache)
get_reltol(cache.termination_cache)
end

## SII Interface
SII.symbolic_container(cache::AbstractNonlinearSolveCache) = cache.prob
SII.parameter_values(cache::AbstractNonlinearSolveCache) = SII.parameter_values(cache.prob)
SII.parameter_values(cache::AbstractNonlinearSolveCache) = cache.p
SII.state_values(cache::AbstractNonlinearSolveCache) = get_u(cache)

function Base.getproperty(cache::AbstractNonlinearSolveCache, sym::Symbol)
Expand Down
7 changes: 7 additions & 0 deletions lib/NonlinearSolveBase/src/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@
values_p
partials_p
end

function NonlinearSolveBase.get_abstol(cache::NonlinearSolveForwardDiffCache)
NonlinearSolveBase.get_abstol(cache.cache)
end
function NonlinearSolveBase.get_reltol(cache::NonlinearSolveForwardDiffCache)
NonlinearSolveBase.get_reltol(cache.cache)
end
60 changes: 60 additions & 0 deletions lib/NonlinearSolveBase/src/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
struct NonlinearSolveDefaultInit <: SciMLBase.DAEInitializationAlgorithm end

function run_initialization!(cache, initializealg = cache.initializealg, prob = cache.prob)
_run_initialization!(cache, initializealg, prob, Val(SciMLBase.isinplace(cache)))
end

function _run_initialization!(
cache, ::NonlinearSolveDefaultInit, prob, isinplace::Union{Val{true}, Val{false}})
if SciMLBase.has_initialization_data(prob.f) &&
prob.f.initialization_data isa SciMLBase.OverrideInitData
return _run_initialization!(cache, SciMLBase.OverrideInit(), prob, isinplace)
end
return cache, true
end

function _run_initialization!(cache, initalg::SciMLBase.OverrideInit, prob,
isinplace::Union{Val{true}, Val{false}})
if cache isa AbstractNonlinearSolveCache && isdefined(cache.alg, :autodiff)
autodiff = cache.alg.autodiff
else
autodiff = ADTypes.AutoForwardDiff()
end
alg = initialization_alg(prob.f.initialization_data.initializeprob, autodiff)
if alg === nothing && cache isa AbstractNonlinearSolveCache
alg = cache.alg
end
u0, p, success = SciMLBase.get_initial_values(
prob, cache, prob.f, initalg, isinplace; nlsolve_alg = alg,
abstol = get_abstol(cache), reltol = get_reltol(cache))
cache = update_initial_values!(cache, u0, p)
if cache isa AbstractNonlinearSolveCache && isdefined(cache, :retcode) && !success
cache.retcode = ReturnCode.InitialFailure
end

return cache, success
end

function get_abstol(prob::AbstractNonlinearProblem)
get_tolerance(get(prob.kwargs, :abstol, nothing), eltype(SII.state_values(prob)))
end
function get_reltol(prob::AbstractNonlinearProblem)
get_tolerance(get(prob.kwargs, :reltol, nothing), eltype(SII.state_values(prob)))
end

initialization_alg(initprob, autodiff) = nothing

function update_initial_values!(cache::AbstractNonlinearSolveCache, u0, p)
InternalAPI.reinit!(cache; u0, p)
cache.prob = SciMLBase.remake(cache.prob; u0, p)
return cache
end

function update_initial_values!(prob::AbstractNonlinearProblem, u0, p)
return SciMLBase.remake(prob; u0, p)
end

function _run_initialization!(
cache::AbstractNonlinearSolveCache, ::SciMLBase.NoInit, prob, isinplace)
return cache, true
end
32 changes: 28 additions & 4 deletions lib/NonlinearSolveBase/src/polyalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ end
u0
u0_aliased
alias_u0::Bool

initializealg
end

function update_initial_values!(cache::NonlinearSolvePolyAlgorithmCache, u0, p)
foreach(cache.caches) do subcache
update_initial_values!(subcache, u0, p)
end
cache.prob = SciMLBase.remake(cache.prob; u0, p)
return cache
end

function NonlinearSolveBase.get_abstol(cache::NonlinearSolvePolyAlgorithmCache)
NonlinearSolveBase.get_abstol(cache.caches[cache.current])
end
function NonlinearSolveBase.get_reltol(cache::NonlinearSolvePolyAlgorithmCache)
NonlinearSolveBase.get_reltol(cache.caches[cache.current])
end

function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache)
Expand All @@ -67,6 +84,9 @@ end
function SII.state_values(cache::NonlinearSolvePolyAlgorithmCache)
SII.state_values(SII.symbolic_container(cache))
end
function SII.parameter_values(cache::NonlinearSolvePolyAlgorithmCache)
SII.parameter_values(SII.symbolic_container(cache))
end

function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache)
println(io, "NonlinearSolvePolyAlgorithmCache with \
Expand Down Expand Up @@ -97,7 +117,8 @@ end
function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...;
stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000,
internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs...
internalnorm = L2_NORM, alias_u0 = false, verbose = true,
initializealg = NonlinearSolveDefaultInit(), kwargs...
)
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
Expand All @@ -109,18 +130,21 @@ function SciMLBase.__init(
u0_aliased = alias_u0 ? copy(u0) : u0
alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased))

return NonlinearSolvePolyAlgorithmCache(
cache = NonlinearSolvePolyAlgorithmCache(
alg.static_length, prob,
map(alg.algs) do solver
SciMLBase.__init(
prob, solver, args...;
stats, maxtime, internalnorm, alias_u0, verbose, kwargs...
stats, maxtime, internalnorm, alias_u0, verbose,
initializealg = SciMLBase.NoInit(), kwargs...
)
end,
alg, -1, alg.start_index, 0, stats, 0.0, maxtime,
ReturnCode.Default, false, maxiters, internalnorm,
u0, u0_aliased, alias_u0
u0, u0_aliased, alias_u0, initializealg
)
run_initialization!(cache)
return cache
end

@generated function InternalAPI.step!(
Expand Down
61 changes: 58 additions & 3 deletions lib/NonlinearSolveBase/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ function SciMLBase.__solve(
end

function CommonSolve.solve!(cache::AbstractNonlinearSolveCache)
if cache.retcode == ReturnCode.InitialFailure
return SciMLBase.build_solution(
cache.prob, cache.alg, get_u(cache), get_fu(cache);
cache.retcode, cache.stats, cache.trace
)
end

while not_terminated(cache)
CommonSolve.step!(cache)
end
Expand Down Expand Up @@ -40,6 +47,17 @@ end
sol_syms = [gensym("sol") for i in 1:N]
u_result_syms = [gensym("u_result") for i in 1:N]

push!(calls,
quote
if cache.retcode == ReturnCode.InitialFailure
u = $(SII.state_values)(cache)
return build_solution_less_specialize(
cache.prob, cache.alg, u, $(Utils.evaluate_f)(cache.prob, u);
retcode = cache.retcode
)
end
end)

for i in 1:N
push!(calls,
quote
Expand Down Expand Up @@ -111,7 +129,8 @@ end

@generated function __generated_polysolve(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...;
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs...
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true,
initializealg = NonlinearSolveDefaultInit(), kwargs...
) where {N}
sol_syms = [gensym("sol") for _ in 1:N]
prob_syms = [gensym("prob") for _ in 1:N]
Expand All @@ -123,9 +142,23 @@ end
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing
end
end]

push!(calls,
quote
prob, success = $(run_initialization!)(prob, initializealg, prob)
if !success
u = $(SII.state_values)(prob)
return build_solution_less_specialize(
prob, alg, u, $(Utils.evaluate_f)(prob, u);
retcode = $(ReturnCode.InitialFailure))
end
end)

push!(calls, quote
u0 = prob.u0
u0_aliased = alias_u0 ? zero(u0) : u0
end]
end)
for i in 1:N
cur_sol = sol_syms[i]
push!(calls,
Expand Down Expand Up @@ -246,8 +279,21 @@ end
alg
args
kwargs::Any
initializealg

retcode::ReturnCode.T
end

function get_abstol(cache::NonlinearSolveNoInitCache)
get(cache.kwargs, :abstol, get_tolerance(nothing, eltype(cache.prob.u0)))
end
function get_reltol(cache::NonlinearSolveNoInitCache)
get(cache.kwargs, :reltol, get_tolerance(nothing, eltype(cache.prob.u0)))
end

SII.parameter_values(cache::NonlinearSolveNoInitCache) = SII.parameter_values(cache.prob)
SII.state_values(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob)

get_u(cache::NonlinearSolveNoInitCache) = SII.state_values(cache.prob)

function SciMLBase.reinit!(
Expand All @@ -264,11 +310,20 @@ end

function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...;
initializealg = NonlinearSolveDefaultInit(),
kwargs...
)
return NonlinearSolveNoInitCache(prob, alg, args, kwargs)
cache = NonlinearSolveNoInitCache(
prob, alg, args, kwargs, initializealg, ReturnCode.Default)
run_initialization!(cache)
return cache
end

function CommonSolve.solve!(cache::NonlinearSolveNoInitCache)
if cache.retcode == ReturnCode.InitialFailure
u = SII.state_values(cache)
return SciMLBase.build_solution(
cache.prob, cache.alg, u, Utils.evaluate_f(cache.prob, u); cache.retcode)
end
return CommonSolve.solve(cache.prob, cache.alg, cache.args...; cache.kwargs...)
end
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ const AbsNormModes = Union{
u_diff_cache::uType
end

get_abstol(cache::NonlinearTerminationModeCache) = cache.abstol
get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol

function update_u!!(cache::NonlinearTerminationModeCache, u)
cache.u === nothing && return
if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u)
Expand Down
12 changes: 9 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ end
retcode::ReturnCode.T
force_stop::Bool
kwargs

initializealg
end

function InternalAPI.reinit_self!(
Expand Down Expand Up @@ -121,7 +123,7 @@ function SciMLBase.__init(
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxiters = 1000,
abstol = nothing, reltol = nothing, maxtime = nothing,
termination_condition = nothing, internalnorm = L2_NORM,
linsolve_kwargs = (;), kwargs...
linsolve_kwargs = (;), initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(), kwargs...
)
@set! alg.autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)
provided_jvp_autodiff = alg.jvp_autodiff !== nothing
Expand Down Expand Up @@ -206,13 +208,17 @@ function SciMLBase.__init(
prob, alg, u, fu, J, du; kwargs...
)

return GeneralizedFirstOrderAlgorithmCache(
cache = GeneralizedFirstOrderAlgorithmCache(
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
jac_cache, descent_cache, linesearch_cache, trustregion_cache,
stats, 0, maxiters, maxtime, alg.max_shrink_times, timer,
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs
0.0, true, termination_cache, trace, ReturnCode.Default, false, kwargs,
initializealg
)
NonlinearSolveBase.run_initialization!(cache)
end

return cache
end

function InternalAPI.step!(
Expand Down
20 changes: 17 additions & 3 deletions lib/NonlinearSolveQuasiNewton/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ end
force_stop::Bool
force_reinit::Bool
kwargs

# Initialization
initializealg
end

function NonlinearSolveBase.get_abstol(cache::QuasiNewtonCache)
NonlinearSolveBase.get_abstol(cache.termination_cache)
end
function NonlinearSolveBase.get_reltol(cache::QuasiNewtonCache)
NonlinearSolveBase.get_reltol(cache.termination_cache)
end

function InternalAPI.reinit_self!(
Expand Down Expand Up @@ -130,7 +140,8 @@ function SciMLBase.__init(
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, maxtime = nothing,
maxiters = 1000, abstol = nothing, reltol = nothing,
linsolve_kwargs = (;), termination_condition = nothing,
internalnorm::F = L2_NORM, kwargs...
internalnorm::F = L2_NORM, initializealg = NonlinearSolveBase.NonlinearSolveDefaultInit(),
kwargs...
) where {F}
timer = get_timer_output()
@static_timeit timer "cache construction" begin
Expand Down Expand Up @@ -204,15 +215,18 @@ function SciMLBase.__init(
uses_jacobian_inverse = inverted_jac, kwargs...
)

return QuasiNewtonCache(
cache = QuasiNewtonCache(
fu, u, u_cache, prob.p, du, J, alg, prob, globalization,
initialization_cache, descent_cache, linesearch_cache,
trustregion_cache, update_rule_cache, reinit_rule_cache,
inv_workspace, stats, 0, 0, alg.max_resets, maxiters, maxtime,
alg.max_shrink_times, 0, timer, 0.0, termination_cache, trace,
ReturnCode.Default, false, false, kwargs
ReturnCode.Default, false, false, kwargs, initializealg
)
NonlinearSolveBase.run_initialization!(cache)
end

return cache
end

function InternalAPI.step!(
Expand Down
Loading
Loading