Skip to content

Commit

Permalink
Avoid using MPI.bcast() to prevent type instability
Browse files Browse the repository at this point in the history
MPI.bcast() can communicate (almost?) any type of object, but that means
that the type of its result is not necessarily known before
communication happens, leading to type instability. Therefore prefer to
use other MPI.jl functions that are type-stable.

Also use in-place MPI operations in a few more places to avoid
possibility of allocating extra buffers.
  • Loading branch information
johnomotani committed Sep 28, 2024
1 parent 31384c5 commit 8b89e2d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 77 deletions.
26 changes: 12 additions & 14 deletions moment_kinetics/src/initial_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,23 +364,21 @@ function initialize_electrons!(pdf, moments, fields, geometry, composition, r, z
#
# q at the boundaries tells us dTe/dz for Braginskii electrons
nu_ei = collisions.electron_fluid.nu_ei
dTe_dz_lower = Ref{mk_float}(0.0)
if z.irank == 0
dTe_dz_lower = @. -moments.electron.qpar[1,:] * 2.0 / 3.16 /
moments.electron.ppar[1,:] *
composition.me_over_mi * nu_ei
else
dTe_dz_lower = nothing
dTe_dz_lower[] = @. -moments.electron.qpar[1,:] * 2.0 / 3.16 /
moments.electron.ppar[1,:] *
composition.me_over_mi * nu_ei
end
dTe_dz_lower = MPI.bcast(dTe_dz_lower, z.comm; root=0)
MPI.Bcast!(dTe_dz_lower, z.comm; root=0)

dTe_dz_upper = Ref{mk_float}(0.0)
if z.irank == z.nrank - 1
dTe_dz_upper = @. -moments.electron.qpar[end,:] * 2.0 / 3.16 /
moments.electron.ppar[end,:] *
composition.me_over_mi * nu_ei
else
dTe_dz_upper = nothing
dTe_dz_upper[] = @. -moments.electron.qpar[end,:] * 2.0 / 3.16 /
moments.electron.ppar[end,:] *
composition.me_over_mi * nu_ei
end
dTe_dz_upper = MPI.bcast(dTe_dz_upper, z.comm; root=(z.nrank - 1))
MPI.Bcast!(dTe_dz_upper, z.comm; root=(z.nrank - 1))

# The temperature should already be equal to the 'Boltzmann electron'
# Te, so we just need to add a cubic that vanishes at ±Lz/2
Expand All @@ -401,9 +399,9 @@ function initialize_electrons!(pdf, moments, fields, geometry, composition, r, z
# 2*B - 3*2*B = -4*B = dTe/dz_upper + dTe/dz_lower
Lz = z.L
zg = z.grid
C = @. (dTe_dz_upper - dTe_dz_lower) / 2.0 / Lz
C = @. (dTe_dz_upper[] - dTe_dz_lower[]) / 2.0 / Lz
A = @. -C * Lz^2 / 4
B = @. -(dTe_dz_lower + dTe_dz_upper) / 4.0
B = @. -(dTe_dz_lower[] + dTe_dz_upper[]) / 4.0
D = @. -4.0 * B / Lz^2
@loop_r ir begin
@. moments.electron.temp[:,ir] += A[ir] + B[ir]*zg + C[ir]*zg^2 +
Expand Down
92 changes: 43 additions & 49 deletions moment_kinetics/src/nonlinear_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,17 +497,16 @@ function distributed_norm_z(residual::AbstractArray{mk_float, 1}; coords, rtol,
end

_block_synchronize()
block_norm = MPI.Reduce(local_norm, +, comm_block[])
global_norm = Ref(local_norm)
block_norm = MPI.Reduce!(global_norm, +, comm_block[]) # global_norm is the norm_square for the block

if block_rank[] == 0
global_norm = MPI.Allreduce(block_norm, +, comm_inter_block[])
global_norm = sqrt(global_norm / z.n_global)
else
global_norm = nothing
MPI.Allreduce!(global_norm, +, comm_inter_block[]) # global_norm is the norm_square for the whole grid
global_norm[] = sqrt(global_norm[] / z.n_global)
end
global_norm = MPI.bcast(global_norm, comm_block[]; root=0)
MPI.Bcast!(global_norm, comm_block[]; root=0)

return global_norm
return global_norm[]
end

function distributed_norm_vpa(residual::AbstractArray{mk_float, 1}; coords, rtol, atol, x)
Expand Down Expand Up @@ -548,13 +547,12 @@ function distributed_norm_z_vperp_vpa(residual::Tuple{AbstractArray{mk_float, 1}
end

_block_synchronize()
ppar_block_norm_square = MPI.Reduce(ppar_local_norm_square, +, comm_block[])
global_norm_ppar = Ref(ppar_local_norm_square) # global_norm_ppar is the norm_square for ppar in the block
ppar_block_norm_square = MPI.Reduce!(global_norm_ppar, +, comm_block[])

if block_rank[] == 0
ppar_global_norm_square = MPI.Allreduce(ppar_block_norm_square, +, comm_inter_block[])
ppar_global_norm_square = ppar_global_norm_square / z.n_global
else
ppar_global_norm_square = nothing
MPI.Allreduce!(global_norm_ppar, +, comm_inter_block[]) # global_norm_ppar is the norm_square for ppar in the whole grid
global_norm_ppar[] = global_norm_ppar[] / z.n_global
end

begin_z_vperp_vpa_region()
Expand All @@ -570,20 +568,21 @@ function distributed_norm_z_vperp_vpa(residual::Tuple{AbstractArray{mk_float, 1}
end

_block_synchronize()
pdf_block_norm_square = MPI.Reduce(pdf_local_norm_square, +, comm_block[])
global_norm = Ref(pdf_local_norm_square)
MPI.Reduce!(global_norm, +, comm_block[]) # global_norm is the norm_square for the block

if block_rank[] == 0
pdf_global_norm_square = MPI.Allreduce(pdf_block_norm_square, +, comm_inter_block[])
pdf_global_norm_square = pdf_global_norm_square / (z.n_global * vperp.n_global * vpa.n_global)
MPI.Allreduce!(global_norm, +, comm_inter_block[]) # global_norm is the norm_square for the whole grid
global_norm[] = global_norm[] / (z.n_global * vperp.n_global * vpa.n_global)

global_norm = sqrt(mean((ppar_global_norm_square, pdf_global_norm_square)))
global_norm = sqrt(mean((global_norm_ppar[], global_norm[])))
else
global_norm = nothing
end

global_norm = MPI.bcast(global_norm, comm_block[]; root=0)
MPI.Bcast!(global_norm, comm_block[]; root=0)

return global_norm
return global_norm[]
end

function distributed_norm_s_r_z_vperp_vpa(residual::AbstractArray{mk_float, 5};
Expand Down Expand Up @@ -617,17 +616,16 @@ function distributed_norm_s_r_z_vperp_vpa(residual::AbstractArray{mk_float, 5};
end

_block_synchronize()
block_norm = MPI.Reduce(local_norm, +, comm_block[])
global_norm = Ref(local_norm)
MPI.Reduce!(global_norm, +, comm_block[]) # global_norm is the norm_square for the block

if block_rank[] == 0
global_norm = MPI.Allreduce(block_norm, +, comm_inter_block[])
global_norm = sqrt(global_norm / (n_ion_species * r.n_global * z.n_global * vperp.n_global * vpa.n_global))
else
global_norm = nothing
MPI.Allreduce!(global_norm, +, comm_inter_block[]) # global_norm is the norm_square for the whole grid
global_norm[] = sqrt(global_norm[] / (n_ion_species * r.n_global * z.n_global * vperp.n_global * vpa.n_global))
end
global_norm = MPI.bcast(global_norm, comm_block[]; root=0)
MPI.Bcast!(global_norm, comm_block[]; root=0)

return global_norm
return global_norm[]
end

"""
Expand Down Expand Up @@ -683,16 +681,15 @@ function distributed_dot_z(v::AbstractArray{mk_float, 1}, w::AbstractArray{mk_fl
end

_block_synchronize()
block_dot = MPI.Reduce(local_dot, +, comm_block[])
global_dot = Ref(local_dot)
MPI.Reduce!(global_dot, +, comm_block[]) # global_dot is the dot for the block

if block_rank[] == 0
global_dot = MPI.Allreduce(block_dot, +, comm_inter_block[])
global_dot = global_dot / z.n_global
else
global_dot = nothing
MPI.Allreduce!(global_dot, +, comm_inter_block[]) # global_dot is the dot for the whole grid
global_dot[] = global_dot[] / z.n_global
end

return global_dot
return global_dot[]
end

function distributed_dot_vpa(v::AbstractArray{mk_float, 1}, w::AbstractArray{mk_float, 1};
Expand Down Expand Up @@ -735,13 +732,12 @@ function distributed_dot_z_vperp_vpa(v::Tuple{AbstractArray{mk_float, 1},Abstrac
end

_block_synchronize()
ppar_block_dot = MPI.Reduce(ppar_local_dot, +, comm_block[])
ppar_global_dot = Ref(ppar_local_dot)
MPI.Reduce(ppar_global_dot, +, comm_block[]) # ppar_global_dot is the ppar_dot for the block

if block_rank[] == 0
ppar_global_dot = MPI.Allreduce(ppar_block_dot, +, comm_inter_block[])
ppar_global_dot = ppar_global_dot / z.n_global
else
ppar_global_dot = nothing
MPI.Allreduce!(ppar_global_dot, +, comm_inter_block[]) # ppar_global_dot is the ppar_dot for the whole grid
ppar_global_dot[] = ppar_global_dot[] / z.n_global
end

begin_z_vperp_vpa_region()
Expand All @@ -755,18 +751,17 @@ function distributed_dot_z_vperp_vpa(v::Tuple{AbstractArray{mk_float, 1},Abstrac
end

_block_synchronize()
pdf_block_dot = MPI.Reduce(pdf_local_dot, +, comm_block[])
global_dot = Ref(pdf_local_dot)
MPI.Reduce!(global_dot, +, comm_block[]) # global_dot is the dot for the block

if block_rank[] == 0
pdf_global_dot = MPI.Allreduce(pdf_block_dot, +, comm_inter_block[])
pdf_global_dot = pdf_global_dot / (z.n_global * vperp.n_global * vpa.n_global)
MPI.Allreduce(global_dot, +, comm_inter_block[]) # global_dot is the dot for the whole grid
global_dot[] = global_dot[] / (z.n_global * vperp.n_global * vpa.n_global)

global_dot = mean((ppar_global_dot, pdf_global_dot))
else
global_dot = nothing
global_dot[] = mean((ppar_global_dot[], global_dot[]))
end

return global_dot
return global_dot[]
end

function distributed_dot_s_r_z_vperp_vpa(v::AbstractArray{mk_float, 5},
Expand Down Expand Up @@ -800,16 +795,15 @@ function distributed_dot_s_r_z_vperp_vpa(v::AbstractArray{mk_float, 5},
end

_block_synchronize()
block_dot = MPI.Reduce(local_dot, +, comm_block[])
global_dot = Ref(local_dot)
MPI.Reduce(global_dot, +, comm_block[]) # global_dot is the dot for the block

if block_rank[] == 0
global_dot = MPI.Allreduce(block_dot, +, comm_inter_block[])
global_dot = global_dot / (n_ion_species * r.n_global * z.n_global * vperp.n_global * vpa.n_global)
else
global_dot = nothing
MPI.Allreduce(global_dot, +, comm_inter_block[]) # global_dot is the dot for the whole grid
global_dot[] = global_dot[] / (n_ion_species * r.n_global * z.n_global * vperp.n_global * vpa.n_global)
end

return global_dot
return global_dot[]
end

"""
Expand Down
28 changes: 14 additions & 14 deletions moment_kinetics/src/runge_kutta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1082,26 +1082,26 @@ function adaptive_timestep_update_t_params!(t_params, CFL_limits, error_norms,

if error_norm_method == "Linf"
# Get overall maximum error on the shared-memory block
error_norms = MPI.Reduce(error_norms, max, comm_block[]; root=0)
MPI.Reduce!(error_norms, max, comm_block[]; root=0)

error_norm = nothing
error_norm = Ref{mk_float}(0.0)
max_error_variable_index = -1
@serial_region begin
# Get maximum error over all blocks
error_norms = MPI.Allreduce(error_norms, max, comm_inter_block[])
MPI.Allreduce!(error_norms, max, comm_inter_block[])
max_error_variable_index = argmax(error_norms)
error_norm = error_norms[max_error_variable_index]
error_norm[] = error_norms[max_error_variable_index]
end
error_norm = MPI.bcast(error_norm, 0, comm_block[])
MPI.Bcast!(error_norm, 0, comm_block[])
elseif error_norm_method == "L2"
# Get overall maximum error on the shared-memory block
error_norms = MPI.Reduce(error_norms, +, comm_block[]; root=0)
MPI.Reduce!(error_norms, +, comm_block[]; root=0)

error_norm = nothing
error_norm = Ref{mk_float}(0.0)
max_error_variable_index = -1
@serial_region begin
# Get maximum error over all blocks
error_norms = MPI.Allreduce(error_norms, +, comm_inter_block[])
MPI.Allreduce!(error_norms, +, comm_inter_block[])

# So far `error_norms` is the sum of squares of the errors. Now that summation
# is finished, need to divide by total number of points and take square-root.
Expand All @@ -1110,13 +1110,13 @@ function adaptive_timestep_update_t_params!(t_params, CFL_limits, error_norms,
# Weight the error from each variable equally by taking the mean, so the
# larger number of points in the distribution functions does not mean that
# error on the moments is ignored.
error_norm = mean(error_norms)
error_norm[] = mean(error_norms)

# Record which variable had the maximum error
max_error_variable_index = argmax(error_norms)
end

error_norm = MPI.bcast(error_norm, 0, comm_block[])
MPI.Bcast!(error_norm, 0, comm_block[])
else
error("Unrecognized error_norm_method '$method'")
end
Expand Down Expand Up @@ -1170,7 +1170,7 @@ function adaptive_timestep_update_t_params!(t_params, CFL_limits, error_norms,
t_params.step_to_moments_output[] = false
t_params.step_to_dfns_output[] = false
end
elseif (error_norm > 1.0 || isnan(error_norm)) && current_dt > t_params.minimum_dt * (1.0 + 1.0e-13)
elseif (error_norm[] > 1.0 || isnan(error_norm[])) && current_dt > t_params.minimum_dt * (1.0 + 1.0e-13)
# (1.0 + 1.0e-13) fudge factor accounts for possible rounding errors when
# t+dt=next_output_time.
# Use current_dt instead of t_params.dt[] here because we are about to write to
Expand All @@ -1191,7 +1191,7 @@ function adaptive_timestep_update_t_params!(t_params, CFL_limits, error_norms,
# Get new timestep estimate using same formula as for a successful step, but
# limit decrease to factor 1/2 - this factor should probably be settable!
t_params.dt[] = max(t_params.dt[] / 2.0,
t_params.dt[] * t_params.step_update_prefactor * error_norm^(-1.0/t_params.rk_order))
t_params.dt[] * t_params.step_update_prefactor * error_norm[]^(-1.0/t_params.rk_order))
t_params.dt[] = max(t_params.dt[], t_params.minimum_dt)

# Don't update the simulation time, as this step failed
Expand All @@ -1206,7 +1206,7 @@ function adaptive_timestep_update_t_params!(t_params, CFL_limits, error_norms,
t_params.step_to_moments_output[] = false
t_params.step_to_dfns_output[] = false

#println("t=$t, timestep failed, error_norm=$error_norm, error_norms=$error_norms, decreasing timestep to ", t_params.dt[])
#println("t=$t, timestep failed, error_norm=$(error_norm[]), error_norms=$error_norms, decreasing timestep to ", t_params.dt[])
end
else
@serial_region begin
Expand Down Expand Up @@ -1237,7 +1237,7 @@ function adaptive_timestep_update_t_params!(t_params, CFL_limits, error_norms,
# `step_update_prefactor` is a constant numerical factor to make the estimate
# of a good value for the next timestep slightly conservative. It defaults to
# 0.9.
t_params.dt[] *= t_params.step_update_prefactor * error_norm^(-1.0/t_params.rk_order)
t_params.dt[] *= t_params.step_update_prefactor * error_norm[]^(-1.0/t_params.rk_order)

if t_params.dt[] > CFL_limit
t_params.dt[] = CFL_limit
Expand Down

0 comments on commit 8b89e2d

Please sign in to comment.