diff --git a/Src/FFT/AMReX_FFT_Helper.H b/Src/FFT/AMReX_FFT_Helper.H index be3c76ea4c5..4ffec6add60 100644 --- a/Src/FFT/AMReX_FFT_Helper.H +++ b/Src/FFT/AMReX_FFT_Helper.H @@ -43,6 +43,8 @@ namespace amrex::FFT enum struct Direction { forward, backward, both, none }; +enum struct DomainStrategy { slab, pencil }; + AMREX_ENUM( Boundary, periodic, even, odd ); enum struct Kind { none, r2c_f, r2c_b, c2c_f, c2c_b, r2r_ee_f, r2r_ee_b, @@ -172,18 +174,29 @@ struct Plan } template - void init_r2c (Box const& box, T* pr, VendorComplex* pc) + void init_r2c (Box const& box, T* pr, VendorComplex* pc, bool is_2d_transform = false) { static_assert(D == Direction::forward || D == Direction::backward); + int rank = is_2d_transform ? 2 : 1; + kind = (D == Direction::forward) ? Kind::r2c_f : Kind::r2c_b; defined = true; pf = (void*)pr; pb = (void*)pc; - n = box.length(0); - int nc = (n/2) + 1; - howmany = AMREX_D_TERM(1, *box.length(1), *box.length(2)); + int len[2] = {}; + if (rank == 1) { + len[0] = box.length(0); + } else { + len[0] = box.length(1); + len[1] = box.length(0); + } + int nr = (rank == 1) ? len[0] : len[0]*len[1]; + n = nr; + int nc = (rank == 1) ? (len[0]/2+1) : (len[1]/2+1)*len[0]; + howmany = (rank == 1) ? AMREX_D_TERM(1, *box.length(1), *box.length(2)) + : AMREX_D_TERM(1, *1 , *box.length(2)); amrex::ignore_unused(nc); @@ -193,33 +206,38 @@ struct Plan if constexpr (D == Direction::forward) { cufftType fwd_type = std::is_same_v ? CUFFT_R2C : CUFFT_D2Z; AMREX_CUFFT_SAFE_CALL - (cufftMakePlanMany(plan, 1, &n, nullptr, 1, n, nullptr, 1, nc, fwd_type, howmany, &work_size)); + (cufftMakePlanMany(plan, rank, len, nullptr, 1, nr, nullptr, 1, nc, fwd_type, howmany, &work_size)); AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); } else { cufftType bwd_type = std::is_same_v ? CUFFT_C2R : CUFFT_Z2D; AMREX_CUFFT_SAFE_CALL - (cufftMakePlanMany(plan, 1, &n, nullptr, 1, nc, nullptr, 1, n, bwd_type, howmany, &work_size)); + (cufftMakePlanMany(plan, rank, len, nullptr, 1, nc, nullptr, 1, nr, bwd_type, howmany, &work_size)); AMREX_CUFFT_SAFE_CALL(cufftSetStream(plan, Gpu::gpuStream())); } #elif defined(AMREX_USE_HIP) auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; - const std::size_t length = n; + const std::size_t length[2] = {std::size_t(len[0]), std::size_t(len[1])}; if constexpr (D == Direction::forward) { AMREX_ROCFFT_SAFE_CALL (rocfft_plan_create(&plan, rocfft_placement_notinplace, - rocfft_transform_type_real_forward, prec, 1, - &length, howmany, nullptr)); + rocfft_transform_type_real_forward, prec, rank, + length, howmany, nullptr)); } else { AMREX_ROCFFT_SAFE_CALL (rocfft_plan_create(&plan, rocfft_placement_notinplace, - rocfft_transform_type_real_inverse, prec, 1, - &length, howmany, nullptr)); + rocfft_transform_type_real_inverse, prec, rank, + length, howmany, nullptr)); } #elif defined(AMREX_USE_SYCL) - auto* pp = new mkl_desc_r(n); + mkl_desc_c* pp; + if (rank == 1) { + pp = new mkl_desc_r(len[0]); + } else { + pp = new mkl_desc_r({std::int64_t(len[0]), std::int64_t(len[1])}); + } #ifndef AMREX_USE_MKL_DFTI_2024 pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); @@ -227,9 +245,12 @@ struct Plan pp->set_value(oneapi::mkl::dft::config_param::PLACEMENT, DFTI_NOT_INPLACE); #endif pp->set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, howmany); - pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, n); + pp->set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, nr); pp->set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, nc); - std::vector strides = {0,1}; + std::vector strides; + strides.push_back(0); + if (rank == 2) { strides.push_back(len[1]); } + rank.push_back(1); #ifndef AMREX_USE_MKL_DFTI_2024 pp->set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides); pp->set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides); @@ -247,21 +268,21 @@ struct Plan if constexpr (std::is_same_v) { if constexpr (D == Direction::forward) { plan = fftwf_plan_many_dft_r2c - (1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc, + (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } else { plan = fftwf_plan_many_dft_c2r - (1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n, + (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } } else { if constexpr (D == Direction::forward) { plan = fftw_plan_many_dft_r2c - (1, &n, howmany, pr, nullptr, 1, n, pc, nullptr, 1, nc, + (rank, len, howmany, pr, nullptr, 1, nr, pc, nullptr, 1, nc, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } else { plan = fftw_plan_many_dft_c2r - (1, &n, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, n, + (rank, len, howmany, pc, nullptr, 1, nc, pr, nullptr, 1, nr, FFTW_ESTIMATE | FFTW_DESTROY_INPUT); } } @@ -1087,13 +1108,17 @@ namespace detail template std::unique_ptr make_mfs_share (FA1& fa1, FA2& fa2) { + bool not_same_fa = true; + if constexpr (std::is_same_v) { + not_same_fa = (&fa1 != &fa2); + } using FAB1 = typename FA1::FABType::value_type; using FAB2 = typename FA2::FABType::value_type; using T1 = typename FAB1::value_type; using T2 = typename FAB2::value_type; auto myproc = ParallelContext::MyProcSub(); bool alloc_1 = (myproc < fa1.size()); - bool alloc_2 = (myproc < fa2.size()); + bool alloc_2 = (myproc < fa2.size()) && not_same_fa; void* p = nullptr; if (alloc_1 && alloc_2) { Box const& box1 = fa1.fabbox(myproc); diff --git a/Src/FFT/AMReX_FFT_Poisson.H b/Src/FFT/AMReX_FFT_Poisson.H index c50f26f66eb..bc521535cb8 100644 --- a/Src/FFT/AMReX_FFT_Poisson.H +++ b/Src/FFT/AMReX_FFT_Poisson.H @@ -19,18 +19,33 @@ public: template ,int> = 0> Poisson (Geometry const& geom, Array,AMREX_SPACEDIM> const& bc) - : m_geom(geom), m_bc(bc), m_r2x(geom.Domain(),bc) - {} + : m_geom(geom), m_bc(bc) + { + bool all_periodic = true; + for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) { + all_periodic = all_periodic + && (bc[idim].first == Boundary::periodic) + && (bc[idim].second == Boundary::periodic); + } + if (all_periodic) { + m_r2c = std::make_unique>(m_geom.Domain()); + } else { + m_r2x = std::make_unique> (m_geom.Domain(), m_bc); + } + } template ,int> = 0> explicit Poisson (Geometry const& geom) : m_geom(geom), m_bc{AMREX_D_DECL(std::make_pair(Boundary::periodic,Boundary::periodic), std::make_pair(Boundary::periodic,Boundary::periodic), - std::make_pair(Boundary::periodic,Boundary::periodic))}, - m_r2x(geom.Domain(),m_bc) + std::make_pair(Boundary::periodic,Boundary::periodic))} { - AMREX_ALWAYS_ASSERT(m_geom.isAllPeriodic()); + if (m_geom.isAllPeriodic()) { + m_r2c = std::make_unique>(m_geom.Domain()); + } else { + amrex::Abort("FFT::Poisson: wrong BC"); + } } void solve (MF& soln, MF const& rhs); @@ -38,7 +53,8 @@ public: private: Geometry m_geom; Array,AMREX_SPACEDIM> m_bc; - R2X m_r2x; + std::unique_ptr> m_r2x; + std::unique_ptr> m_r2c; }; #if (AMREX_SPACEDIM == 3) @@ -114,7 +130,7 @@ void Poisson::solve (MF& soln, MF const& rhs) {AMREX_D_DECL(T(2)/T(m_geom.CellSize(0)*m_geom.CellSize(0)), T(2)/T(m_geom.CellSize(1)*m_geom.CellSize(1)), T(2)/T(m_geom.CellSize(2)*m_geom.CellSize(2)))}; - auto scale = m_r2x.scalingFactor(); + auto scale = (m_r2x) ? m_r2x->scalingFactor() : T(1)/T(m_geom.Domain().numPts()); GpuArray offset{AMREX_D_DECL(T(0),T(0),T(0))}; // Not sure about odd-even and even-odd yet @@ -133,8 +149,7 @@ void Poisson::solve (MF& soln, MF const& rhs) } } - m_r2x.forwardThenBackward(rhs, soln, - [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& spectral_data) + auto f = [=] AMREX_GPU_DEVICE (int i, int j, int k, auto& spectral_data) { amrex::ignore_unused(j,k); AMREX_D_TERM(T a = fac[0]*(i+offset[0]);, @@ -147,7 +162,13 @@ void Poisson::solve (MF& soln, MF const& rhs) spectral_data /= k2; } spectral_data *= scale; - }); + }; + + if (m_r2x) { + m_r2x->forwardThenBackward(rhs, soln, f); + } else { + m_r2c->forwardThenBackward(rhs, soln, f); + } } #if (AMREX_SPACEDIM == 3) diff --git a/Src/FFT/AMReX_FFT_R2C.H b/Src/FFT/AMReX_FFT_R2C.H index 5969c6a789b..cdc9605a264 100644 --- a/Src/FFT/AMReX_FFT_R2C.H +++ b/Src/FFT/AMReX_FFT_R2C.H @@ -26,7 +26,8 @@ template class OpenBCSolver; * For more details, we refer the users to * https://amrex-codes.github.io/amrex/docs_html/FFT_Chapter.html. */ -template +template class R2C { public: @@ -176,18 +177,22 @@ private: std::unique_ptr m_cmd_y2x; // (y,x,z) -> (x,y,z) std::unique_ptr m_cmd_y2z; // (y,x,z) -> (z,x,y) std::unique_ptr m_cmd_z2y; // (z,x,y) -> (y,x,z) + std::unique_ptr m_cmd_x2z; // (x,y,z) -> (z,x,y) + std::unique_ptr m_cmd_z2x; // (z,x,y) -> (x,y,z) Swap01 m_dtos_x2y{}; Swap01 m_dtos_y2x{}; Swap02 m_dtos_y2z{}; Swap02 m_dtos_z2y{}; + RotateFwd m_dtos_x2z{}; + RotateBwd m_dtos_z2x{}; MF m_rx; cMF m_cx; cMF m_cy; cMF m_cz; - std::unique_ptr m_data_rx_cy; - std::unique_ptr m_data_cx_cz; + std::unique_ptr m_data_1; + std::unique_ptr m_data_2; Box m_real_domain; Box m_spectral_domain_x; @@ -195,10 +200,12 @@ private: Box m_spectral_domain_z; Info m_info; + + bool m_slab_decomp = false; }; -template -R2C::R2C (Box const& domain, Info const& info) +template +R2C::R2C (Box const& domain, Info const& info) : m_real_domain(domain), m_spectral_domain_x(IntVect(0), IntVect(AMREX_D_DECL(domain.length(0)/2, domain.length(1)-1, @@ -232,12 +239,18 @@ R2C::R2C (Box const& domain, Info const& info) int myproc = ParallelContext::MyProcSub(); int nprocs = ParallelContext::NProcsSub(); +#if (AMREX_SPACEDIM == 3) + if (S == DomainStrategy::slab && (m_real_domain.length(1) > 1)) { + m_slab_decomp = true; + } +#endif + // // make data containers // auto bax = amrex::decompose(m_real_domain, nprocs, - {AMREX_D_DECL(false,true,true)}, true); + {AMREX_D_DECL(false,!m_slab_decomp,true)}, true); DistributionMapping dmx = detail::make_iota_distromap(bax.size()); m_rx.define(bax, dmx, 1, 0, MFInfo().SetAlloc(false)); @@ -253,7 +266,7 @@ R2C::R2C (Box const& domain, Info const& info) #if (AMREX_SPACEDIM >= 2) DistributionMapping cdmy; - if (m_real_domain.length(1) > 1) { + if ((m_real_domain.length(1) > 1) && !m_slab_decomp) { auto cbay = amrex::decompose(m_spectral_domain_y, nprocs, {AMREX_D_DECL(false,true,true)}, true); if (cbay.size() == dmx.size()) { @@ -283,8 +296,13 @@ R2C::R2C (Box const& domain, Info const& info) } #endif - m_data_rx_cy = detail::make_mfs_share(m_rx, m_cy); - m_data_cx_cz = detail::make_mfs_share(m_cx, m_cz); + if (m_slab_decomp) { + m_data_1 = detail::make_mfs_share(m_rx, m_cz); + m_data_2 = detail::make_mfs_share(m_cx, m_cx); + } else { + m_data_1 = detail::make_mfs_share(m_rx, m_cy); + m_data_2 = detail::make_mfs_share(m_cx, m_cz); + } // // make copiers @@ -300,12 +318,20 @@ R2C::R2C (Box const& domain, Info const& info) } #endif #if (AMREX_SPACEDIM == 3) - if (! m_cz.empty()) { - // comm meta-data between y and z phases - m_cmd_y2z = std::make_unique - (m_cz, m_spectral_domain_z, m_cy, IntVect(0), m_dtos_y2z); - m_cmd_z2y = std::make_unique - (m_cy, m_spectral_domain_y, m_cz, IntVect(0), m_dtos_z2y); + if (! m_cz.empty() ) { + if (m_slab_decomp) { + // comm meta-data between xy and z phases + m_cmd_x2z = std::make_unique + (m_cz, m_spectral_domain_z, m_cx, IntVect(0), m_dtos_x2z); + m_cmd_z2x = std::make_unique + (m_cx, m_spectral_domain_x, m_cz, IntVect(0), m_dtos_z2x); + } else { + // comm meta-data between y and z phases + m_cmd_y2z = std::make_unique + (m_cz, m_spectral_domain_z, m_cy, IntVect(0), m_dtos_y2z); + m_cmd_z2y = std::make_unique + (m_cy, m_spectral_domain_y, m_cz, IntVect(0), m_dtos_z2y); + } } #endif @@ -319,14 +345,14 @@ R2C::R2C (Box const& domain, Info const& info) auto* pr = m_rx[myproc].dataPtr(); auto* pc = (typename Plan::VendorComplex *)m_cx[myproc].dataPtr(); #ifdef AMREX_USE_SYCL - m_fft_fwd_x.template init_r2c(box, pr, pc); + m_fft_fwd_x.template init_r2c(box, pr, pc, m_slab_decomp); m_fft_bwd_x = m_fft_fwd_x; #else if constexpr (D == Direction::both || D == Direction::forward) { - m_fft_fwd_x.template init_r2c(box, pr, pc); + m_fft_fwd_x.template init_r2c(box, pr, pc, m_slab_decomp); } if constexpr (D == Direction::both || D == Direction::backward) { - m_fft_bwd_x.template init_r2c(box, pr, pc); + m_fft_bwd_x.template init_r2c(box, pr, pc, m_slab_decomp); } #endif } @@ -343,8 +369,8 @@ R2C::R2C (Box const& domain, Info const& info) #endif } -template -R2C::~R2C () +template +R2C::~R2C () { if (m_fft_bwd_x.plan != m_fft_fwd_x.plan) { m_fft_bwd_x.destroy(); @@ -360,10 +386,10 @@ R2C::~R2C () m_fft_fwd_z.destroy(); } -template +template template > -void R2C::forward (MF const& inmf) +void R2C::forward (MF const& inmf) { BL_PROFILE("FFT::R2C::forward(in)"); @@ -380,18 +406,23 @@ void R2C::forward (MF const& inmf) if ( m_cmd_y2z) { ParallelCopy(m_cz, m_cy, *m_cmd_y2z, 0, 0, 1, m_dtos_y2z); } +#if (AMREX_SPACEDIM == 3) + else if ( m_cmd_x2z) { + ParallelCopy(m_cz, m_cx, *m_cmd_x2z, 0, 0, 1, m_dtos_x2z); + } +#endif m_fft_fwd_z.template compute_c2c(); } -template +template template > -void R2C::backward (MF& outmf) +void R2C::backward (MF& outmf) { backward_doit(outmf); } -template -void R2C::backward_doit (MF& outmf, IntVect const& ngout) +template +void R2C::backward_doit (MF& outmf, IntVect const& ngout) { BL_PROFILE("FFT::R2C::backward(out)"); @@ -399,6 +430,11 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout) if ( m_cmd_z2y) { ParallelCopy(m_cy, m_cz, *m_cmd_z2y, 0, 0, 1, m_dtos_z2y); } +#if (AMREX_SPACEDIM == 3) + else if ( m_cmd_z2x) { + ParallelCopy(m_cx, m_cz, *m_cmd_z2x, 0, 0, 1, m_dtos_z2x); + } +#endif m_fft_bwd_y.template compute_c2c(); if ( m_cmd_y2x) { @@ -409,9 +445,9 @@ void R2C::backward_doit (MF& outmf, IntVect const& ngout) outmf.ParallelCopy(m_rx, 0, 0, 1, IntVect(0), ngout); } -template +template std::pair, Plan> -R2C::make_c2c_plans (cMF& inout) +R2C::make_c2c_plans (cMF& inout) { Plan fwd; Plan bwd; @@ -437,9 +473,9 @@ R2C::make_c2c_plans (cMF& inout) return {fwd, bwd}; } -template +template template -void R2C::post_forward_doit (F const& post_forward) +void R2C::post_forward_doit (F const& post_forward) { if (m_info.batch_mode) { amrex::Abort("xxxxx todo: post_forward"); @@ -478,11 +514,11 @@ void R2C::post_forward_doit (F const& post_forward) } } -template +template template > -std::pair::cMF *, IntVect> -R2C::getSpectralData () +std::pair::cMF *, IntVect> +R2C::getSpectralData () { if (!m_cz.empty()) { return std::make_pair(&m_cz, IntVect{AMREX_D_DECL(2,0,1)}); @@ -493,10 +529,10 @@ R2C::getSpectralData () } } -template +template template > -void R2C::forward (MF const& inmf, cMF& outmf) +void R2C::forward (MF const& inmf, cMF& outmf) { BL_PROFILE("FFT::R2C::forward(inout)"); @@ -515,10 +551,10 @@ void R2C::forward (MF const& inmf, cMF& outmf) } } -template +template template > -void R2C::backward (cMF const& inmf, MF& outmf) +void R2C::backward (cMF const& inmf, MF& outmf) { BL_PROFILE("FFT::R2C::backward(inout)"); @@ -537,9 +573,9 @@ void R2C::backward (cMF const& inmf, MF& outmf) backward_doit(outmf); } -template +template std::pair -R2C::getSpectralDataLayout () const +R2C::getSpectralDataLayout () const { #if (AMREX_SPACEDIM == 3) if (!m_cz.empty()) { diff --git a/Tests/FFT/R2C/main.cpp b/Tests/FFT/R2C/main.cpp index caa457a7f56..594a9ec760d 100644 --- a/Tests/FFT/R2C/main.cpp +++ b/Tests/FFT/R2C/main.cpp @@ -74,7 +74,7 @@ int main (int argc, char* argv[]) // forward { - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c(geom.Domain()); auto const& [cba, cdm] = r2c.getSpectralDataLayout(); cmf.define(cba, cdm, 1, 0); r2c.forward(mf,cmf); @@ -82,7 +82,7 @@ int main (int argc, char* argv[]) // backward { - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c(geom.Domain()); r2c.backward(cmf,mf2); } @@ -105,7 +105,7 @@ int main (int argc, char* argv[]) mf2.setVal(std::numeric_limits::max()); { // forward and backward - FFT::R2C r2c(geom.Domain()); + FFT::R2C r2c(geom.Domain()); r2c.forwardThenBackward(mf, mf2, [=] AMREX_GPU_DEVICE (int, int, int, auto& sp) {