Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fft helper functions #17

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions fft/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#ifndef KOKKOSFFT_HELPERS_HPP
#define KOKKOSFFT_HELPERS_HPP

#include <Kokkos_Core.hpp>
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_utils.hpp"

namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM=1>
auto _get_shift(const ViewType& inout, axis_type<DIM> _axes, int direction=1) {
static_assert(DIM > 0,
"KokkosFFT::Impl::_get_shift: Rank of shift axes must be larger than or equal to 1.");

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
for(std::size_t i=0; i<DIM; i++) {
int axis = KokkosFFT::Impl::convert_negative_axis(inout, _axes.at(i));
axes.push_back(axis);
}

// Assert if the elements are overlapped
constexpr int rank = ViewType::rank();
assert( ! KokkosFFT::Impl::has_duplicate_values(axes) );
assert( ! KokkosFFT::Impl::is_out_of_range_value_included(axes, rank) );

axis_type<rank> shift = {0};
for(int i=0; i<DIM; i++) {
int axis = axes.at(i);
shift.at(axis) = inout.extent(axis) / 2 * direction;
}
return shift;
}

template <typename ExecutionSpace, typename ViewType>
void _roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<1> shift, axis_type<1> axes) {
static_assert(ViewType::rank() == 1,
"KokkosFFT::Impl::_roll: Rank of View must be 1.");
std::size_t n0 = inout.extent(0);

ViewType tmp("tmp", n0);
std::size_t len = (n0-1) / 2 + 1;

auto [_shift0, _shift1, _shift2] = KokkosFFT::Impl::convert_negative_shift(inout, shift.at(0), 0);
int shift0 = _shift0, shift1 = _shift1, shift2 = _shift2;

// shift2 == 0 means shift
if(shift2 == 0) {
Kokkos::parallel_for(
Kokkos::RangePolicy<ExecutionSpace, Kokkos::IndexType<std::size_t>>(exec_space, 0, len),
KOKKOS_LAMBDA(const int& i) {
tmp(i+shift0) = inout(i);
if(i+shift1<n0) {
tmp(i) = inout(i+shift1);
}
}
);

inout = tmp;
}
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM1=1>
void _roll(const ExecutionSpace& exec_space, ViewType& inout, axis_type<2> shift, axis_type<DIM1> axes) {
constexpr std::size_t DIM0 = 2;
static_assert(ViewType::rank() == DIM0,
"KokkosFFT::Impl::_roll: Rank of View must be 2.");
int n0 = inout.extent(0), n1 = inout.extent(1);

ViewType tmp("tmp", n0, n1);
int len0 = (n0-1) / 2 + 1;
int len1 = (n1-1) / 2 + 1;

using range_type = Kokkos::MDRangePolicy<ExecutionSpace, Kokkos::Rank<2, Kokkos::Iterate::Default, Kokkos::Iterate::Default> >;
using tile_type = typename range_type::tile_type;
using point_type = typename range_type::point_type;

range_type range(
point_type{{0, 0}},
point_type{{len0, len1}},
tile_type{{4, 4}} // [TO DO] Choose optimal tile sizes for each device
);

axis_type<2> shift0 = {0}, shift1 = {0}, shift2 = {n0/2, n1/2};
for(int i=0; i<DIM1; i++) {
int axis = axes.at(i);

auto [_shift0, _shift1, _shift2] = KokkosFFT::Impl::convert_negative_shift(inout, shift.at(axis), axis);
shift0.at(axis) = _shift0;
shift1.at(axis) = _shift1;
shift2.at(axis) = _shift2;
}

int shift_00 = shift0.at(0), shift_10 = shift0.at(1);
int shift_01 = shift1.at(0), shift_11 = shift1.at(1);
int shift_02 = shift2.at(0), shift_12 = shift2.at(1);

Kokkos::parallel_for(range,
KOKKOS_LAMBDA (int i0, int i1) {
if(i0+shift_00<n0 && i1+shift_10<n1) {
tmp(i0+shift_00, i1+shift_10) = inout(i0, i1);
}

if(i0+shift_01<n0 && i1+shift_11<n1) {
tmp(i0, i1) = inout(i0+shift_01, i1+shift_11);
}

if(i0+shift_01<n0 && i1+shift_10<n1) {
tmp(i0+shift_02, i1+shift_10+shift_12) = inout(i0+shift_01+shift_02, i1+shift_12);
}

if(i0+shift_00<n0 && i1+shift_11<n1) {
tmp(i0+shift_00+shift_02, i1+shift_12) = inout(i0+shift_02, i1+shift_11+shift_12);
}
}
);

inout = tmp;
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM=1>
void _fftshift(const ExecutionSpace& exec_space, ViewType& inout, axis_type<DIM> axes) {
static_assert(ViewType::rank() >= DIM,
"KokkosFFT::Impl::_fftshift: Rank of View must be larger thane or equal to the Rank of shift axes.");
auto shift = _get_shift(inout, axes);
_roll(exec_space, inout, shift, axes);
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM=1>
void _ifftshift(const ExecutionSpace& exec_space, ViewType& inout, axis_type<DIM> axes) {
static_assert(ViewType::rank() >= DIM,
"KokkosFFT::Impl::_ifftshift: Rank of View must be larger thane or equal to the Rank of shift axes.");
auto shift = _get_shift(inout, axes, -1);
_roll(exec_space, inout, shift, axes);
}
} // namespace Impl
} // namespace KokkosFFT

namespace KokkosFFT {
template <typename ExecutionSpace, typename RealType>
auto fftfreq(const ExecutionSpace& exec_space,const std::size_t n, const RealType d = 1.0) {
static_assert(std::is_floating_point<RealType>::value,
"KokkosFFT::fftfreq: d must be real");
using ViewType = Kokkos::View<RealType*, ExecutionSpace>;
ViewType freq("freq", n);

RealType val = 1.0 / ( static_cast<RealType>(n) * d );
int N1 = (n - 1) / 2 + 1;
int N2 = n/2;

auto h_freq = Kokkos::create_mirror_view(freq);

auto p1 = KokkosFFT::Impl::arange(0, N1);
auto p2 = KokkosFFT::Impl::arange(-N2, 0);

for(int i=0; i<N1; i++) { h_freq(i) = static_cast<RealType>( p1.at(i) ) * val; }
for(int i=0; i<N2; i++) { h_freq(i+N1) = static_cast<RealType>( p2.at(i) ) * val; }
Kokkos::deep_copy(freq, h_freq);

return freq;
}

template <typename ExecutionSpace, typename RealType>
auto rfftfreq(const ExecutionSpace& exec_space,const std::size_t n, const RealType d = 1.0) {
static_assert(std::is_floating_point<RealType>::value,
"KokkosFFT::fftfreq: d must be real");
using ViewType = Kokkos::View<RealType*, ExecutionSpace>;
ViewType freq("freq", n);

RealType val = 1.0 / ( static_cast<RealType>(n) * d );
int N = n/2 + 1;

auto h_freq = Kokkos::create_mirror_view(freq);
auto p = KokkosFFT::Impl::arange(0, N);

for(int i=0; i<N; i++) { h_freq(i) = static_cast<RealType>( p.at(i) ) * val; }
Kokkos::deep_copy(freq, h_freq);

return freq;
}

template <typename ExecutionSpace, typename ViewType>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout) {
constexpr std::size_t rank = ViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = KokkosFFT::Impl::index_sequence<rank>(start);
KokkosFFT::Impl::_fftshift(exec_space, inout, axes);
}

template <typename ExecutionSpace, typename ViewType>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout, int axes) {
KokkosFFT::Impl::_fftshift(exec_space, inout, axis_type<1>{axes});
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM=1>
void fftshift(const ExecutionSpace& exec_space, ViewType& inout, axis_type<DIM> axes) {
KokkosFFT::Impl::_fftshift(exec_space, inout, axes);
}

template <typename ExecutionSpace, typename ViewType>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout) {
constexpr std::size_t rank = ViewType::rank();
constexpr int start = -static_cast<int>(rank);
axis_type<rank> axes = KokkosFFT::Impl::index_sequence<rank>(start);
KokkosFFT::Impl::_ifftshift(exec_space, inout, axes);
}

