Skip to content

Commit

Permalink
Improve termination conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 30, 2023
1 parent a83cf9c commit 1a0df4e
Show file tree
Hide file tree
Showing 17 changed files with 273 additions and 411 deletions.
6 changes: 3 additions & 3 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.DiffEqBase]]
deps = ["ArrayInterface", "ChainRulesCore", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces", "ZygoteRules"]
git-tree-sha1 = "e5049e32074cd22f86d74036caf6663637623003"
git-tree-sha1 = "4e661d0beddac31da05e71b79afd769232622de8"
repo-rev = "ap/tstable_termination"
repo-url = "https://github.com/SciML/DiffEqBase.jl"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down Expand Up @@ -689,9 +689,9 @@ version = "0.1.0"

[[deps.SLEEFPirates]]
deps = ["IfElse", "Static", "VectorizationBase"]
git-tree-sha1 = "897b39ec056c0619ea87adc7eeadba0bec0cf931"
git-tree-sha1 = "f5c896d781486f1d67c8492f0e0ead2c3517208c"
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
version = "0.6.40"
version = "0.6.41"

[[deps.SciMLBase]]
deps = ["ADTypes", "ArrayInterface", "ChainRulesCore", "CommonSolve", "ConstructionBase", "Distributed", "DocStringExtensions", "EnumX", "FillArrays", "FunctionWrappersWrappers", "IteratorInterfaceExtensions", "LinearAlgebra", "Logging", "Markdown", "PrecompileTools", "Preferences", "RecipesBase", "RecursiveArrayTools", "Reexport", "RuntimeGeneratedFunctions", "SciMLOperators", "StaticArraysCore", "Statistics", "SymbolicIndexingInterface", "Tables", "TruncatedStacktraces", "ZygoteRules"]
Expand Down
14 changes: 12 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ PrecompileTools.@recompile_invalidations begin
end

@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
import DiffEqBase: AbstractNonlinearTerminationMode
import DiffEqBase: AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, get_termination_mode

const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}
Expand All @@ -53,6 +55,8 @@ function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
end
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)
get_u(cache::AbstractNonlinearSolveCache) = cache.u

Check warning on line 59 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L58-L59

Added lines #L58 - L59 were not covered by tests

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
while not_terminated(cache)
Expand All @@ -69,7 +73,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
end
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache);
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);

Check warning on line 76 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L76

Added line #L76 was not covered by tests
cache.retcode, cache.stats)
end

Expand Down Expand Up @@ -113,4 +117,10 @@ export RobustMultiNewton, FastShortcutNonlinearPolyalg

export LineSearch, LiFukushimaLineSearch

# Export the termination conditions from DiffEqBase
export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode,
NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode,
AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode,
RelSafeBestTerminationMode, AbsSafeBestTerminationMode

end # module
38 changes: 13 additions & 25 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ end
prob
stats::NLStats
ls_cache
termination_condition
tc_storage
tc_cache
end

get_fu(cache::GeneralBroydenCache) = cache.fu
set_fu!(cache::GeneralBroydenCache, fu) = (cache.fu = fu)

Check warning on line 60 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L60

Added line #L60 was not covered by tests

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
Expand All @@ -71,34 +71,26 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance

abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
termination_condition, eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,

Check warning on line 74 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L74

Added line #L74 was not covered by tests
termination_condition)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
storage)
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache)
end

function perform_step!(cache::GeneralBroydenCache{true})
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache

Check warning on line 85 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L85

Added line #L85 was not covered by tests
T = eltype(u)

mul!(_vec(du), J⁻¹, _vec(fu))
α = perform_linesearch!(cache.ls_cache, u, du)
_axpy!(-α, du, u)

Check warning on line 90 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L88-L90

Added lines #L88 - L90 were not covered by tests
f(fu2, u, p)

termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
check_and_update!(cache, fu2, u, u_prev)

Check warning on line 93 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L93

Added line #L93 was not covered by tests
cache.stats.nf += 1

cache.force_stop && return nothing
Expand Down Expand Up @@ -130,9 +122,7 @@ function perform_step!(cache::GeneralBroydenCache{true})
end

function perform_step!(cache::GeneralBroydenCache{false})
@unpack f, p, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
@unpack f, p = cache

Check warning on line 125 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L125

Added line #L125 was not covered by tests

T = eltype(cache.u)

Expand All @@ -141,8 +131,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.u = cache.u .- α * cache.du

Check warning on line 131 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L129-L131

Added lines #L129 - L131 were not covered by tests
cache.fu2 = f(cache.u, p)

termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)

Check warning on line 134 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L134

