Skip to content

Commit

Permalink
[DFT][rocFFt] Address rocFFT failing tests (#563)
Browse files Browse the repository at this point in the history
Co-authored-by: Hugh Bird <[email protected]>
  • Loading branch information
s-Nick and hjabird authored Sep 30, 2024
1 parent 4ed3e97 commit 00a9af7
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 24 deletions.
8 changes: 2 additions & 6 deletions src/dft/backends/cufft/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::cufft> selector) {
Expand All @@ -44,6 +42,4 @@ template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(
backend_selector<backend::cufft>);
template void descriptor<precision::DOUBLE, domain::REAL>::commit(backend_selector<backend::cufft>);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail
4 changes: 2 additions & 2 deletions src/dft/backends/portfft/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

#include "oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp"

namespace oneapi::mkl::dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::portfft> selector) {
Expand All @@ -44,4 +44,4 @@ template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(
template void descriptor<precision::DOUBLE, domain::REAL>::commit(
backend_selector<backend::portfft>);

} // namespace oneapi::mkl::dft
} // namespace oneapi::mkl::dft::detail
43 changes: 33 additions & 10 deletions src/dft/backends/rocfft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "rocfft_handle.hpp"

#include <rocfft.h>
#include <rocfft-version.h>
#include <hip/hip_runtime_api.h>

namespace oneapi::mkl::dft::rocfft {
Expand Down Expand Up @@ -259,12 +260,32 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
std::reverse(stride_vecs.vec_b.begin(), stride_vecs.vec_b.end());
stride_vecs.vec_b.pop_back(); // Offset is not included.

rocfft_plan_description plan_desc;
if (rocfft_plan_description_create(&plan_desc) != rocfft_status_success) {
// This workaround is needed due to a confirmed issue in rocFFT from version
// 1.0.23 to 1.0.30. Those rocFFT version correspond to rocm version from
// 5.6.0 to 6.3.0.
// Link to rocFFT issue: https://github.com/ROCm/rocFFT/issues/507
if constexpr (rocfft_version_major == 1 && rocfft_version_minor == 0 &&
(rocfft_version_patch > 22 && rocfft_version_patch < 31)) {
// rocFFT's functional status for problems like cfoA:B:1xB:1:A is unknown as
// of 4ed3e97bb7c11531684168665d5a980fde0284c9 (due to project's implementation preventing testing thereof)
if (dom == dft::domain::COMPLEX &&
config_values.placement == dft::config_value::NOT_INPLACE && dimensions > 2) {
if (stride_vecs.vec_a != stride_vecs.vec_b)
throw oneapi::mkl::unimplemented(
"DFT", func,
"due to a bug in rocfft version in use, it requires fwd and bwd stride to be the same for COMPLEX out_of_place computations");
}
}

rocfft_plan_description plan_desc_fwd, plan_desc_bwd; // Can't reuse with ROCm 6 due to bug.
if (rocfft_plan_description_create(&plan_desc_fwd) != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to create plan description.");
}
if (rocfft_plan_description_create(&plan_desc_bwd) != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to create plan description.");
}

// plan_description can be destroyed afted plan_create
auto description_destroy = [](rocfft_plan_description p) {
if (rocfft_plan_description_destroy(p) != rocfft_status_success) {
Expand All @@ -273,7 +294,9 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
}
};
std::unique_ptr<rocfft_plan_description_t, decltype(description_destroy)>
description_destroyer(plan_desc, description_destroy);
description_destroyer_fwd(plan_desc_fwd, description_destroy);
std::unique_ptr<rocfft_plan_description_t, decltype(description_destroy)>
description_destroyer_bwd(plan_desc_bwd, description_destroy);

std::array<std::size_t, 3> stride_a_indices{ 0, 1, 2 };
std::sort(&stride_a_indices[0], &stride_a_indices[dimensions],
Expand Down Expand Up @@ -324,7 +347,7 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {

if (valid_forward) {
auto res =
rocfft_plan_description_set_data_layout(plan_desc, fwd_array_ty, bwd_array_ty,
rocfft_plan_description_set_data_layout(plan_desc_fwd, fwd_array_ty, bwd_array_ty,
nullptr, // in offsets
nullptr, // out offsets
dimensions,
Expand All @@ -339,15 +362,15 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
"Failed to set forward data layout.");
}

if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.fwd_scale) !=
if (rocfft_plan_description_set_scale_factor(plan_desc_fwd, config_values.fwd_scale) !=
rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to set forward scale factor.");
}

rocfft_plan fwd_plan;
res = rocfft_plan_create(&fwd_plan, placement, fwd_type, precision, dimensions,
lengths.data(), number_of_transforms, plan_desc);
lengths.data(), number_of_transforms, plan_desc_fwd);

if (res != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
Expand Down Expand Up @@ -380,7 +403,7 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {

if (valid_backward) {
auto res =
rocfft_plan_description_set_data_layout(plan_desc, bwd_array_ty, fwd_array_ty,
rocfft_plan_description_set_data_layout(plan_desc_bwd, bwd_array_ty, fwd_array_ty,
nullptr, // in offsets
nullptr, // out offsets
dimensions,
Expand All @@ -395,15 +418,15 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
"Failed to set backward data layout.");
}

if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.bwd_scale) !=
if (rocfft_plan_description_set_scale_factor(plan_desc_bwd, config_values.bwd_scale) !=
rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to set backward scale factor.");
}

rocfft_plan bwd_plan;
res = rocfft_plan_create(&bwd_plan, placement, bwd_type, precision, dimensions,
lengths.data(), number_of_transforms, plan_desc);
lengths.data(), number_of_transforms, plan_desc_bwd);
if (res != rocfft_status_success) {
throw mkl::exception("dft/backends/rocfft", __FUNCTION__,
"Failed to create backward rocFFT plan.");
Expand Down
8 changes: 2 additions & 6 deletions src/dft/backends/rocfft/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::rocfft> selector) {
Expand All @@ -46,6 +44,4 @@ template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(
template void descriptor<precision::DOUBLE, domain::REAL>::commit(
backend_selector<backend::rocfft>);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail

0 comments on commit 00a9af7

Please sign in to comment.