Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DFT] Corrections to FWD/BWD_STRIDES API #549

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dft/backends/cufft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class cufft_commit final : public dft::detail::commit_impl<prec, dom> {
if (a_min - stride_vecs.vec_a.begin() != b_min - stride_vecs.vec_b.begin()) {
throw mkl::unimplemented(
"dft/backends/cufft", __FUNCTION__,
"cufft requires that if ordered by stride length, the order of strides is the same for input and output strides!");
"cufft requires that if ordered by stride length, the order of strides is the same for input/output or fwd/bwd strides!");
}
}
const int a_stride = static_cast<int>(*a_min);
Expand Down
8 changes: 2 additions & 6 deletions src/dft/backends/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "../descriptor.cxx"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(sycl::queue &queue) {
Expand All @@ -41,6 +39,4 @@ template void descriptor<precision::SINGLE, domain::REAL>::commit(sycl::queue &)
template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(sycl::queue &);
template void descriptor<precision::DOUBLE, domain::REAL>::commit(sycl::queue &);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail
8 changes: 2 additions & 6 deletions src/dft/backends/mklcpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklcpu> selector) {
Expand All @@ -46,6 +44,4 @@ template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(
template void descriptor<precision::DOUBLE, domain::REAL>::commit(
backend_selector<backend::mklcpu>);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail
8 changes: 2 additions & 6 deletions src/dft/backends/mklgpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@

#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp"

namespace oneapi {
namespace mkl {
namespace dft {
namespace oneapi::mkl::dft::detail {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklgpu> selector) {
Expand All @@ -46,6 +44,4 @@ template void descriptor<precision::DOUBLE, domain::COMPLEX>::commit(
template void descriptor<precision::DOUBLE, domain::REAL>::commit(
backend_selector<backend::mklgpu>);

} //namespace dft
} //namespace mkl
} //namespace oneapi
} //namespace oneapi::mkl::dft::detail
44 changes: 26 additions & 18 deletions src/dft/backends/rocfft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,26 +289,34 @@ class rocfft_commit final : public dft::detail::commit_impl<prec, dom> {
if (dom == dft::domain::REAL) {
lengths_cplx[0] = lengths_cplx[0] / 2 + 1;
}
// When creating real-complex descriptions, the strides will always be wrong for one of the directions.
// When creating real-complex descriptions with INPUT/OUTPUT_STRIDES,
// the strides will always be wrong for one of the directions.
// This is because the least significant dimension is symmetric.
// If the strides are invalid (too small to fit) then just don't bother creating the plan.
const bool vec_a_valid_as_reals =
dimensions == 1 ||
(lengths_cplx[stride_a_indices[0]] <= stride_vecs.vec_a[stride_a_indices[1]] &&
(dimensions == 2 ||
lengths_cplx[stride_a_indices[0]] * lengths_cplx[stride_a_indices[1]] <=
stride_vecs.vec_a[stride_a_indices[2]]));
const bool vec_b_valid_as_reals =
dimensions == 1 ||
(lengths_cplx[stride_b_indices[0]] <= stride_vecs.vec_b[stride_b_indices[1]] &&
(dimensions == 2 ||
lengths_cplx[stride_b_indices[0]] * lengths_cplx[stride_b_indices[1]] <=
stride_vecs.vec_b[stride_b_indices[2]]));
// Test if the stride vector being used as the fwd domain for each direction has valid strides for that use.
bool valid_forward =
stride_vecs.fwd_in == stride_vecs.vec_a && vec_a_valid_as_reals || vec_b_valid_as_reals;
bool valid_backward = stride_vecs.bwd_out == stride_vecs.vec_a && vec_a_valid_as_reals ||
vec_b_valid_as_reals;
auto are_strides_smaller_than_lengths = [=](auto& svec, auto& sindices,
auto& domain_lengths) {
return dimensions == 1 ||
(domain_lengths[sindices[0]] <= svec[sindices[1]] &&
(dimensions == 2 ||
svec[sindices[1]] * domain_lengths[sindices[1]] <= svec[sindices[2]]));
};

const bool vec_a_valid_as_fwd_domain =
are_strides_smaller_than_lengths(stride_vecs.vec_a, stride_a_indices, lengths);
const bool vec_b_valid_as_fwd_domain =
are_strides_smaller_than_lengths(stride_vecs.vec_b, stride_b_indices, lengths);
const bool vec_a_valid_as_bwd_domain =
are_strides_smaller_than_lengths(stride_vecs.vec_a, stride_a_indices, lengths_cplx);
const bool vec_b_valid_as_bwd_domain =
are_strides_smaller_than_lengths(stride_vecs.vec_b, stride_b_indices, lengths_cplx);

// Test if the stride vector being used as the fwd/bwd domain for each direction has valid strides for that use.
bool valid_forward = (stride_vecs.fwd_in == stride_vecs.vec_a &&
vec_a_valid_as_fwd_domain && vec_b_valid_as_bwd_domain) ||
(vec_b_valid_as_fwd_domain && vec_a_valid_as_bwd_domain);
bool valid_backward = (stride_vecs.bwd_in == stride_vecs.vec_a &&
vec_a_valid_as_bwd_domain && vec_b_valid_as_fwd_domain) ||
(vec_b_valid_as_bwd_domain && vec_a_valid_as_fwd_domain);

if (!valid_forward && !valid_backward) {
throw mkl::exception("dft/backends/cufft", __FUNCTION__, "Invalid strides.");
Expand Down
42 changes: 27 additions & 15 deletions src/dft/descriptor.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,28 @@ namespace mkl {
namespace dft {
namespace detail {

// Compute the default strides. Modifies real_strides and complex_strides arguments.
// Compute the default strides. Modifies real_strides and complex_strides arguments
template <domain dom>
inline void compute_default_strides(const std::vector<std::int64_t>& dimensions,
std::vector<std::int64_t>& fwd_strides,
std::vector<std::int64_t>& bwd_strides) {
auto rank = dimensions.size();
std::vector<std::int64_t> strides(rank + 1, 1);
for (auto i = rank - 1; i > 0; --i) {
strides[i] = strides[i + 1] * dimensions[i];
const auto rank = dimensions.size();
fwd_strides = std::vector<std::int64_t>(rank + 1, 1);
fwd_strides[0] = 0;
bwd_strides = fwd_strides;
if (rank == 1) {
return;
}

bwd_strides[rank - 1] =
dom == domain::COMPLEX ? dimensions[rank - 1] : (dimensions[rank - 1] / 2) + 1;
fwd_strides[rank - 1] =
dom == domain::COMPLEX ? dimensions[rank - 1] : 2 * bwd_strides[rank - 1];
for (auto i = rank - 1; i > 1; --i) {
// Can't start at rank - 2 with unsigned type and minimum value of rank being 1.
bwd_strides[i - 1] = bwd_strides[i] * dimensions[i - 1];
fwd_strides[i - 1] = fwd_strides[i] * dimensions[i - 1];
}
strides[0] = 0;
// Fwd/Bwd strides and Input/Output strides being the same by default means
// that we don't have to specify if we default to using fwd/bwd strides or
// input/output strides.
bwd_strides = strides;
fwd_strides = std::move(strides);
}

template <precision prec, domain dom>
Expand All @@ -69,12 +76,15 @@ void descriptor<prec, dom>::set_value(config_param param, ...) {
case config_param::PRECISION:
throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter.");
break;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
case config_param::INPUT_STRIDES:
detail::set_value<config_param::INPUT_STRIDES>(values_, va_arg(vl, std::int64_t*));
break;
case config_param::OUTPUT_STRIDES:
detail::set_value<config_param::OUTPUT_STRIDES>(values_, va_arg(vl, std::int64_t*));
break;
#pragma clang diagnostic pop
case config_param::FWD_STRIDES:
detail::set_value<config_param::FWD_STRIDES>(values_, va_arg(vl, std::int64_t*));
break;
Expand Down Expand Up @@ -150,10 +160,9 @@ descriptor<prec, dom>::descriptor(std::vector<std::int64_t> dimensions) {
"Invalid dimension value (negative or 0).");
}
}
compute_default_strides(dimensions, values_.fwd_strides, values_.bwd_strides);
// Assume forward transform.
values_.input_strides = values_.fwd_strides;
values_.output_strides = values_.bwd_strides;
compute_default_strides<dom>(dimensions, values_.fwd_strides, values_.bwd_strides);
values_.input_strides = std::vector<std::int64_t>(dimensions.size() + 1, 0);
values_.output_strides = std::vector<std::int64_t>(dimensions.size() + 1, 0);
values_.bwd_scale = real_t(1.0);
values_.fwd_scale = real_t(1.0);
values_.number_of_transforms = 1;
Expand Down Expand Up @@ -220,6 +229,8 @@ void descriptor<prec, dom>::get_value(config_param param, ...) const {
*va_arg(vl, config_value*) = values_.conj_even_storage;
break;
case config_param::PLACEMENT: *va_arg(vl, config_value*) = values_.placement; break;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
case config_param::INPUT_STRIDES:
std::copy(values_.input_strides.begin(), values_.input_strides.end(),
va_arg(vl, std::int64_t*));
Expand All @@ -228,6 +239,7 @@ void descriptor<prec, dom>::get_value(config_param param, ...) const {
std::copy(values_.output_strides.begin(), values_.output_strides.end(),
va_arg(vl, std::int64_t*));
break;
#pragma clang diagnostic pop
case config_param::FWD_STRIDES:
std::copy(values_.fwd_strides.begin(), values_.fwd_strides.end(),
va_arg(vl, std::int64_t*));
Expand Down
97 changes: 58 additions & 39 deletions tests/unit_tests/dft/include/compute_out_of_place.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,26 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
return test_skipped;
}

descriptor_t descriptor{ sizes };
auto strides_fwd_cpy = strides_fwd;
auto strides_bwd_cpy = strides_bwd;
if (strides_fwd_cpy.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data());
}
else {
strides_fwd_cpy.resize(sizes.size() + 1);
descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data());
}
if (strides_bwd_cpy.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data());
}
else {
strides_bwd_cpy.resize(sizes.size() + 1);
descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data());
}
auto [forward_distance, backward_distance] =
get_default_distances<domain>(sizes, strides_fwd, strides_bwd);
get_default_distances<domain>(sizes, strides_fwd_cpy, strides_bwd_cpy);
auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>());

descriptor_t descriptor{ sizes };
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
Expand All @@ -45,22 +60,12 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
if (strides_fwd.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd.data());
}
if (strides_bwd.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd.data());
}
else if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, complex_strides.data());
}
commit_descriptor(descriptor, sycl_queue);
std::vector<FwdInputType> fwd_data(
strided_copy(input, sizes, strides_fwd, batches, forward_distance));
strided_copy(input, sizes, strides_fwd_cpy, batches, forward_distance));

