Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jacobian-Free Krylov Versions for TR/LM/GN #282

Merged
merged 8 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.8.2"
version = "2.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -50,7 +50,7 @@ FiniteDiff = "2"
ForwardDiff = "0.10.3"
LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearAlgebra = "1.9"
LinearAlgebra = "<0.0.1, 1"
LinearSolve = "2.12"
NonlinearProblemLibrary = "0.1"
PrecompileTools = "1"
Expand All @@ -59,7 +59,7 @@ Reexport = "0.2, 1"
SciMLBase = "2.8.2"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "1.9"
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.12"
StaticArraysCore = "1.4"
UnPack = "1.0"
Expand Down
30 changes: 14 additions & 16 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ An advanced GaussNewton implementation with support for efficient handling of sp
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
for large-scale and numerically-difficult nonlinear least squares problems.

!!! note
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
is a more general extension of `Gauss-Newton` Method.

### Keyword Arguments

- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
Expand All @@ -33,28 +29,30 @@ for large-scale and numerically-difficult nonlinear least squares problems.
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.

!!! warning

Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
construction. This will be fixed in the near future.
- `vjp_autodiff`: Automatic Differentiation Backend used for vector-jacobian products.
This is applicable if the linear solver doesn't require a concrete jacobian, for eg.,
Krylov Methods. Defaults to `nothing`, which means if the problem is out of place and
`Zygote` is loaded then, we use `AutoZygote`. In all other, cases `FiniteDiff` is used.
"""
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
linsolve
precs
linesearch
vjp_autodiff
end

function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch, alg.vjp_autodiff)
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, vjp_autodiff = nothing,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch,
vjp_autodiff)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
Expand Down Expand Up @@ -122,8 +120,8 @@ function perform_step!(cache::GaussNewtonCache{true})
jacobian!!(J, cache)

if JᵀJ !== nothing
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)
__update_JᵀJ!(Val{true}(), cache, :JᵀJ, J)
__update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1)
end

# u = u - JᵀJ \ Jᵀfu
Expand Down Expand Up @@ -160,8 +158,8 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.J = jacobian!!(cache.J, cache)

if cache.JᵀJ !== nothing
cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1
__update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J)
__update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1)
end

# u = u - J \ fu
Expand Down
112 changes: 101 additions & 11 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
@concrete struct KrylovJᵀJ
JᵀJ
Jᵀ
end

SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)

Check warning on line 6 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L6

Added line #L6 was not covered by tests

sparsity_detection_alg(_, _) = NoSparsityDetection()
function sparsity_detection_alg(f, ad::AbstractSparseADType)
if f.sparsity === nothing
Expand Down Expand Up @@ -54,7 +61,7 @@
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
Expand All @@ -63,12 +70,10 @@
jac_cache = nothing
end

# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
J = if !(linsolve_needs_jac || alg_wants_jac)
if f.jvp === nothing
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))
else
if iip
jvp = (_, u, v) -> (du = similar(fu); f.jvp(du, v, u, p); du)
Expand All @@ -92,9 +97,9 @@
du = _mutable_zero(u)

if needsJᵀJ
JᵀJ = __init_JᵀJ(J)
# FIXME: This needs to be handled better for JacVec Operator
Jᵀfu = J' * _vec(fu)
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; f,
vjp_autodiff = __get_nonsparse_ad(_getproperty(alg, Val(:vjp_autodiff))),
jvp_autodiff = __get_nonsparse_ad(alg.ad))
end

if linsolve_init
Expand All @@ -120,21 +125,72 @@
nothing)..., weight)
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
end
__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg)

Check warning on line 128 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L128

Added line #L128 was not covered by tests

__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
__get_nonsparse_ad(ad) = ad

__init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = J' * J
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)

Check warning on line 135 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L135

Added line #L135 was not covered by tests
function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
JᵀJ = J' * J
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is actually using this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the algorithms by default. But if LM/GN/TR is forced to use a Linear Solve which only works with square matrices then this needs to be triggered.

Jᵀfu = J' * fu
return JᵀJ, Jᵀfu
end
function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
return JᵀJ, J' * fu
end
function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; f = nothing,
vjp_autodiff = nothing, jvp_autodiff = nothing, kwargs...)
# FIXME: Proper fix to this requires the FunctionOperator patch
if f !== nothing && f.vjp !== nothing
@warn "Currently we don't make use of user provided `jvp`. This is planned to be \
fixed in the near future."
end
autodiff = __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
Jᵀ = VecJac(uf, u; fu, autodiff)
JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operator shouldn't need to be constructed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without doing the cache thing, it complained that for in place operations we need to run set the cache (something along those lines)

JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ)
Jᵀfu = Jᵀ * fu
return JᵀJ, Jᵀfu
end

function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
if vjp_autodiff === nothing
if isinplace(uf)
# VecJac can be only FiniteDiff
return AutoFiniteDiff()
else
# Short circuit if we see that FiniteDiff was used for J computation
jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
# Check if Zygote is loaded then use Zygote else use FiniteDiff
if haskey(Base.loaded_modules,
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
return AutoZygote()
else
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
return AutoFiniteDiff()

Check warning on line 173 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L173

Added line #L173 was not covered by tests
end
end
else
ad = __get_nonsparse_ad(vjp_autodiff)
if isinplace(uf) && ad isa AutoZygote
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \
Falling back to finite differencing."
return AutoFiniteDiff()
end
return ad
end
end

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x
# LinearSolve with `nothing` doesn't dispatch correctly here
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x

Check warning on line 192 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L192

Added line #L192 was not covered by tests
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
Expand All @@ -145,3 +201,37 @@
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, nothing, u, nothing, nothing, u
end

function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
end
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J)
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J)
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H

function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return mul!(_vec(getproperty(cache, sym1)), J', fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu)
end

# Left-Right Multiplication
__lr_mul(::Val, H, g) = dot(g, H, g)
## TODO: Use a cache here to avoid allocations
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
c = similar(g)
mul!(c, H.JᵀJ, g)
return dot(g, c)
end
12 changes: 8 additions & 4 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
_concrete_jac = ifelse(concrete_jac === nothing, true, concrete_jac)
return LevenbergMarquardt{_unwrap_val(_concrete_jac)}(ad, linsolve, precs,
damping_initial, damping_increase_factor, damping_decrease_factor,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
end
Expand Down Expand Up @@ -365,9 +366,10 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
if linsolve === nothing
cache.v = -cache.mat_tmp \ (J' * fu1)
else
linres = dolinsolve(alg.precs, linsolve; A = -__maybe_symmetric(cache.mat_tmp),
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _vec(J' * _vec(fu1)), linu = _vec(cache.v), p, reltol = cache.abstol)
cache.linsolve = linres.cache
cache.v .*= -1
end
end

Expand All @@ -383,9 +385,11 @@ function perform_step!(cache::LevenbergMarquardtCache{false, fastls}) where {fas
if linsolve === nothing
cache.a = -cache.mat_tmp \ _vec(J' * rhs_term)
else
linres = dolinsolve(alg.precs, linsolve; b = _mutable(_vec(J' * rhs_term)),
linu = _vec(cache.a), p, reltol = cache.abstol)
linres = dolinsolve(alg.precs, linsolve; A = __maybe_symmetric(cache.mat_tmp),
b = _mutable(_vec(J' * rhs_term)), linu = _vec(cache.a), p,
reltol = cache.abstol, reuse_A_if_factorization = true)
cache.linsolve = linres.cache
cache.a .*= -1
end
end
cache.stats.nsolve += 1
Expand Down
29 changes: 21 additions & 8 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
LineSearch(method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
LineSearch(method = nothing, autodiff = nothing, alpha = true)

Wrapper over algorithms from
[LineSeaches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl/). Allows automatic
Expand All @@ -13,7 +13,7 @@
- `autodiff`: the automatic differentiation backend to use for the line search. Defaults to
`AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP.
`AutoZygote()` will be faster in most cases, but it requires `Zygote.jl` to be manually
installed and loaded
installed and loaded.
- `alpha`: the initial step size to use. Defaults to `true` (which is equivalent to `1`).
"""
@concrete struct LineSearch
Expand All @@ -22,7 +22,7 @@
α
end