template <typename ExecutionSpace, typename ViewType>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout, int axes) {
KokkosFFT::Impl::_ifftshift(exec_space, inout, axis_type<1>{axes});
}

template <typename ExecutionSpace, typename ViewType, std::size_t DIM=1>
void ifftshift(const ExecutionSpace& exec_space, ViewType& inout, axis_type<DIM> axes) {
KokkosFFT::Impl::_ifftshift(exec_space, inout, axes);
}
} // namespace KokkosFFT

#endif
1 change: 1 addition & 0 deletions fft/unit_test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ add_executable(unit-tests-kokkos-fft-core
Test_Main.cpp
Test_Plans.cpp
Test_Transform.cpp
Test_Helpers.cpp
)

target_compile_features(unit-tests-kokkos-fft-core PUBLIC cxx_std_17)
392 changes: 392 additions & 0 deletions fft/unit_test/Test_Helpers.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,392 @@
#include <gtest/gtest.h>
#include <vector>
#include <Kokkos_Random.hpp>
#include "KokkosFFT_Helpers.hpp"
#include "Test_Types.hpp"
#include "Test_Utils.hpp"

template <std::size_t DIM>
using axes_type = std::array<int, DIM>;

using test_types = ::testing::Types<
std::pair<float, Kokkos::LayoutLeft>,
std::pair<float, Kokkos::LayoutRight>,
std::pair<double, Kokkos::LayoutLeft>,
std::pair<double, Kokkos::LayoutRight>
>;

