Skip to content

Commit

Permalink
feat: add SimpleHalley method
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 4, 2024
1 parent 87fad10 commit 5b51678
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 6 deletions.
10 changes: 4 additions & 6 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ using CommonSolve: CommonSolve, solve
using ConcreteStructs: @concrete
using FastClosures: @closure
using LineSearch: LiFukushimaLineSearch
using LinearAlgebra: dot
using MaybeInplace: @bb
using LinearAlgebra: LinearAlgebra, dot
using MaybeInplace: @bb, setindex_trait, CannotSetindex, CanSetindex
using PrecompileTools: @compile_workload, @setup_workload
using Reexport: @reexport
@reexport using SciMLBase # I don't like this but needed to avoid a breaking change
Expand Down Expand Up @@ -82,18 +82,15 @@ function solve_adjoint_internal end
algs = [
SimpleBroyden(),
SimpleKlement(),
SimpleHalley(),
SimpleNewtonRaphson(),
SimpleTrustRegion()
]
algs_no_iip = []

@compile_workload begin
for alg in algs, prob in (prob_scalar, prob_iip, prob_oop)
CommonSolve.solve(prob, alg)
end
for alg in algs_no_iip
CommonSolve.solve(prob_scalar, alg)
end
end
end
end
Expand All @@ -104,5 +101,6 @@ export Alefeld, Bisection, Brent, Falsi, ITP, Ridder

export SimpleBroyden, SimpleKlement
export SimpleGaussNewton, SimpleNewtonRaphson, SimpleTrustRegion
export SimpleHalley

end
83 changes: 83 additions & 0 deletions lib/SimpleNonlinearSolve/src/halley.jl
Original file line number Diff line number Diff line change
@@ -1 +1,84 @@
"""
SimpleHalley(autodiff)
SimpleHalley(; autodiff = nothing)
A low-overhead implementation of Halley's Method.
!!! note
As part of the decreased overhead, this method omits some of the higher level error
catching of the other methods. Thus, to see better error messages, use one of the other
methods like `NewtonRaphson`.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
automatic backend selection). Valid choices include jacobian backends from
`DifferentiationInterface.jl`.
"""
@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm
autodiff = nothing
end

function SciMLBase.__solve(
prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = Utils.maybe_unaliased(prob.u0, alias_u0)
fx = Utils.get_fx(prob, x)
fx = Utils.eval_f(prob, fx, x)
T = promote_type(eltype(fx), eltype(x))

iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))

autodiff = NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)

@bb xo = copy(x)

strait = setindex_trait(x)

A = strait isa CanSetindex ? similar(x, length(x), length(x)) : x
Aaᵢ = strait isa CanSetindex ? similar(x, length(x)) : x
cᵢ = strait isa CanSetindex ? similar(x) : x

for _ in 1:maxiters
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)

strait isa CannotSetindex && (A = J)

# Factorize Once and Reuse
J_fact = if J isa Number
J
else
fact = LinearAlgebra.lu(J; check = false)
!LinearAlgebra.issuccess(fact) && return SciMLBase.build_solution(
prob, alg, x, fx; retcode = ReturnCode.Unstable)
fact
end

aᵢ = J_fact \ Utils.safe_vec(fx)
A_ = Utils.safe_vec(A)
@bb A_ = H × aᵢ
A = Utils.restructure(A, A_)

@bb Aaᵢ = A × aᵢ
@bb A .*= -1
bᵢ = J_fact \ Utils.safe_vec(Aaᵢ)

cᵢ_ = Utils.safe_vec(cᵢ)
@bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ))
cᵢ = Utils.restructure(cᵢ, cᵢ_)

solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)

@bb @. x += cᵢ
@bb copyto!(xo, x)
end

return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end
22 changes: 22 additions & 0 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,26 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
return J
end

function compute_jacobian_and_hessian(autodiff, prob, _, x::Number)
H = DI.second_derivative(prob.f, autodiff, x, Constant(prob.p))
fx, J = DI.value_and_derivative(prob.f, autodiff, x, Constant(prob.p))
return fx, J, H
end
function compute_jacobian_and_hessian(autodiff, prob, fx, x)
if SciMLBase.isinplace(prob)
jac_fn = @closure (u, p) -> begin
du = similar(fx, promote_type(eltype(fx), eltype(u)))
return DI.jacobian(prob.f, du, autodiff, u, Constant(p))
end
J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p))
fx = Utils.eval_f(prob, fx, x)
return fx, J, H
else
jac_fn = @closure (u, p) -> DI.jacobian(prob.f, autodiff, u, Constant(p))
J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p))
fx = Utils.eval_f(prob, fx, x)
return fx, J, H
end
end

end

0 comments on commit 5b51678

Please sign in to comment.