auto tmp = std::vector<FwdOutputType>(
cast_unsigned(backward_distance * batches + get_default(strides_bwd, 0, 0L)), 0);
cast_unsigned(backward_distance * batches + get_default(strides_bwd_cpy, 0, 0L)), 0);
{
sycl::buffer<FwdInputType, 1> fwd_buf{ fwd_data };
sycl::buffer<FwdOutputType, 1> bwd_buf{ tmp };
Expand All @@ -74,7 +79,7 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
for (std::int64_t i = 0; i < batches; i++) {
EXPECT_TRUE(check_equal_strided<domain == oneapi::mkl::dft::domain::REAL>(
bwd_ptr + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes,
strides_bwd, abs_error_margin, rel_error_margin, std::cout));
strides_bwd_cpy, abs_error_margin, rel_error_margin, std::cout));
}
}

Expand All @@ -88,9 +93,9 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
[this](auto &x) { x *= static_cast<PrecisionType>(forward_elements); });

for (std::int64_t i = 0; i < batches; i++) {
EXPECT_TRUE(check_equal_strided<false>(fwd_data.data() + forward_distance * i,
input.data() + ref_distance * i, sizes, strides_fwd,
abs_error_margin, rel_error_margin, std::cout));
EXPECT_TRUE(check_equal_strided<false>(
fwd_data.data() + forward_distance * i, input.data() + ref_distance * i, sizes,
strides_fwd_cpy, abs_error_margin, rel_error_margin, std::cout));
}