// Basically the same fixtures, used for labeling tests
template <typename T>
struct FFTHelper : public ::testing::Test {
using float_type = typename T::first_type;
using layout_type = typename T::second_type;
};

TYPED_TEST_SUITE(FFTHelper, test_types);

// Tests for FFT Freq
template <typename T, typename LayoutType>
void test_fft_freq(T atol=1.0e-12) {
constexpr std::size_t n_odd = 9, n_even = 10;
using RealView1DType = Kokkos::View<T*, LayoutType, execution_space>;
RealView1DType x_odd_ref("x_odd_ref", n_odd), x_even_ref("x_even_ref", n_even);

auto h_x_odd_ref = Kokkos::create_mirror_view(x_odd_ref);
auto h_x_even_ref = Kokkos::create_mirror_view(x_even_ref);

std::vector<int> _x_odd_ref = {0, 1, 2, 3, 4, -4, -3, -2, -1};
std::vector<int> _x_even_ref = {0, 1, 2, 3, 4, -5, -4, -3, -2, -1};

for(std::size_t i=0; i<_x_odd_ref.size(); i++) {
h_x_odd_ref(i) = static_cast<T>( _x_odd_ref.at(i) );
}

for(std::size_t i=0; i<_x_even_ref.size(); i++) {
h_x_even_ref(i) = static_cast<T>( _x_even_ref.at(i) );
}

Kokkos::deep_copy(x_odd_ref, h_x_odd_ref);
Kokkos::deep_copy(x_even_ref, h_x_even_ref);
T pi = static_cast<T>(M_PI);
auto x_odd = KokkosFFT::fftfreq<execution_space, T>(execution_space(), n_odd);
auto x_odd_pi = KokkosFFT::fftfreq<execution_space, T>(execution_space(), n_odd, pi);
multiply(x_odd, static_cast<T>(n_odd));
multiply(x_odd_pi, static_cast<T>(n_odd)*pi);

EXPECT_TRUE( allclose(x_odd, x_odd_ref, 1.e-5, atol) );
EXPECT_TRUE( allclose(x_odd_pi, x_odd_ref, 1.e-5, atol) );

auto x_even = KokkosFFT::fftfreq<execution_space, T>(execution_space(), n_even);
auto x_even_pi = KokkosFFT::fftfreq<execution_space, T>(execution_space(), n_even, pi);
multiply(x_even, static_cast<T>(n_even));
multiply(x_even_pi, static_cast<T>(n_even) * pi);

EXPECT_TRUE( allclose(x_even, x_even_ref, 1.e-5, atol) );
EXPECT_TRUE( allclose(x_even_pi, x_even_ref, 1.e-5, atol) );
}

