diff --git a/src/jacobian.jl b/src/jacobian.jl index cfea15239..a6d71f159 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -5,6 +5,11 @@ end SciMLBase.isinplace(::JacobianWrapper{iip}) where {iip} = iip +@concrete struct KrylovJᵀJ + JᵀJ + Jᵀ +end + # Previous Implementation did not hold onto `iip`, but this causes problems in packages # where we check for the presence of function signatures to check which dispatch to call (uf::JacobianWrapper{false})(u) = uf.f(u, uf.p) @@ -67,7 +72,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val # NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) : (iip ? deepcopy(f.resid_prototype) : f.resid_prototype) - if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) # || needsJᵀJ) + if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) sd = sparsity_detection_alg(f, alg.ad) ad = alg.ad jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) : @@ -76,9 +81,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val jac_cache = nothing end - # FIXME: To properly support needsJᵀJ without Jacobian, we need to implement - # a reverse diff operation with the seed being `Jx`, this is not yet implemented - J = if !(linsolve_needs_jac || alg_wants_jac)# || needsJᵀJ) + J = if !(linsolve_needs_jac || alg_wants_jac) # We don't need to construct the Jacobian JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad)) else @@ -147,11 +150,6 @@ function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; return JᵀJ, Jᵀfu end -@concrete struct KrylovJᵀJ - JᵀJ - Jᵀ -end - SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf) @@ -226,3 +224,6 @@ function __lr_mul(::Val{true}, H::KrylovJᵀJ, g) mul!(c, H.JᵀJ, g) return dot(g, c) end + +__zero(JᵀJ) = zero(JᵀJ) +__zero(JᵀJ::KrylovJᵀJ) = JᵀJ diff --git a/src/levenberg.jl b/src/levenberg.jl index fa3189332..a9b0bf89f 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -106,7 +106,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing, α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) - return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs, + _concrete_jac = ifelse(concrete_jac === nothing, true, concrete_jac) + return LevenbergMarquardt{_unwrap_val(_concrete_jac)}(ad, linsolve, precs, damping_initial, damping_increase_factor, damping_decrease_factor, finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D) end