Skip to content

Commit

Permalink
Fix API issue
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Nov 21, 2024
1 parent d746f22 commit 5f8d644
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHousehold
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

function SimpleNonlinearSolve.evaluate_hvvp_internal(hvvp, prob::ImmutableNonlinearProblem, u, a)
function SimpleNonlinearSolve.evaluate_hvvp_internal(
hvvp, prob::ImmutableNonlinearProblem, u, a)
if SciMLBase.isinplace(prob)
binary_f = @closure (y, x) -> prob.f(y, x, prob.p)
TaylorDiff.derivative!(hvvp, binary_f, cache.fu, u, a, Val(2))
Expand Down
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function solve_adjoint_internal end

function evaluate_hvvp(args...; kws...)
is_extension_loaded(Val(:TaylorDiff)) && return evaluate_hvvp_internal(args...; kws...)
error("Halley's mathod with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.")
error("Halley's method with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.")
end

function evaluate_hvvp_internal end
Expand Down
24 changes: 12 additions & 12 deletions lib/SimpleNonlinearSolve/src/halley.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
SimpleHalley(autodiff, taylor_mode)
SimpleHalley(; autodiff = nothing, taylor_mode = Val(false))
SimpleHalley(autodiff)
SimpleHalley(; autodiff = nothing)
A low-overhead implementation of Halley's Method.
Expand All @@ -15,18 +15,17 @@ A low-overhead implementation of Halley's Method.
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
automatic backend selection). Valid choices include jacobian backends from
`DifferentiationInterface.jl`.
- `taylor_mode`: whether to use Taylor mode automatic differentiation to compute the Hessian-vector-vector product. Defaults to `Val(false)`. If `Val(true)`, you must have `TaylorDiff.jl` loaded.
In addition, `AutoTaylorDiff` can be used to enable Taylor mode for computing the Hessian-vector-vector product more efficiently; in this case, the Jacobian would still be calculated using the default backend. You need to have `TaylorDiff.jl` loaded to use this option.
"""
@kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm
autodiff = nothing
taylor_mode = Val(false)
end

function SciMLBase.__solve(
prob::ImmutableNonlinearProblem, alg::SimpleHalley{ad, Val{taylor_mode}}, args...;
prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...
) where {ad, taylor_mode}
)
x = NLBUtils.maybe_unaliased(prob.u0, alias_u0)
fx = NLBUtils.evaluate_f(prob, x)
T = promote_type(eltype(fx), eltype(x))
Expand All @@ -40,6 +39,7 @@ function SciMLBase.__solve(

# The way we write the 2nd order derivatives, we know Enzyme won't work there
autodiff = alg.autodiff === nothing ? AutoForwardDiff() : alg.autodiff
jac_autodiff = autodiff === AutoTaylorDiff() ? AutoForwardDiff() : autodiff
@set! alg.autodiff = autodiff

@bb xo = copy(x)
Expand All @@ -54,16 +54,16 @@ function SciMLBase.__solve(

fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
NLBUtils.safe_similar(fx) : fx
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)
jac_cache = Utils.prepare_jacobian(prob, jac_autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, jac_autodiff, fx_cache, x, jac_cache)

for _ in 1:maxiters
if taylor_mode
if autodiff isa AutoTaylorDiff
fx = NLBUtils.evaluate_f!!(prob, fx, x)
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
J = Utils.compute_jacobian!!(J, prob, jac_autodiff, fx_cache, x, jac_cache)
H = nothing
else
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
fx, J, H = Utils.compute_jacobian_and_hessian(jac_autodiff, prob, fx, x)
end

NLBUtils.can_setindex(x) || (A = J)
Expand All @@ -81,7 +81,7 @@ function SciMLBase.__solve(

aᵢ = J_fact \ NLBUtils.safe_vec(fx)

if taylor_mode
if autodiff isa AutoTaylorDiff
Aaᵢ = evaluate_hvvp(Aaᵢ, prob, x, typeof(x)(aᵢ))
else
A_ = NLBUtils.safe_vec(A)
Expand Down
21 changes: 11 additions & 10 deletions lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,25 +82,26 @@ end
AutoForwardDiff(),
AutoFiniteDiff(),
AutoReverseDiff(),
AutoTaylorDiff(),
nothing
), taylor_mode in (Val(false), Val(true))
)
@testset "[OOP] u0: $(typeof(u0))" for u0 in (
[1.0, 1.0], @SVector[1.0, 1.0], 1.0)
sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff, taylor_mode))
sol = run_nlsolve_oop(
quadratic_f, u0; solver = alg(; autodiff))
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9
end
end

@testset for taylor_mode in (Val(false), Val(true))
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
@testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(
probN, alg(; autodiff = AutoForwardDiff(), taylor_mode); termination_condition).u .≈
sqrt(2.0))
end
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(
probN, alg(; autodiff = AutoTaylorDiff());
termination_condition).u .≈
sqrt(2.0))
end
end
end
Expand Down

0 comments on commit 5f8d644

Please sign in to comment.