Added line #L134 was not covered by tests
cache.stats.nf += 1

cache.force_stop && return nothing
Expand Down Expand Up @@ -172,9 +161,8 @@ function perform_step!(cache::GeneralBroydenCache{false})
end

function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
Expand All @@ -185,12 +173,12 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
cache.fu = cache.f(cache.u, p)
end

termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,

Check warning on line 176 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L176

Added line #L176 was not covered by tests
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.tc_cache = tc_cache

Check warning on line 181 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L181

Added line #L181 was not covered by tests
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
60 changes: 18 additions & 42 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ function DFSane(; σ_min = 1e-10, σ_max = 1e+10, σ_1 = 1.0, M = 10, γ = 1e-4,
n_exp, η_strategy, max_inner_iterations)
end

@concrete mutable struct DFSaneCache{iip}
# FIXME: Someone please make this code conform to the style of the remaining solvers
@concrete mutable struct DFSaneCache{iip} <: AbstractNonlinearSolveCache{iip}
alg
uₙ
uₙ₋₁
Expand Down Expand Up @@ -91,10 +92,13 @@ end
reltol
prob
stats::NLStats
termination_condition
tc_storage
tc_cache
end

get_fu(cache::DFSaneCache) = cache.fuₙ
set_fu!(cache::DFSaneCache, fu) = (cache.fuₙ = fu)
get_u(cache::DFSaneCache) = cache.uₙ

Check warning on line 100 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L98-L100

Added lines #L98 - L100 were not covered by tests

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
Expand Down Expand Up @@ -124,24 +128,18 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.

= fill(f₍ₙₒᵣₘ₎ₙ₋₁, M)

abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol,
termination_condition, T)

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fuₙ₋₁, uₙ₋₁,

Check warning on line 131 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L131

Added line #L131 was not covered by tests
termination_condition)

return DFSaneCache{iip}(alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀,
M, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, p, false, maxiters,
internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
termination_condition, storage)
tc_cache)
end

function perform_step!(cache::DFSaneCache{true})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache

Check warning on line 141 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L141

Added line #L141 was not covered by tests

termination_condition = cache.termination_condition(tc_storage)
f = (dx, x) -> cache.prob.f(dx, x, cache.p)

T = eltype(cache.uₙ)
Expand Down Expand Up @@ -184,9 +182,7 @@ function perform_step!(cache::DFSaneCache{true})
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
cache.force_stop = true
end
check_and_update!(cache, cache.fuₙ, cache.uₙ, cache.uₙ₋₁)

Check warning on line 185 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L185

Added line #L185 was not covered by tests

# Update spectral parameter
@. cache.uₙ₋₁ = cache.uₙ - cache.uₙ₋₁
Expand Down Expand Up @@ -215,9 +211,8 @@ function perform_step!(cache::DFSaneCache{true})
end

function perform_step!(cache::DFSaneCache{false})
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M, tc_storage = cache
@unpack alg, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, σₙ, σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ, τₘₐₓ, nₑₓₚ, M = cache

Check warning on line 214 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L214

Added line #L214 was not covered by tests

termination_condition = cache.termination_condition(tc_storage)
f = x -> cache.prob.f(x, cache.p)

T = eltype(cache.uₙ)
Expand Down Expand Up @@ -260,9 +255,7 @@ function perform_step!(cache::DFSaneCache{false})
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

if termination_condition(cache.fuₙ, cache.uₙ, cache.uₙ₋₁, cache.abstol, cache.reltol)
cache.force_stop = true
end
check_and_update!(cache, cache.fuₙ, cache.uₙ, cache.uₙ₋₁)

Check warning on line 258 in src/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/dfsane.jl#L258

Added line #L258 was not covered by tests

# Update spectral parameter
cache.uₙ₋₁ = @. cache.uₙ - cache.uₙ₋₁
Expand Down Expand Up @@ -290,26 +283,9 @@ function perform_step!(cache::DFSaneCache{false})
return nothing
end

function SciMLBase.solve!(cache::DFSaneCache)
while !cache.force_stop && cache.stats.nsteps < cache.maxiters
cache.stats.nsteps += 1
perform_step!(cache)
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.uₙ, cache.fuₙ;
retcode = cache.retcode, stats = cache.stats)
end

function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.uₙ, u0)
Expand All @@ -330,12 +306,12 @@ function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p
T = eltype(cache.uₙ)
cache.σₙ = T(cache.alg.σ_1)

termination_condition = _get_reinit_termination_condition(cache, abstol, reltol,
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fuₙ, cache.uₙ,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
Loading

0 comments on commit 1a0df4e

Please sign in to comment.