Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT OpenBC Solver: Communication optimization #4230

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Src/Base/AMReX_NonLocalBC.H
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ unpack_recv_buffer_gpu (FabArray<FAB>& mf, int scomp, int ncomp,
struct PackComponents {
int dest_component{0};
int src_component{0};
int n_components{0};
int n_components{1};
};

//! \brief Dispatch local copies to the default behaviour that knows no DTOS nor projection.
Expand Down
4 changes: 4 additions & 0 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
inmf.setVal(T(0));
inmf.ParallelCopy(rho, 0, 0, 1);

m_r2c.m_openbc_half = true;
m_r2c.forward(inmf);
m_r2c.m_openbc_half = false;

auto scaling_factor = m_r2c.scalingFactor();

Expand Down Expand Up @@ -199,7 +201,9 @@ void OpenBCSolver<T>::solve (MF& phi, MF const& rho)
}
}

m_r2c.m_openbc_half = true;
m_r2c.backward_doit(phi, phi.nGrowVect());
m_r2c.m_openbc_half = false;
}

}
Expand Down
39 changes: 37 additions & 2 deletions Src/FFT/AMReX_FFT_R2C.H
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ private:
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2y; // (z,x,y) -> (y,x,z)
std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2z; // (x,y,z) -> (z,x,y)
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2x; // (z,x,y) -> (x,y,z)
std::unique_ptr<MultiBlockCommMetaData> m_cmd_x2z_half; // for openbc
std::unique_ptr<MultiBlockCommMetaData> m_cmd_z2x_half; // for openbc
Swap01 m_dtos_x2y{};
Swap01 m_dtos_y2x{};
Swap02 m_dtos_y2z{};
Expand All @@ -209,6 +211,7 @@ private:
Info m_info;

bool m_slab_decomp = false;
bool m_openbc_half = false;
};

template <typename T, Direction D, DomainStrategy S>
Expand Down Expand Up @@ -415,7 +418,24 @@ void R2C<T,D,S>::forward (MF const& inmf)
}
#if (AMREX_SPACEDIM == 3)
else if ( m_cmd_x2z) {
ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, 1, m_dtos_x2z);
if (m_openbc_half) {
Box upper_half = m_spectral_domain_z;
// Note that z-direction's index is 0 because we z is the unit-stride direction here.
upper_half.growLo (0,-m_spectral_domain_z.length(0)/2);
if (! m_cmd_x2z_half) {
Box bottom_half = m_spectral_domain_z;
bottom_half.growHi(0,-m_spectral_domain_z.length(0)/2);
m_cmd_x2z_half = std::make_unique<MultiBlockCommMetaData>
(m_cz, bottom_half, m_cx, IntVect(0), m_dtos_x2z);
}
NonLocalBC::ApplyDtosAndProjectionOnReciever packing
{NonLocalBC::PackComponents{}, m_dtos_x2z};
auto handler = ParallelCopy_nowait(m_cz, m_cx, *m_cmd_x2z_half, packing);
m_cz.setVal(0, upper_half, 0, 1);
ParallelCopy_finish(m_cz, std::move(handler), *m_cmd_x2z_half, packing);
} else {
ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, 1, m_dtos_x2z);
}
}
#endif
m_fft_fwd_z.template compute_c2c<Direction::forward>();
Expand All @@ -439,7 +459,22 @@ void R2C<T,D,S>::backward_doit (MF& outmf, IntVect const& ngout)
}
#if (AMREX_SPACEDIM == 3)
else if ( m_cmd_z2x) {
ParallelCopy(m_cx, m_cz, *m_cmd_z2x, 0, 0, 1, m_dtos_z2x);
if (m_openbc_half) {
Box upper_half = m_spectral_domain_x;
upper_half.growLo (2,-m_spectral_domain_x.length(2)/2);
if (! m_cmd_z2x_half) {
Box bottom_half = m_spectral_domain_x;
bottom_half.growHi(2,-m_spectral_domain_x.length(2)/2);
m_cmd_z2x_half = std::make_unique<MultiBlockCommMetaData>
(m_cx, bottom_half, m_cz, IntVect(0), m_dtos_z2x);
}
NonLocalBC::ApplyDtosAndProjectionOnReciever packing
{NonLocalBC::PackComponents{}, m_dtos_z2x};
auto handler = ParallelCopy_nowait(m_cx, m_cz, *m_cmd_z2x_half, packing);
ParallelCopy_finish(m_cx, std::move(handler), *m_cmd_z2x_half, packing);
} else {
ParallelCopy(m_cx, m_cz, *m_cmd_z2x, 0, 0, 1, m_dtos_z2x);
}
}
#endif

Expand Down
Loading