From cb105fcbbbd3f172d7d6ce2620dc751b55d34cee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 5 Oct 2023 21:00:39 -0400 Subject: [PATCH] Fix Jacobian Construction --- src/NonlinearSolve.jl | 2 +- src/jacobian.jl | 26 +++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 90839dea5..168658bfe 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -4,7 +4,7 @@ if isdefined(Base, :Experimental) && isdefined(Base.Experimental, Symbol("@max_m @eval Base.Experimental.@max_methods 1 end -using DiffEqBase, LinearAlgebra, LinearSolve, SparseDiffTools +using DiffEqBase, LinearAlgebra, LinearSolve, SparseArrays, SparseDiffTools import ForwardDiff import ADTypes: AbstractFiniteDifferencesMode diff --git a/src/jacobian.jl b/src/jacobian.jl index 82f2ef2bb..d9327e701 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -80,7 +80,11 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii if has_analytic_jac f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype else - f.jac_prototype === nothing ? init_jacobian(jac_cache) : f.jac_prototype + if f.jac_prototype === nothing + __safe_init_jacobian(jac_cache) + else + f.jac_prototype + end end end @@ -98,6 +102,26 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii return uf, linsolve, J, fu, jac_cache, du end +@generated function __getfield(c::T, ::Val{S}) where {T, S} + hasfield(T, S) && return :(c.$(S)) + return :(nothing) +end + +function __safe_init_jacobian(c::SparseDiffTools.AbstractMaybeSparseJacobianCache) + T = promote_type(eltype(c.fx), eltype(c.x)) + return __safe_init_jacobian(__getfield(c, Val(:jac_prototype)), T, c.fx, c.x) +end +function __safe_init_jacobian(::Nothing, ::Type{T}, fx, x) where {T} + return similar(fx, T, length(fx), length(x)) +end +function __safe_init_jacobian(J::SparseMatrixCSC, ::Type{T}, fx, x) where {T} + @assert size(J, 1) == length(fx) && size(J, 2) == length(x) + return T.(J) +end +function __safe_init_jacobian(J, ::Type{T}, fx, x) where {T} + return similar(fx, T, length(fx), length(x)) # This is not safe for sparse jacobians +end + __get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff() __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff() __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()