diff --git a/Src/FFT/AMReX_FFT_Helper.H b/Src/FFT/AMReX_FFT_Helper.H index 6c7f1bb6c1e..af45b00c9ab 100644 --- a/Src/FFT/AMReX_FFT_Helper.H +++ b/Src/FFT/AMReX_FFT_Helper.H @@ -188,9 +188,10 @@ struct Plan int len[2] = {}; if (rank == 1) { len[0] = box.length(0); + len[1] = box.length(0); // Not used except for HIP. Yes it's `(0)`. } else { - len[0] = box.length(1); - len[1] = box.length(0); + len[0] = box.length(1); // Most FFT libraries assume row-major ordering + len[1] = box.length(0); // except for rocfft } int nr = (rank == 1) ? len[0] : len[0]*len[1]; n = nr; @@ -221,7 +222,8 @@ struct Plan #elif defined(AMREX_USE_HIP) auto prec = std::is_same_v ? rocfft_precision_single : rocfft_precision_double; - const std::size_t length[2] = {std::size_t(len[0]), std::size_t(len[1])}; + // switch to column-major ordering + std::size_t length[2] = {std::size_t(len[1]), std::size_t(len[0])}; if constexpr (D == Direction::forward) { AMREX_ROCFFT_SAFE_CALL (rocfft_plan_create(&plan, rocfft_placement_notinplace,