Skip to content

Commit

Permalink
Raise an error if inplace plan is executed on out-of-place views
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Nov 5, 2024
1 parent d211123 commit 97bc37f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
5 changes: 5 additions & 0 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,11 @@ class Plan {
KOKKOSFFT_THROW_IF(out_extents != m_out_extents,
"extents of output View for plan and "
"execution are not identical.");

bool is_inplace = KokkosFFT::Impl::are_aliasing(in.data(), out.data());
KOKKOSFFT_THROW_IF(is_inplace != m_is_inplace,
"If the plan is in-place, the input and output Views "
"must be identical.");
}
};
} // namespace KokkosFFT
Expand Down
39 changes: 39 additions & 0 deletions fft/unit_test/Test_Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,45 @@ void test_fft1_identity_inplace(T atol = 1.0e-12) {

EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol));

// Create a plan for inplace transform
Kokkos::deep_copy(a_ref, a);
Kokkos::deep_copy(ar_ref, ar);

Kokkos::fence();

int axis = -1;
KokkosFFT::Plan fft_plan(execution_space(), a, a_hat,
KokkosFFT::Direction::forward, axis);
fft_plan.execute(a, a_hat);

KokkosFFT::Plan ifft_plan(execution_space(), a_hat, inv_a_hat,
KokkosFFT::Direction::backward, axis);
ifft_plan.execute(a_hat, inv_a_hat);

KokkosFFT::Plan rfft_plan(execution_space(), ar, ar_hat,
KokkosFFT::Direction::forward, axis);
rfft_plan.execute(ar, ar_hat);

KokkosFFT::Plan irfft_plan(execution_space(), ar_hat, inv_ar_hat,
KokkosFFT::Direction::backward, axis);
irfft_plan.execute(ar_hat, inv_ar_hat);

EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol));
EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol));

// inplace Plan cannot be reused for out-of-place case
ComplexView1DType a_hat_out("a_hat_out", i),
inv_a_hat_out("inv_a_hat_out", i);

RealView1DType inv_ar_hat_out("inv_ar_hat_out", i);
ComplexView1DType ar_hat_out("ar_hat_out", i / 2 + 1);
EXPECT_THROW(fft_plan.execute(a, a_hat_out), std::runtime_error);
EXPECT_THROW(ifft_plan.execute(a_hat_out, inv_a_hat_out),
std::runtime_error);
EXPECT_THROW(rfft_plan.execute(ar, ar_hat_out), std::runtime_error);
EXPECT_THROW(irfft_plan.execute(ar_hat_out, inv_ar_hat_out),
std::runtime_error);
}
}

Expand Down

0 comments on commit 97bc37f

Please sign in to comment.