Skip to content

Commit

Permalink
Use async sends in particle communication (#4257)
Browse files Browse the repository at this point in the history
This changes a few instances of `ParallelDescriptor::Send` to
`ParallelDescriptor::Asend`. This improves the overall runtime of the
redistribute benchmark by a few percent for multi-node runs:

```
Gpus      Send        Asend
4	  4.125       4.123
8	  4.29 	      4.198
16	  4.472       4.297
32	  4.62 	      4.456
```

The proposed changes:
- [ ] fix a bug or incorrect behavior in AMReX
- [ ] add new capabilities to AMReX
- [ ] changes answers in the test suite to more than roundoff level
- [ ] are likely to significantly affect the results of downstream AMReX
users
- [ ] include documentation in the code and/or rst files, if appropriate
  • Loading branch information
atmyers authored Dec 10, 2024
1 parent 8024e3a commit 9643da4
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 20 deletions.
6 changes: 3 additions & 3 deletions Src/Base/AMReX_ParallelDescriptor.H
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,9 @@ while ( false )
*/
inline int SeqNum () noexcept { return ParallelContext::get_inc_mpi_tag(); }

template <class T> Message Asend(const T*, size_t n, int pid, int tag);
template <class T> Message Asend(const T*, size_t n, int pid, int tag, MPI_Comm comm);
template <class T> Message Asend(const std::vector<T>& buf, int pid, int tag);
template <class T> [[nodiscard]] Message Asend(const T*, size_t n, int pid, int tag);
template <class T> [[nodiscard]] Message Asend(const T*, size_t n, int pid, int tag, MPI_Comm comm);
template <class T> [[nodiscard]] Message Asend(const std::vector<T>& buf, int pid, int tag);

template <class T> Message Arecv(T*, size_t n, int pid, int tag);
template <class T> Message Arecv(T*, size_t n, int pid, int tag, MPI_Comm comm);
Expand Down
18 changes: 13 additions & 5 deletions Src/Particle/AMReX_ParticleCommunication.H
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ struct ParticleCopyPlan
mutable Vector<MPI_Status> m_build_stats;
mutable Vector<MPI_Request> m_build_rreqs;

mutable Vector<MPI_Status> m_particle_stats;
mutable Vector<MPI_Status> m_particle_rstats;
mutable Vector<MPI_Request> m_particle_rreqs;

mutable Vector<MPI_Status> m_particle_sstats;
mutable Vector<MPI_Request> m_particle_sreqs;

Vector<Long> m_snd_num_particles;
Vector<Long> m_rcv_num_particles;

Expand Down Expand Up @@ -533,12 +536,15 @@ void communicateParticlesStart (const PC& pc, ParticleCopyPlan& plan, const SndB

plan.m_nrcvs = int(RcvProc.size());

plan.m_particle_stats.resize(0);
plan.m_particle_stats.resize(plan.m_nrcvs);
plan.m_particle_rstats.resize(0);
plan.m_particle_rstats.resize(plan.m_nrcvs);

plan.m_particle_rreqs.resize(0);
plan.m_particle_rreqs.resize(plan.m_nrcvs);

plan.m_particle_sstats.resize(0);
plan.m_particle_sreqs.resize(0);

const int SeqNum = ParallelDescriptor::SeqNum();

// Post receives.
Expand Down Expand Up @@ -571,10 +577,12 @@ void communicateParticlesStart (const PC& pc, ParticleCopyPlan& plan, const SndB
AMREX_ASSERT(plan.m_snd_counts[i] % ParallelDescriptor::sizeof_selected_comm_data_type(plan.m_snd_num_particles[i]*psize) == 0);
AMREX_ASSERT(Who >= 0 && Who < NProcs);

ParallelDescriptor::Send((char const*)(snd_buffer.dataPtr()+snd_offset), Cnt, Who, SeqNum,
ParallelContext::CommunicatorSub());
plan.m_particle_sreqs.push_back(ParallelDescriptor::Asend((char const*)(snd_buffer.dataPtr()+snd_offset), Cnt, Who, SeqNum,
ParallelContext::CommunicatorSub()).req());
}

plan.m_particle_sstats.resize(plan.m_particle_sreqs.size());

amrex::ignore_unused(pc);
#else
amrex::ignore_unused(pc,plan,snd_buffer,rcv_buffer);
Expand Down
41 changes: 29 additions & 12 deletions Src/Particle/AMReX_ParticleCommunication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ void ParticleCopyPlan::buildMPIStart (const ParticleBufferMap& map, Long psize)
m_build_rreqs[i] = ParallelDescriptor::Arecv((char*) (m_rcv_data.dataPtr() + offset), Cnt, Who, SeqNum, ParallelContext::CommunicatorSub()).req();
}

Vector<MPI_Request> snd_reqs;
Vector<MPI_Status> snd_stats;
for (auto i : m_neighbor_procs)
{
if (i == MyProc) { continue; }
Expand All @@ -169,8 +171,8 @@ void ParticleCopyPlan::buildMPIStart (const ParticleBufferMap& map, Long psize)
AMREX_ASSERT(Who >= 0 && Who < NProcs);
AMREX_ASSERT(Cnt < std::numeric_limits<int>::max());

ParallelDescriptor::Send((char*) snd_data[i].data(), Cnt, Who, SeqNum,
ParallelContext::CommunicatorSub());
snd_reqs.push_back(ParallelDescriptor::Asend((char*) snd_data[i].data(), Cnt, Who, SeqNum,
ParallelContext::CommunicatorSub()).req());
}

m_snd_counts.resize(0);
Expand Down Expand Up @@ -199,6 +201,10 @@ void ParticleCopyPlan::buildMPIStart (const ParticleBufferMap& map, Long psize)
m_snd_pad_correction_d.resize(m_snd_pad_correction_h.size());
Gpu::copy(Gpu::hostToDevice, m_snd_pad_correction_h.begin(), m_snd_pad_correction_h.end(),
m_snd_pad_correction_d.begin());

snd_stats.resize(0);
snd_stats.resize(snd_reqs.size());
ParallelDescriptor::Waitall(snd_reqs, snd_stats);
#else
amrex::ignore_unused(map,psize);
#endif
Expand Down Expand Up @@ -265,8 +271,10 @@ void ParticleCopyPlan::doHandShakeLocal (const Vector<Long>& Snds, Vector<Long>&
#ifdef AMREX_USE_MPI
const int SeqNum = ParallelDescriptor::SeqNum();
const auto num_rcvs = static_cast<int>(m_neighbor_procs.size());
Vector<MPI_Status> stats(num_rcvs);
Vector<MPI_Status> rstats(num_rcvs);
Vector<MPI_Request> rreqs(num_rcvs);
Vector<MPI_Status> sstats(num_rcvs);
Vector<MPI_Request> sreqs(num_rcvs);

// Post receives
for (int i = 0; i < num_rcvs; ++i)
Expand All @@ -288,13 +296,14 @@ void ParticleCopyPlan::doHandShakeLocal (const Vector<Long>& Snds, Vector<Long>&

AMREX_ASSERT(Who >= 0 && Who < ParallelContext::NProcsSub());

ParallelDescriptor::Send(&Snds[Who], Cnt, Who, SeqNum,
ParallelContext::CommunicatorSub());
sreqs[i] = ParallelDescriptor::Asend(&Snds[Who], Cnt, Who, SeqNum,
ParallelContext::CommunicatorSub()).req();
}

if (num_rcvs > 0)
{
ParallelDescriptor::Waitall(rreqs, stats);
ParallelDescriptor::Waitall(sreqs, sstats);
ParallelDescriptor::Waitall(rreqs, rstats);
}
#else
amrex::ignore_unused(Snds,Rcvs);
Expand Down Expand Up @@ -339,8 +348,10 @@ void ParticleCopyPlan::doHandShakeGlobal (const Vector<Long>& Snds, Vector<Long>
ParallelDescriptor::Mpi_typemap<Long>::type(), MPI_SUM,
ParallelContext::CommunicatorSub());

Vector<MPI_Status> stats(num_rcvs);
Vector<MPI_Status> rstats(num_rcvs);
Vector<MPI_Request> rreqs(num_rcvs);
Vector<MPI_Status> sstats;
Vector<MPI_Request> sreqs;

Vector<Long> num_bytes_rcv(num_rcvs);
for (int i = 0; i < static_cast<int>(num_rcvs); ++i)
Expand All @@ -352,15 +363,17 @@ void ParticleCopyPlan::doHandShakeGlobal (const Vector<Long>& Snds, Vector<Long>
{
if (Snds[i] == 0) { continue; }
const Long Cnt = 1;
MPI_Send( &Snds[i], Cnt, ParallelDescriptor::Mpi_typemap<Long>::type(), i, SeqNum,
ParallelContext::CommunicatorSub());
sreqs.push_back(ParallelDescriptor::Asend( &Snds[i], Cnt, i, SeqNum, ParallelContext::CommunicatorSub()).req());
}

MPI_Waitall(static_cast<int>(num_rcvs), rreqs.data(), stats.data());
sstats.resize(0);
sstats.resize(sreqs.size());
ParallelDescriptor::Waitall(sreqs, sstats);
ParallelDescriptor::Waitall(rreqs, rstats);

for (int i = 0; i < num_rcvs; ++i)
{
const auto Who = stats[i].MPI_SOURCE;
const auto Who = rstats[i].MPI_SOURCE;
Rcvs[Who] = num_bytes_rcv[i];
}
#else
Expand All @@ -372,9 +385,13 @@ void amrex::communicateParticlesFinish (const ParticleCopyPlan& plan)
{
BL_PROFILE("amrex::communicateParticlesFinish");
#ifdef AMREX_USE_MPI
if (plan.m_NumSnds > 0)
{
ParallelDescriptor::Waitall(plan.m_particle_sreqs, plan.m_particle_sstats);
}
if (plan.m_nrcvs > 0)
{
ParallelDescriptor::Waitall(plan.m_particle_rreqs, plan.m_particle_stats);
ParallelDescriptor::Waitall(plan.m_particle_rreqs, plan.m_particle_rstats);
}
#else
amrex::ignore_unused(plan);
Expand Down

0 comments on commit 9643da4

Please sign in to comment.