Skip to content

Commit

Permalink
Some progress on LM
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 13, 2023
1 parent c1343b8 commit a942138
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
19 changes: 10 additions & 9 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ end

SciMLBase.isinplace(::JacobianWrapper{iip}) where {iip} = iip

Check warning on line 6 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L6

Added line #L6 was not covered by tests

@concrete struct KrylovJᵀJ
JᵀJ
Jᵀ
end

# Previous Implementation did not hold onto `iip`, but this causes problems in packages
# where we check for the presence of function signatures to check which dispatch to call
(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
Expand Down Expand Up @@ -67,7 +72,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) # || needsJᵀJ)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
Expand All @@ -76,9 +81,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
jac_cache = nothing
end

# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac)# || needsJᵀJ)
J = if !(linsolve_needs_jac || alg_wants_jac)
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
else
Expand Down Expand Up @@ -147,11 +150,6 @@ function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...;
return JᵀJ, Jᵀfu

Check warning on line 150 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L145-L150

Added lines #L145 - L150 were not covered by tests
end

@concrete struct KrylovJᵀJ
JᵀJ
Jᵀ
end

SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)

Check warning on line 153 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L153

Added line #L153 was not covered by tests

function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf)
Expand Down Expand Up @@ -226,3 +224,6 @@ function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
mul!(c, H.JᵀJ, g)
return dot(g, c)

Check warning on line 225 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L221-L225

Added lines #L221 - L225 were not covered by tests
end

__zero(JᵀJ) = zero(JᵀJ)
__zero(JᵀJ::KrylovJᵀJ) = JᵀJ
3 changes: 2 additions & 1 deletion src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
_concrete_jac = ifelse(concrete_jac === nothing, true, concrete_jac)
return LevenbergMarquardt{_unwrap_val(_concrete_jac)}(ad, linsolve, precs,

Check warning on line 110 in src/levenberg.jl

View check run for this annotation

Codecov / codecov/patch

src/levenberg.jl#L109-L110

Added lines #L109 - L110 were not covered by tests
damping_initial, damping_increase_factor, damping_decrease_factor,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end
Expand Down

0 comments on commit a942138

Please sign in to comment.