diff --git a/fft/src/KokkosFFT_Plans.hpp b/fft/src/KokkosFFT_Plans.hpp index 1025e006..5ec1e744 100644 --- a/fft/src/KokkosFFT_Plans.hpp +++ b/fft/src/KokkosFFT_Plans.hpp @@ -394,6 +394,11 @@ class Plan { KOKKOSFFT_THROW_IF(out_extents != m_out_extents, "extents of output View for plan and " "execution are not identical."); + + bool is_inplace = KokkosFFT::Impl::are_aliasing(in.data(), out.data()); + KOKKOSFFT_THROW_IF(is_inplace != m_is_inplace, + "If the plan is in-place, the input and output Views " + "must be identical."); } }; } // namespace KokkosFFT diff --git a/fft/unit_test/Test_Transform.cpp b/fft/unit_test/Test_Transform.cpp index 1a793533..f268f5e1 100644 --- a/fft/unit_test/Test_Transform.cpp +++ b/fft/unit_test/Test_Transform.cpp @@ -115,6 +115,45 @@ void test_fft1_identity_inplace(T atol = 1.0e-12) { EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol)); EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol)); + + // Create a plan for inplace transform + Kokkos::deep_copy(a_ref, a); + Kokkos::deep_copy(ar_ref, ar); + + Kokkos::fence(); + + int axis = -1; + KokkosFFT::Plan fft_plan(execution_space(), a, a_hat, + KokkosFFT::Direction::forward, axis); + fft_plan.execute(a, a_hat); + + KokkosFFT::Plan ifft_plan(execution_space(), a_hat, inv_a_hat, + KokkosFFT::Direction::backward, axis); + ifft_plan.execute(a_hat, inv_a_hat); + + KokkosFFT::Plan rfft_plan(execution_space(), ar, ar_hat, + KokkosFFT::Direction::forward, axis); + rfft_plan.execute(ar, ar_hat); + + KokkosFFT::Plan irfft_plan(execution_space(), ar_hat, inv_ar_hat, + KokkosFFT::Direction::backward, axis); + irfft_plan.execute(ar_hat, inv_ar_hat); + + EXPECT_TRUE(allclose(inv_a_hat, a_ref, 1.e-5, atol)); + EXPECT_TRUE(allclose(inv_ar_hat, ar_ref, 1.e-5, atol)); + + // inplace Plan cannot be reused for out-of-place case + ComplexView1DType a_hat_out("a_hat_out", i), + inv_a_hat_out("inv_a_hat_out", i); + + RealView1DType inv_ar_hat_out("inv_ar_hat_out", i); + ComplexView1DType ar_hat_out("ar_hat_out", i / 2 + 1); + EXPECT_THROW(fft_plan.execute(a, a_hat_out), std::runtime_error); + EXPECT_THROW(ifft_plan.execute(a_hat_out, inv_a_hat_out), + std::runtime_error); + EXPECT_THROW(rfft_plan.execute(ar, ar_hat_out), std::runtime_error); + EXPECT_THROW(irfft_plan.execute(ar_hat_out, inv_ar_hat_out), + std::runtime_error); } }