return !::testing::Test::HasFailure();
Expand All @@ -103,10 +108,34 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
}
const std::vector<sycl::event> no_dependencies;

auto [forward_distance, backward_distance] =
get_default_distances<domain>(sizes, strides_fwd, strides_bwd);

descriptor_t descriptor{ sizes };
auto strides_fwd_cpy = strides_fwd;
auto strides_bwd_cpy = strides_bwd;
if (strides_fwd_cpy.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data());
}
else {
strides_fwd_cpy.resize(sizes.size() + 1);
descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd_cpy.data());
}
if (strides_bwd_cpy.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data());
}
else {
strides_bwd_cpy.resize(sizes.size() + 1);
descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd_cpy.data());
}
auto [forward_distance, backward_distance] =
get_default_distances<domain>(sizes, strides_fwd_cpy, strides_bwd_cpy);
auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>());
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
descriptor.set_value(oneapi::mkl::dft::config_param::CONJUGATE_EVEN_STORAGE,
oneapi::mkl::dft::config_value::COMPLEX_COMPLEX);
descriptor.set_value(oneapi::mkl::dft::config_param::PACKED_FORMAT,
oneapi::mkl::dft::config_value::CCE_FORMAT);
}
descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
oneapi::mkl::dft::config_value::NOT_INPLACE);
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
Expand All @@ -118,36 +147,26 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
if (strides_fwd.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd.data());
}
if (strides_bwd.size()) {
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd.data());
}
else if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, complex_strides.data());
}
commit_descriptor(descriptor, sycl_queue);

