diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index 31550da96..7f9aec5a5 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -1,15 +1,19 @@ module NonlinearSolveBaseForwardDiffExt using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff +using ArrayInterface: ArrayInterface using CommonSolve: solve +using DifferentiationInterface: DifferentiationInterface, Constant using FastClosures: @closure using ForwardDiff: ForwardDiff, Dual +using LinearAlgebra: mul! using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem, - NonlinearProblem, - NonlinearLeastSquaresProblem, remake + NonlinearProblem, NonlinearLeastSquaresProblem, remake using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils +const DI = DifferentiationInterface + function NonlinearSolveBase.additional_incompatible_backend_check( prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff}) return !ForwardDiff.can_dual(eltype(prob.u0)) @@ -50,22 +54,108 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( return sol, partials end +function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( + prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...) + p = Utils.value(prob.p) + newprob = remake(prob; p, u0 = Utils.value(prob.u0)) + sol = solve(newprob, alg, args...; kwargs...) + uu = sol.u + + # First check for custom `vjp` then custom `Jacobian` and if nothing is provided use + # nested autodiff as the last resort + if SciMLBase.has_vjp(prob.f) + if SciMLBase.isinplace(prob) + vjp_fn = @closure (du, u, p) -> begin + resid = Utils.safe_similar(du, length(sol.resid)) + prob.f(resid, u, p) + prob.f.vjp(du, resid, u, p) + du .*= 2 + return nothing + end + else + vjp_fn = @closure (u, p) -> begin + resid = prob.f(u, p) + return reshape(2 .* prob.f.vjp(resid, u, p), size(u)) + end + end + elseif SciMLBase.has_jac(prob.f) + if SciMLBase.isinplace(prob) + vjp_fn = @closure (du, u, p) -> begin + J = Utils.safe_similar(du, length(sol.resid), length(u)) + prob.f.jac(J, u, p) + resid = Utils.safe_similar(du, length(sol.resid)) + prob.f(resid, u, p) + mul!(reshape(du, 1, :), vec(resid)', J, 2, false) + return nothing + end + else + vjp_fn = @closure (u, p) -> begin + return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u)) + end + end + else + # For small problems, nesting ForwardDiff is actually quite fast + autodiff = length(uu) + length(sol.resid) ≥ 50 ? + NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) : + AutoForwardDiff() + + if SciMLBase.isinplace(prob) + vjp_fn = @closure (du, u, p) -> begin + resid = Utils.safe_similar(du, length(sol.resid)) + prob.f(resid, u, p) + # Using `Constant` lead to dual ordering issues + ff = @closure (du, u) -> prob.f(du, u, p) + resid2 = copy(resid) + DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,)) + @. du *= 2 + return nothing + end + else + vjp_fn = @closure (u, p) -> begin + v = prob.f(u, p) + # Using `Constant` lead to dual ordering issues + ff = Base.Fix2(prob.f, p) + res = only(DI.pullback(ff, autodiff, u, (v,))) + ArrayInterface.can_setindex(res) || return 2 .* res + @. res *= 2 + return res + end + end + end + + Jₚ = nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p) + Jᵤ = nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p) + z = -Jᵤ \ Jₚ + pp = prob.p + sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z) + + if uu isa Number + partials = sum(sumfun, zip(z, pp)) + elseif p isa Number + partials = sumfun((z, pp)) + else + partials = sum(sumfun, zip(eachcol(z), pp)) + end + + return sol, partials +end + function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F} if SciMLBase.isinplace(prob) - f = @closure p -> begin + f2 = @closure p -> begin du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p))) f(du, u, p) return du end else - f = Base.Fix1(f, u) + f2 = Base.Fix1(f, u) end if p isa Number - return Utils.safe_reshape(ForwardDiff.derivative(f, p), :, 1) + return Utils.safe_reshape(ForwardDiff.derivative(f2, p), :, 1) elseif u isa Number - return Utils.safe_reshape(ForwardDiff.gradient(f, p), 1, :) + return Utils.safe_reshape(ForwardDiff.gradient(f2, p), 1, :) else - return ForwardDiff.jacobian(f, p) + return ForwardDiff.jacobian(f2, p) end end diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 081201a24..e83de5e39 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -61,6 +61,19 @@ function CommonSolve.solve( prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) end +function CommonSolve.solve( + prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip, + <:Union{ + <:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}}, + alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; + kwargs...) where {T, V, P, iip} + sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) + dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) +end + function CommonSolve.solve( prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) diff --git a/lib/SimpleNonlinearSolve/src/raphson.jl b/lib/SimpleNonlinearSolve/src/raphson.jl index 763fcf32a..a18a1b6be 100644 --- a/lib/SimpleNonlinearSolve/src/raphson.jl +++ b/lib/SimpleNonlinearSolve/src/raphson.jl @@ -24,7 +24,8 @@ end const SimpleGaussNewton = SimpleNewtonRaphson function SciMLBase.__solve( - prob::ImmutableNonlinearProblem, alg::SimpleNewtonRaphson, args...; + prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, + alg::SimpleNewtonRaphson, args...; abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false, termination_condition = nothing, kwargs...) x = Utils.maybe_unaliased(prob.u0, alias_u0)