Skip to content

Commit

Permalink
Merge pull request #282 from avik-pal/ap/krylov
Browse files Browse the repository at this point in the history
Jacobian-Free Krylov Versions for TR/LM/GN
  • Loading branch information
ChrisRackauckas authored Nov 22, 2023
2 parents 0026bc1 + bcfcc16 commit 46912f2
Show file tree
Hide file tree
Showing 15 changed files with 289 additions and 97 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
style = "sciml"
format_markdown = true
annotate_untyped_fields_with_any = false
format_docstrings = true
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.8.2"
version = "2.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -30,11 +30,13 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2"
Expand All @@ -50,7 +52,7 @@ FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearAlgebra = "1.9"
LinearAlgebra = "<0.0.1, 1"
LinearSolve = "2.12"
NonlinearProblemLibrary = "0.1"
PrecompileTools = "1"
Expand All @@ -59,7 +61,7 @@ Reexport = "0.2, 1"
SciMLBase = "2.8.2"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "1.9"
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.12"
StaticArraysCore = "1.4"
UnPack = "1.0"
Expand Down
7 changes: 7 additions & 0 deletions ext/NonlinearSolveZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module NonlinearSolveZygoteExt

import NonlinearSolve, Zygote

NonlinearSolve.is_extension_loaded(::Val{:Zygote}) = true

end
3 changes: 3 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ import DiffEqBase: AbstractNonlinearTerminationMode,
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

# Type-Inference Friendly Check for Extension Loading
is_extension_loaded(::Val) = false

abstract type AbstractNonlinearSolveLineSearchAlgorithm end

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
Expand Down
21 changes: 13 additions & 8 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ for solving `NonlinearLeastSquaresProblem`.
## Arguments:
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
on the Jacobian structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If `nothing`,
then `LeastSquaresOptim.jl` will choose the best linear solver based on the Jacobian
structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or
`:forward`.
!!! note
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
Expand All @@ -36,21 +38,24 @@ end
"""
FastLevenbergMarquardtJL(linsolve = :cholesky)
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
`NonlinearLeastSquaresProblem`.
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl)
for solving `NonlinearLeastSquaresProblem`.
!!! warning
This is not really the fastest solver. It is called that since the original package
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
!!! warning
This algorithm requires the jacobian function to be provided!
## Arguments:
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
!!! note
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
Expand Down
30 changes: 14 additions & 16 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ An advanced GaussNewton implementation with support for efficient handling of sp
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
for large-scale and numerically-difficult nonlinear least squares problems.
!!! note
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
is a more general extension of `Gauss-Newton` Method.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
Expand All @@ -33,28 +29,30 @@ for large-scale and numerically-difficult nonlinear least squares problems.
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
!!! warning
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
construction. This will be fixed in the near future.
- `vjp_autodiff`: Automatic Differentiation Backend used for vector-jacobian products.
This is applicable if the linear solver doesn't require a concrete jacobian, for eg.,
Krylov Methods. Defaults to `nothing`, which means if the problem is out of place and
`Zygote` is loaded then, we use `AutoZygote`. In all other, cases `FiniteDiff` is used.
"""
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
linesearch
vjp_autodiff
end

function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch, alg.vjp_autodiff)
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, vjp_autodiff = nothing,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch,
vjp_autodiff)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
Expand Down Expand Up @@ -122,8 +120,8 @@ function perform_step!(cache::GaussNewtonCache{true})
jacobian!!(J, cache)

if JᵀJ !== nothing
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)
__update_JᵀJ!(Val{true}(), cache, :JᵀJ, J)
__update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1)
end

# u = u - JᵀJ \ Jᵀfu
Expand Down Expand Up @@ -160,8 +158,8 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.J = jacobian!!(cache.J, cache)

if cache.JᵀJ !== nothing
cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1
__update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J)
__update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1)
end

# u = u - J \ fu
Expand Down
108 changes: 97 additions & 11 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
@concrete struct KrylovJᵀJ
JᵀJ
Jᵀ
end

SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)

sparsity_detection_alg(_, _) = NoSparsityDetection()
function sparsity_detection_alg(f, ad::AbstractSparseADType)
if f.sparsity === nothing
Expand Down Expand Up @@ -54,7 +61,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
Expand All @@ -63,12 +70,10 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
jac_cache = nothing
end

# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
J = if !(linsolve_needs_jac || alg_wants_jac)
if f.jvp === nothing
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
else
if iip
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
Expand All @@ -92,9 +97,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
du = _mutable_zero(u)

if needsJᵀJ
JᵀJ = __init_JᵀJ(J)
# FIXME: This needs to be handled better for JacVec Operator
Jᵀfu = J' * _vec(fu)
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f,
vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))),
jvp_autodiff = __get_nonsparse_ad(alg.ad))
end

if linsolve_init
Expand All @@ -120,21 +125,68 @@ function __setup_linsolve(A, b, u, p, alg)
nothing)..., weight)
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
end
__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg)

__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
__get_nonsparse_ad(ad) = ad

__init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = J' * J
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)
function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
JᵀJ = J' * J
Jᵀfu = J' * fu
return JᵀJ, Jᵀfu
end
function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
return JᵀJ, J' * fu
end
function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; f = nothing,
vjp_autodiff = nothing, jvp_autodiff = nothing, kwargs...)
# FIXME: Proper fix to this requires the FunctionOperator patch
if f !== nothing && f.vjp !== nothing
@warn "Currently we don't make use of user provided `jvp`. This is planned to be \
fixed in the near future."
end
autodiff = __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
Jᵀ = VecJac(uf, u; fu, autodiff)
JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u)
JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ)
Jᵀfu = Jᵀ * fu
return JᵀJ, Jᵀfu
end

function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
if vjp_autodiff === nothing
if isinplace(uf)
# VecJac can be only FiniteDiff
return AutoFiniteDiff()
else
# Short circuit if we see that FiniteDiff was used for J computation
jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
# Check if Zygote is loaded then use Zygote else use FiniteDiff
is_extension_loaded(Val{:Zygote}()) && return AutoZygote()
return AutoFiniteDiff()
end
else
ad = __get_nonsparse_ad(vjp_autodiff)
if isinplace(uf) && ad isa AutoZygote
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \
Falling back to finite differencing."
return AutoFiniteDiff()
end
return ad
end
end

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x
# LinearSolve with `nothing` doesn't dispatch correctly here
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
Expand All @@ -145,3 +197,37 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, nothing, u, nothing, nothing, u
end

function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
end
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J)
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J)
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H

function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return mul!(_vec(getproperty(cache, sym1)), J', fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu)
end

# Left-Right Multiplication
__lr_mul(::Val, H, g) = dot(g, H, g)
## TODO: Use a cache here to avoid allocations
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
c = similar(g)
mul!(c, H.JᵀJ, g)
return dot(g, c)
end
12 changes: 8 additions & 4 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
_concrete_jac = ifelse(concrete_jac === nothing, true, concrete_jac)
return LevenbergMarquardt{_unwrap_val(_concrete_jac)}(ad, linsolve, precs,
damping_initial, damping_increase_factor, damping_decrease_factor,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end
Expand Down Expand Up @@ -365,9 +366,10 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)
else
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache
cache.v .*= -1
end
end

Expand All @@ -383,9 +385,11 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
if linsolve === nothing
cache.a = -cache.mat_tmp \ _vec(J' * rhs_term)
else
linres = dolinsolve(alg.precs, linsolve; b = _mutable(_vec(J' * rhs_term)),
linu = _vec(cache.a), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _mutable(_vec(J' * rhs_term)), linu = _vec(cache.a), p,
reltol = cache.abstol, reuse_A_if_factorization = true)
cache.linsolve = linres.cache
cache.a .*= -1
end
end
cache.stats.nsolve += 1
Expand Down
Loading

0 comments on commit 46912f2

Please sign in to comment.