From 40c6de36cbc2756e6c6d92a30b67aa651c956b05 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Nov 2023 13:03:00 -0500 Subject: [PATCH] Fix matrix resizing --- src/gaussnewton.jl | 8 ++++---- src/jacobian.jl | 12 ++++++------ src/trustRegion.jl | 15 +++++---------- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 2062a3bfd..261a4e1d0 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -113,8 +113,8 @@ function perform_step!(cache::GaussNewtonCache{true}) jacobian!!(J, cache) if JᵀJ !== nothing - __matmul!(JᵀJ, J', J) - __matmul!(Jᵀf, J', fu1) + __update_JᵀJ!(Val{true}(), cache, :JᵀJ, J) + __update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1) end # u = u - JᵀJ \ Jᵀfu @@ -151,8 +151,8 @@ function perform_step!(cache::GaussNewtonCache{false}) cache.J = jacobian!!(cache.J, cache) if cache.JᵀJ !== nothing - cache.JᵀJ = cache.J' * cache.J - cache.Jᵀf = cache.J' * fu1 + __update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J) + __update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1) end # u = u - J \ fu diff --git a/src/jacobian.jl b/src/jacobian.jl index a015835fe..8bd490a8e 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -83,7 +83,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val J = if !(linsolve_needs_jac || alg_wants_jac) # We don't need to construct the Jacobian - JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad)) + JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad)) else if has_analytic_jac f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype @@ -179,7 +179,7 @@ __maybe_symmetric(x::Number) = x __maybe_symmetric(x::StaticArray) = x __maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x __maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x -__maybe_symmetric(x::KrylovJᵀJ) = x +__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ ## Special Handling for Scalars function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p, @@ -203,16 +203,16 @@ function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu) return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu) end function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu) - return setproperty!(cache, sym1, J' * fu) + return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu)) end function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu) - return mul!(getproperty(cache, sym1), J', fu) + return mul!(_vec(getproperty(cache, sym1)), J', fu) end function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu) - return setproperty!(cache, sym1, H.Jᵀ * fu) + return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu)) end function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu) - return mul!(getproperty(cache, sym1), H.Jᵀ, fu) + return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu) end # Left-Right Multiplication diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 1e47c7403..dbd2e289c 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -239,10 +239,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, fu_prev = zero(fu1) loss = get_loss(fu1) - # uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); - # linsolve_kwargs) uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false)) + g = _restructure(fu1, g) linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, du, p, alg) u_tmp = zero(u) @@ -250,8 +249,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, u_gauss_newton = _mutable_zero(u) loss_new = loss - # H = zero(J' * J) - # g = _mutable_zero(fu1) shrink_counter = 0 fu_new = zero(fu1) make_new_J = true @@ -351,9 +348,7 @@ function perform_step!(cache::TrustRegionCache{true}) if cache.make_new_J jacobian!!(J, cache) __update_JᵀJ!(Val{true}(), cache, :H, J) - # mul!(cache.H, J', J) - __update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu) - # mul!(_vec(cache.g), J', _vec(fu)) + __update_Jᵀf!(Val{true}(), cache, :g, :H, J, _vec(fu)) cache.stats.njacs += 1 # do not use A = cache.H, b = _vec(cache.g) since it is equivalent @@ -383,7 +378,7 @@ function perform_step!(cache::TrustRegionCache{false}) if make_new_J J = jacobian!!(cache.J, cache) __update_JᵀJ!(Val{false}(), cache, :H, J) - __update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu) + __update_Jᵀf!(Val{false}(), cache, :g, :H, J, _vec(fu)) cache.stats.njacs += 1 if cache.linsolve === nothing @@ -420,8 +415,8 @@ function retrospective_step!(cache::TrustRegionCache) cache.H = J' * J cache.g = J' * fu else - mul!(cache.H, J', J) - mul!(cache.g, J', fu) + __update_JᵀJ!(Val{isinplace(cache)}(), cache, :H, J) + __update_Jᵀf!(Val{isinplace(cache)}(), cache, :g, :H, J, fu) end cache.stats.njacs += 1 @unpack H, g, du = cache