Skip to content

Commit

Permalink
Make the internal field names more consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 30, 2023
1 parent f147663 commit 4f2dec0
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 102 deletions.
21 changes: 11 additions & 10 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ end
function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
end
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)

get_fu(cache::AbstractNonlinearSolveCache) = cache.fu
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu)
get_u(cache::AbstractNonlinearSolveCache) = cache.u
SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)

Expand Down Expand Up @@ -152,17 +153,17 @@ include("trace.jl")
include("extension_algs.jl")
include("linesearch.jl")
include("raphson.jl")
include("trustRegion.jl")
include("levenberg.jl")
include("gaussnewton.jl")
include("dfsane.jl")
include("pseudotransient.jl")
# include("trustRegion.jl")
# include("levenberg.jl")
# include("gaussnewton.jl")
# include("dfsane.jl")
# include("pseudotransient.jl")
include("broyden.jl")
include("klement.jl")
include("lbroyden.jl")
# include("lbroyden.jl")
include("jacobian.jl")
include("ad.jl")
include("default.jl")
# include("ad.jl")
# include("default.jl")

# @setup_workload begin
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
Expand Down
49 changes: 19 additions & 30 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@ end
f
alg
u
u_prev
u_cache
du
fu
fu2
fu_cache
dfu
p
J⁻¹
J⁻¹₂
J⁻¹df
J⁻¹dfu
force_stop::Bool
resets::Int
max_resets::Int
Expand All @@ -57,9 +56,6 @@ end
trace
end

get_fu(cache::GeneralBroydenCache) = cache.fu
set_fu!(cache::GeneralBroydenCache, fu) = (cache.fu = fu)

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
Expand All @@ -73,19 +69,18 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance

@bb u_prev = copy(u)
@bb fu2 = copy(fu)
@bb u_cache = copy(u)
@bb fu_cache = similar(fu)
@bb dfu = similar(fu)
@bb J⁻¹₂ = similar(u)
@bb J⁻¹df = similar(u)
@bb J⁻¹dfu = similar(u)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
kwargs...)

return GeneralBroydenCache{iip}(f, alg, u, u_prev, du, fu, fu2, dfu, p, J⁻¹,
J⁻¹, J⁻¹df, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default,
return GeneralBroydenCache{iip}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p,
J⁻¹, J⁻¹dfu, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
end
Expand All @@ -97,22 +92,16 @@ function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
@bb axpy!(-α, cache.du, cache.u)

if iip
cache.f(cache.fu2, cache.u, cache.p)
else
cache.fu2 = cache.f(cache.u, cache.p)
end

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
cache.fu2, cache.J⁻¹, cache.du, α)
evaluate_f(cache, cache.u, cache.p)

check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
update_trace!(cache, α)
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)
cache.stats.nf += 1

cache.force_stop && return nothing

# Update the inverse jacobian
@bb @. cache.dfu = cache.fu2 - cache.fu
@bb @. cache.dfu = cache.fu - cache.fu_cache

if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
if cache.resets cache.max_resets
Expand All @@ -124,15 +113,15 @@ function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
cache.resets += 1
else
@bb cache.du .*= -1
@bb cache.J⁻¹df = cache.J⁻¹ × vec(cache.dfu)
@bb cache.J⁻¹₂ = cache.J⁻¹ × vec(cache.du)
denom = dot(cache.du, cache.J⁻¹df)
@bb @. cache.du = (cache.du - cache.J⁻¹df) / ifelse(iszero(denom), T(1e-5), denom)
@bb cache.J⁻¹ += vec(cache.du) × transpose(cache.J⁻¹₂)
@bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu)
@bb cache.u_cache = cache.J⁻¹ × vec(cache.du)
denom = dot(cache.du, cache.J⁻¹dfu)
@bb @. cache.du = (cache.du - cache.J⁻¹dfu) / ifelse(iszero(denom), T(1e-5), denom)
@bb cache.J⁻¹ += vec(cache.du) × transpose(cache.u_cache)
end

