Skip to content

Commit

Permalink
Specialize on functions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 28, 2023
1 parent dcfee27 commit 411a649
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ end
jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))

# Build Jacobian Caches
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{iip};
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val{iip};
linsolve_kwargs = (;), lininit::Val{linsolve_init} = Val(true),
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init}
linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false)) where {iip, needsJᵀJ, linsolve_init, F}
uf = JacobianWrapper{iip}(f, p)

haslinsolve = hasfield(typeof(alg), :linsolve)
Expand Down Expand Up @@ -135,9 +135,9 @@ __maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
::Val{false}; linsolve_with_JᵀJ::Val{needsJᵀJ} = Val(false),
kwargs...) where {needsJᵀJ}
kwargs...) where {needsJᵀJ, F}
# NOTE: Scalar `u` assumes scalar output from `f`
uf = JacobianWrapper{false}(f, p)
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
Expand Down

0 comments on commit 411a649

Please sign in to comment.