Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create plan with shape arg #96

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "KokkosFFT_common_types.hpp"
#include "KokkosFFT_utils.hpp"
#include "KokkosFFT_transpose.hpp"
#include "KokkosFFT_padding.hpp"

namespace KokkosFFT {
namespace Impl {
Expand All @@ -20,14 +21,11 @@ namespace Impl {
*/
template <typename InViewType, typename OutViewType, std::size_t DIM = 1>
auto get_extents(const InViewType& in, const OutViewType& out,
axis_type<DIM> _axes) {
axis_type<DIM> axes, shape_type<DIM> shape = {0}) {
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;

// index map after transpose over axis
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, _axes);

static_assert(InViewType::rank() >= DIM,
"KokkosFFT::get_map_axes: Rank of View must be larger thane or "
"equal to the Rank of FFT axes.");
Expand All @@ -41,20 +39,32 @@ auto get_extents(const InViewType& in, const OutViewType& out,
? 0
: (rank - 1);

std::vector<int> _in_extents, _out_extents, _fft_extents;
// index map after transpose over axis
auto [map, map_inv] = KokkosFFT::Impl::get_map_axes(in, axes);

// Get new shape based on shape parameter
// [TO DO] get_modified shape should take out as well and check is_C2R
// internally
bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point<out_value_type>::value;
auto modified_in_shape =
KokkosFFT::Impl::get_modified_shape(in, shape, axes, is_C2R);

// Get extents for the inner most axes in LayoutRight
// If we allow the FFT on the layoutLeft, this part should be modified
std::vector<int> _in_extents, _out_extents, _fft_extents;
for (std::size_t i = 0; i < rank; i++) {
auto _idx = map.at(i);
_in_extents.push_back(in.extent(_idx));
_out_extents.push_back(out.extent(_idx));
auto _idx = map.at(i);
auto in_extent = modified_in_shape.at(_idx);
auto out_extent = out.extent(_idx);
_in_extents.push_back(in_extent);
_out_extents.push_back(out_extent);

// The extent for transform is always equal to the extent
// of the extent of real type (R2C or C2R)
// For C2C, the in and out extents are the same.
// In the end, we can just use the largest extent among in and out extents.
auto fft_extent = std::max(in.extent(_idx), out.extent(_idx));
auto fft_extent = std::max(in_extent, out_extent);
_fft_extents.push_back(fft_extent);
}

Expand Down
5 changes: 5 additions & 0 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ auto get_modified_shape(const ViewType& view, shape_type<DIM> shape,
"larger than or equal to 1");
constexpr int rank = static_cast<int>(ViewType::rank());

shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (shape == zeros) {
return KokkosFFT::Impl::extract_extents(view);
}

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> positive_axes;
for (std::size_t i = 0; i < DIM; i++) {
Expand Down
20 changes: 12 additions & 8 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<1> axes) {
[[maybe_unused]] Direction direction, axis_type<1> axes,
shape_type<1> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -39,7 +40,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
const int nx = fft_extents.at(0);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
Expand All @@ -59,7 +60,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<2> axes) {
[[maybe_unused]] Direction direction, axis_type<2> axes,
shape_type<2> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -77,7 +79,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
const int nx = fft_extents.at(0), ny = fft_extents.at(1);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
Expand All @@ -97,7 +99,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<3> axes) {
[[maybe_unused]] Direction direction, axis_type<3> axes,
shape_type<3> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -115,7 +118,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);

const int nx = fft_extents.at(0), ny = fft_extents.at(1),
nz = fft_extents.at(2);
Expand All @@ -137,7 +140,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -153,7 +157,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
20 changes: 12 additions & 8 deletions fft/src/KokkosFFT_HIP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<1> axes) {
[[maybe_unused]] Direction direction, axis_type<1> axes,
shape_type<1> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -40,7 +41,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
const int nx = fft_extents.at(0);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
Expand All @@ -61,7 +62,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<2> axes) {
[[maybe_unused]] Direction direction, axis_type<2> axes,
shape_type<2> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -80,7 +82,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
const int nx = fft_extents.at(0), ny = fft_extents.at(1);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());
Expand All @@ -101,7 +103,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<3> axes) {
[[maybe_unused]] Direction direction, axis_type<3> axes,
shape_type<3> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -120,7 +123,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);

const int nx = fft_extents.at(0), ny = fft_extents.at(1),
nz = fft_extents.at(2);
Expand All @@ -143,7 +146,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -159,7 +163,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
5 changes: 3 additions & 2 deletions fft/src/KokkosFFT_OpenMP_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -57,7 +58,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
30 changes: 27 additions & 3 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <Kokkos_Core.hpp>
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_transpose.hpp"
#include "KokkosFFT_padding.hpp"
#include "KokkosFFT_utils.hpp"

#if defined(KOKKOS_ENABLE_CUDA)
Expand Down Expand Up @@ -117,11 +118,14 @@ class Plan {
//! whether transpose is needed or not
bool m_is_transpose_needed;

//! whether crop or pad is needed or not
bool m_is_crop_or_pad_needed;

//! axes for fft
axis_type<DIM> m_axes;

//! Shape of the transformed axis of the output
shape_type<DIM> m_shape;
extents_type m_shape;

//! directions of fft
KokkosFFT::Direction m_direction;
Expand Down Expand Up @@ -186,12 +190,24 @@ class Plan {
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");

shape_type<1> s = {0};
if (n) {
std::size_t _n = n.value();
s = shape_type<1>({_n});
}

bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point<out_value_type>::value;

m_in_extents = KokkosFFT::Impl::extract_extents(in);
m_out_extents = KokkosFFT::Impl::extract_extents(out);
std::tie(m_map, m_map_inv) = KokkosFFT::Impl::get_map_axes(in, axis);
m_is_transpose_needed = KokkosFFT::Impl::is_transpose_needed(m_map);
m_shape = KokkosFFT::Impl::get_modified_shape(in, s, m_axes, is_C2R);
m_is_crop_or_pad_needed =
KokkosFFT::Impl::is_crop_or_pad_needed(in, m_shape);
m_fft_size = KokkosFFT::Impl::_create(exec_space, m_plan, in, out, m_buffer,
m_info, direction, m_axes);
m_info, direction, m_axes, s);
}

/// \brief Constructor for multidimensional FFT
Expand Down Expand Up @@ -240,12 +256,18 @@ class Plan {
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");

bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point<out_value_type>::value;

m_in_extents = KokkosFFT::Impl::extract_extents(in);
m_out_extents = KokkosFFT::Impl::extract_extents(out);
std::tie(m_map, m_map_inv) = KokkosFFT::Impl::get_map_axes(in, axes);
m_is_transpose_needed = KokkosFFT::Impl::is_transpose_needed(m_map);
m_shape = KokkosFFT::Impl::get_modified_shape(in, s, m_axes, is_C2R);
m_is_crop_or_pad_needed =
KokkosFFT::Impl::is_crop_or_pad_needed(in, m_shape);
m_fft_size = KokkosFFT::Impl::_create(exec_space, m_plan, in, out, m_buffer,
m_info, direction, axes);
m_info, direction, axes, s);
}

~Plan() {
Expand Down Expand Up @@ -331,6 +353,8 @@ class Plan {
fft_size_type fft_size() const { return m_fft_size; }
KokkosFFT::Direction direction() const { return m_direction; }
bool is_transpose_needed() const { return m_is_transpose_needed; }
bool is_crop_or_pad_needed() const { return m_is_crop_or_pad_needed; }
extents_type shape() const { return m_shape; }
map_type map() const { return m_map; }
map_type map_inv() const { return m_map_inv; }
nonConstInViewType& in_T() { return m_in_T; }
Expand Down
5 changes: 3 additions & 2 deletions fft/src/KokkosFFT_ROCM_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ template <typename ExecutionSpace, typename PlanType, typename InViewType,
auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
BufferViewType& buffer, InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -105,7 +106,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
5 changes: 3 additions & 2 deletions fft/src/KokkosFFT_SYCL_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
const InViewType& in, const OutViewType& out,
[[maybe_unused]] BufferViewType& buffer,
[[maybe_unused]] InfoType& execution_info,
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes) {
[[maybe_unused]] Direction direction, axis_type<fft_rank> axes,
shape_type<fft_rank> s) {
static_assert(Kokkos::is_view<InViewType>::value,
"KokkosFFT::_create: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
Expand All @@ -69,7 +70,7 @@ auto _create(const ExecutionSpace& exec_space, std::unique_ptr<PlanType>& plan,
KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes);
KokkosFFT::Impl::get_extents(in, out, axes, s);
int idist = std::accumulate(in_extents.begin(), in_extents.end(), 1,
std::multiplies<>());
int odist = std::accumulate(out_extents.begin(), out_extents.end(), 1,
Expand Down
Loading
Loading