// Tests for RFFT Freq
template <typename T, typename LayoutType>
void test_rfft_freq(T atol=1.0e-12) {
constexpr std::size_t n_odd = 9, n_even = 10;
using RealView1DType = Kokkos::View<T*, LayoutType, execution_space>;
RealView1DType x_odd_ref("x_odd_ref", n_odd/2), x_even_ref("x_even_ref", n_even/2);

auto h_x_odd_ref = Kokkos::create_mirror_view(x_odd_ref);
auto h_x_even_ref = Kokkos::create_mirror_view(x_even_ref);

std::vector<int> _x_odd_ref = {0, 1, 2, 3, 4};
std::vector<int> _x_even_ref = {0, 1, 2, 3, 4, 5};

for(std::size_t i=0; i<_x_odd_ref.size(); i++) {
h_x_odd_ref(i) = static_cast<T>( _x_odd_ref.at(i) );
}

for(std::size_t i=0; i<_x_even_ref.size(); i++) {
h_x_even_ref(i) = static_cast<T>( _x_even_ref.at(i) );
}

Kokkos::deep_copy(x_odd_ref, h_x_odd_ref);
Kokkos::deep_copy(x_even_ref, h_x_even_ref);
T pi = static_cast<T>(M_PI);
auto x_odd = KokkosFFT::rfftfreq<execution_space, T>(execution_space(), n_odd);
auto x_odd_pi = KokkosFFT::rfftfreq<execution_space, T>(execution_space(), n_odd, pi);
multiply(x_odd, static_cast<T>(n_odd));
multiply(x_odd_pi, static_cast<T>(n_odd)*pi);

EXPECT_TRUE( allclose(x_odd, x_odd_ref, 1.e-5, atol) );
EXPECT_TRUE( allclose(x_odd_pi, x_odd_ref, 1.e-5, atol) );

auto x_even = KokkosFFT::rfftfreq<execution_space, T>(execution_space(), n_even);
auto x_even_pi = KokkosFFT::rfftfreq<execution_space, T>(execution_space(), n_even, pi);
multiply(x_even, static_cast<T>(n_even));
multiply(x_even_pi, static_cast<T>(n_even) * pi);

EXPECT_TRUE( allclose(x_even, x_even_ref, 1.e-5, atol) );
EXPECT_TRUE( allclose(x_even_pi, x_even_ref, 1.e-5, atol) );
}

// Tests for fftfreq
TYPED_TEST(FFTHelper, fftfreq) {
using float_type = typename TestFixture::float_type;
using layout_type = typename TestFixture::layout_type;

float_type atol = std::is_same_v<float_type, float> ? 1.0e-6 : 1.0e-12;
test_fft_freq<float_type, layout_type>(atol);
}

// Tests for rfftfreq
TYPED_TEST(FFTHelper, rfftfreq) {
using float_type = typename TestFixture::float_type;
using layout_type = typename TestFixture::layout_type;

float_type atol = std::is_same_v<float_type, float> ? 1.0e-6 : 1.0e-12;
test_rfft_freq<float_type, layout_type>(atol);
}

