From b44112d7e01d931802b2871b7620aa91b948a73a Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 6 Nov 2024 17:52:46 +0800 Subject: [PATCH] refactor: Don't use duplicate solve --- lib/NonlinearSolveBase/src/polyalg.jl | 158 -------------------------- 1 file changed, 158 deletions(-) diff --git a/lib/NonlinearSolveBase/src/polyalg.jl b/lib/NonlinearSolveBase/src/polyalg.jl index 601ef3326..54b61998f 100644 --- a/lib/NonlinearSolveBase/src/polyalg.jl +++ b/lib/NonlinearSolveBase/src/polyalg.jl @@ -121,78 +121,6 @@ function SciMLBase.__init( ) end -@generated function CommonSolve.solve!(cache::NonlinearSolvePolyAlgorithmCache{Val{N}}) where {N} - calls = [quote - 1 ≤ cache.current ≤ $(N) || error("Current choices shouldn't get here!") - end] - - cache_syms = [gensym("cache") for i in 1:N] - sol_syms = [gensym("sol") for i in 1:N] - u_result_syms = [gensym("u_result") for i in 1:N] - - for i in 1:N - push!(calls, - quote - $(cache_syms[i]) = cache.caches[$(i)] - if $(i) == cache.current - cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0) - $(sol_syms[i]) = CommonSolve.solve!($(cache_syms[i])) - if SciMLBase.successful_retcode($(sol_syms[i])) - stats = $(sol_syms[i]).stats - if cache.alias_u0 - copyto!(cache.u0, $(sol_syms[i]).u) - $(u_result_syms[i]) = cache.u0 - else - $(u_result_syms[i]) = $(sol_syms[i]).u - end - fu = NonlinearSolveBase.get_fu($(cache_syms[i])) - return build_solution_less_specialize( - cache.prob, cache.alg, $(u_result_syms[i]), fu; - retcode = $(sol_syms[i]).retcode, stats, - original = $(sol_syms[i]), trace = $(sol_syms[i]).trace - ) - elseif cache.alias_u0 - # For safety we need to maintain a copy of the solution - $(u_result_syms[i]) = copy($(sol_syms[i]).u) - end - cache.current = $(i + 1) - end - end) - end - - resids = map(Base.Fix2(Symbol, :resid), cache_syms) - for (sym, resid) in zip(cache_syms, resids) - push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing)) - end - push!(calls, quote - fus = tuple($(Tuple(resids)...)) - minfu, idx = findmin_caches(cache.prob, fus) - end) - for i in 1:N - push!(calls, - quote - if idx == $(i) - u = cache.alias_u0 ? $(u_result_syms[i]) : - NonlinearSolveBase.get_u(cache.caches[$(i)]) - end - end) - end - push!(calls, - quote - retcode = cache.caches[idx].retcode - if cache.alias_u0 - copyto!(cache.u0, u) - u = cache.u0 - end - return build_solution_less_specialize( - cache.prob, cache.alg, u, fus[idx]; - retcode, cache.stats, cache.caches[idx].trace - ) - end) - - return Expr(:block, calls...) -end - @generated function InternalAPI.step!( cache::NonlinearSolvePolyAlgorithmCache{Val{N}}, args...; kwargs... ) where {N} @@ -232,92 +160,6 @@ end return Expr(:block, calls...) end -@generated function SciMLBase.__solve( - prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...; - stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs... -) where {N} - sol_syms = [gensym("sol") for _ in 1:N] - prob_syms = [gensym("prob") for _ in 1:N] - u_result_syms = [gensym("u_result") for _ in 1:N] - calls = [quote - current = alg.start_index - if alias_u0 && !ArrayInterface.ismutable(prob.u0) - verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \ - immutable (checked using `ArrayInterface.ismutable`)." - alias_u0 = false # If immutable don't care about aliasing - end - u0 = prob.u0 - u0_aliased = alias_u0 ? zero(u0) : u0 - end] - for i in 1:N - cur_sol = sol_syms[i] - push!(calls, - quote - if current == $(i) - if alias_u0 - copyto!(u0_aliased, u0) - $(prob_syms[i]) = SciMLBase.remake(prob; u0 = u0_aliased) - else - $(prob_syms[i]) = prob - end - $(cur_sol) = SciMLBase.__solve( - $(prob_syms[i]), alg.algs[$(i)], args...; - stats, alias_u0, verbose, kwargs... - ) - if SciMLBase.successful_retcode($(cur_sol)) - if alias_u0 - copyto!(u0, $(cur_sol).u) - $(u_result_syms[i]) = u0 - else - $(u_result_syms[i]) = $(cur_sol).u - end - return build_solution_less_specialize( - prob, alg, $(u_result_syms[i]), $(cur_sol).resid; - $(cur_sol).retcode, $(cur_sol).stats, - $(cur_sol).trace, original = $(cur_sol) - ) - elseif alias_u0 - # For safety we need to maintain a copy of the solution - $(u_result_syms[i]) = copy($(cur_sol).u) - end - current = $(i + 1) - end - end) - end - - resids = map(Base.Fix2(Symbol, :resid), sol_syms) - for (sym, resid) in zip(sol_syms, resids) - push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing)) - end - - push!(calls, quote - resids = tuple($(Tuple(resids)...)) - minfu, idx = findmin_resids(prob, resids) - end) - - for i in 1:N - push!(calls, - quote - if idx == $(i) - if alias_u0 - copyto!(u0, $(u_result_syms[i])) - $(u_result_syms[i]) = u0 - else - $(u_result_syms[i]) = $(sol_syms[i]).u - end - return build_solution_less_specialize( - prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid; - $(sol_syms[i]).retcode, $(sol_syms[i]).stats, - $(sol_syms[i]).trace, original = $(sol_syms[i]) - ) - end - end) - end - push!(calls, :(error("Current choices shouldn't get here!"))) - - return Expr(:block, calls...) -end - # Original is often determined on runtime information especially for PolyAlgorithms so it # is best to never specialize on that function build_solution_less_specialize(