Skip to content

Commit

Permalink
Feed shape arg to plan constructor inside FFT APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Apr 11, 2024
1 parent 6cda7d5 commit 4f29558
Showing 1 changed file with 64 additions and 153 deletions.
217 changes: 64 additions & 153 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,16 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"fft: execution_space cannot access data in OutViewType");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axis, n);
InViewType _in;
if (n) {
std::size_t _n = n.value();
auto modified_shape = KokkosFFT::Impl::get_modified_shape(
in, shape_type<1>({_n}), axis_type<1>({axis}));
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

KokkosFFT::Impl::Plan plan(exec_space, _in, out,
KokkosFFT::Direction::forward, axis);
if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -202,23 +196,17 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"fft: execution_space cannot access data in OutViewType");

plan.template good<ExecutionSpace, InViewType, OutViewType>(
in, out, KokkosFFT::Direction::forward, axis_type<1>{axis});

InViewType _in;
if (n) {
std::size_t _n = n.value();
auto modified_shape = KokkosFFT::Impl::get_modified_shape(
in, shape_type<1>({_n}), axis_type<1>({axis}));
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

plan.template good<ExecutionSpace, InViewType, OutViewType>(
_in, out, KokkosFFT::Direction::forward, axis_type<1>{axis});

if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -273,29 +261,17 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"ifft: execution_space cannot access data in OutViewType");

using out_value_type = typename OutViewType::non_const_value_type;
KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axis, n);

InViewType _in;
if (n) {
std::size_t _n = n.value();
bool is_C2R = std::is_floating_point<out_value_type>::value;
auto modified_shape = KokkosFFT::Impl::get_modified_shape(
in, shape_type<1>({_n}), axis_type<1>({axis}), is_C2R);

/* [FIX THIS] Shallow copy should be sufficient
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
*/
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

KokkosFFT::Impl::Plan plan(exec_space, _in, out,
KokkosFFT::Direction::backward, axis);
if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -352,29 +328,17 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"ifft: execution_space cannot access data in OutViewType");

using out_value_type = typename OutViewType::non_const_value_type;
plan.template good<ExecutionSpace, InViewType, OutViewType>(
in, out, KokkosFFT::Direction::backward, axis_type<1>{axis});

InViewType _in;
if (n) {
std::size_t _n = n.value();
bool is_C2R = std::is_floating_point<out_value_type>::value;
auto modified_shape = KokkosFFT::Impl::get_modified_shape(
in, shape_type<1>({_n}), axis_type<1>({axis}), is_C2R);
/* [FIX THIS] Shallow copy should be sufficient
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
*/
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

plan.template good<ExecutionSpace, InViewType, OutViewType>(
_in, out, KokkosFFT::Direction::backward, axis_type<1>{axis});

if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -847,21 +811,17 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"fft2: execution_space cannot access data in OutViewType");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axes, s);

InViewType _in;
shape_type<2> zeros = {0}; // default shape means no crop or pad
if (s != zeros) {
auto modified_shape = KokkosFFT::Impl::get_modified_shape(in, s, axes);
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

KokkosFFT::Impl::Plan plan(exec_space, _in, out,
KokkosFFT::Direction::forward, axes);
if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -917,22 +877,17 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"fft2: execution_space cannot access data in OutViewType");

plan.template good<ExecutionSpace, InViewType, OutViewType>(
in, out, KokkosFFT::Direction::forward, axes);

InViewType _in;
shape_type<2> zeros = {0}; // default shape means no crop or pad
if (s != zeros) {
auto modified_shape = KokkosFFT::Impl::get_modified_shape(in, s, axes);
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

plan.template good<ExecutionSpace, InViewType, OutViewType>(
_in, out, KokkosFFT::Direction::forward, axes);

if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -986,25 +941,17 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"ifft2: execution_space cannot access data in OutViewType");

using out_value_type = typename OutViewType::non_const_value_type;
KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axes, s);

InViewType _in;
shape_type<2> zeros = {0}; // default shape means no crop or pad
if (s != zeros) {
bool is_C2R = std::is_floating_point<out_value_type>::value;
auto modified_shape =
KokkosFFT::Impl::get_modified_shape(in, s, axes, is_C2R);
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

KokkosFFT::Impl::Plan plan(exec_space, _in, out,
KokkosFFT::Direction::backward, axes);
if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -1060,26 +1007,17 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"ifft2: execution_space cannot access data in OutViewType");

using out_value_type = typename OutViewType::non_const_value_type;
plan.template good<ExecutionSpace, InViewType, OutViewType>(
in, out, KokkosFFT::Direction::backward, axes);

InViewType _in;
shape_type<2> zeros = {0}; // default shape means no crop or pad
if (s != zeros) {
bool is_C2R = std::is_floating_point<out_value_type>::value;
auto modified_shape =
KokkosFFT::Impl::get_modified_shape(in, s, axes, is_C2R);
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

plan.template good<ExecutionSpace, InViewType, OutViewType>(
_in, out, KokkosFFT::Direction::backward, axes);

if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -1411,21 +1349,16 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"fftn: execution_space cannot access data in OutViewType");

KokkosFFT::Impl::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward,
axes, s);
InViewType _in;
shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (s != zeros) {
auto modified_shape = KokkosFFT::Impl::get_modified_shape(in, s, axes);
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

KokkosFFT::Impl::Plan plan(exec_space, _in, out,
KokkosFFT::Direction::forward, axes);
if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -1481,22 +1414,17 @@ void fftn(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"fftn: execution_space cannot access data in OutViewType");

plan.template good<ExecutionSpace, InViewType, OutViewType>(
in, out, KokkosFFT::Direction::forward, axes);

InViewType _in;
shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (s != zeros) {
auto modified_shape = KokkosFFT::Impl::get_modified_shape(in, s, axes);
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

plan.template good<ExecutionSpace, InViewType, OutViewType>(
_in, out, KokkosFFT::Direction::forward, axes);

if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -1624,25 +1552,17 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"ifftn: execution_space cannot access data in OutViewType");

using out_value_type = typename OutViewType::non_const_value_type;
KokkosFFT::Impl::Plan plan(exec_space, in, out,
KokkosFFT::Direction::backward, axes, s);

InViewType _in;
shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (s != zeros) {
bool is_C2R = std::is_floating_point<out_value_type>::value;
auto modified_shape =
KokkosFFT::Impl::get_modified_shape(in, s, axes, is_C2R);
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

KokkosFFT::Impl::Plan plan(exec_space, _in, out,
KokkosFFT::Direction::backward, axes);
if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down Expand Up @@ -1698,26 +1618,17 @@ void ifftn(const ExecutionSpace& exec_space, const InViewType& in,
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"ifftn: execution_space cannot access data in OutViewType");

using out_value_type = typename OutViewType::non_const_value_type;
plan.template good<ExecutionSpace, InViewType, OutViewType>(
in, out, KokkosFFT::Direction::backward, axes);

InViewType _in;
shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (s != zeros) {
bool is_C2R = std::is_floating_point<out_value_type>::value;
auto modified_shape =
KokkosFFT::Impl::get_modified_shape(in, s, axes, is_C2R);
if (KokkosFFT::Impl::is_crop_or_pad_needed(in, modified_shape)) {
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, modified_shape);
} else {
_in = in;
}
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in, new_shape);
} else {
_in = in;
}

plan.template good<ExecutionSpace, InViewType, OutViewType>(
_in, out, KokkosFFT::Direction::backward, axes);

if (plan.is_transpose_needed()) {
InViewType in_T;
OutViewType out_T;
Expand Down

0 comments on commit 4f29558

Please sign in to comment.