// Tests for get shift
void test_get_shift(int direction) {
constexpr int n_odd = 9, n_even = 10, n2 = 8;
using RealView1DType = Kokkos::View<double*, execution_space>;
using RealView2DType = Kokkos::View<double**, execution_space>;
RealView1DType x1_odd("x1_odd", n_odd), x1_even("x1_even", n_even);
RealView2DType x2_odd("x2_odd", n_odd, n2), x2_even("x2_even", n_even, n2);

KokkosFFT::axis_type<1> shift1_odd_ref = {direction * n_odd/2};
KokkosFFT::axis_type<1> shift1_even_ref = {direction * n_even/2};
KokkosFFT::axis_type<2> shift1_axis0_odd_ref = {direction * n_odd/2, 0};
KokkosFFT::axis_type<2> shift1_axis0_even_ref = {direction * n_even/2, 0};
KokkosFFT::axis_type<2> shift1_axis1_odd_ref = {0, direction * n2/2};
KokkosFFT::axis_type<2> shift1_axis1_even_ref = {0, direction * n2/2};
KokkosFFT::axis_type<2> shift2_odd_ref = {direction * n_odd/2, direction * n2/2};
KokkosFFT::axis_type<2> shift2_even_ref = {direction * n_even/2, direction * n2/2};

auto shift1_odd = KokkosFFT::Impl::_get_shift(x1_odd, KokkosFFT::axis_type<1>({0}), direction);
auto shift1_even = KokkosFFT::Impl::_get_shift(x1_even, KokkosFFT::axis_type<1>({0}), direction);
auto shift1_axis0_odd = KokkosFFT::Impl::_get_shift(x2_odd, KokkosFFT::axis_type<1>({0}), direction);
auto shift1_axis0_even = KokkosFFT::Impl::_get_shift(x2_even, KokkosFFT::axis_type<1>({0}), direction);
auto shift1_axis1_odd = KokkosFFT::Impl::_get_shift(x2_odd, KokkosFFT::axis_type<1>({1}), direction);
auto shift1_axis1_even = KokkosFFT::Impl::_get_shift(x2_even, KokkosFFT::axis_type<1>({1}), direction);
auto shift2_odd = KokkosFFT::Impl::_get_shift(x2_odd, KokkosFFT::axis_type<2>({0, 1}), direction);
auto shift2_even = KokkosFFT::Impl::_get_shift(x2_even, KokkosFFT::axis_type<2>({0, 1}), direction);

EXPECT_TRUE( shift1_odd == shift1_odd_ref );
EXPECT_TRUE( shift1_even == shift1_even_ref );
EXPECT_TRUE( shift1_axis0_odd == shift1_axis0_odd_ref );
EXPECT_TRUE( shift1_axis0_even == shift1_axis0_even_ref );
EXPECT_TRUE( shift1_axis1_odd == shift1_axis1_odd_ref );
EXPECT_TRUE( shift1_axis1_even == shift1_axis1_even_ref );
EXPECT_TRUE( shift2_odd == shift2_odd_ref );
EXPECT_TRUE( shift2_even == shift2_even_ref );
}

class GetShiftParamTests: public ::testing::TestWithParam<int> {};

TEST_P(GetShiftParamTests, ForwardAndInverse) {
int direction = GetParam();
test_get_shift(direction);
}

INSTANTIATE_TEST_SUITE_P(
GetShift,
GetShiftParamTests,
::testing::Values(
1, -1
)
);

// Identity Tests for fftshift1D on 1D View
void test_fftshift1D_1DView_identity(int n0) {
using RealView1DType = Kokkos::View<double*, execution_space>;

RealView1DType x("x", n0), x_ref("x_ref", n0);

Kokkos::Random_XorShift64_Pool<> random_pool(/*seed=*/12345);
Kokkos::fill_random(x, random_pool, 1.0);
Kokkos::deep_copy(x_ref, x);

Kokkos::fence();

KokkosFFT::fftshift(execution_space(), x);
KokkosFFT::ifftshift(execution_space(), x);

EXPECT_TRUE( allclose(x, x_ref, 1.e-5, 1.e-12) );
}

