Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use exec space from Plan in _fft function #94

Merged
merged 1 commit into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ namespace Impl {
template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
class Plan {
public:
//! The type of Kokkos execution pace
using execSpace = ExecutionSpace;

Expand Down Expand Up @@ -97,6 +98,10 @@ class Plan {
//! The type of extents of input/output views
using extents_type = shape_type<InViewType::rank()>;

private:
//! Execution space
execSpace m_exec_space;

//! Dynamically allocatable fft plan.
std::unique_ptr<fft_plan_type> m_plan;

Expand Down Expand Up @@ -148,7 +153,8 @@ class Plan {
explicit Plan(const ExecutionSpace& exec_space, InViewType& in,
OutViewType& out, KokkosFFT::Direction direction, int axis,
std::optional<std::size_t> n = std::nullopt)
: m_fft_size(1),
: m_exec_space(exec_space),
m_fft_size(1),
m_is_transpose_needed(false),
m_direction(direction),
m_axes({axis}) {
Expand Down Expand Up @@ -200,7 +206,8 @@ class Plan {
explicit Plan(const ExecutionSpace& exec_space, InViewType& in,
OutViewType& out, KokkosFFT::Direction direction,
axis_type<DIM> axes, shape_type<DIM> s = {0})
: m_fft_size(1),
: m_exec_space(exec_space),
m_fft_size(1),
m_is_transpose_needed(false),
m_direction(direction),
m_axes(axes) {
Expand Down Expand Up @@ -311,14 +318,18 @@ class Plan {
}
}

/// \brief Return the execution space
execSpace const& exec_space() const noexcept { return m_exec_space; }

/// \brief Return the FFT plan
fft_plan_type& plan() const { return *m_plan; }

/// \brief Return the FFT info
const fft_info_type& info() const { return m_info; }
fft_info_type const& info() const { return m_info; }

/// \brief Return the FFT size
fft_size_type fft_size() const { return m_fft_size; }
KokkosFFT::Direction direction() const { return m_direction; }
bool is_transpose_needed() const { return m_is_transpose_needed; }
map_type map() const { return m_map; }
map_type map_inv() const { return m_map_inv; }
Expand Down
117 changes: 36 additions & 81 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@
// General Transform Interface
namespace KokkosFFT {
namespace Impl {
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType>
void _fft(const ExecutionSpace& exec_space, PlanType& plan,
const InViewType& in, OutViewType& out,
template <typename PlanType, typename InViewType, typename OutViewType>
void _fft(const PlanType& plan, const InViewType& in, OutViewType& out,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward) {
static_assert(Kokkos::is_view<InViewType>::value,
"_fft: InViewType is not a Kokkos::View.");
Expand All @@ -65,6 +63,7 @@ void _fft(const ExecutionSpace& exec_space, PlanType& plan,
"_fft: InViewType and OutViewType must have "
"the same Layout.");

using ExecutionSpace = typename PlanType::execSpace;
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename InViewType::memory_space>::accessible,
Expand All @@ -82,57 +81,13 @@ void _fft(const ExecutionSpace& exec_space, PlanType& plan,
auto* odata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
ExecutionSpace, out_value_type>::type*>(out.data());

auto forward = direction_type<ExecutionSpace>(KokkosFFT::Direction::forward);
KokkosFFT::Impl::_exec(plan.plan(), idata, odata, forward, plan.info());
KokkosFFT::Impl::normalize(exec_space, out, KokkosFFT::Direction::forward,
norm, plan.fft_size());
auto const exec_space = plan.exec_space();
auto const fft_direction = direction_type<ExecutionSpace>(plan.direction());
KokkosFFT::Impl::_exec(plan.plan(), idata, odata, fft_direction, plan.info());
KokkosFFT::Impl::normalize(exec_space, out, plan.direction(), norm,
plan.fft_size());
}

template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType>
void _ifft(const ExecutionSpace& exec_space, PlanType& plan,
const InViewType& in, OutViewType& out,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward) {
static_assert(Kokkos::is_view<InViewType>::value,
"_ifft: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<OutViewType>::value,
"_ifft: OutViewType is not a Kokkos::View.");
static_assert(KokkosFFT::Impl::is_layout_left_or_right_v<InViewType>,
"_ifft: InViewType must be either LayoutLeft or LayoutRight.");
static_assert(KokkosFFT::Impl::is_layout_left_or_right_v<OutViewType>,
"_ifft: OutViewType must be either LayoutLeft or LayoutRight.");

static_assert(InViewType::rank() == OutViewType::rank(),
"_ifft: InViewType and OutViewType must have "
"the same rank.");
static_assert(std::is_same_v<typename InViewType::array_layout,
typename OutViewType::array_layout>,
"_ifft: InViewType and OutViewType must have "
"the same Layout.");

static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename InViewType::memory_space>::accessible,
"_ifft: execution_space cannot access data in InViewType");
static_assert(
Kokkos::SpaceAccessibility<
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"_ifft: execution_space cannot access data in OutViewType");

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto* idata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
ExecutionSpace, in_value_type>::type*>(in.data());
auto* odata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
ExecutionSpace, out_value_type>::type*>(out.data());

auto backward =
direction_type<ExecutionSpace>(KokkosFFT::Direction::backward);
KokkosFFT::Impl::_exec(plan.plan(), idata, odata, backward, plan.info());
KokkosFFT::Impl::normalize(exec_space, out, KokkosFFT::Direction::backward,
norm, plan.fft_size());
}
} // namespace Impl
} // namespace KokkosFFT

Expand Down Expand Up @@ -198,12 +153,12 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -271,12 +226,12 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -348,12 +303,12 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -427,12 +382,12 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -914,11 +869,11 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -985,11 +940,11 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1057,11 +1012,11 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1132,11 +1087,11 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1409,11 +1364,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1478,11 +1433,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1549,11 +1504,11 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_fft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_fft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1622,11 +1577,11 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1695,11 +1650,11 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down Expand Up @@ -1770,11 +1725,11 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::_ifft(exec_space, plan, in_T, out_T, norm);
KokkosFFT::Impl::_fft(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());
} else {
KokkosFFT::Impl::_ifft(exec_space, plan, _in, out, norm);
KokkosFFT::Impl::_fft(plan, _in, out, norm);
}
}

Expand Down
Loading