From ad257fd2e2fc03112eb9b2886b312624f1dc9775 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Nov 2023 20:32:35 -0500 Subject: [PATCH] Allow specifying custom jvp --- Project.toml | 6 ++++-- src/NonlinearSolve.jl | 1 + src/jacobian.jl | 17 +++++++++++++++-- test/basictests.jl | 38 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ee073a893..fcc210bef 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "2.8.0" +version = "2.8.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -19,6 +19,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" @@ -48,14 +49,15 @@ FastLevenbergMarquardt = "0.1" FiniteDiff = "2" ForwardDiff = "0.10.3" LeastSquaresOptim = "0.8" -LinearAlgebra = "1.9" LineSearches = "7" +LinearAlgebra = "1.9" LinearSolve = "2.12" NonlinearProblemLibrary = "0.1" PrecompileTools = "1" RecursiveArrayTools = "2" Reexport = "0.2, 1" SciMLBase = "2.4" +SciMLOperators = "0.3" SimpleNonlinearSolve = "0.1.23" SparseArrays = "1.9" SparseDiffTools = "2.11" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 5c97b8340..58d10b290 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -23,6 +23,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work import RecursiveArrayTools: ArrayPartition, AbstractVectorOfArray, recursivecopy!, recursivefill! import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace + import SciMLOperators: FunctionOperator import StaticArraysCore: StaticArray, SVector, SArray, MArray import UnPack: @unpack diff --git a/src/jacobian.jl b/src/jacobian.jl index 41df3c092..368e0bb70 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -77,8 +77,21 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val # FIXME: To properly support needsJᵀJ without Jacobian, we need to implement # a reverse diff operation with the seed being `Jx`, this is not yet implemented J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ) - # We don't need to construct the Jacobian - JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad)) + if f.jvp === nothing + # We don't need to construct the Jacobian + JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad)) + else + if iip + jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du) + jvp! = (du, _, u, v) -> f.jvp(du, v, u, p) + else + jvp = (_, u, v) -> f.jvp(v, u, p) + jvp! = (du, _, u, v) -> (du .= f.jvp(v, u, p)) + end + op = SparseDiffTools.FwdModeAutoDiffVecProd(f, u, (), jvp, jvp!) + FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(false), + p, islinear = true) + end else if has_analytic_jac f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype diff --git a/test/basictests.jl b/test/basictests.jl index b1f9d3cb7..2ab059502 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -973,3 +973,41 @@ end termination_condition).u .≈ sqrt(2.0)) end end + +# Miscelleneous Tests +@testset "Custom JVP" begin + function F(u::Vector{Float64}, p::Vector{Float64}) + Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99)) + return u + 0.1 * u .* Δ * u - p + end + + function F!(du::Vector{Float64}, u::Vector{Float64}, p::Vector{Float64}) + Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99)) + du .= u + 0.1 * u .* Δ * u - p + return nothing + end + + function JVP(v::Vector{Float64}, u::Vector{Float64}, p::Vector{Float64}) + Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99)) + return v + 0.1 * (u .* Δ * v + v .* Δ * u) + end + + function JVP!(du::Vector{Float64}, v::Vector{Float64}, u::Vector{Float64}, + p::Vector{Float64}) + Δ = Tridiagonal(-ones(99), 2 * ones(100), -ones(99)) + du .= v + 0.1 * (u .* Δ * v + v .* Δ * u) + return nothing + end + + u0 = rand(100) + + prob = NonlinearProblem(NonlinearFunction{false}(F; jvp = JVP), u0, u0) + sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES())) + + @test norm(F(sol.u, u0)) ≤ 1e-8 + + prob = NonlinearProblem(NonlinearFunction{true}(F!; jvp = JVP!), u0, u0) + sol = solve(prob, NewtonRaphson(; linsolve = KrylovJL_GMRES())) + + @test norm(F(sol.u, u0)) ≤ 1e-8 +end