// Tests for fftshift1D on 1D View
void test_fftshift1D_1DView(int n0) {
using RealView1DType = Kokkos::View<double*, execution_space>;
RealView1DType x("x", n0), y("y", n0);
RealView1DType x_ref("x_ref", n0), y_ref("y_ref", n0);

auto h_x_ref = Kokkos::create_mirror_view(x_ref);
auto h_y_ref = Kokkos::create_mirror_view(y_ref);

std::vector<int> _x_ref;
std::vector<int> _y_ref;

if(n0 % 2 == 0) {
_x_ref = {0, 1, 2, 3, 4, -5, -4, -3, -2, -1};
_y_ref = {-5, -4, -3, -2, -1, 0, 1, 2, 3, 4};
} else {
_x_ref = {0, 1, 2, 3, 4, -4, -3, -2, -1};
_y_ref = {-4, -3, -2, -1, 0, 1, 2, 3, 4};
}

for(std::size_t i=0; i<n0; i++) {
h_x_ref(i) = static_cast<double>( _x_ref.at(i) );
h_y_ref(i) = static_cast<double>( _y_ref.at(i) );
}

Kokkos::deep_copy(x_ref, h_x_ref);
Kokkos::deep_copy(y_ref, h_y_ref);
Kokkos::deep_copy(x, h_x_ref);
Kokkos::deep_copy(y, h_y_ref);

KokkosFFT::fftshift(execution_space(), x);
KokkosFFT::ifftshift(execution_space(), y);

EXPECT_TRUE( allclose(x, y_ref) );
EXPECT_TRUE( allclose(y, x_ref) );
}

// Tests for fftshift1D on 2D View
void test_fftshift1D_2DView(int n0) {
using RealView2DType = Kokkos::View<double**, Kokkos::LayoutLeft, execution_space>;
constexpr int n1 = 3;
RealView2DType x("x", n0, n1), y_axis0("y_axis0", n0, n1), y_axis1("y_axis1", n0, n1);
RealView2DType x_ref("x_ref", n0, n1);
RealView2DType y_axis0_ref("y_axis0_ref", n0, n1), y_axis1_ref("y_axis1_ref", n0, n1);

auto h_x_ref = Kokkos::create_mirror_view(x_ref);
auto h_y_axis0_ref = Kokkos::create_mirror_view(y_axis0_ref);
auto h_y_axis1_ref = Kokkos::create_mirror_view(y_axis1_ref);

std::vector<int> _x_ref;
std::vector<int> _y0_ref, _y1_ref;

if(n0 % 2 == 0) {
_x_ref = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, -15, -14, -13, -12, -11,
-10, -9, -8, -7, -6, -5, -4, -3, -2, -1
};
_y0_ref = {5, 6, 7, 8, 9, 0, 1, 2, 3, 4,
-15, -14, -13, -12, -11, 10, 11, 12, 13, 14,
-5, -4, -3, -2, -1, -10, -9, -8, -7, -6,
};
_y1_ref = {-10, -9, -8, -7, -6, -5, -4, -3, -2, -1,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, -15, -14, -13, -12, -11
};
} else {
_x_ref = {0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, -13, -12, -11, -10,
-9, -8, -7, -6, -5, -4, -3, -2, -1
};
_y0_ref = {5, 6, 7, 8, 0, 1, 2, 3, 4,
-13, -12, -11, -10, 9, 10, 11, 12, 13,
-4, -3, -2, -1, -9, -8, -7, -6, -5
};
_y1_ref = {-9, -8, -7, -6, -5, -4, -3, -2, -1,
0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, -13, -12, -11, -10
};
}

for(std::size_t i1=0; i1<n1; i1++) {
for(std::size_t i0=0; i0<n0; i0++) {
std::size_t i = i0 + i1 * n0;
h_x_ref(i0, i1) = static_cast<double>( _x_ref.at(i) );
h_y_axis0_ref(i0, i1) = static_cast<double>( _y0_ref.at(i) );
h_y_axis1_ref(i0, i1) = static_cast<double>( _y1_ref.at(i) );
}
}

Kokkos::deep_copy(x_ref, h_x_ref);
Kokkos::deep_copy(y_axis0_ref, h_y_axis0_ref);
Kokkos::deep_copy(y_axis1_ref, h_y_axis1_ref);
Kokkos::deep_copy(x, h_x_ref);
Kokkos::deep_copy(y_axis0, h_y_axis0_ref);
Kokkos::deep_copy(y_axis1, h_y_axis1_ref);