function LineSearch(; method = nothing, autodiff = AutoFiniteDiff(), alpha = true)
function LineSearch(; method = nothing, autodiff = nothing, alpha = true)
return LineSearch(method, autodiff, alpha)
end

Expand Down Expand Up @@ -113,15 +113,28 @@

g₀ = _mutable_zero(u)

autodiff = if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote)
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. Falling \
back to finite differencing."
AutoFiniteDiff()
autodiff = if ls.autodiff === nothing
if !iip && haskey(Base.loaded_modules,
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
AutoZygote()
else
AutoFiniteDiff()
end
else
ls.autodiff
if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote)
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \

Check warning on line 125 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L125

Added line #L125 was not covered by tests
Falling back to finite differencing."
AutoFiniteDiff()

Check warning on line 127 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L127

Added line #L127 was not covered by tests
else
ls.autodiff
end
end

function g!(u, fu)
if f.jvp !== nothing
@warn "Currently we don't make use of user provided `jvp` in linesearch. This \

Check warning on line 135 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L135

Added line #L135 was not covered by tests
is planned to be fixed in the near future." maxlog=1
end
op = VecJac(SciMLBase.JacobianWrapper(f, p), u; fu = fu1, autodiff)
if iip
mul!(g₀, op, fu)
Expand Down
Loading