diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTaylorDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTaylorDiffExt.jl index 20d7beece..209e475ef 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTaylorDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTaylorDiffExt.jl @@ -59,7 +59,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHousehold return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) end -function SimpleNonlinearSolve.evaluate_hvvp_internal(hvvp, prob::ImmutableNonlinearProblem, u, a) +function SimpleNonlinearSolve.evaluate_hvvp_internal( + hvvp, prob::ImmutableNonlinearProblem, u, a) if SciMLBase.isinplace(prob) binary_f = @closure (y, x) -> prob.f(y, x, prob.p) TaylorDiff.derivative!(hvvp, binary_f, cache.fu, u, a, Val(2)) diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 83e81d618..281ffef42 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -131,7 +131,7 @@ function solve_adjoint_internal end function evaluate_hvvp(args...; kws...) is_extension_loaded(Val(:TaylorDiff)) && return evaluate_hvvp_internal(args...; kws...) - error("Halley's mathod with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.") + error("Halley's method with Taylor mode requires `TaylorDiff.jl` to be explicitly loaded.") end function evaluate_hvvp_internal end diff --git a/lib/SimpleNonlinearSolve/src/halley.jl b/lib/SimpleNonlinearSolve/src/halley.jl index f568ffa5e..5078925c3 100644 --- a/lib/SimpleNonlinearSolve/src/halley.jl +++ b/lib/SimpleNonlinearSolve/src/halley.jl @@ -1,6 +1,6 @@ """ - SimpleHalley(autodiff, taylor_mode) - SimpleHalley(; autodiff = nothing, taylor_mode = Val(false)) + SimpleHalley(autodiff) + SimpleHalley(; autodiff = nothing) A low-overhead implementation of Halley's Method. @@ -15,18 +15,17 @@ A low-overhead implementation of Halley's Method. - `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e. automatic backend selection). Valid choices include jacobian backends from `DifferentiationInterface.jl`. - - `taylor_mode`: whether to use Taylor mode automatic differentiation to compute the Hessian-vector-vector product. Defaults to `Val(false)`. If `Val(true)`, you must have `TaylorDiff.jl` loaded. + In addition, `AutoTaylorDiff` can be used to enable Taylor mode for computing the Hessian-vector-vector product more efficiently; in this case, the Jacobian would still be calculated using the default backend. You need to have `TaylorDiff.jl` loaded to use this option. """ @kwdef @concrete struct SimpleHalley <: AbstractSimpleNonlinearSolveAlgorithm autodiff = nothing - taylor_mode = Val(false) end function SciMLBase.__solve( - prob::ImmutableNonlinearProblem, alg::SimpleHalley{ad, Val{taylor_mode}}, args...; + prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs... -) where {ad, taylor_mode} +) x = NLBUtils.maybe_unaliased(prob.u0, alias_u0) fx = NLBUtils.evaluate_f(prob, x) T = promote_type(eltype(fx), eltype(x)) @@ -40,6 +39,7 @@ function SciMLBase.__solve( # The way we write the 2nd order derivatives, we know Enzyme won't work there autodiff = alg.autodiff === nothing ? AutoForwardDiff() : alg.autodiff + jac_autodiff = autodiff === AutoTaylorDiff() ? AutoForwardDiff() : autodiff @set! alg.autodiff = autodiff @bb xo = copy(x) @@ -54,16 +54,16 @@ function SciMLBase.__solve( fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? NLBUtils.safe_similar(fx) : fx - jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) - J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) + jac_cache = Utils.prepare_jacobian(prob, jac_autodiff, fx_cache, x) + J = Utils.compute_jacobian!!(nothing, prob, jac_autodiff, fx_cache, x, jac_cache) for _ in 1:maxiters - if taylor_mode + if autodiff isa AutoTaylorDiff fx = NLBUtils.evaluate_f!!(prob, fx, x) - J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache) + J = Utils.compute_jacobian!!(J, prob, jac_autodiff, fx_cache, x, jac_cache) H = nothing else - fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x) + fx, J, H = Utils.compute_jacobian_and_hessian(jac_autodiff, prob, fx, x) end NLBUtils.can_setindex(x) || (A = J) @@ -81,7 +81,7 @@ function SciMLBase.__solve( aᵢ = J_fact \ NLBUtils.safe_vec(fx) - if taylor_mode + if autodiff isa AutoTaylorDiff Aaᵢ = evaluate_hvvp(Aaᵢ, prob, x, typeof(x)(aᵢ)) else A_ = NLBUtils.safe_vec(A) diff --git a/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl b/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl index af1847fc6..e459702d6 100644 --- a/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl @@ -82,25 +82,26 @@ end AutoForwardDiff(), AutoFiniteDiff(), AutoReverseDiff(), + AutoTaylorDiff(), nothing - ), taylor_mode in (Val(false), Val(true)) + ) @testset "[OOP] u0: $(typeof(u0))" for u0 in ( [1.0, 1.0], @SVector[1.0, 1.0], 1.0) - sol = run_nlsolve_oop(quadratic_f, u0; solver = alg(; autodiff, taylor_mode)) + sol = run_nlsolve_oop( + quadratic_f, u0; solver = alg(; autodiff)) @test SciMLBase.successful_retcode(sol) @test maximum(abs, quadratic_f(sol.u, 2.0)) < 1e-9 end end - @testset for taylor_mode in (Val(false), Val(true)) - @testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, - u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) + @testset "Termination Condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) - probN = NonlinearProblem(quadratic_f, u0, 2.0) - @test all(solve( - probN, alg(; autodiff = AutoForwardDiff(), taylor_mode); termination_condition).u .≈ - sqrt(2.0)) - end + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve( + probN, alg(; autodiff = AutoTaylorDiff()); + termination_condition).u .≈ + sqrt(2.0)) end end end