Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Halley and Householder to SimpleNonlinearSolve #507

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolveDiffEqBaseExt = "DiffEqBase"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveTaylorDiffExt = "TaylorDiff"
SimpleNonlinearSolveTrackerExt = "Tracker"

[compat]
Expand Down Expand Up @@ -66,6 +68,7 @@ SciMLBase = "2.58"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.3"
TaylorDiff = "0.3"
Test = "1.10"
TestItemRunner = "1"
Tracker = "0.2.35"
Expand All @@ -84,10 +87,11 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TaylorDiff = "b36ab563-344f-407b-a36a-4f200bebf99c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "TaylorDiff", "Test", "TestItemRunner", "Tracker", "Zygote"]
73 changes: 73 additions & 0 deletions lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTaylorDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
module SimpleNonlinearSolveTaylorDiffExt
using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleHouseholder, Utils
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
AbstractNonlinearSolveAlgorithm
using MaybeInplace: @bb
using FastClosures: @closure
import SciMLBase
import TaylorDiff

SimpleNonlinearSolve.is_extension_loaded(::Val{:TaylorDiff}) = true

const NLBUtils = NonlinearSolveBase.Utils

@inline function __get_higher_order_derivatives(
::SimpleHouseholder{N}, prob, x, fx) where {N}
vN = Val(N)
l = map(one, x)
t = TaylorDiff.make_seed(x, l, vN)

if SciMLBase.isinplace(prob)
bundle = similar(fx, TaylorDiff.TaylorScalar{eltype(fx), N})
prob.f(bundle, t, prob.p)
map!(TaylorDiff.value, fx, bundle)
else
bundle = prob.f(t, prob.p)
fx = map(TaylorDiff.value, bundle)
end
invbundle = inv.(bundle)
num = N == 1 ? map(TaylorDiff.value, invbundle) :
TaylorDiff.extract_derivative(invbundle, Val(N - 1))
den = TaylorDiff.extract_derivative(invbundle, vN)
return num, den, fx
end

function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHouseholder{N},
args...; abstol = nothing, reltol = nothing, maxiters = 1000,
termination_condition = nothing, alias_u0 = false, kwargs...) where {N}
length(prob.u0) == 1 ||
throw(ArgumentError("SimpleHouseholder only supports scalar problems"))
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
fx = NLBUtils.evaluate_f(prob, x)

iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))

@bb xo = similar(x)

for i in 1:maxiters
@bb copyto!(xo, x)
num, den, fx = __get_higher_order_derivatives(alg, prob, x, fx)
@bb x .+= N .* num ./ den
solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)
end

return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

function SimpleNonlinearSolve.evaluate_hvvp_internal(hvvp, prob::ImmutableNonlinearProblem, u, a)
tansongchen marked this conversation as resolved.
Show resolved Hide resolved
if SciMLBase.isinplace(prob)
binary_f = @closure (y, x) -> prob.f(y, x, prob.p)
TaylorDiff.derivative!(hvvp, binary_f, cache.fu, u, a, Val(2))
else
unary_f = Base.Fix2(prob.f, prob.p)
hvvp = TaylorDiff.derivative(unary_f, u, a, Val(2))
end
hvvp
end

end
10 changes: 9 additions & 1 deletion lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
include("broyden.jl")
include("dfsane.jl")
include("halley.jl")
include("householder.jl")
include("klement.jl")
include("lbroyden.jl")
include("raphson.jl")
Expand Down Expand Up @@ -128,6 +129,13 @@

function solve_adjoint_internal end

function evaluate_hvvp(args...; kws...)
is_extension_loaded(Val(:TaylorDiff)) && return evaluate_hvvp_internal(args...; kws...)
error("Halley's mathod with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.")

Check warning on line 134 in lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"mathod" should be "method".
end

function evaluate_hvvp_internal end

@setup_workload begin
for T in (Float64,)
prob_scalar = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
Expand Down Expand Up @@ -161,7 +169,7 @@
export SimpleBroyden, SimpleKlement, SimpleLimitedMemoryBroyden
export SimpleDFSane
export SimpleGaussNewton, SimpleNewtonRaphson, SimpleTrustRegion
export SimpleHalley
export SimpleHalley, SimpleHouseholder

export solve

Expand Down
38 changes: 28 additions & 10 deletions lib/SimpleNonlinearSolve/src/halley.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
SimpleHalley(autodiff)
SimpleHalley(; autodiff = nothing)
SimpleHalley(autodiff, taylor_mode)
SimpleHalley(; autodiff = nothing, taylor_mode = Val(false))