auto ua_input = usm_allocator_t<FwdInputType>(cxt, *dev);
auto ua_output = usm_allocator_t<FwdOutputType>(cxt, *dev);

std::vector<FwdInputType, decltype(ua_input)> fwd(
strided_copy(input, sizes, strides_fwd, batches, forward_distance, ua_input), ua_input);
strided_copy(input, sizes, strides_fwd_cpy, batches, forward_distance, ua_input), ua_input);
std::vector<FwdOutputType, decltype(ua_output)> bwd(
cast_unsigned(backward_distance * batches + get_default(strides_bwd, 0, 0L)), ua_output);
cast_unsigned(backward_distance * batches + get_default(strides_bwd_cpy, 0, 0L)),
ua_output);

oneapi::mkl::dft::compute_forward<descriptor_t, FwdInputType, FwdOutputType>(
descriptor, fwd.data(), bwd.data(), no_dependencies)
.wait_and_throw();

auto bwd_ptr = &bwd[0];
auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>());
for (std::int64_t i = 0; i < batches; i++) {
EXPECT_TRUE(check_equal_strided<domain == oneapi::mkl::dft::domain::REAL>(
bwd_ptr + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes,
strides_bwd, abs_error_margin, rel_error_margin, std::cout));
strides_bwd_cpy, abs_error_margin, rel_error_margin, std::cout));
}

oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>, FwdOutputType,
Expand All @@ -160,9 +179,9 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
[this](auto &x) { x *= static_cast<PrecisionType>(forward_elements); });

for (std::int64_t i = 0; i < batches; i++) {
EXPECT_TRUE(check_equal_strided<false>(fwd.data() + forward_distance * i,
input.data() + ref_distance * i, sizes, strides_fwd,
abs_error_margin, rel_error_margin, std::cout));
EXPECT_TRUE(check_equal_strided<false>(
fwd.data() + forward_distance * i, input.data() + ref_distance * i, sizes,
strides_fwd_cpy, abs_error_margin, rel_error_margin, std::cout));
}

return !::testing::Test::HasFailure();
Expand Down
Loading