From 75be6e49e8acd5d50d46297d5d6b22a94c80ffb8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 31 Oct 2024 12:29:21 -0400 Subject: [PATCH] fix: jacobian caching --- lib/NonlinearSolveBase/src/jacobian.jl | 25 ++++++++++++++++++------- test/core_tests.jl | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/lib/NonlinearSolveBase/src/jacobian.jl b/lib/NonlinearSolveBase/src/jacobian.jl index 73a544b8b..6ab31efd2 100644 --- a/lib/NonlinearSolveBase/src/jacobian.jl +++ b/lib/NonlinearSolveBase/src/jacobian.jl @@ -142,10 +142,16 @@ end ## Numbers function (cache::JacobianCache{<:Number})(::Number, u, p = cache.p) cache.stats.njacs += 1 - SciMLBase.has_jac(cache.f) && return cache.f.jac(u, p) - SciMLBase.has_vjp(cache.f) && return cache.f.vjp(one(u), u, p) - SciMLBase.has_jvp(cache.f) && return cache.f.jvp(one(u), u, p) - return DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) + cache.J = if SciMLBase.has_jac(cache.f) + cache.f.jac(u, p) + elseif SciMLBase.has_vjp(cache.f) + cache.f.vjp(one(u), u, p) + elseif SciMLBase.has_jvp(cache.f) + cache.f.jvp(one(u), u, p) + else + DI.derivative(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) + end + return cache.J end ## Actually Compute the Jacobian @@ -156,12 +162,17 @@ function (cache::JacobianCache)(J::Union{AbstractMatrix, Nothing}, u, p = cache. cache.f.jac(J, u, p) else DI.jacobian!( - cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p)) + cache.f, cache.fu, J, cache.di_extras, cache.autodiff, u, Constant(p) + ) end return J else - SciMLBase.has_jac(cache.f) && return cache.f.jac(u, p) - return DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) + if SciMLBase.has_jac(cache.f) + cache.J = cache.f.jac(u, p) + else + cache.J = DI.jacobian(cache.f, cache.di_extras, cache.autodiff, u, Constant(p)) + end + return cache.J end end diff --git a/test/core_tests.jl b/test/core_tests.jl index 301b0b389..65ed43ac8 100644 --- a/test/core_tests.jl +++ b/test/core_tests.jl @@ -4,7 +4,7 @@ dataOut = f([1, 2, 3], nothing) + 0.1 * randn(10, 1) resid(x, p) = f(x, p) - dataOut - jac(x, p) = [dataIn .^ 2 dataIn ones(10, 1)] + jac(x, p) = [1:10 .^ 2 1:10 ones(10, 1)] x0 = [1, 1, 1] prob = NonlinearLeastSquaresProblem(resid, x0)