diff --git a/src/dft/backends/cufft/descriptor.cpp b/src/dft/backends/cufft/descriptor.cpp index d102164c2..bf26b600f 100644 --- a/src/dft/backends/cufft/descriptor.cpp +++ b/src/dft/backends/cufft/descriptor.cpp @@ -22,9 +22,7 @@ #include "oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp" -namespace oneapi { -namespace mkl { -namespace dft { +namespace oneapi::mkl::dft::detail { template void descriptor::commit(backend_selector selector) { @@ -44,6 +42,4 @@ template void descriptor::commit( backend_selector); template void descriptor::commit(backend_selector); -} //namespace dft -} //namespace mkl -} //namespace oneapi +} //namespace oneapi::mkl::dft::detail diff --git a/src/dft/backends/portfft/descriptor.cpp b/src/dft/backends/portfft/descriptor.cpp index d72d23bb5..c45b9f2c5 100644 --- a/src/dft/backends/portfft/descriptor.cpp +++ b/src/dft/backends/portfft/descriptor.cpp @@ -22,7 +22,7 @@ #include "oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp" -namespace oneapi::mkl::dft { +namespace oneapi::mkl::dft::detail { template void descriptor::commit(backend_selector selector) { @@ -44,4 +44,4 @@ template void descriptor::commit( template void descriptor::commit( backend_selector); -} // namespace oneapi::mkl::dft +} // namespace oneapi::mkl::dft::detail diff --git a/src/dft/backends/rocfft/commit.cpp b/src/dft/backends/rocfft/commit.cpp index 74f480720..991b0c471 100644 --- a/src/dft/backends/rocfft/commit.cpp +++ b/src/dft/backends/rocfft/commit.cpp @@ -39,6 +39,7 @@ #include "rocfft_handle.hpp" #include +#include #include namespace oneapi::mkl::dft::rocfft { @@ -259,12 +260,32 @@ class rocfft_commit final : public dft::detail::commit_impl { std::reverse(stride_vecs.vec_b.begin(), stride_vecs.vec_b.end()); stride_vecs.vec_b.pop_back(); // Offset is not included. - rocfft_plan_description plan_desc; - if (rocfft_plan_description_create(&plan_desc) != rocfft_status_success) { + // This workaround is needed due to a confirmed issue in rocFFT from version + // 1.0.23 to 1.0.30. Those rocFFT version correspond to rocm version from + // 5.6.0 to 6.3.0. + // Link to rocFFT issue: https://github.com/ROCm/rocFFT/issues/507 + if constexpr (rocfft_version_major == 1 && rocfft_version_minor == 0 && + (rocfft_version_patch > 22 && rocfft_version_patch < 31)) { + // rocFFT's functional status for problems like cfoA:B:1xB:1:A is unknown as + // of 4ed3e97bb7c11531684168665d5a980fde0284c9 (due to project's implementation preventing testing thereof) + if (dom == dft::domain::COMPLEX && + config_values.placement == dft::config_value::NOT_INPLACE && dimensions > 2) { + if (stride_vecs.vec_a != stride_vecs.vec_b) + throw oneapi::mkl::unimplemented( + "DFT", func, + "due to a bug in rocfft version in use, it requires fwd and bwd stride to be the same for COMPLEX out_of_place computations"); + } + } + + rocfft_plan_description plan_desc_fwd, plan_desc_bwd; // Can't reuse with ROCm 6 due to bug. + if (rocfft_plan_description_create(&plan_desc_fwd) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create plan description."); + } + if (rocfft_plan_description_create(&plan_desc_bwd) != rocfft_status_success) { throw mkl::exception("dft/backends/rocfft", __FUNCTION__, "Failed to create plan description."); } - // plan_description can be destroyed afted plan_create auto description_destroy = [](rocfft_plan_description p) { if (rocfft_plan_description_destroy(p) != rocfft_status_success) { @@ -273,7 +294,9 @@ class rocfft_commit final : public dft::detail::commit_impl { } }; std::unique_ptr - description_destroyer(plan_desc, description_destroy); + description_destroyer_fwd(plan_desc_fwd, description_destroy); + std::unique_ptr + description_destroyer_bwd(plan_desc_bwd, description_destroy); std::array stride_a_indices{ 0, 1, 2 }; std::sort(&stride_a_indices[0], &stride_a_indices[dimensions], @@ -324,7 +347,7 @@ class rocfft_commit final : public dft::detail::commit_impl { if (valid_forward) { auto res = - rocfft_plan_description_set_data_layout(plan_desc, fwd_array_ty, bwd_array_ty, + rocfft_plan_description_set_data_layout(plan_desc_fwd, fwd_array_ty, bwd_array_ty, nullptr, // in offsets nullptr, // out offsets dimensions, @@ -339,7 +362,7 @@ class rocfft_commit final : public dft::detail::commit_impl { "Failed to set forward data layout."); } - if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.fwd_scale) != + if (rocfft_plan_description_set_scale_factor(plan_desc_fwd, config_values.fwd_scale) != rocfft_status_success) { throw mkl::exception("dft/backends/rocfft", __FUNCTION__, "Failed to set forward scale factor."); @@ -347,7 +370,7 @@ class rocfft_commit final : public dft::detail::commit_impl { rocfft_plan fwd_plan; res = rocfft_plan_create(&fwd_plan, placement, fwd_type, precision, dimensions, - lengths.data(), number_of_transforms, plan_desc); + lengths.data(), number_of_transforms, plan_desc_fwd); if (res != rocfft_status_success) { throw mkl::exception("dft/backends/rocfft", __FUNCTION__, @@ -380,7 +403,7 @@ class rocfft_commit final : public dft::detail::commit_impl { if (valid_backward) { auto res = - rocfft_plan_description_set_data_layout(plan_desc, bwd_array_ty, fwd_array_ty, + rocfft_plan_description_set_data_layout(plan_desc_bwd, bwd_array_ty, fwd_array_ty, nullptr, // in offsets nullptr, // out offsets dimensions, @@ -395,7 +418,7 @@ class rocfft_commit final : public dft::detail::commit_impl { "Failed to set backward data layout."); } - if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.bwd_scale) != + if (rocfft_plan_description_set_scale_factor(plan_desc_bwd, config_values.bwd_scale) != rocfft_status_success) { throw mkl::exception("dft/backends/rocfft", __FUNCTION__, "Failed to set backward scale factor."); @@ -403,7 +426,7 @@ class rocfft_commit final : public dft::detail::commit_impl { rocfft_plan bwd_plan; res = rocfft_plan_create(&bwd_plan, placement, bwd_type, precision, dimensions, - lengths.data(), number_of_transforms, plan_desc); + lengths.data(), number_of_transforms, plan_desc_bwd); if (res != rocfft_status_success) { throw mkl::exception("dft/backends/rocfft", __FUNCTION__, "Failed to create backward rocFFT plan."); diff --git a/src/dft/backends/rocfft/descriptor.cpp b/src/dft/backends/rocfft/descriptor.cpp index 83fdbe1dc..22f21590a 100644 --- a/src/dft/backends/rocfft/descriptor.cpp +++ b/src/dft/backends/rocfft/descriptor.cpp @@ -22,9 +22,7 @@ #include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" -namespace oneapi { -namespace mkl { -namespace dft { +namespace oneapi::mkl::dft::detail { template void descriptor::commit(backend_selector selector) { @@ -46,6 +44,4 @@ template void descriptor::commit( template void descriptor::commit( backend_selector); -} //namespace dft -} //namespace mkl -} //namespace oneapi +} //namespace oneapi::mkl::dft::detail