KokkosFFT::fftshift(execution_space(), x, axes_type<1>({0}));
KokkosFFT::ifftshift(execution_space(), y_axis0, axes_type<1>({0}));

EXPECT_TRUE( allclose(x, y_axis0_ref) );
EXPECT_TRUE( allclose(y_axis0, x_ref) );

Kokkos::deep_copy(x, h_x_ref);

KokkosFFT::fftshift(execution_space(), x, axes_type<1>({1}));
KokkosFFT::ifftshift(execution_space(), y_axis1, axes_type<1>({1}));

EXPECT_TRUE( allclose(x, y_axis1_ref) );
EXPECT_TRUE( allclose(y_axis1, x_ref) );
}

// Tests for fftshift2D on 2D View
void test_fftshift2D_2DView(int n0) {
using RealView2DType = Kokkos::View<double**, Kokkos::LayoutLeft, execution_space>;
constexpr int n1 = 3;
RealView2DType x("x", n0, n1), y("y", n0, n1);
RealView2DType x_ref("x_ref", n0, n1), y_ref("y_ref", n0, n1);

auto h_x_ref = Kokkos::create_mirror_view(x_ref);
auto h_y_ref = Kokkos::create_mirror_view(y_ref);

std::vector<int> _x_ref;
std::vector<int> _y_ref;

if(n0 % 2 == 0) {
_x_ref = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, -15, -14, -13, -12, -11,
-10, -9, -8, -7, -6, -5, -4, -3, -2, -1
};
_y_ref = {-5, -4, -3, -2, -1, -10, -9, -8, -7, -6,
5, 6, 7, 8, 9, 0, 1, 2, 3, 4,
-15, -14, -13, -12, -11, 10, 11, 12, 13, 14
};
} else {
_x_ref = {0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, -13, -12, -11, -10,
-9, -8, -7, -6, -5, -4, -3, -2, -1
};
_y_ref = {-4, -3, -2, -1, -9, -8, -7, -6, -5,
5, 6, 7, 8, 0, 1, 2, 3, 4,
-13, -12, -11, -10, 9, 10, 11, 12, 13
};
}

for(std::size_t i1=0; i1<n1; i1++) {
for(std::size_t i0=0; i0<n0; i0++) {
std::size_t i = i0 + i1 * n0;
h_x_ref(i0, i1) = static_cast<double>( _x_ref.at(i) );
h_y_ref(i0, i1) = static_cast<double>( _y_ref.at(i) );
}
}

Kokkos::deep_copy(x_ref, h_x_ref);
Kokkos::deep_copy(y_ref, h_y_ref);
Kokkos::deep_copy(x, h_x_ref);
Kokkos::deep_copy(y, h_y_ref);

KokkosFFT::fftshift(execution_space(), x, axes_type<2>({0, 1}));
KokkosFFT::ifftshift(execution_space(), y, axes_type<2>({0, 1}));

EXPECT_TRUE( allclose(x, y_ref) );
EXPECT_TRUE( allclose(y, x_ref) );
}

class FFTShiftParamTests: public ::testing::TestWithParam<int> {};

// Identity Tests for fftshift1D on 1D View
TEST_P(FFTShiftParamTests, Identity) {
int n0 = GetParam();
test_fftshift1D_1DView_identity(n0);
}

// Tests for fftshift1D on 1D View
TEST_P(FFTShiftParamTests, 1DShift1DView) {
int n0 = GetParam();
test_fftshift1D_1DView(n0);
}

// Tests for fftshift1D on 2D View
TEST_P(FFTShiftParamTests, 1DShift2DView) {
int n0 = GetParam();
test_fftshift1D_2DView(n0);
}

// Tests for fftshift2D on 2D View
TEST_P(FFTShiftParamTests, 2DShift2DView) {
int n0 = GetParam();
test_fftshift2D_2DView(n0);
}

INSTANTIATE_TEST_SUITE_P(
FFTShift,
FFTShiftParamTests,
::testing::Values(
9, 10
)
);