@bb copyto!(cache.fu, cache.fu2)
@bb copyto!(cache.u_prev, cache.u)
@bb copyto!(cache.fu_cache, cache.fu)
@bb copyto!(cache.u_cache, cache.u)

return nothing
end
Expand Down
11 changes: 6 additions & 5 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ jacobian!!(J, _) = J
# `!!` notation is from BangBang.jl since J might be jacobian in case of oop `f.jac`
# and we don't want wasteful `copyto!`
function jacobian!!(J::Union{AbstractMatrix{<:Number}, Nothing}, cache)
@unpack f, uf, u, p, jac_cache, alg, fu2 = cache
@unpack f, uf, u, p, jac_cache, alg, fu_cache = cache
iip = isinplace(cache)
if iip
if has_jac(f)
f.jac(J, u, p)
else
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu2, u)
sparse_jacobian!(J, alg.ad, jac_cache, uf, fu_cache, u)
end
return J
else
Expand Down Expand Up @@ -116,9 +116,10 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
end

if linsolve_init
linprob_A = alg isa PseudoTransient ?
(J - (1 / (convert(eltype(u), alg.alpha_initial))) * I) :
(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J)
linprob_A = needsJᵀJ ? __maybe_symmetric(JᵀJ) : J
# linprob_A = alg isa PseudoTransient ?
# (J - (1 / (convert(eltype(u), alg.alpha_initial))) * I) :
# (needsJᵀJ ? __maybe_symmetric(JᵀJ) : J)
linsolve = linsolve_caches(linprob_A, needsJᵀJ ? Jᵀfu : fu, du, p, alg;
linsolve_kwargs)
else
Expand Down
55 changes: 22 additions & 33 deletions src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ end
f
alg
u
u_prev
u_cache
fu
fu2
fu_cache
du
p
linsolve
J
J_cache
J_cache2
Jᵀ²du
J_cache_2
Jdu
Jdu_cache
resets
force_stop
maxiters::Int
Expand All @@ -66,9 +66,6 @@ end
trace
end

get_fu(cache::GeneralKlementCache) = cache.fu
set_fu!(cache::GeneralKlementCache, fu) = (cache.fu = fu)

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
Expand All @@ -94,16 +91,16 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, J, du; kwargs...)

@bb u_prev = copy(u)
@bb fu2 = similar(fu)
@bb u_cache = similar(u)
@bb fu_cache = similar(fu)
@bb J_cache = similar(J)
@bb J_cache2 = similar(J)
@bb Jᵀ²du = similar(fu)
@bb J_cache_2 = similar(J)
@bb Jdu = similar(fu)
@bb Jdu_cache = similar(fu)

return GeneralKlementCache{iip}(f, alg, u, u_prev, fu, fu2, du, p, linsolve, J, J_cache,
J_cache2, Jᵀ²du, Jdu, 0, false, maxiters, internalnorm, ReturnCode.Default, abstol,
reltol, prob, NLStats(1, 0, 0, 0, 0),
return GeneralKlementCache{iip}(f, alg, u, u_cache, fu, fu_cache, du, p, linsolve,
J, J_cache, J_cache_2, Jdu, Jdu_cache, 0, false, maxiters, internalnorm,
ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
end

Expand All @@ -127,24 +124,18 @@ function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu),
linu = _vec(cache.du), cache.p, reltol = cache.abstol)
cache.linsolve = linres.cache

!iip && (cache.du = linres.u)

# Line Search
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
@bb axpy!(-α, cache.du, cache.u)

if iip
cache.f(cache.fu2, cache.u, cache.p)
else
cache.fu2 = cache.f(cache.u, cache.p)
end
evaluate_f(cache, cache.u, cache.p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), cache.fu2, cache.J,
cache.du, α)
update_trace!(cache, α)
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)

check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
@bb copyto!(cache.u_prev, cache.u)
@bb copyto!(cache.u_cache, cache.u)

