From 4806cfbf0eeb67e3e233a195cc61b66985b2704d Mon Sep 17 00:00:00 2001
From: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Date: Wed, 4 Dec 2024 18:20:41 +0900
Subject: [PATCH] fix: fftw plan creation

---
 fft/src/KokkosFFT_FFTW_Types.hpp | 43 ++++----------------------------
 fft/src/KokkosFFT_Host_plans.hpp | 34 ++++++++++++++++++++++---
 2 files changed, 35 insertions(+), 42 deletions(-)

diff --git a/fft/src/KokkosFFT_FFTW_Types.hpp b/fft/src/KokkosFFT_FFTW_Types.hpp
index 16d2e8c7..16a0c0f6 100644
--- a/fft/src/KokkosFFT_FFTW_Types.hpp
+++ b/fft/src/KokkosFFT_FFTW_Types.hpp
@@ -70,7 +70,10 @@ struct ScopedFFTWPlanType {
   plan_type m_plan;
   bool m_is_created = false;
 
-  ScopedFFTWPlanType() {}
+  ScopedFFTWPlanType() = delete;
+  ScopedFFTWPlanType(const ExecutionSpace &exec_space) {
+    init_threads<floating_point_type>(exec_space);
+  }
   ~ScopedFFTWPlanType() {
     cleanup_threads<floating_point_type>();
     if constexpr (std::is_same_v<floating_point_type, float>) {
@@ -80,43 +83,7 @@ struct ScopedFFTWPlanType {
     }
   }
 
-  plan_type &plan() { return m_plan; }
-
-  template <typename InScalarType, typename OutScalarType>
-  void create(const ExecutionSpace &exec_space, int rank, const int *n,
-              int howmany, InScalarType *in, const int *inembed, int istride,
-              int idist, OutScalarType *out, const int *onembed, int ostride,
-              int odist, [[maybe_unused]] int sign, unsigned flags) {
-    init_threads<floating_point_type>(exec_space);
-
-    constexpr auto type = fftw_transform_type<ExecutionSpace, T1, T2>::type();
-
-    if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
-      m_plan =
-          fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
-                                  out, onembed, ostride, odist, flags);
-    } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) {
-      m_plan =
-          fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
-                                 out, onembed, ostride, odist, flags);
-    } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) {
-      m_plan =
-          fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
-                                  out, onembed, ostride, odist, flags);
-    } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) {
-      m_plan =
-          fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
-                                 out, onembed, ostride, odist, flags);
-    } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) {
-      m_plan =
-          fftwf_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
-                              out, onembed, ostride, odist, sign, flags);
-    } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) {
-      m_plan = fftw_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
-                                  out, onembed, ostride, odist, sign, flags);
-    }
-    m_is_created = true;
-  }
+  const plan_type &plan() const { return m_plan; }
 
  private:
   template <typename T>
diff --git a/fft/src/KokkosFFT_Host_plans.hpp b/fft/src/KokkosFFT_Host_plans.hpp
index dff9d15e..47dd78cc 100644
--- a/fft/src/KokkosFFT_Host_plans.hpp
+++ b/fft/src/KokkosFFT_Host_plans.hpp
@@ -58,10 +58,36 @@ auto create_plan(const ExecutionSpace& exec_space,
   [[maybe_unused]] auto sign =
       KokkosFFT::Impl::direction_type<ExecutionSpace>(direction);
 
-  plan = std::make_unique<PlanType>();
-  plan->create(exec_space, rank, fft_extents.data(), howmany, idata,
-               in_extents.data(), istride, idist, odata, out_extents.data(),
-               ostride, odist, sign, FFTW_ESTIMATE);
+  plan = std::make_unique<PlanType>(exec_space);
+  constexpr auto type =
+      KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
+                                      out_value_type>::type();
+  if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
+    plan->m_plan = fftwf_plan_many_dft_r2c(
+        rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
+        idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
+  } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) {
+    plan->m_plan = fftw_plan_many_dft_r2c(
+        rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
+        idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
+  } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) {
+    plan->m_plan = fftwf_plan_many_dft_c2r(
+        rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
+        idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
+  } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) {
+    plan->m_plan = fftw_plan_many_dft_c2r(
+        rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
+        idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
+  } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) {
+    plan->m_plan = fftwf_plan_many_dft(
+        rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
+        idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE);
+  } else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) {
+    plan->m_plan = fftw_plan_many_dft(
+        rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
+        idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE);
+  }
+  plan->m_is_created;
 
   return fft_size;
 }