From 0e3efd72170b801c6f50e226435c0816b9c56aff Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 30 Nov 2023 19:32:32 -0500 Subject: [PATCH] Fix GN --- src/gaussnewton.jl | 10 +++++----- src/utils.jl | 9 +++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index 1b4fc9432..5ff01d79a 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -108,7 +108,7 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: return GaussNewtonCache{iip}(f, alg, u, u_cache, fu, fu_cache, du, dfu, p, uf, linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2, - init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)), trace) + init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), trace) end function perform_step!(cache::GaussNewtonCache{iip}) where {iip} @@ -117,14 +117,14 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip} # Use normal form to solve the Linear Problem if cache.JᵀJ !== nothing __update_JᵀJ!(Val{iip}(), cache, :JᵀJ, cache.J) - __update_Jᵀf!(Val{iip}(), cache, :Jᵀf, :JᵀJ, cache.J, cache.fu1) + __update_Jᵀf!(Val{iip}(), cache, :Jᵀf, :JᵀJ, cache.J, cache.fu) A, b = __maybe_symmetric(cache.JᵀJ), _vec(cache.Jᵀf) else A, b = cache.J, _vec(cache.fu) end - linres = dolinsolve(alg.precs, linsolve; A, b, linu = _vec(du), cache.p, - reltol = cache.abstol) + linres = dolinsolve(cache.alg.precs, cache.linsolve; A, b, linu = _vec(cache.du), + cache.p, reltol = cache.abstol) cache.linsolve = linres.cache cache.du = _restructure(cache.du, linres.u) @@ -136,7 +136,7 @@ function perform_step!(cache::GaussNewtonCache{iip}) where {iip} check_and_update!(cache.tc_cache_1, cache, cache.fu, cache.u, cache.u_cache) if !cache.force_stop @bb @. cache.dfu = cache.fu .- cache.dfu - check_and_update!(cache.tc_cache_2, cache, cache.dfu, cache.u, cache.u_prev) + check_and_update!(cache.tc_cache_2, cache, cache.dfu, cache.u, cache.u_cache) end @bb copyto!(cache.u_cache, cache.u) diff --git a/src/utils.jl b/src/utils.jl index 46c5b9295..00b7d3726 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -188,6 +188,15 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip}, return fu end +function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip} + if iip + f(fu, u, p) + return fu + else + return f(u, p) + end +end + function evaluate_f(cache, u, p) if isinplace(cache) cache.prob.f(get_fu(cache), u, p)