Skip to content

Commit

Permalink
refactor: Don't use duplicate solve
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Nov 6, 2024
1 parent 61e97a8 commit b44112d
Showing 1 changed file with 0 additions and 158 deletions.
158 changes: 0 additions & 158 deletions lib/NonlinearSolveBase/src/polyalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,78 +121,6 @@ function SciMLBase.__init(
)
end

@generated function CommonSolve.solve!(cache::NonlinearSolvePolyAlgorithmCache{Val{N}}) where {N}
calls = [quote
1 cache.current $(N) || error("Current choices shouldn't get here!")
end]

cache_syms = [gensym("cache") for i in 1:N]
sol_syms = [gensym("sol") for i in 1:N]
u_result_syms = [gensym("u_result") for i in 1:N]

for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
cache.alias_u0 && copyto!(cache.u0_aliased, cache.u0)
$(sol_syms[i]) = CommonSolve.solve!($(cache_syms[i]))
if SciMLBase.successful_retcode($(sol_syms[i]))
stats = $(sol_syms[i]).stats
if cache.alias_u0
copyto!(cache.u0, $(sol_syms[i]).u)
$(u_result_syms[i]) = cache.u0
else
$(u_result_syms[i]) = $(sol_syms[i]).u
end
fu = NonlinearSolveBase.get_fu($(cache_syms[i]))
return build_solution_less_specialize(
cache.prob, cache.alg, $(u_result_syms[i]), fu;
retcode = $(sol_syms[i]).retcode, stats,
original = $(sol_syms[i]), trace = $(sol_syms[i]).trace
)
elseif cache.alias_u0
# For safety we need to maintain a copy of the solution
$(u_result_syms[i]) = copy($(sol_syms[i]).u)
end
cache.current = $(i + 1)
end
end)
end

resids = map(Base.Fix2(Symbol, :resid), cache_syms)
for (sym, resid) in zip(cache_syms, resids)
push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing))
end
push!(calls, quote
fus = tuple($(Tuple(resids)...))
minfu, idx = findmin_caches(cache.prob, fus)
end)
for i in 1:N
push!(calls,
quote
if idx == $(i)
u = cache.alias_u0 ? $(u_result_syms[i]) :
NonlinearSolveBase.get_u(cache.caches[$(i)])
end
end)
end
push!(calls,
quote
retcode = cache.caches[idx].retcode
if cache.alias_u0
copyto!(cache.u0, u)
u = cache.u0
end
return build_solution_less_specialize(
cache.prob, cache.alg, u, fus[idx];
retcode, cache.stats, cache.caches[idx].trace
)
end)

return Expr(:block, calls...)
end

@generated function InternalAPI.step!(
cache::NonlinearSolvePolyAlgorithmCache{Val{N}}, args...; kwargs...
) where {N}
Expand Down Expand Up @@ -232,92 +160,6 @@ end
return Expr(:block, calls...)
end

@generated function SciMLBase.__solve(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm{Val{N}}, args...;
stats = NLStats(0, 0, 0, 0, 0), alias_u0 = false, verbose = true, kwargs...
) where {N}
sol_syms = [gensym("sol") for _ in 1:N]
prob_syms = [gensym("prob") for _ in 1:N]
u_result_syms = [gensym("u_result") for _ in 1:N]
calls = [quote
current = alg.start_index
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing
end
u0 = prob.u0
u0_aliased = alias_u0 ? zero(u0) : u0
end]
for i in 1:N
cur_sol = sol_syms[i]
push!(calls,
quote
if current == $(i)
if alias_u0
copyto!(u0_aliased, u0)
$(prob_syms[i]) = SciMLBase.remake(prob; u0 = u0_aliased)
else
$(prob_syms[i]) = prob
end
$(cur_sol) = SciMLBase.__solve(
$(prob_syms[i]), alg.algs[$(i)], args...;
stats, alias_u0, verbose, kwargs...
)
if SciMLBase.successful_retcode($(cur_sol))
if alias_u0
copyto!(u0, $(cur_sol).u)
$(u_result_syms[i]) = u0
else
$(u_result_syms[i]) = $(cur_sol).u
end
return build_solution_less_specialize(
prob, alg, $(u_result_syms[i]), $(cur_sol).resid;
$(cur_sol).retcode, $(cur_sol).stats,
$(cur_sol).trace, original = $(cur_sol)
)
elseif alias_u0
# For safety we need to maintain a copy of the solution
$(u_result_syms[i]) = copy($(cur_sol).u)
end
current = $(i + 1)
end
end)
end

resids = map(Base.Fix2(Symbol, :resid), sol_syms)
for (sym, resid) in zip(sol_syms, resids)
push!(calls, :($(resid) = @isdefined($(sym)) ? $(sym).resid : nothing))
end

push!(calls, quote
resids = tuple($(Tuple(resids)...))
minfu, idx = findmin_resids(prob, resids)
end)

for i in 1:N
push!(calls,
quote
if idx == $(i)
if alias_u0
copyto!(u0, $(u_result_syms[i]))
$(u_result_syms[i]) = u0
else
$(u_result_syms[i]) = $(sol_syms[i]).u
end
return build_solution_less_specialize(
prob, alg, $(u_result_syms[i]), $(sol_syms[i]).resid;
$(sol_syms[i]).retcode, $(sol_syms[i]).stats,
$(sol_syms[i]).trace, original = $(sol_syms[i])
)
end
end)
end
push!(calls, :(error("Current choices shouldn't get here!")))

return Expr(:block, calls...)
end

# Original is often determined on runtime information especially for PolyAlgorithms so it
# is best to never specialize on that
function build_solution_less_specialize(
Expand Down

0 comments on commit b44112d

Please sign in to comment.