Skip to content

Commit

Permalink
fix: fftw plan creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 4, 2024
1 parent 1d22abd commit 4806cfb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 42 deletions.
43 changes: 5 additions & 38 deletions fft/src/KokkosFFT_FFTW_Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ struct ScopedFFTWPlanType {
plan_type m_plan;
bool m_is_created = false;

ScopedFFTWPlanType() {}
ScopedFFTWPlanType() = delete;
ScopedFFTWPlanType(const ExecutionSpace &exec_space) {
init_threads<floating_point_type>(exec_space);
}
~ScopedFFTWPlanType() {
cleanup_threads<floating_point_type>();
if constexpr (std::is_same_v<floating_point_type, float>) {
Expand All @@ -80,43 +83,7 @@ struct ScopedFFTWPlanType {
}
}

plan_type &plan() { return m_plan; }

template <typename InScalarType, typename OutScalarType>
void create(const ExecutionSpace &exec_space, int rank, const int *n,
int howmany, InScalarType *in, const int *inembed, int istride,
int idist, OutScalarType *out, const int *onembed, int ostride,
int odist, [[maybe_unused]] int sign, unsigned flags) {
init_threads<floating_point_type>(exec_space);

constexpr auto type = fftw_transform_type<ExecutionSpace, T1, T2>::type();

if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
m_plan =
fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) {
m_plan =
fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) {
m_plan =
fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) {
m_plan =
fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) {
m_plan =
fftwf_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, sign, flags);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) {
m_plan = fftw_plan_many_dft(rank, n, howmany, in, inembed, istride, idist,
out, onembed, ostride, odist, sign, flags);
}
m_is_created = true;
}
const plan_type &plan() const { return m_plan; }

private:
template <typename T>
Expand Down
34 changes: 30 additions & 4 deletions fft/src/KokkosFFT_Host_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,36 @@ auto create_plan(const ExecutionSpace& exec_space,
[[maybe_unused]] auto sign =
KokkosFFT::Impl::direction_type<ExecutionSpace>(direction);

plan = std::make_unique<PlanType>();
plan->create(exec_space, rank, fft_extents.data(), howmany, idata,
in_extents.data(), istride, idist, odata, out_extents.data(),
ostride, odist, sign, FFTW_ESTIMATE);
plan = std::make_unique<PlanType>(exec_space);
constexpr auto type =
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
plan->m_plan = fftwf_plan_many_dft_r2c(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::D2Z) {
plan->m_plan = fftw_plan_many_dft_r2c(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2R) {
plan->m_plan = fftwf_plan_many_dft_c2r(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2D) {
plan->m_plan = fftw_plan_many_dft_c2r(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::C2C) {
plan->m_plan = fftwf_plan_many_dft(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE);
} else if constexpr (type == KokkosFFT::Impl::FFTWTransformType::Z2Z) {
plan->m_plan = fftw_plan_many_dft(
rank, fft_extents.data(), howmany, idata, in_extents.data(), istride,
idist, odata, out_extents.data(), ostride, odist, sign, FFTW_ESTIMATE);
}
plan->m_is_created;

return fft_size;
}
Expand Down

0 comments on commit 4806cfb

Please sign in to comment.