A low-overhead implementation of Halley's Method.

Expand All @@ -15,16 +15,18 @@ A low-overhead implementation of Halley's Method.
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
automatic backend selection). Valid choices include jacobian backends from
`DifferentiationInterface.jl`.
- `taylor_mode`: whether to use Taylor mode automatic differentiation to compute the Hessian-vector-vector product. Defaults to `Val(false)`. If `Val(true)`, you must have `TaylorDiff.jl` loaded.
tansongchen marked this conversation as resolved.
Show resolved Hide resolved
"""
@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm
autodiff = nothing
taylor_mode = Val(false)
end

function SciMLBase.__solve(
prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
prob::ImmutableNonlinearProblem, alg::SimpleHalley{ad, Val{taylor_mode}}, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...
)
) where {ad, taylor_mode}
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
fx = NLBUtils.evaluate_f(prob, x)
T = promote_type(eltype(fx), eltype(x))
Expand All @@ -50,8 +52,19 @@ function SciMLBase.__solve(
A, Aaᵢ, cᵢ = x, x, x
end

fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
NLBUtils.safe_similar(fx) : fx
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)

for _ in 1:maxiters
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
if taylor_mode
fx = NLBUtils.evaluate_f!!(prob, fx, x)
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
H = nothing
else
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
end

NLBUtils.can_setindex(x) || (A = J)

Expand All @@ -67,12 +80,17 @@ function SciMLBase.__solve(
end

aᵢ = J_fact \ NLBUtils.safe_vec(fx)
A_ = NLBUtils.safe_vec(A)
@bb A_ = H × aᵢ
A = NLBUtils.restructure(A, A_)

@bb Aaᵢ = A × aᵢ
@bb A .*= -1
if taylor_mode
Aaᵢ = evaluate_hvvp(Aaᵢ, prob, x, typeof(x)(aᵢ))
else
A_ = NLBUtils.safe_vec(A)
@bb A_ = H × aᵢ
A = NLBUtils.restructure(A, A_)

@bb Aaᵢ = A × aᵢ
@bb A .*= -1
end
bᵢ = J_fact \ NLBUtils.safe_vec(Aaᵢ)

cᵢ_ = NLBUtils.safe_vec(cᵢ)
Expand Down
16 changes: 16 additions & 0 deletions lib/SimpleNonlinearSolve/src/householder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
SimpleHouseholder{order}()

A low-overhead implementation of Householder's method to arbitrary order.
This method is non-allocating on scalar and static array problems.

!!! warning

Needs `TaylorDiff.jl` to be explicitly loaded before using this functionality.
Internally, this uses TaylorDiff.jl for automatic differentiation.

### Type Parameters

- `order`: the order of the Householder method. `order = 1` is the same as Newton's method, `order = 2` is the same as Halley's method, etc.
"""
struct SimpleHouseholder{order} <: AbstractSimpleNonlinearSolveAlgorithm end
35 changes: 31 additions & 4 deletions lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl
tansongchen marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@testsnippet RootfindTestSnippet begin
using StaticArrays, Random, LinearAlgebra, ForwardDiff, NonlinearSolveBase, SciMLBase
using ADTypes, PolyesterForwardDiff, Enzyme, ReverseDiff
import TaylorDiff

quadratic_f(u, p) = u .* u .- p
quadratic_f!(du, u, p) = (du .= u .* u .- p)
Expand Down Expand Up @@ -82,21 +83,47 @@ end
AutoFiniteDiff(),
AutoReverseDiff(),
nothing
)
), taylor_mode in (Val(false), Val(true))
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff))
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff, taylor_mode))
tansongchen marked this conversation as resolved.
Show resolved Hide resolved
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
end
end

@testset for taylor_mode in (Val(false), Val(true))
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(
probN, alg(; autodiff = AutoForwardDiff(), taylor_mode); termination_condition).u .≈
tansongchen marked this conversation as resolved.
Show resolved Hide resolved
sqrt(2.0))
end
end
end
end

@testitem "Higher Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin
@testset for alg in (
SimpleHouseholder,
)
@testset for order in (1, 2, 3, 4)
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
[1.0], @SVector[1.0], 1.0)
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg{order}())
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
end
end

@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
u0 in (1.0, [1.0], @SVector[1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(
probN, alg(; autodiff = AutoForwardDiff()); termination_condition).u .≈
probN, alg{2}(); termination_condition).u .≈
sqrt(2.0))
end
end
Expand Down
Loading