Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jacobian-Free Krylov Versions for TR/LM/GN #282

Merged
merged 8 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
style = "sciml"
format_markdown = true
annotate_untyped_fields_with_any = false
format_docstrings = true
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.8.2"
version = "2.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -30,11 +30,13 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2"
Expand All @@ -50,7 +52,7 @@ FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearAlgebra = "1.9"
LinearAlgebra = "<0.0.1, 1"
LinearSolve = "2.12"
NonlinearProblemLibrary = "0.1"
PrecompileTools = "1"
Expand All @@ -59,7 +61,7 @@ Reexport = "0.2, 1"
SciMLBase = "2.8.2"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "1.9"
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.12"
StaticArraysCore = "1.4"
UnPack = "1.0"
Expand Down
7 changes: 7 additions & 0 deletions ext/NonlinearSolveZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module NonlinearSolveZygoteExt

import NonlinearSolve, Zygote

NonlinearSolve.is_extension_loaded(::Val{:Zygote}) = true

end
3 changes: 3 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

# Type-Inference Friendly Check for Extension Loading
is_extension_loaded(::Val) = false

Check warning on line 42 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L42

Added line #L42 was not covered by tests

abstract type AbstractNonlinearSolveLineSearchAlgorithm end

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
Expand Down
21 changes: 13 additions & 8 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ for solving `NonlinearLeastSquaresProblem`.

## Arguments:

- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
on the Jacobian structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If `nothing`,
then `LeastSquaresOptim.jl` will choose the best linear solver based on the Jacobian
structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or
`:forward`.

!!! note

This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
Expand All @@ -36,21 +38,24 @@ end
"""
FastLevenbergMarquardtJL(linsolve = :cholesky)

Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
`NonlinearLeastSquaresProblem`.
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl)
for solving `NonlinearLeastSquaresProblem`.

!!! warning

This is not really the fastest solver. It is called that since the original package
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.

!!! warning

This algorithm requires the jacobian function to be provided!

## Arguments:

- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.

!!! note

This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
Expand Down
30 changes: 14 additions & 16 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ An advanced GaussNewton implementation with support for efficient handling of sp
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
for large-scale and numerically-difficult nonlinear least squares problems.

!!! note
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
is a more general extension of `Gauss-Newton` Method.

### Keyword Arguments

- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
Expand All @@ -33,28 +29,30 @@ for large-scale and numerically-difficult nonlinear least squares problems.
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.

!!! warning

Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
construction. This will be fixed in the near future.
- `vjp_autodiff`: Automatic Differentiation Backend used for vector-jacobian products.
This is applicable if the linear solver doesn't require a concrete jacobian, for eg.,
Krylov Methods. Defaults to `nothing`, which means if the problem is out of place and
`Zygote` is loaded then, we use `AutoZygote`. In all other, cases `FiniteDiff` is used.
"""
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
linesearch
vjp_autodiff
end

function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch, alg.vjp_autodiff)
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, vjp_autodiff = nothing,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch,
vjp_autodiff)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
Expand Down Expand Up @@ -122,8 +120,8 @@ function perform_step!(cache::GaussNewtonCache{true})
jacobian!!(J, cache)

if JᵀJ !== nothing
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)
__update_JᵀJ!(Val{true}(), cache, :JᵀJ, J)
__update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1)
end

# u = u - JᵀJ \ Jᵀfu
Expand Down Expand Up @@ -160,8 +158,8 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.J = jacobian!!(cache.J, cache)

if cache.JᵀJ !== nothing
cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1
__update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J)
__update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1)
end

# u = u - J \ fu
Expand Down
108 changes: 97 additions & 11 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
@concrete struct KrylovJᵀJ
JᵀJ
Jᵀ
end

SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)

Check warning on line 6 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L6

Added line #L6 was not covered by tests

sparsity_detection_alg(_, _) = NoSparsityDetection()
function sparsity_detection_alg(f, ad::AbstractSparseADType)
if f.sparsity === nothing
Expand Down Expand Up @@ -54,7 +61,7 @@
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
Expand All @@ -63,12 +70,10 @@
jac_cache = nothing
end

# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
J = if !(linsolve_needs_jac || alg_wants_jac)
if f.jvp === nothing
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
else
if iip
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
Expand All @@ -92,9 +97,9 @@
du = _mutable_zero(u)

if needsJᵀJ
JᵀJ = __init_JᵀJ(J)
# FIXME: This needs to be handled better for JacVec Operator
Jᵀfu = J' * _vec(fu)
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f,
vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))),
jvp_autodiff = __get_nonsparse_ad(alg.ad))
end

if linsolve_init
Expand All @@ -120,21 +125,68 @@
nothing)..., weight)
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
end
__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg)

Check warning on line 128 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L128

Added line #L128 was not covered by tests

__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
__get_nonsparse_ad(ad) = ad

__init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = J' * J
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)

Check warning on line 135 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L135

Added line #L135 was not covered by tests
function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
JᵀJ = J' * J
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is actually using this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the algorithms by default. But if LM/GN/TR is forced to use a Linear Solve which only works with square matrices then this needs to be triggered.

Jᵀfu = J' * fu
return JᵀJ, Jᵀfu
end
function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
return JᵀJ, J' * fu
end
function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; f = nothing,
vjp_autodiff = nothing, jvp_autodiff = nothing, kwargs...)
# FIXME: Proper fix to this requires the FunctionOperator patch
if f !== nothing && f.vjp !== nothing
@warn "Currently we don't make use of user provided `jvp`. This is planned to be \
fixed in the near future."
end
autodiff = __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
Jᵀ = VecJac(uf, u; fu, autodiff)
JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operator shouldn't need to be constructed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without doing the cache thing, it complained that for in place operations we need to run set the cache (something along those lines)

JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ)
Jᵀfu = Jᵀ * fu
return JᵀJ, Jᵀfu
end

function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
if vjp_autodiff === nothing
if isinplace(uf)
# VecJac can be only FiniteDiff
return AutoFiniteDiff()
else
# Short circuit if we see that FiniteDiff was used for J computation
jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
# Check if Zygote is loaded then use Zygote else use FiniteDiff
is_extension_loaded(Val{:Zygote}()) && return AutoZygote()
return AutoFiniteDiff()

Check warning on line 170 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L170

Added line #L170 was not covered by tests
end
else
ad = __get_nonsparse_ad(vjp_autodiff)
if isinplace(uf) && ad isa AutoZygote
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \
Falling back to finite differencing."
return AutoFiniteDiff()
end
return ad
end
end

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x
# LinearSolve with `nothing` doesn't dispatch correctly here
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x

Check warning on line 188 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L188

Added line #L188 was not covered by tests
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
Expand All @@ -145,3 +197,37 @@
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, nothing, u, nothing, nothing, u
end

function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
end
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J)
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J)
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H

function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return mul!(_vec(getproperty(cache, sym1)), J', fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu)
end

# Left-Right Multiplication
__lr_mul(::Val, H, g) = dot(g, H, g)
## TODO: Use a cache here to avoid allocations
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
c = similar(g)
mul!(c, H.JᵀJ, g)
return dot(g, c)
end
12 changes: 8 additions & 4 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
_concrete_jac = ifelse(concrete_jac === nothing, true, concrete_jac)
return LevenbergMarquardt{_unwrap_val(_concrete_jac)}(ad, linsolve, precs,
damping_initial, damping_increase_factor, damping_decrease_factor,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end
Expand Down Expand Up @@ -365,9 +366,10 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)
else
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache
cache.v .*= -1
end
end

Expand All @@ -383,9 +385,11 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
if linsolve === nothing
cache.a = -cache.mat_tmp \ _vec(J' * rhs_term)
else
linres = dolinsolve(alg.precs, linsolve; b = _mutable(_vec(J' * rhs_term)),
linu = _vec(cache.a), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _mutable(_vec(J' * rhs_term)), linu = _vec(cache.a), p,
reltol = cache.abstol, reuse_A_if_factorization = true)
cache.linsolve = linres.cache
cache.a .*= -1
end
end
cache.stats.nsolve += 1
Expand Down
Loading