From c1343b80af5a1457b829c9432cc786ccda951e3e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 12 Nov 2023 19:47:01 -0500 Subject: [PATCH] Krylov Version for Trust Region --- Project.toml | 2 +- src/gaussnewton.jl | 9 ----- src/jacobian.jl | 99 ++++++++++++++++++++++++++++++++++++++++++---- src/trustRegion.jl | 28 +++++++------ 4 files changed, 109 insertions(+), 29 deletions(-) diff --git a/Project.toml b/Project.toml index ee073a893..b675f1126 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "2.8.0" +version = "2.9.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 012767dcf..2062a3bfd 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -6,10 +6,6 @@ An advanced GaussNewton implementation with support for efficient handling of sp matrices via colored automatic differentiation and preconditioned linear solvers. Designed for large-scale and numerically-difficult nonlinear least squares problems. -!!! note - In most practical situations, users should prefer using `LevenbergMarquardt` instead! It - is a more general extension of `Gauss-Newton` Method. - ### Keyword Arguments - `autodiff`: determines the backend used for the Jacobian. Note that this argument is @@ -33,11 +29,6 @@ for large-scale and numerically-difficult nonlinear least squares problems. - `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref), which means that no line search is performed. Algorithms from `LineSearches.jl` can be used here directly, and they will be converted to the correct `LineSearch`. - -!!! warning - - Jacobian-Free version of `GaussNewton` doesn't work yet, and it forces jacobian - construction. This will be fixed in the near future. """ @concrete struct GaussNewton{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} ad::AD diff --git a/src/jacobian.jl b/src/jacobian.jl index 41df3c092..cfea15239 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -3,6 +3,8 @@ p end +SciMLBase.isinplace(::JacobianWrapper{iip}) where {iip} = iip + # Previous Implementation did not hold onto `iip`, but this causes problems in packages # where we check for the presence of function signatures to check which dispatch to call (uf::JacobianWrapper{false})(u) = uf.f(u, uf.p) @@ -65,7 +67,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val # NOTE: The deepcopy is needed here since we are using the resid_prototype elsewhere fu = f.resid_prototype === nothing ? (iip ? _mutable_zero(u) : _mutable(f(u, p))) : (iip ? deepcopy(f.resid_prototype) : f.resid_prototype) - if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac || needsJᵀJ) + if !has_analytic_jac && (linsolve_needs_jac || alg_wants_jac) # || needsJᵀJ) sd = sparsity_detection_alg(f, alg.ad) ad = alg.ad jac_cache = iip ? sparse_jacobian_cache(ad, sd, uf, fu, _maybe_mutable(u, ad)) : @@ -76,7 +78,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val # FIXME: To properly support needsJᵀJ without Jacobian, we need to implement # a reverse diff operation with the seed being `Jx`, this is not yet implemented - J = if !(linsolve_needs_jac || alg_wants_jac || needsJᵀJ) + J = if !(linsolve_needs_jac || alg_wants_jac)# || needsJᵀJ) # We don't need to construct the Jacobian JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad)) else @@ -90,9 +92,9 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val du = _mutable_zero(u) if needsJᵀJ - JᵀJ = __init_JᵀJ(J) - # FIXME: This needs to be handled better for JacVec Operator - Jᵀfu = J' * _vec(fu) + # TODO: Pass in `jac_transpose_autodiff` + JᵀJ, Jᵀfu = __init_JᵀJ(J, _vec(fu), uf, u; + jac_autodiff = __get_nonsparse_ad(alg.ad)) end if linsolve_init @@ -118,21 +120,68 @@ function __setup_linsolve(A, b, u, p, alg) nothing)..., weight) return init(linprob, alg.linsolve; alias_A = true, alias_b = true, Pl, Pr) end +__setup_linsolve(A::KrylovJᵀJ, b, u, p, alg) = __setup_linsolve(A.JᵀJ, b, u, p, alg) __get_nonsparse_ad(::AutoSparseForwardDiff) = AutoForwardDiff() __get_nonsparse_ad(::AutoSparseFiniteDiff) = AutoFiniteDiff() __get_nonsparse_ad(::AutoSparseZygote) = AutoZygote() __get_nonsparse_ad(ad) = ad -__init_JᵀJ(J::Number) = zero(J) -__init_JᵀJ(J::AbstractArray) = J' * J -__init_JᵀJ(J::StaticArray) = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef) +__init_JᵀJ(J::Number, args...; kwargs...) = zero(J), zero(J) +function __init_JᵀJ(J::AbstractArray, fu, args...; kwargs...) + JᵀJ = J' * J + Jᵀfu = J' * fu + return JᵀJ, Jᵀfu +end +function __init_JᵀJ(J::StaticArray, fu, args...; kwargs...) + JᵀJ = MArray{Tuple{size(J, 2), size(J, 2)}, eltype(J)}(undef) + return JᵀJ, J' * fu +end +function __init_JᵀJ(J::FunctionOperator, fu, uf, u, args...; + jac_transpose_autodiff = nothing, jac_autodiff = nothing, kwargs...) + autodiff = __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf) + Jᵀ = VecJac(uf, u; autodiff) + JᵀJ_op = SciMLOperators.cache_operator(Jᵀ * J, u) + JᵀJ = KrylovJᵀJ(JᵀJ_op, Jᵀ) + Jᵀfu = Jᵀ * fu + return JᵀJ, Jᵀfu +end + +@concrete struct KrylovJᵀJ + JᵀJ + Jᵀ +end + +SciMLBase.isinplace(JᵀJ::KrylovJᵀJ) = isinplace(JᵀJ.Jᵀ) + +function __concrete_jac_transpose_autodiff(jac_transpose_autodiff, jac_autodiff, uf) + if jac_transpose_autodiff === nothing + if isinplace(uf) + # VecJac can be only FiniteDiff + return AutoFiniteDiff() + else + # Short circuit if we see that FiniteDiff was used for J computation + jac_autodiff isa AutoFiniteDiff && return jac_autodiff + # Check if Zygote is loaded then use Zygote else use FiniteDiff + if haskey(Base.loaded_modules, + Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) + return AutoZygote() + else + return AutoFiniteDiff() + end + end + else + return __get_nonsparse_ad(jac_transpose_autodiff) + end +end __maybe_symmetric(x) = Symmetric(x) __maybe_symmetric(x::Number) = x # LinearSolve with `nothing` doesn't dispatch correctly here __maybe_symmetric(x::StaticArray) = x __maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x +__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x +__maybe_symmetric(x::KrylovJᵀJ) = x ## Special Handling for Scalars function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p, @@ -143,3 +192,37 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u return uf, nothing, u, nothing, nothing, u end + +function __update_JᵀJ!(iip::Val, cache, sym::Symbol, J) + return __update_JᵀJ!(iip, cache, sym, getproperty(cache, sym), J) +end +__update_JᵀJ!(::Val{false}, cache, sym::Symbol, _, J) = setproperty!(cache, sym, J' * J) +__update_JᵀJ!(::Val{true}, cache, sym::Symbol, _, J) = mul!(getproperty(cache, sym), J', J) +__update_JᵀJ!(::Val{false}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H +__update_JᵀJ!(::Val{true}, cache, sym::Symbol, H::KrylovJᵀJ, J) = H + +function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu) + return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu) +end +function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu) + return setproperty!(cache, sym1, J' * fu) +end +function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu) + return mul!(getproperty(cache, sym1), J', fu) +end +function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu) + return setproperty!(cache, sym1, H.Jᵀ * fu) +end +function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu) + return mul!(getproperty(cache, sym1), H.Jᵀ, fu) +end + +# Left-Right Multiplication +__lr_mul(::Val, H, g) = dot(g, H, g) +## TODO: Use a cache here to avoid allocations +__lr_mul(::Val{false}, H::KrylovJᵀJ, g) = dot(g, H.JᵀJ, g) +function __lr_mul(::Val{true}, H::KrylovJᵀJ, g) + c = similar(g) + mul!(c, H.JᵀJ, g) + return dot(g, c) +end diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 0070694ff..8fde39d6f 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -239,15 +239,19 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, fu_prev = zero(fu1) loss = get_loss(fu1) - uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); - linsolve_kwargs) + # uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); + # linsolve_kwargs) + uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip); + linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false)) + linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, u, p, alg) + u_tmp = zero(u) u_cauchy = zero(u) u_gauss_newton = _mutable_zero(u) loss_new = loss - H = zero(J' * J) - g = _mutable_zero(fu1) + # H = zero(J' * J) + # g = _mutable_zero(fu1) shrink_counter = 0 fu_new = zero(fu1) make_new_J = true @@ -346,8 +350,10 @@ function perform_step!(cache::TrustRegionCache{true}) @unpack make_new_J, J, fu, f, u, p, u_gauss_newton, alg, linsolve = cache if cache.make_new_J jacobian!!(J, cache) - mul!(cache.H, J', J) - mul!(_vec(cache.g), J', _vec(fu)) + __update_JᵀJ!(Val{true}(), cache, :H, J) + # mul!(cache.H, J', J) + __update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu) + # mul!(_vec(cache.g), J', _vec(fu)) cache.stats.njacs += 1 # do not use A = cache.H, b = _vec(cache.g) since it is equivalent @@ -376,8 +382,8 @@ function perform_step!(cache::TrustRegionCache{false}) if make_new_J J = jacobian!!(cache.J, cache) - cache.H = J' * J - cache.g = _restructure(fu, J' * _vec(fu)) + __update_JᵀJ!(Val{false}(), cache, :H, J) + __update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu) cache.stats.njacs += 1 if cache.linsolve === nothing @@ -431,7 +437,7 @@ function trust_region_step!(cache::TrustRegionCache) # Compute the ratio of the actual reduction to the predicted reduction. cache.r = -(loss - cache.loss_new) / - (dot(_vec(du), _vec(g)) + dot(_vec(du), H, _vec(du)) / 2) + (dot(_vec(du), _vec(g)) + __lr_mul(Val(isinplace(cache)), H, _vec(du)) / 2) @unpack r = cache if radius_update_scheme === RadiusUpdateSchemes.Simple @@ -594,7 +600,7 @@ function dogleg!(cache::TrustRegionCache{true}) # Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region l_grad = norm(cache.g) # length of the gradient - d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate + d_cauchy = l_grad^3 / __lr_mul(Val{true}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate if d_cauchy >= trust_r @. cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region return @@ -624,7 +630,7 @@ function dogleg!(cache::TrustRegionCache{false}) ## Take intersection of steepest descent direction and trust region if Cauchy point lies outside of trust region l_grad = norm(cache.g) - d_cauchy = l_grad^3 / dot(_vec(cache.g), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate + d_cauchy = l_grad^3 / __lr_mul(Val{false}(), cache.H, _vec(cache.g)) # distance of the cauchy point from the current iterate if d_cauchy > trust_r # cauchy point lies outside of trust region cache.du = -(trust_r / l_grad) * cache.g # step to the end of the trust region return