Skip to content

Commit

Permalink
Fix DFSane
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 1, 2023
1 parent 0e3efd7 commit 031639f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 183 deletions.
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ include("raphson.jl")
# include("trustRegion.jl")
# include("levenberg.jl")
include("gaussnewton.jl")
# include("dfsane.jl")
include("dfsane.jl")
include("pseudotransient.jl")
include("broyden.jl")
include("klement.jl")
Expand Down
229 changes: 62 additions & 167 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
DFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
n_exp::Int = 2, η_strategy::Function = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
max_inner_iterations::Int = 1000)
DFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0, M::Int = 10,
γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5, n_exp::Int = 2,
η_strategy::Function = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
max_inner_iterations::Int = 100)
A low-overhead and allocation-free implementation of the df-sane method for solving large-scale nonlinear
systems of equations. For in depth information about all the parameters and the algorithm,
Expand Down Expand Up @@ -39,34 +39,27 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_
`f_n` the current residual. Should satisfy ``η > 0`` and ``∑ₖ ηₖ < ∞``. Defaults to
``fn_1 / n^2``.
- `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
algorithm. Defaults to `1000`.
algorithm. Defaults to `100`.
"""
@concrete struct DFSane <: AbstractNonlinearSolveAlgorithm
σ_min
σ_max
σ_1
M::Int
γ
τ_min
τ_max
n_exp::Int
η_strategy
max_inner_iterations::Int
end

function DFSane(; σ_min = 1e-10, σ_max = 1e+10, σ_1 = 1.0, M = 10, γ = 1e-4, τ_min = 0.1,
τ_max = 0.5, n_exp = 2, η_strategy::F = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
max_inner_iterations = 1000) where {F}
return DFSane(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, n_exp, η_strategy,
max_inner_iterations)
@kwdef @concrete struct DFSane <: AbstractNonlinearSolveAlgorithm
σ_min = 1e-10
σ_max = 1e10
σ_1 = 1.0
M::Int = 10
γ = 1e-4
τ_min = 0.1
τ_max = 0.5
n_exp::Int = 2
η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2
max_inner_iterations::Int = 100
end

@concrete mutable struct DFSaneCache{iip} <: AbstractNonlinearSolveCache{iip}
alg
u
uprev
u_cache
fu
fuprev
fu_cache
du
history
f_norm
Expand All @@ -93,36 +86,35 @@ end
trace
end

get_fu(cache::DFSaneCache) = cache.fu
set_fu!(cache::DFSaneCache, fu) = (cache.fu = fu)

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
kwargs...) where {uType, iip, F}
u = alias_u0 ? prob.u0 : deepcopy(prob.u0)
u = __maybe_unaliased(prob.u0, alias_u0)
T = eltype(u)

du, uprev = copy(u), copy(u)
@bb du = similar(u)
@bb u_cache = copy(u)

fu = evaluate_f(prob, u)
fuprev = copy(fu)
@bb fu_cache = copy(fu)

f_norm = internalnorm(fu)^alg.n_exp
f_norm_0 = f_norm

history = fill(f_norm, alg.M)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, uprev,
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u_cache,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)

return DFSaneCache{iip}(alg, u, uprev, fu, fuprev, du, history, f_norm, f_norm_0, alg.M,
T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), T(alg.τ_min),
return DFSaneCache{iip}(alg, u, u_cache, fu, fu_cache, du, history, f_norm, f_norm_0,
alg.M, T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), T(alg.τ_min),
T(alg.τ_max), alg.n_exp, prob.p, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
end

function perform_step!(cache::DFSaneCache{true})
function perform_step!(cache::DFSaneCache{iip}) where {iip}
@unpack alg, f_norm, σ_n, σ_min, σ_max, α_1, γ, τ_min, τ_max, n_exp, M, prob = cache
T = eltype(cache.u)
f_norm_old = f_norm
Expand All @@ -131,128 +123,64 @@ function perform_step!(cache::DFSaneCache{true})
σ_n = sign(σ_n) * clamp(abs(σ_n), σ_min, σ_max)

# Line search direction
@. cache.du = -σ_n * cache.fuprev
@bb @. cache.du = -σ_n * cache.fu

η = alg.η_strategy(cache.f_norm_0, cache.stats.nsteps, cache.u, cache.fu)

f_bar = maximum(cache.history)
α₊ = α_1
α₋ = α_1
_axpy!(α₊, cache.du, cache.u)

prob.f(cache.fu, cache.u, cache.p)
f_norm = cache.internalnorm(cache.fu)^n_exp

# TODO: Failure mode with inner line search failed?
for _ in 1:(cache.alg.max_inner_iterations)
c = f_bar + η - γ * α₊^2 * f_norm_old

f_norm c && break

α₊ = α₊ * clamp(α₊ * f_norm_old / (f_norm + (T(2) * α₊ - T(1)) * f_norm_old),
τ_min, τ_max)
@. cache.u = cache.uprev - α₋ * cache.du

prob.f(cache.fu, cache.u, cache.p)
f_norm = cache.internalnorm(cache.fu)^n_exp

f_norm c && break

α₋ = α₋ * clamp(α₋ * f_norm_old / (f_norm + (T(2) * α₋ - T(1)) * f_norm_old),
τ_min, τ_max)
@. cache.u = cache.uprev + α₊ * cache.du

prob.f(cache.fu, cache.u, cache.p)
f_norm = cache.internalnorm(cache.fu)^n_exp
end

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
cache.du, α₊)

check_and_update!(cache, cache.fu, cache.u, cache.uprev)
@bb axpy!(α₊, cache.du, cache.u)

# Update spectral parameter
@. cache.uprev = cache.u - cache.uprev
@. cache.fuprev = cache.fu - cache.fuprev

α₊ = sum(abs2, cache.uprev)
@. cache.uprev *= cache.fuprev
α₋ = sum(cache.uprev)
cache.σ_n = α₊ / α₋

# Spectral parameter bounds check
if !(σ_min abs(cache.σ_n) σ_max)
test_norm = sqrt(sum(abs2, cache.fuprev))
cache.σ_n = clamp(inv(test_norm), T(1), T(1e5))
end

# Take step
@. cache.uprev = cache.u
@. cache.fuprev = cache.fu
cache.f_norm = f_norm

# Update history
cache.history[cache.stats.nsteps % M + 1] = f_norm
cache.stats.nf += 1
return nothing
end

function perform_step!(cache::DFSaneCache{false})
@unpack alg, f_norm, σ_n, σ_min, σ_max, α_1, γ, τ_min, τ_max, n_exp, M, prob = cache
T = eltype(cache.u)
f_norm_old = f_norm

# Spectral parameter range check
σ_n = sign(σ_n) * clamp(abs(σ_n), σ_min, σ_max)

# Line search direction
cache.du = @. -σ_n * cache.fuprev

η = alg.η_strategy(cache.f_norm_0, cache.stats.nsteps, cache.u, cache.fu)

f_bar = maximum(cache.history)
α₊ = α_1
α₋ = α_1
cache.u = @. cache.uprev + α₊ * cache.du

cache.fu = prob.f(cache.u, cache.p)
evaluate_f(cache, cache.u, cache.p)
f_norm = cache.internalnorm(cache.fu)^n_exp
α = α₊

# TODO: Failure mode with inner line search failed?
for _ in 1:(cache.alg.max_inner_iterations)
c = f_bar + η - γ * α₊^2 * f_norm_old

f_norm c && break
inner_converged = false
for k in 1:(cache.alg.max_inner_iterations)
if f_norm f_bar + η - γ * α₊^2 * f_norm_old
α = α₊
inner_converged = true
break
end

α₊ = α₊ * clamp(α₊ * f_norm_old / (f_norm + (T(2) * α₊ - T(1)) * f_norm_old),
τ_min, τ_max)
cache.u = @. cache.uprev - α₋ * cache.du
@bb axpy!(-α₋, cache.du, cache.u)

cache.fu = prob.f(cache.u, cache.p)
evaluate_f(cache, cache.u, cache.p)
f_norm = cache.internalnorm(cache.fu)^n_exp

f_norm c && break
if f_norm f_bar + η - γ * α₋^2 * f_norm_old
α = α₋
inner_converged = true
break
end

α₋ = α₋ * clamp(α₋ * f_norm_old / (f_norm + (T(2) * α₋ - T(1)) * f_norm_old),
τ_min, τ_max)
cache.u = @. cache.uprev + α₊ * cache.du
@bb axpy!(α₊, cache.du, cache.u)

cache.fu = prob.f(cache.u, cache.p)
evaluate_f(cache, cache.u, cache.p)
f_norm = cache.internalnorm(cache.fu)^n_exp
end

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), nothing,
cache.du, α₊)
if !inner_converged
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
end

check_and_update!(cache, cache.fu, cache.u, cache.uprev)
update_trace!(cache, α)
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)

# Update spectral parameter
cache.uprev = @. cache.u - cache.uprev
cache.fuprev = @. cache.fu - cache.fuprev
@bb @. cache.u_cache = cache.u - cache.u_cache
@bb @. cache.fu_cache = cache.fu - cache.fu_cache

α₊ = sum(abs2, cache.uprev)
cache.uprev = @. cache.uprev * cache.fuprev
α₋ = sum(cache.uprev)
α₊ = sum(abs2, cache.u_cache)
@bb @. cache.u_cache *= cache.fu_cache
α₋ = sum(cache.u_cache)
cache.σ_n = α₊ / α₋

# Spectral parameter bounds check
Expand All @@ -262,8 +190,8 @@ function perform_step!(cache::DFSaneCache{false})
end

# Take step
cache.uprev = cache.u
cache.fuprev = cache.fu
@bb copyto!(cache.u_cache, cache.u)
@bb copyto!(cache.fu_cache, cache.fu)
cache.f_norm = f_norm

# Update history
Expand All @@ -272,41 +200,8 @@ function perform_step!(cache::DFSaneCache{false})
return nothing
end

function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
recursivecopy!(cache.uprev, u0)
cache.prob.f(cache.fu, cache.u, p)
cache.prob.f(cache.fuprev, cache.uprev, p)
else
cache.u = u0
cache.uprev = u0
cache.fu = cache.prob.f(cache.u, p)
cache.fuprev = cache.prob.f(cache.uprev, p)
end

function __reinit_internal!(cache::DFSaneCache; kwargs...)
cache.f_norm = cache.internalnorm(cache.fu)^cache.n_exp
cache.f_norm_0 = cache.f_norm

fill!(cache.history, cache.f_norm)

T = eltype(cache.u)
cache.σ_n = T(cache.alg.σ_1)

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
return
end
15 changes: 0 additions & 15 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,6 @@ function evaluate_f(cache, u, p)
return nothing
end

"""
__matmul!(C, A, B)
Defaults to `mul!(C, A, B)`. However, for sparse matrices uses `C .= A * B`.
"""
__matmul!(C, A, B) = mul!(C, A, B)
__matmul!(C::AbstractSparseMatrix, A, B) = C .= A * B

# Concretize Algorithms
function get_concrete_algorithm(alg, prob)
!hasfield(typeof(alg), :ad) && return alg
Expand Down Expand Up @@ -381,15 +373,8 @@ function __try_factorize_and_check_singular!(linsolve, X)
end
__try_factorize_and_check_singular!(::FakeLinearSolveJLCache, x) = _issingular(x), false

# TODO: Remove. handled in MaybeInplace.jl
@generated function _axpy!(α, x, y)
hasmethod(axpy!, Tuple{α, x, y}) && return :(axpy!(α, x, y))
return :(@. y += α * x)
end

# Non-square matrix
@inline __needs_square_A(_, ::Number) = true
# @inline __needs_square_A(_, ::StaticArray) = true
@inline __needs_square_A(alg, _) = LinearSolve.needs_square_A(alg.linsolve)

# Define special concatenation for certain Array combinations
Expand Down

0 comments on commit 031639f

Please sign in to comment.