Skip to content

Commit

Permalink
FFT: Optimization for single process
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Nov 16, 2024
1 parent 0165b67 commit 8ef3bff
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 83 deletions.
25 changes: 25 additions & 0 deletions Src/Base/AMReX_FabArrayCommI.H
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,31 @@ FabArray<FAB>::ParallelCopy_nowait (const FabArray<FAB>& src,

n_filled = dnghost;

if ((ParallelDescriptor::NProcs() == 1) &&
(this->size() == 1) && (src.size() == 1) &&
!period.isAnyPeriodic() && !to_ghost_cells_only)
{
if (this != &src) { // avoid self copy or plus
auto const& da = this->array(0, dcomp);
auto const& sa = src.const_array(0, scomp);
Box box = amrex::grow(src.box(0),snghost)
& amrex::grow(this->box(0),dnghost);
if (op == FabArrayBase::COPY) {
ParallelFor(box, ncomp,
[=] AMREX_GPU_DEVICE (int i, int j, int k, int n) {
da(i,j,k,n) = sa(i,j,k,n);
});
} else {
ParallelFor(box, ncomp,
[=] AMREX_GPU_DEVICE (int i, int j, int k, int n) {
da(i,j,k,n) += sa(i,j,k,n);
});
}
Gpu::streamSynchronize();
}
return;
}

if ((src.boxArray().ixType().cellCentered() || op == FabArrayBase::COPY) &&
(boxarray == src.boxarray && distributionMap == src.distributionMap) &&
snghost == IntVect::TheZeroVector() &&
Expand Down
4 changes: 4 additions & 0 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ template <typename T>
template <class F>
void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
{
BL_PROFILE("OpenBCSolver::setGreensFunction");

auto* infab = m_info.batch_mode ? detail::get_fab(m_r2c_green->m_rx)
: detail::get_fab(m_r2c.m_rx);
auto const& lo = m_domain.smallEnd();
Expand Down Expand Up @@ -158,6 +160,8 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
template <typename T>
void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
{
BL_PROFILE("OpenBCSolver::solve");

auto& inmf = m_r2c.m_rx;
inmf.setVal(T(0));
inmf.ParallelCopy(rho, 0, 0, 1);
Expand Down
216 changes: 133 additions & 83 deletions Src/FFT/AMReX_FFT_R2C.H
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ public:

void prepare_openbc ();

void backward_doit (MF& outmf, IntVect const& ngout = IntVect(0));

private:

static std::pair<Plan<T>,Plan<T>> make_c2c_plans (cMF& inout);

void backward_doit (MF& outmf, IntVect const& ngout = IntVect(0));

Plan<T> m_fft_fwd_x{};
Plan<T> m_fft_bwd_x{};
Plan<T> m_fft_fwd_y{};
Expand Down Expand Up @@ -214,6 +214,7 @@ private:

Info m_info;

bool m_do_alld_fft = false;
bool m_slab_decomp = false;
bool m_openbc_half = false;
};
Expand Down Expand Up @@ -259,10 +260,6 @@ R2C<T,D,S>::R2C (Box const& domain, Info const& info)
}
#endif

//
// make data containers
//

auto bax = amrex::decompose(m_real_domain, nprocs,
{AMREX_D_DECL(false,!m_slab_decomp,true)}, true);
DistributionMapping dmx = detail::make_iota_distromap(bax.size());
Expand All @@ -278,109 +275,138 @@ R2C<T,D,S>::R2C (Box const& domain, Info const& info)
m_cx.define(cbax, dmx, 1, 0, MFInfo().SetAlloc(false));
}

m_do_alld_fft = (ParallelDescriptor::NProcs() == 1) && (! m_info.batch_mode);

if (!m_do_alld_fft) // do a series of 1d or 2d ffts
{
//
// make data containers
//

#if (AMREX_SPACEDIM >= 2)
DistributionMapping cdmy;
if ((m_real_domain.length(1) > 1) && !m_slab_decomp) {
auto cbay = amrex::decompose(m_spectral_domain_y, nprocs,
{AMREX_D_DECL(false,true,true)}, true);
if (cbay.size() == dmx.size()) {
cdmy = dmx;
} else {
cdmy = detail::make_iota_distromap(cbay.size());
DistributionMapping cdmy;
if ((m_real_domain.length(1) > 1) && !m_slab_decomp) {
auto cbay = amrex::decompose(m_spectral_domain_y, nprocs,
{AMREX_D_DECL(false,true,true)}, true);
if (cbay.size() == dmx.size()) {
cdmy = dmx;
} else {
cdmy = detail::make_iota_distromap(cbay.size());
}
m_cy.define(cbay, cdmy, 1, 0, MFInfo().SetAlloc(false));
}
m_cy.define(cbay, cdmy, 1, 0, MFInfo().SetAlloc(false));
}
#endif

#if (AMREX_SPACEDIM == 3)
if (m_real_domain.length(1) > 1 &&
(! m_info.batch_mode && m_real_domain.length(2) > 1))
{
auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs,
{false,true,true}, true);
DistributionMapping cdmz;
if (cbaz.size() == dmx.size()) {
cdmz = dmx;
} else if (cbaz.size() == cdmy.size()) {
cdmz = cdmy;
} else {
cdmz = detail::make_iota_distromap(cbaz.size());
if (m_real_domain.length(1) > 1 &&
(! m_info.batch_mode && m_real_domain.length(2) > 1))
{
auto cbaz = amrex::decompose(m_spectral_domain_z, nprocs,
{false,true,true}, true);
DistributionMapping cdmz;
if (cbaz.size() == dmx.size()) {
cdmz = dmx;
} else if (cbaz.size() == cdmy.size()) {
cdmz = cdmy;
} else {
cdmz = detail::make_iota_distromap(cbaz.size());
}
m_cz.define(cbaz, cdmz, 1, 0, MFInfo().SetAlloc(false));
}
m_cz.define(cbaz, cdmz, 1, 0, MFInfo().SetAlloc(false));
}
#endif

if (m_slab_decomp) {
m_data_1 = detail::make_mfs_share(m_rx, m_cz);
m_data_2 = detail::make_mfs_share(m_cx, m_cx);
} else {
m_data_1 = detail::make_mfs_share(m_rx, m_cy);
m_data_2 = detail::make_mfs_share(m_cx, m_cz);
}
if (m_slab_decomp) {
m_data_1 = detail::make_mfs_share(m_rx, m_cz);
m_data_2 = detail::make_mfs_share(m_cx, m_cx);
} else {
m_data_1 = detail::make_mfs_share(m_rx, m_cy);
m_data_2 = detail::make_mfs_share(m_cx, m_cz);
}

//
// make copiers
//
//
// make copiers
//

#if (AMREX_SPACEDIM >= 2)
if (! m_cy.empty()) {
// comm meta-data between x and y phases
m_cmd_x2y = std::make_unique<MultiBlockCommMetaData>
(m_cy, m_spectral_domain_y, m_cx, IntVect(0), m_dtos_x2y);
m_cmd_y2x = std::make_unique<MultiBlockCommMetaData>
(m_cx, m_spectral_domain_x, m_cy, IntVect(0), m_dtos_y2x);
}
if (! m_cy.empty()) {
// comm meta-data between x and y phases
m_cmd_x2y = std::make_unique<MultiBlockCommMetaData>
(m_cy, m_spectral_domain_y, m_cx, IntVect(0), m_dtos_x2y);
m_cmd_y2x = std::make_unique<MultiBlockCommMetaData>
(m_cx, m_spectral_domain_x, m_cy, IntVect(0), m_dtos_y2x);
}
#endif
#if (AMREX_SPACEDIM == 3)
if (! m_cz.empty() ) {
if (m_slab_decomp) {
// comm meta-data between xy and z phases
m_cmd_x2z = std::make_unique<MultiBlockCommMetaData>
(m_cz, m_spectral_domain_z, m_cx, IntVect(0), m_dtos_x2z);
m_cmd_z2x = std::make_unique<MultiBlockCommMetaData>
(m_cx, m_spectral_domain_x, m_cz, IntVect(0), m_dtos_z2x);
} else {
// comm meta-data between y and z phases
m_cmd_y2z = std::make_unique<MultiBlockCommMetaData>
(m_cz, m_spectral_domain_z, m_cy, IntVect(0), m_dtos_y2z);
m_cmd_z2y = std::make_unique<MultiBlockCommMetaData>
(m_cy, m_spectral_domain_y, m_cz, IntVect(0), m_dtos_z2y);
if (! m_cz.empty() ) {
if (m_slab_decomp) {
// comm meta-data between xy and z phases
m_cmd_x2z = std::make_unique<MultiBlockCommMetaData>
(m_cz, m_spectral_domain_z, m_cx, IntVect(0), m_dtos_x2z);
m_cmd_z2x = std::make_unique<MultiBlockCommMetaData>
(m_cx, m_spectral_domain_x, m_cz, IntVect(0), m_dtos_z2x);
} else {
// comm meta-data between y and z phases
m_cmd_y2z = std::make_unique<MultiBlockCommMetaData>
(m_cz, m_spectral_domain_z, m_cy, IntVect(0), m_dtos_y2z);
m_cmd_z2y = std::make_unique<MultiBlockCommMetaData>
(m_cy, m_spectral_domain_y, m_cz, IntVect(0), m_dtos_z2y);
}
}
}
#endif

//
// make plans
//
//
// make plans
//

if (myproc < m_rx.size())
{
Box const& box = m_rx.box(myproc);
auto* pr = m_rx[myproc].dataPtr();
auto* pc = (typename Plan<T>::VendorComplex *)m_cx[myproc].dataPtr();
#ifdef AMREX_USE_SYCL
m_fft_fwd_x.template init_r2c<Direction::forward>(box, pr, pc, m_slab_decomp);
m_fft_bwd_x = m_fft_fwd_x;
#else
if constexpr (D == Direction::both || D == Direction::forward) {
m_fft_fwd_x.template init_r2c<Direction::forward>(box, pr, pc, m_slab_decomp);
}
if constexpr (D == Direction::both || D == Direction::backward) {
m_fft_bwd_x.template init_r2c<Direction::backward>(box, pr, pc, m_slab_decomp);
}
#endif
}

if (myproc < m_rx.size())
#if (AMREX_SPACEDIM >= 2)
if (! m_cy.empty()) {
std::tie(m_fft_fwd_y, m_fft_bwd_y) = make_c2c_plans(m_cy);
}
#endif
#if (AMREX_SPACEDIM == 3)
if (! m_cz.empty()) {
std::tie(m_fft_fwd_z, m_fft_bwd_z) = make_c2c_plans(m_cz);
}
#endif
}
else // do fft in all dimensions at the same time
{
Box const& box = m_rx.box(myproc);
auto* pr = m_rx[myproc].dataPtr();
auto* pc = (typename Plan<T>::VendorComplex *)m_cx[myproc].dataPtr();
m_data_1 = detail::make_mfs_share(m_rx, m_rx);
m_data_2 = detail::make_mfs_share(m_cx, m_cx);

auto const& len = m_real_domain.length();
auto* pr = (void*)m_rx[0].dataPtr();
auto* pc = (void*)m_cx[0].dataPtr();
#ifdef AMREX_USE_SYCL
m_fft_fwd_x.template init_r2c<Direction::forward>(box, pr, pc, m_slab_decomp);
m_fft_fwd_x.template init_r2c<Direction::forward>(len, pr, pc, false);
m_fft_bwd_x = m_fft_fwd_x;
#else
if constexpr (D == Direction::both || D == Direction::forward) {
m_fft_fwd_x.template init_r2c<Direction::forward>(box, pr, pc, m_slab_decomp);
m_fft_fwd_x.template init_r2c<Direction::forward>(len, pr, pc, false);
}
if constexpr (D == Direction::both || D == Direction::backward) {
m_fft_bwd_x.template init_r2c<Direction::backward>(box, pr, pc, m_slab_decomp);
m_fft_bwd_x.template init_r2c<Direction::backward>(len, pr, pc, false);
}
#endif
}

#if (AMREX_SPACEDIM >= 2)
if (! m_cy.empty()) {
std::tie(m_fft_fwd_y, m_fft_bwd_y) = make_c2c_plans(m_cy);
}
#endif
#if (AMREX_SPACEDIM == 3)
if (! m_cz.empty()) {
std::tie(m_fft_fwd_z, m_fft_bwd_z) = make_c2c_plans(m_cz);
}
#endif
}

template <typename T, Direction D, DomainStrategy S>
Expand Down Expand Up @@ -408,6 +434,8 @@ template <typename T, Direction D, DomainStrategy S>
void R2C<T,D,S>::prepare_openbc ()
{
#if (AMREX_SPACEDIM == 3)
if (m_do_alld_fft) { return; }

if (m_slab_decomp) {
auto* fab = detail::get_fab(m_rx);
if (fab) {
Expand Down Expand Up @@ -456,14 +484,20 @@ void R2C<T,D,S>::prepare_openbc ()

template <typename T, Direction D, DomainStrategy S>
template <Direction DIR, std::enable_if_t<DIR == Direction::forward ||
DIR == Direction::both, int> >
DIR == Direction::both, int> FOO>
void R2C<T,D,S>::forward (MF const& inmf)
{
BL_PROFILE("FFT::R2C::forward(in)");

if (&m_rx != &inmf) {
m_rx.ParallelCopy(inmf, 0, 0, 1);
}

if (m_do_alld_fft) {
m_fft_fwd_x.template compute_r2c<Direction::forward>();
return;
}

auto& fft_x = m_openbc_half ? m_fft_fwd_x_half : m_fft_fwd_x;
fft_x.template compute_r2c<Direction::forward>();

Expand Down Expand Up @@ -509,6 +543,22 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
{
BL_PROFILE("FFT::R2C::backward(out)");

if (m_do_alld_fft) {
m_fft_bwd_x.template compute_r2c<Direction::backward>();
// That we get here means there is only a single process.
auto const& src = m_rx[0].const_array();
auto const& dst = outmf.arrays();
ParallelFor(outmf, IntVect(0),
[=] AMREX_GPU_DEVICE (int b, int i, int j, int k)
{
if (src.contains(i,j,k)) {
dst[b](i,j,k) = src(i,j,k);
}
});
Gpu::streamSynchronize();
return;
}

m_fft_bwd_z.template compute_c2c<Direction::backward>();
if ( m_cmd_z2y) {
ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y);
Expand Down

0 comments on commit 8ef3bff

Please sign in to comment.