Skip to content

Commit

Permalink
Add assertions for completely mismatched extents
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Nov 5, 2024
1 parent 367e7cc commit d211123
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
16 changes: 16 additions & 0 deletions common/src/KokkosFFT_Extents.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ auto get_extents(const InViewType& in, const OutViewType& out,
static_assert(!(is_real_v<in_value_type> && is_real_v<out_value_type>),
"get_extents: real to real transform is not supported");

for (std::size_t i = 0; i < rank; i++) {
// The requirement for inner_most_axis is different for transform type
if (static_cast<int>(i) == inner_most_axis) continue;
KOKKOSFFT_THROW_IF(in_extents_full.at(i) != out_extents_full.at(i),
"input and output extents must be the same except for "
"the transform axis");
}

if constexpr (is_complex_v<in_value_type> && is_complex_v<out_value_type>) {
// Then C2C
KOKKOSFFT_THROW_IF(
in_extents_full.at(inner_most_axis) !=
out_extents_full.at(inner_most_axis),
"input and output extents must be the same for C2C transform");
}

if constexpr (is_real_v<in_value_type>) {
// Then R2C
if (is_inplace) {
Expand Down
43 changes: 43 additions & 0 deletions common/unit_test/Test_Extents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ void test_extents_1d_batched_FFT_2d() {
EXPECT_TRUE(fft_extents_c2c_axis1 == ref_fft_extents_r2c_axis1);
EXPECT_TRUE(out_extents_c2c_axis1 == ref_in_extents_r2c_axis1);
EXPECT_EQ(howmany_c2c_axis1, ref_howmany_r2c_axis1);

// Check if errors are correctly raised aginst invalid extents
ComplexView2Dtype xcout2_wrong("xcout2_wrong", n0 + 3, n1);
for (int i = 0; i < 2; i++) {
EXPECT_THROW(
{ KokkosFFT::Impl::get_extents(xcin2, xcout2_wrong, axes_type({i})); },
std::runtime_error);
}
}

template <typename LayoutType>
Expand Down Expand Up @@ -306,6 +314,14 @@ void test_extents_1d_batched_FFT_3d() {
EXPECT_TRUE(fft_extents_c2c_axis2 == ref_fft_extents_r2c_axis2);
EXPECT_TRUE(out_extents_c2c_axis2 == ref_in_extents_r2c_axis2);
EXPECT_EQ(howmany_c2c_axis2, ref_howmany_r2c_axis2);

// Check if errors are correctly raised aginst invalid extents
ComplexView3Dtype xcout3_wrong("xcout3_wrong", n0 + 3, n1, n2);
for (int i = 0; i < 3; i++) {
EXPECT_THROW(
{ KokkosFFT::Impl::get_extents(xcin3, xcout3_wrong, axes_type({i})); },
std::runtime_error);
}
}

TYPED_TEST(Extents1D, 1DFFT_1DView) {
Expand Down Expand Up @@ -429,6 +445,20 @@ void test_extents_2d() {

EXPECT_EQ(howmany_c2c_axis01, 1);
EXPECT_EQ(howmany_c2c_axis10, 1);

// Check if errors are correctly raised aginst invalid extents
ComplexView2Dtype xcout2_wrong("xcout2_wrong", n0 + 3, n1);
for (int axis0 = 0; axis0 < 2; axis0++) {
for (int axis1 = 0; axis1 < 2; axis1++) {
if (axis0 == axis1) continue;
EXPECT_THROW(
{
KokkosFFT::Impl::get_extents(xcin2, xcout2_wrong,
axes_type({axis0, axis1}));
},
std::runtime_error);
}
}
}

template <typename LayoutType>
Expand Down Expand Up @@ -709,6 +739,19 @@ void test_extents_2d_batched_FFT_3d() {
EXPECT_TRUE(fft_extents_c2c_axis_21 == ref_fft_extents_r2c_axis_21);
EXPECT_TRUE(out_extents_c2c_axis_21 == ref_in_extents_r2c_axis_21);
EXPECT_EQ(howmany_c2c_axis_21, ref_howmany_r2c_axis_21);

ComplexView3Dtype xcout3_wrong("xcout3_wrong", n0 + 3, n1, n2 + 2);
for (int axis0 = 0; axis0 < 3; axis0++) {
for (int axis1 = 0; axis1 < 3; axis1++) {
if (axis0 == axis1) continue;
EXPECT_THROW(
{
KokkosFFT::Impl::get_extents(xcin3, xcout3_wrong,
axes_type({axis0, axis1}));
},
std::runtime_error);
}
}
}

TYPED_TEST(Extents2D, 2DFFT_2DView) {
Expand Down

0 comments on commit d211123

Please sign in to comment.