diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 203d06f14..95d077614 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -2,7 +2,7 @@ module NonlinearSolveBaseForwardDiffExt using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface -using CommonSolve: CommonSolve, solve +using CommonSolve: CommonSolve, solve, solve!, init using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface using FastClosures: @closure @@ -10,14 +10,14 @@ using ForwardDiff: ForwardDiff, Dual using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, NonlinearProblem, NonlinearLeastSquaresProblem, remake -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, - AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, - AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI, + AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm, + NonlinearSolveForwardDiffCache const DI = DifferentiationInterface const GENERAL_SOLVER_TYPES = [ - Nothing, AbstractNonlinearSolveAlgorithm, NonlinearSolvePolyAlgorithm + Nothing, NonlinearSolvePolyAlgorithm ] const DualNonlinearProblem = NonlinearProblem{ @@ -135,24 +135,16 @@ for algType in GENERAL_SOLVER_TYPES end end -@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache - cache - prob - alg - p - values_p - partials_p -end - function InternalAPI.reinit!( cache::NonlinearSolveForwardDiffCache, args...; p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs... ) InternalAPI.reinit!( - cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs... + cache.cache; p = NonlinearSolveBase.nodual_value(p), + u0 = NonlinearSolveBase.nodual_value(u0), kwargs... ) cache.p = p - cache.values_p = nodual_value(p) + cache.values_p = NonlinearSolveBase.nodual_value(p) cache.partials_p = ForwardDiff.partials(p) return cache end @@ -161,8 +153,8 @@ for algType in GENERAL_SOLVER_TYPES @eval function SciMLBase.__init( prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs... ) - p = nodual_value(prob.p) - newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p) + p = NonlinearSolveBase.nodual_value(prob.p) + newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p) cache = init(newprob, alg, args...; kwargs...) return NonlinearSolveForwardDiffCache( cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p) @@ -196,8 +188,17 @@ function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache) ) end -nodual_value(x) = x -nodual_value(x::Dual) = ForwardDiff.value(x) -nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) +NonlinearSolveBase.nodual_value(x) = x +NonlinearSolveBase.nodual_value(x::Dual) = ForwardDiff.value(x) +NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) + +""" + pickchunksize(x) = pickchunksize(length(x)) + pickchunksize(x::Int) + +Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. +""" +@inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x)) +@inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 8fd4b1947..df65e1fed 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -58,6 +58,8 @@ include("descent/geodesic_acceleration.jl") include("solve.jl") +include("forward_diff.jl") + # Unexported Public API @compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance)) @compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution)) diff --git a/lib/NonlinearSolveBase/src/common_defaults.jl b/lib/NonlinearSolveBase/src/common_defaults.jl index 5a5433ee3..4518063a5 100644 --- a/lib/NonlinearSolveBase/src/common_defaults.jl +++ b/lib/NonlinearSolveBase/src/common_defaults.jl @@ -45,12 +45,3 @@ function get_tolerance(::Union{StaticArray, Number}, ::Nothing, ::Type{T}) where # Rational numbers can throw an error if used inside GPU Kernels return T(real(oneunit(T)) * (eps(real(one(T)))^(real(T)(0.8)))) end - -""" - pickchunksize(x) = pickchunksize(length(x)) - pickchunksize(x::Int) - -Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length. -""" -@inline pickchunksize(x) = pickchunksize(length(x)) -@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) diff --git a/lib/NonlinearSolveBase/src/forward_diff.jl b/lib/NonlinearSolveBase/src/forward_diff.jl new file mode 100644 index 000000000..a588aa52d --- /dev/null +++ b/lib/NonlinearSolveBase/src/forward_diff.jl @@ -0,0 +1,8 @@ +@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache + cache + prob + alg + p + values_p + partials_p +end diff --git a/lib/NonlinearSolveBase/src/public.jl b/lib/NonlinearSolveBase/src/public.jl index d076f7873..b68e3806f 100644 --- a/lib/NonlinearSolveBase/src/public.jl +++ b/lib/NonlinearSolveBase/src/public.jl @@ -11,6 +11,8 @@ function nonlinearsolve_dual_solution end function nonlinearsolve_∂f_∂p end function nonlinearsolve_∂f_∂u end function nlls_generate_vjp_function end +function nodual_value end +function pickchunksize end # Nonlinear Solve Termination Conditions abstract type AbstractNonlinearTerminationMode end diff --git a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl index 15b99c5d1..666cc7435 100644 --- a/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl +++ b/lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl @@ -22,7 +22,7 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm, Utils, InternalAPI, get_timer_output, @static_timeit, update_trace!, L2_NORM, NonlinearSolvePolyAlgorithm, NewtonDescent, DampedNewtonDescent, GeodesicAcceleration, - Dogleg + Dogleg, NonlinearSolveForwardDiffCache using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode, NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize diff --git a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl index afba60d43..ca4e7bb94 100644 --- a/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl +++ b/lib/NonlinearSolveQuasiNewton/ext/NonlinearSolveQuasiNewtonForwardDiffExt.jl @@ -1,9 +1,8 @@ module NonlinearSolveQuasiNewtonForwardDiffExt -using CommonSolve: CommonSolve, solve +using CommonSolve: CommonSolve, init using ForwardDiff: ForwardDiff, Dual -using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, - NonlinearProblem, NonlinearLeastSquaresProblem, remake +using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem using NonlinearSolveBase: NonlinearSolveBase diff --git a/lib/NonlinearSolveSpectralMethods/Project.toml b/lib/NonlinearSolveSpectralMethods/Project.toml index 7175c5ea9..a248be107 100644 --- a/lib/NonlinearSolveSpectralMethods/Project.toml +++ b/lib/NonlinearSolveSpectralMethods/Project.toml @@ -27,6 +27,7 @@ CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" DiffEqBase = "6.158.3" ExplicitImports = "1.5" +ForwardDiff = "0.10.36" Hwloc = "3" InteractiveUtils = "<0.0.1, 1" LineSearch = "0.1.4" diff --git a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl index 86604d7e2..5dfc559f6 100644 --- a/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl +++ b/lib/NonlinearSolveSpectralMethods/ext/NonlinearSolveSpectralMethodsForwardDiffExt.jl @@ -1,9 +1,8 @@ module NonlinearSolveSpectralMethodsForwardDiffExt -using CommonSolve: CommonSolve, solve +using CommonSolve: CommonSolve, init using ForwardDiff: ForwardDiff, Dual -using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, - NonlinearProblem, NonlinearLeastSquaresProblem, remake +using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem using NonlinearSolveBase: NonlinearSolveBase