cache.stats.nf += 1
cache.stats.nsolve += 1
Expand All @@ -155,19 +146,17 @@ function perform_step!(cache::GeneralKlementCache{iip}) where {iip}
# Update the Jacobian
@bb cache.du .*= -1
@bb cache.J_cache .= cache.J' .^ 2
@bb @. cache.Jdu = cache.du ^ 2
@bb cache.Jᵀ²du = cache.J_cache × vec(cache.Jdu)
@bb @. cache.Jdu = cache.du^2
@bb cache.Jdu_cache = cache.J_cache × vec(cache.Jdu)
@bb cache.Jdu = cache.J × vec(cache.du)
@bb @. cache.fu = cache.fu2 - cache.fu

@bb @. cache.fu = (cache.fu - cache.Jdu) / max(cache.Jᵀ²du, eps(real(T)))

@bb @. cache.fu_cache = (cache.fu - cache.fu_cache - cache.Jdu) /
max(cache.Jdu_cache, eps(real(T)))
@bb cache.J_cache = vec(cache.fu) × transpose(_vec(cache.du))
@bb @. cache.J_cache *= cache.J
@bb cache.J_cache2 = cache.J_cache × cache.J
@bb cache.J .+= cache.J_cache2
@bb cache.J_cache_2 = cache.J_cache × cache.J
@bb cache.J .+= cache.J_cache_2

@bb copyto!(cache.fu, cache.fu2)
@bb copyto!(cache.fu_cache, cache.fu)

return nothing
end
Expand Down
31 changes: 14 additions & 17 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ end
f
alg
u
u_prev
fu1
fu2
fu
u_cache
fu_cache
du
p
uf
Expand All @@ -81,19 +81,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
u = __maybe_unaliased(u0, alias_u0)
fu1 = evaluate_f(prob, u)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
fu = evaluate_f(prob, u)
uf, linsolve, J, fu_cache, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu1, u,
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)

ls_cache = init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip))
trace = init_nonlinearsolve_trace(alg, u, fu1, ApplyArray(__zero, J), du; kwargs...)
ls_cache = init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))
trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du; kwargs...)

@bb u_prev = copy(u)
@bb u_cache = copy(u)

return NewtonRaphsonCache{iip}(f, alg, u, u_prev, fu1, fu2, du, p, uf, linsolve, J,
return NewtonRaphsonCache{iip}(f, alg, u, fu, u_cache, fu_cache, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
NLStats(1, 0, 0, 0, 0), ls_cache, tc_cache, trace)
end
Expand All @@ -104,10 +104,9 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip}
cache.J = jacobian!!(cache.J, cache)

# u = u - J \ fu
linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu1),
linres = dolinsolve(alg.precs, cache.linsolve; A = cache.J, b = _vec(cache.fu),
linu = _vec(cache.du), cache.p, reltol = cache.abstol)
cache.linsolve = linres.cache

!iip && (cache.du = linres.u)

# Line Search
Expand All @@ -116,12 +115,10 @@ function perform_step!(cache::NewtonRaphsonCache{iip}) where {iip}

evaluate_f(cache, cache.u, cache.p)

update_trace!(cache.trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), cache.J,
cache.du, α)

check_and_update!(cache, cache.fu1, cache.u, cache.u_prev)
update_trace!(cache, α)
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)

@bb copyto!(cache.u_prev, cache.u)
@bb copyto!(cache.u_cache, cache.u)
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
Expand Down
20 changes: 20 additions & 0 deletions src/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,23 @@ function update_trace_with_invJ!(trace::NonlinearSolveTrace{ShT, StT}, iter, u,
show_now && show(entry)
return trace
end

function update_trace!(cache::AbstractNonlinearSolveCache, α = true)
trace = __getproperty(cache, Val(:trace))
trace === nothing && return nothing

J = __getproperty(cache, Val(:J))
if J === nothing
J_inv = __getproperty(cache, Val(:J⁻¹))
if J_inv === nothing
update_trace!(trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache),
nothing, cache.du, α)
else
update_trace_with_invJ!(trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), J_inv, cache.du, α)
end
else
update_trace!(trace, cache.stats.nsteps + 1, get_u(cache), get_fu(cache), J,
cache.du, α)
end
end
Loading

0 comments on commit 4f2dec0

Please sign in to comment.