Skip to content

Commit

Permalink
Recalculate the preconditioner if number of Newton iterations gets large
Browse files Browse the repository at this point in the history
When number of Newton iterations reaches (multiples of) the
`preconditioner_update_interval`, assume convergence is getting slow and
recalculate the preconditioner.
  • Loading branch information
johnomotani committed Dec 6, 2024
1 parent 7ea6907 commit a3fe1b7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 12 deletions.
29 changes: 20 additions & 9 deletions moment_kinetics/src/electron_kinetic_equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1880,16 +1880,12 @@ to allow the outer r-loop to be parallelised.

newton_success = false
for ir 1:r.n
if nl_solver_params.preconditioner_type === Val(:electron_lu)

if ion_dt > 1.5 * nl_solver_params.precon_dt[] ||
ion_dt < 2.0/3.0 * nl_solver_params.precon_dt[]

# dt has changed significantly, so update the preconditioner
nl_solver_params.solves_since_precon_update[] = nl_solver_params.preconditioner_update_interval
end
f_electron = @view pdf_electron_out[:,:,:,ir]
ppar = @view electron_ppar_out[:,ir]
phi = @view fields.phi[:,ir]

if nl_solver_params.solves_since_precon_update[] nl_solver_params.preconditioner_update_interval
function recalculate_preconditioner!()
if nl_solver_params.preconditioner_type === Val(:electron_lu)
global_rank[] == 0 && println("recalculating precon")
nl_solver_params.solves_since_precon_update[] = 0
nl_solver_params.precon_dt[] = ion_dt
Expand Down Expand Up @@ -1932,6 +1928,21 @@ global_rank[] == 0 && println("recalculating precon")
nl_solver_params.preconditioners[ir] =
(orig_lu, precon_matrix, input_buffer, output_buffer)
end

return nothing
end
end

if nl_solver_params.preconditioner_type === Val(:electron_lu)
if ion_dt > 1.5 * nl_solver_params.precon_dt[] ||
ion_dt < 2.0/3.0 * nl_solver_params.precon_dt[]

# dt has changed significantly, so update the preconditioner
nl_solver_params.solves_since_precon_update[] = nl_solver_params.preconditioner_update_interval
end

if nl_solver_params.solves_since_precon_update[] nl_solver_params.preconditioner_update_interval
recalculate_preconditioner!()
end

@timeit_debug global_timer lu_precon!(x) = begin
Expand Down
16 changes: 13 additions & 3 deletions moment_kinetics/src/nonlinear_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,17 +393,21 @@ is not necessary to have a very tight `linear_rtol` for the GMRES solve.
@timeit global_timer newton_solve!(
x, residual_func!, residual, delta_x, rhs_delta, v, w,
nl_solver_params; left_preconditioner=nothing,
right_preconditioner=nothing, coords) = begin
right_preconditioner=nothing, recalculate_preconditioner=nothing,
coords) = begin
# This wrapper function constructs the `solver_type` from coords, so that the body of
# the inner `newton_solve!()` can be fully type-stable
solver_type = Val(Symbol((c for c keys(coords))...))
return newton_solve!(x, residual_func!, residual, delta_x, rhs_delta, v, w,
nl_solver_params, solver_type; left_preconditioner=left_preconditioner,
right_preconditioner=right_preconditioner, coords=coords)
right_preconditioner=right_preconditioner,
recalculate_preconditioner=recalculate_preconditioner,
coords=coords)
end
function newton_solve!(x, residual_func!, residual, delta_x, rhs_delta, v, w,
nl_solver_params, solver_type::Val; left_preconditioner=nothing,
right_preconditioner=nothing, coords)
right_preconditioner=nothing, recalculate_preconditioner=nothing,
coords)

rtol = nl_solver_params.rtol
atol = nl_solver_params.atol
Expand Down Expand Up @@ -512,6 +516,12 @@ old_precon_iterations = nl_solver_params.precon_iterations[]
parallel_map(solver_type, (w) -> w, x, w)
previous_residual_norm = residual_norm

if recalculate_preconditioner !== nothing && counter % nl_solver_params.preconditioner_update_interval == 0
# Have taken a large number of Newton iterations already - convergence must be
# slow, so try updating the preconditioner.
recalculate_preconditioner()
end

#println("Newton residual ", residual_norm, " ", linear_its, " $rtol $atol")

if residual_norm < 0.1/rtol && close_counter < 0 && close_linear_counter < 0
Expand Down

0 comments on commit a3fe1b7

Please sign in to comment.