Skip to content

Commit

Permalink
Krylov Version for Trust Region
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 13, 2023
1 parent fb2f19a commit c1343b8
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.8.0"
version = "2.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
9 changes: 0 additions & 9 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@ An advanced GaussNewton implementation with support for efficient handling of sp
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
for large-scale and numerically-difficult nonlinear least squares problems.
!!! note
In most practical situations, users should prefer using `LevenbergMarquardt` instead! It
is a more general extension of `Gauss-Newton` Method.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
Expand All @@ -33,11 +29,6 @@ for large-scale and numerically-difficult nonlinear least squares problems.
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
!!! warning
Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian
construction. This will be fixed in the near future.
"""
@concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD}
ad::AD
Expand Down
99 changes: 91 additions & 8 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
p
end

SciMLBase.isinplace(::JacobianWrapper{iip}) where {iip} = iip

# Previous Implementation did not hold onto `iip`, but this causes problems in packages
# where we check for the presence of function signatures to check which dispatch to call
(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
Expand Down Expand Up @@ -65,7 +67,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
# NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere
fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) :
(iip ? deepcopy(f.resid_prototype) : f.resid_prototype)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) # || needsJᵀJ)
sd = sparsity_detection_alg(f, alg.ad)
ad = alg.ad
jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) :
Expand All @@ -76,7 +78,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val

# FIXME: To properly support needsJᵀJ without Jacobian, we need to implement
# a reverse diff operation with the seed being `Jx`, this is not yet implemented
J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ)
J = if !(linsolve_needs_jac || alg_wants_jac)# || needsJᵀJ)
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
else
Expand All @@ -90,9 +92,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val
du = _mutable_zero(u)

if needsJᵀJ
JᵀJ = __init_JᵀJ(J)
# FIXME: This needs to be handled better for JacVec Operator
Jᵀfu = J' * _vec(fu)
# TODO: Pass in `jac_transpose_autodiff`
JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u;
jac_autodiff = __get_nonsparse_ad(alg.ad))
end

if linsolve_init
Expand All @@ -118,21 +120,68 @@ function __setup_linsolve(A, b, u, p, alg)
nothing)..., weight)
return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr)
end
__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg)

__get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff()
__get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff()
__get_nonsparse_ad(::AutoSparseZygote) = AutoZygote()
__get_nonsparse_ad(ad) = ad

__init_JᵀJ(J::Number) = zero(J)
__init_JᵀJ(J::AbstractArray) = J' * J
__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J)
function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...)
JᵀJ = J' * J
Jᵀfu = J' * fu
return JᵀJ, Jᵀfu
end
function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...)
JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef)
return JᵀJ, J' * fu
end
function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...;
jac_transpose_autodiff = nothing, jac_autodiff = nothing, kwargs...)
autodiff = __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf)
Jᵀ = VecJac(uf, u; autodiff)
JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u)
JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ)
Jᵀfu = Jᵀ * fu
return JᵀJ, Jᵀfu
end

@concrete struct KrylovJᵀJ
JᵀJ
Jᵀ
end

SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ)

function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf)
if jac_transpose_autodiff === nothing
if isinplace(uf)
# VecJac can be only FiniteDiff
return AutoFiniteDiff()
else
# Short circuit if we see that FiniteDiff was used for J computation
jac_autodiff isa AutoFiniteDiff && return jac_autodiff
# Check if Zygote is loaded then use Zygote else use FiniteDiff
if haskey(Base.loaded_modules,
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
return AutoZygote()
else
return AutoFiniteDiff()
end
end
else
return __get_nonsparse_ad(jac_transpose_autodiff)
end
end

__maybe_symmetric(x) = Symmetric(x)
__maybe_symmetric(x::Number) = x
# LinearSolve with `nothing` doesn't dispatch correctly here
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
__maybe_symmetric(x::KrylovJᵀJ) = x

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
Expand All @@ -143,3 +192,37 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number,
needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u
return uf, nothing, u, nothing, nothing, u
end

function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J)
return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J)
end
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J)
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J)
__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H
__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H

function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return setproperty!(cache, sym1, J' * fu)
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return mul!(getproperty(cache, sym1), J', fu)
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return setproperty!(cache, sym1, H.Jᵀ * fu)
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return mul!(getproperty(cache, sym1), H.Jᵀ, fu)
end

# Left-Right Multiplication
__lr_mul(::Val, H, g) = dot(g, H, g)
## TODO: Use a cache here to avoid allocations
__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g)
function __lr_mul(::Val{true}, H::KrylovJᵀJ, g)
c = similar(g)
mul!(c, H.JᵀJ, g)
return dot(g, c)
end
28 changes: 17 additions & 11 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
fu_prev = zero(fu1)

loss = get_loss(fu1)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)
# uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
# linsolve_kwargs)
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, u, p, alg)

u_tmp = zero(u)
u_cauchy = zero(u)
u_gauss_newton = _mutable_zero(u)

loss_new = loss
H = zero(J' * J)
g = _mutable_zero(fu1)
# H = zero(J' * J)
# g = _mutable_zero(fu1)
shrink_counter = 0
fu_new = zero(fu1)
make_new_J = true
Expand Down Expand Up @@ -346,8 +350,10 @@ function perform_step!(cache::TrustRegionCache{true})
@unpack make_new_J, J, fu, f, u, p, u_gauss_newton, alg, linsolve = cache
if cache.make_new_J
jacobian!!(J, cache)
mul!(cache.H, J', J)
mul!(_vec(cache.g), J', _vec(fu))
__update_JᵀJ!(Val{true}(), cache, :H, J)
# mul!(cache.H, J', J)
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu)
# mul!(_vec(cache.g), J', _vec(fu))
cache.stats.njacs += 1

# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
Expand Down Expand Up @@ -376,8 +382,8 @@ function perform_step!(cache::TrustRegionCache{false})

if make_new_J
J = jacobian!!(cache.J, cache)
cache.H = J' * J
cache.g = _restructure(fu, J' * _vec(fu))
__update_JᵀJ!(Val{false}(), cache, :H, J)
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu)
cache.stats.njacs += 1

if cache.linsolve === nothing
Expand Down Expand Up @@ -431,7 +437,7 @@ function trust_region_step!(cache::TrustRegionCache)

# Compute the ratio of the actual reduction to the predicted reduction.
cache.r = -(loss - cache.loss_new) /
(dot(_vec(du), _vec(g)) + dot(_vec(du), H, _vec(du)) / 2)
(dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2)
@unpack r = cache

if radius_update_scheme === RadiusUpdateSchemes.Simple
Expand Down Expand Up @@ -594,7 +600,7 @@ function dogleg!(cache::TrustRegionCache{true})

# Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
l_grad = norm(cache.g) # length of the gradient
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
d_cauchy = l_grad^3 / __lr_mul(Val{true}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
if d_cauchy >= trust_r
@. cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
return
Expand Down Expand Up @@ -624,7 +630,7 @@ function dogleg!(cache::TrustRegionCache{false})

## Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region
l_grad = norm(cache.g)
d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
d_cauchy = l_grad^3 / __lr_mul(Val{false}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate
if d_cauchy > trust_r # cauchy point lies outside of trust region
cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region
return
Expand Down

0 comments on commit c1343b8

Please sign in to comment.