Skip to content

Commit

Permalink
Fast General Klement Implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 21, 2023
1 parent 1e4cfde commit b821eaf
Show file tree
Hide file tree
Showing 7 changed files with 302 additions and 12 deletions.
2 changes: 2 additions & 0 deletions docs/src/solvers/NonlinearSystemSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ features, but have a bit of overhead on very small problems.
robustnes on the hard problems.
- `GeneralBroyden()`: Generalization of Broyden's Quasi-Newton Method with Line Search and
Automatic Jacobian Resetting. This is a fast method but unstable for most problems!
- `GeneralKlement()`: Generalization of Klement's Quasi-Newton Method with Line Search and
Automatic Jacobian Resetting. This is a fast method but unstable for most problems!

### SimpleNonlinearSolve.jl

Expand Down
13 changes: 8 additions & 5 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
cache.stats.nsteps += 1
end

if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
# The solver might have set a different `retcode`
if cache.retcode == ReturnCode.Default
if cache.stats.nsteps == cache.maxiters
cache.retcode = ReturnCode.MaxIters
else
cache.retcode = ReturnCode.Success
end
end

return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, get_fu(cache);
Expand Down Expand Up @@ -85,7 +88,7 @@ import PrecompileTools
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

precompile_algs = (NewtonRaphson(), TrustRegion(), LevenbergMarquardt(),
nothing)
PseudoTransient(), GeneralBroyden(), nothing)

for alg in precompile_algs
solve(prob, alg, abstol = T(1e-2))
Expand Down
5 changes: 1 addition & 4 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,7 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.dfu = cache.fu2 .- cache.fu
if cache.resets < cache.max_resets &&
(all(x -> abs(x) 1e-12, cache.du) || all(x -> abs(x) 1e-12, cache.dfu))
J⁻¹ = similar(cache.J⁻¹)
fill!(J⁻¹, 0)
J⁻¹[diagind(J⁻¹)] .= T(1)
cache.J⁻¹ = J⁻¹
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
cache.resets += 1

Check warning on line 117 in src/broyden.jl

View check run for this annotation

Codecov / codecov/patch

src/broyden.jl#L116-L117

Added lines #L116 - L117 were not covered by tests
else
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
Expand Down
4 changes: 1 addition & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ end
]
else
[
# FIXME: Broyden and Klement are type unstable
# (upstream SimpleNonlinearSolve.jl issue)
!iip ? :(Klement()) : nothing, # Klement not yet implemented for IIP
:(GeneralBroyden()),
:(GeneralKlement()),
:(NewtonRaphson(; linsolve, precs, adkwargs...)),
:(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)),
:(TrustRegion(; linsolve, precs, adkwargs...)),
Expand Down
190 changes: 190 additions & 0 deletions src/klement.jl
Original file line number Diff line number Diff line change
@@ -1 +1,191 @@
@concrete struct GeneralKlement <: AbstractNewtonAlgorithm{false, Nothing}
max_resets::Int
linsolve
precs
linesearch
singular_tolerance
end

function GeneralKlement(; max_resets::Int = 5, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, singular_tolerance = nothing)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GeneralKlement(max_resets, linsolve, precs, linesearch, singular_tolerance)
end

@concrete mutable struct GeneralKlementCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
fu
fu2
du
p
linsolve
J
J_cache
J_cache2
Jᵀ²du
Jdu
resets
singular_tolerance
force_stop
maxiters::Int
internalnorm
retcode::ReturnCode.T
abstol
prob
stats::NLStats
lscache
end

get_fu(cache::GeneralKlementCache) = cache.fu

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralKlement, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
fu = evaluate_f(prob, u)
J = __init_identity_jacobian(u, fu)

if u isa Number
linsolve = nothing
else
weight = similar(u)
recursivefill!(weight, true)
Pl, Pr = wrapprecs(alg.precs(J, nothing, u, p, nothing, nothing, nothing, nothing,
nothing)..., weight)
linprob = LinearProblem(J, _vec(fu); u0 = _vec(fu))
linsolve = init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr,
linsolve_kwargs...)
end

singular_tolerance = alg.singular_tolerance === nothing ? inv(sqrt(eps(eltype(u)))) :
eltype(u)(alg.singular_tolerance)

return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), _mutable_zero(u), p, linsolve,
J, zero(J), zero(J), zero(fu), zero(fu), 0, singular_tolerance, false,
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
end

function perform_step!(cache::GeneralKlementCache{true})
@unpack u, fu, f, p, alg, J, linsolve, du = cache
T = eltype(J)

# FIXME: How can we do this faster?
if cond(J) > cache.singular_tolerance
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
return nothing
end
fill!(J, zero(T))
J[diagind(J)] .= T(1)
cache.resets += 1
end

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu), linu = _vec(du),
p, reltol = cache.abstol)
cache.linsolve = linres.cache

