Skip to content

Commit

Permalink
Fix matrix resizing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 13, 2023
1 parent e3ecfd1 commit 40c6de3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 20 deletions.
8 changes: 4 additions & 4 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ function perform_step!(cache::GaussNewtonCache{true})
jacobian!!(J, cache)

if JᵀJ !== nothing
__matmul!(JᵀJ, J', J)
__matmul!(Jᵀf, J', fu1)
__update_JᵀJ!(Val{true}(), cache, :JᵀJ, J)
__update_Jᵀf!(Val{true}(), cache, :Jᵀf, :JᵀJ, J, fu1)

Check warning on line 117 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L116-L117

Added lines #L116 - L117 were not covered by tests
end

# u = u - JᵀJ \ Jᵀfu
Expand Down Expand Up @@ -151,8 +151,8 @@ function perform_step!(cache::GaussNewtonCache{false})
cache.J = jacobian!!(cache.J, cache)

if cache.JᵀJ !== nothing
cache.JᵀJ = cache.J' * cache.J
cache.Jᵀf = cache.J' * fu1
__update_JᵀJ!(Val{false}(), cache, :JᵀJ, cache.J)
__update_Jᵀf!(Val{false}(), cache, :Jᵀf, :JᵀJ, cache.J, fu1)

Check warning on line 155 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L154-L155

Added lines #L154 - L155 were not covered by tests
end

# u = u - J \ fu
Expand Down
12 changes: 6 additions & 6 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u, p, ::Val

J = if !(linsolve_needs_jac || alg_wants_jac)

Check warning on line 84 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L84

Added line #L84 was not covered by tests
# We don't need to construct the Jacobian
JacVec(uf, u; autodiff = __get_nonsparse_ad(alg.ad))
JacVec(uf, u; fu, autodiff = __get_nonsparse_ad(alg.ad))

Check warning on line 86 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L86

Added line #L86 was not covered by tests
else
if has_analytic_jac
f.jac_prototype === nothing ? undefmatrix(u) : f.jac_prototype
Expand Down Expand Up @@ -179,7 +179,7 @@ __maybe_symmetric(x::Number) = x
__maybe_symmetric(x::StaticArray) = x
__maybe_symmetric(x::SparseArrays.AbstractSparseMatrix) = x
__maybe_symmetric(x::SciMLOperators.AbstractSciMLOperator) = x
__maybe_symmetric(x::KrylovJᵀJ) = x
__maybe_symmetric(x::KrylovJᵀJ) = x.JᵀJ

Check warning on line 182 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L181-L182

Added lines #L181 - L182 were not covered by tests

## Special Handling for Scalars
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, p,
Expand All @@ -203,16 +203,16 @@ function __update_Jᵀf!(iip::Val, cache, sym1::Symbol, sym2::Symbol, J, fu)
return __update_Jᵀf!(iip, cache, sym1, sym2, getproperty(cache, sym2), J, fu)

Check warning on line 203 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L202-L203

Added lines #L202 - L203 were not covered by tests
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return setproperty!(cache, sym1, J' * fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), J' * fu))

Check warning on line 206 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L205-L206

Added lines #L205 - L206 were not covered by tests
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, _, J, fu)
return mul!(getproperty(cache, sym1), J', fu)
return mul!(_vec(getproperty(cache, sym1)), J', fu)

Check warning on line 209 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L208-L209

Added lines #L208 - L209 were not covered by tests
end
function __update_Jᵀf!(::Val{false}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return setproperty!(cache, sym1, H.Jᵀ * fu)
return setproperty!(cache, sym1, _restructure(getproperty(cache, sym1), H.Jᵀ * fu))

Check warning on line 212 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L211-L212

Added lines #L211 - L212 were not covered by tests
end
function __update_Jᵀf!(::Val{true}, cache, sym1::Symbol, sym2::Symbol, H::KrylovJᵀJ, J, fu)
return mul!(getproperty(cache, sym1), H.Jᵀ, fu)
return mul!(_vec(getproperty(cache, sym1)), H.Jᵀ, fu)

Check warning on line 215 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L214-L215

Added lines #L214 - L215 were not covered by tests
end

# Left-Right Multiplication
Expand Down
15 changes: 5 additions & 10 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,19 +239,16 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
fu_prev = zero(fu1)

loss = get_loss(fu1)
# uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
# linsolve_kwargs)
uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip);

Check warning on line 242 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L242

Added line #L242 was not covered by tests
linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false))
g = _restructure(fu1, g)
linsolve = u isa Number ? nothing : __setup_linsolve(J, fu2, du, p, alg)

Check warning on line 245 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L244-L245

Added lines #L244 - L245 were not covered by tests

u_tmp = zero(u)
u_cauchy = zero(u)
u_gauss_newton = _mutable_zero(u)

loss_new = loss
# H = zero(J' * J)
# g = _mutable_zero(fu1)
shrink_counter = 0
fu_new = zero(fu1)
make_new_J = true
Expand Down Expand Up @@ -351,9 +348,7 @@ function perform_step!(cache::TrustRegionCache{true})
if cache.make_new_J
jacobian!!(J, cache)
__update_JᵀJ!(Val{true}(), cache, :H, J)
# mul!(cache.H, J', J)
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, fu)
# mul!(_vec(cache.g), J', _vec(fu))
__update_Jᵀf!(Val{true}(), cache, :g, :H, J, _vec(fu))

Check warning on line 351 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L350-L351

Added lines #L350 - L351 were not covered by tests
cache.stats.njacs += 1

# do not use A = cache.H, b = _vec(cache.g) since it is equivalent
Expand Down Expand Up @@ -383,7 +378,7 @@ function perform_step!(cache::TrustRegionCache{false})
if make_new_J
J = jacobian!!(cache.J, cache)
__update_JᵀJ!(Val{false}(), cache, :H, J)
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, fu)
__update_Jᵀf!(Val{false}(), cache, :g, :H, J, _vec(fu))

Check warning on line 381 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L380-L381

Added lines #L380 - L381 were not covered by tests
cache.stats.njacs += 1

if cache.linsolve === nothing
Expand Down Expand Up @@ -420,8 +415,8 @@ function retrospective_step!(cache::TrustRegionCache)
cache.H = J' * J
cache.g = J' * fu
else
mul!(cache.H, J', J)
mul!(cache.g, J', fu)
__update_JᵀJ!(Val{isinplace(cache)}(), cache, :H, J)
__update_Jᵀf!(Val{isinplace(cache)}(), cache, :g, :H, J, fu)

Check warning on line 419 in src/trustRegion.jl

View check run for this annotation

Codecov / codecov/patch

src/trustRegion.jl#L418-L419

Added lines #L418 - L419 were not covered by tests
end
cache.stats.njacs += 1
@unpack H, g, du = cache
Expand Down

0 comments on commit 40c6de3

Please sign in to comment.