# Line Search
α = perform_linesearch!(cache.lscache, u, du)
axpy!(α, du, u)
f(cache.fu2, u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1

cache.force_stop && return nothing

# Update the Jacobian
cache.J_cache .= cache.J' .^ 2
cache.Jdu .= _vec(du) .^ 2
mul!(cache.Jᵀ²du, cache.J_cache, cache.Jdu)
mul!(cache.Jdu, J, _vec(du))
cache.fu .= cache.fu2 .- cache.fu
cache.fu .= (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
mul!(cache.J_cache, _vec(cache.fu), _vec(du)')
cache.J_cache .*= J
mul!(cache.J_cache2, cache.J_cache, J)
J .+= cache.J_cache2

cache.fu .= cache.fu2

return nothing
end

function perform_step!(cache::GeneralKlementCache{false})
@unpack fu, f, p, alg, J, linsolve = cache
T = eltype(J)

# FIXME: How can we do this faster?
if cond(J) > cache.singular_tolerance
if cache.resets == alg.max_resets
cache.force_stop = true
cache.retcode = ReturnCode.Unstable
return nothing

Check warning on line 131 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L128-L131

Added lines #L128 - L131 were not covered by tests
end
cache.J = __init_identity_jacobian(u, fu)
cache.resets += 1

Check warning on line 134 in src/klement.jl

View check run for this annotation

Codecov / codecov/patch

src/klement.jl#L133-L134

Added lines #L133 - L134 were not covered by tests
end

# u = u - J \ fu
if linsolve === nothing
cache.du = -fu / cache.J
else
linres = dolinsolve(alg.precs, linsolve; A = J, b = -_vec(fu),
linu = _vec(cache.du), p, reltol = cache.abstol)
cache.linsolve = linres.cache
end

# Line Search
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
cache.u = @. cache.u + α * cache.du # `u` might not support mutation
cache.fu2 = f(cache.u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1

cache.force_stop && return nothing

# Update the Jacobian
cache.J_cache = cache.J' .^ 2
cache.Jdu = _vec(cache.du) .^ 2
cache.Jᵀ²du = cache.J_cache * cache.Jdu
cache.Jdu = J * _vec(cache.du)
cache.fu = cache.fu2 .- cache.fu
cache.fu = (cache.fu .- _restructure(cache.fu, cache.Jdu)) ./ max.(cache.Jᵀ²du, eps(T))
cache.J_cache = ((_vec(cache.fu) * _vec(cache.du)') .* J) * J
cache.J = J .+ cache.J_cache

cache.fu = cache.fu2

return nothing
end

function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end
11 changes: 11 additions & 0 deletions test/23_test_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,14 @@ end

test_on_library(problems, dicts, alg_ops, broken_tests)
end

@testset "GeneralKlement 23 Test Problems" begin
alg_ops = (GeneralKlement(),
GeneralKlement(; linesearch = BackTracking()))

broken_tests = Dict(alg => Int[] for alg in alg_ops)
broken_tests[alg_ops[1]] = [1, 2, 3, 4, 5, 6, 7, 13, 22]
broken_tests[alg_ops[2]] = [1, 2, 4, 5, 6, 7, 11, 12, 22]

test_on_library(problems, dicts, alg_ops, broken_tests)
end
89 changes: 89 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -754,3 +754,92 @@ end
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
end

# --- GeneralKlement tests ---

@testset "GeneralKlement" begin
function benchmark_nlsolve_oop(f, u0, p = 2.0; linesearch = LineSearch())
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, GeneralKlement(; linesearch), abstol = 1e-9)
end

function benchmark_nlsolve_iip(f, u0, p = 2.0; linesearch = LineSearch())
prob = NonlinearProblem{true}(f, u0, p)
return solve(prob, GeneralKlement(; linesearch), abstol = 1e-9)
end

@testset "LineSearch: $(_nameof(lsmethod)) LineSearch AD: $(_nameof(ad))" for lsmethod in (Static(),
StrongWolfe(), BackTracking(), HagerZhang(), MoreThuente()),
ad in (AutoFiniteDiff(), AutoZygote())

linesearch = LineSearch(; method = lsmethod, autodiff = ad)
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)

@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
sol = benchmark_nlsolve_oop(quadratic_f, u0; linesearch)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0),
GeneralKlement(; linesearch), abstol = 1e-9)
@test (@ballocated solve!($cache)) < 200
end

@testset "[IIP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0],)
ad isa AutoZygote && continue
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linesearch)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
GeneralKlement(; linesearch), abstol = 1e-9)
@test (@ballocated solve!($cache)) 64
end
end

@testset "[OOP] [Immutable AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
end

@testset "[OOP] [Scalar AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p) 1 / (2 * sqrt(p))
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ForwardDiff.jacobian(t, p)

# Iterator interface
function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
cache = init(probN, GeneralKlement(); maxiters = 100, abstol = 1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p)
sol = solve!(cache)
sols[i] = iip ? sol.u[1] : sol.u
end
return sols
end
p = range(0.01, 2, length = 200)
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)
end

0 comments on commit b821eaf

Please sign in to comment.