Skip to content

Commit

Permalink
Apply check functions to common functions (#129)
Browse files Browse the repository at this point in the history
* fix: conflicts

* format

* ix: based on reviews

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Sep 6, 2024
1 parent a48cf5a commit 0dfbae8
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 192 deletions.
41 changes: 15 additions & 26 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,8 @@ auto get_extents(const InViewType& in, const OutViewType& out,
using out_value_type = typename OutViewType::non_const_value_type;
using array_layout_type = typename InViewType::array_layout;

static_assert(InViewType::rank() >= DIM,
"KokkosFFT::get_map_axes: Rank of View must be larger thane or "
"equal to the Rank of FFT axes.");
static_assert(DIM > 0,
"KokkosFFT::get_map_axes: Rank of FFT axes must be larger than "
"or equal to 1.");
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");

constexpr std::size_t rank = InViewType::rank;
[[maybe_unused]] int inner_most_axis =
Expand Down Expand Up @@ -64,32 +60,25 @@ auto get_extents(const InViewType& in, const OutViewType& out,
_fft_extents.push_back(fft_extent);
}

static_assert(!(is_real_v<in_value_type> && is_real_v<out_value_type>),
"get_extents: real to real transform is not supported");

if (is_real_v<in_value_type>) {
// Then R2C
if (is_complex_v<out_value_type>) {
KOKKOSFFT_EXPECTS(
_out_extents.at(inner_most_axis) ==
_in_extents.at(inner_most_axis) / 2 + 1,
"For R2C, the 'output extent' of transform must be equal to "
"'input extent'/2 + 1");
} else {
throw std::runtime_error(
"If the input type is real, the output type should be complex");
}
KOKKOSFFT_EXPECTS(
_out_extents.at(inner_most_axis) ==
_in_extents.at(inner_most_axis) / 2 + 1,
"For R2C, the 'output extent' of transform must be equal to "
"'input extent'/2 + 1");
}

if (is_real_v<out_value_type>) {
// Then C2R
if (is_complex_v<in_value_type>) {
KOKKOSFFT_EXPECTS(
_in_extents.at(inner_most_axis) ==
_out_extents.at(inner_most_axis) / 2 + 1,
"For C2R, the 'input extent' of transform must be equal to "
"'output extent' / 2 + 1");
} else {
throw std::runtime_error(
"If the output type is real, the input type should be complex");
}
KOKKOSFFT_EXPECTS(
_in_extents.at(inner_most_axis) ==
_out_extents.at(inner_most_axis) / 2 + 1,
"For C2R, the 'input extent' of transform must be equal to "
"'output extent' / 2 + 1");
}

if (std::is_same_v<array_layout_type, Kokkos::LayoutLeft>) {
Expand Down
5 changes: 5 additions & 0 deletions common/src/KokkosFFT_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ template <typename ExecutionSpace, typename ViewType>
void normalize(const ExecutionSpace& exec_space, ViewType& inout,
Direction direction, Normalization normalization,
std::size_t fft_size) {
static_assert(KokkosFFT::Impl::is_operatable_view_v<ExecutionSpace, ViewType>,
"normalize: View value type must be float, double, "
"Kokkos::Complex<float>, or Kokkos::Complex<double>. "
"Layout must be either LayoutLeft or LayoutRight. "
"ExecutionSpace must be able to access data in ViewType");
auto [coef, to_normalize] =
get_coefficients(inout, direction, normalization, fft_size);
if (to_normalize) normalize_impl(exec_space, inout, coef);
Expand Down
38 changes: 14 additions & 24 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,11 @@ namespace Impl {
template <typename InViewType, typename OutViewType, std::size_t DIM>
auto get_modified_shape(const InViewType in, const OutViewType /* out */,
shape_type<DIM> shape, axis_type<DIM> axes) {
static_assert(InViewType::rank() >= DIM,
"get_modified_shape: Rank of Input View must be larger "
"than or equal to the Rank of new shape");
static_assert(OutViewType::rank() >= DIM,
"get_modified_shape: Rank of Output View must be larger "
"than or equal to the Rank of new shape");
static_assert(DIM > 0,
"get_modified_shape: Rank of FFT axes must be "
"larger than or equal to 1");
constexpr int rank = static_cast<int>(InViewType::rank());
static_assert(
KokkosFFT::Impl::have_same_rank_v<InViewType, OutViewType>,
"get_modified_shape: Input View and Output View must have the same rank");
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(in, axes),
"input axes are not valid for the view");

shape_type<DIM> zeros = {0}; // default shape means no crop or pad
if (shape == zeros) {
Expand All @@ -50,14 +45,7 @@ auto get_modified_shape(const InViewType in, const OutViewType /* out */,
positive_axes.push_back(axis);
}

// Assert if the elements are overlapped
KOKKOSFFT_EXPECTS(!KokkosFFT::Impl::has_duplicate_values(positive_axes),
"Axes overlap");
KOKKOSFFT_EXPECTS(
!KokkosFFT::Impl::is_out_of_range_value_included(positive_axes, rank),
"Axes include an out-of-range index."
"Axes must be in the range of [-rank, rank-1].");

constexpr int rank = static_cast<int>(InViewType::rank());
using full_shape_type = shape_type<rank>;
full_shape_type modified_shape;
for (int i = 0; i < rank; i++) {
Expand Down Expand Up @@ -346,12 +334,14 @@ template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void crop_or_pad(const ExecutionSpace& exec_space, const InViewType& in,
OutViewType& out, shape_type<DIM> s) {
static_assert(InViewType::rank() == DIM,
"crop_or_pad: Rank of View must be equal to Rank "
"of extended shape.");
static_assert(OutViewType::rank() == DIM,
"crop_or_pad: Rank of View must be equal to Rank "
"of extended shape.");
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"crop_or_pad: InViewType and OutViewType must have the same base "
"floating point "
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");
crop_or_pad_impl(exec_space, in, out, s);
}
} // namespace Impl
Expand Down
38 changes: 14 additions & 24 deletions common/src/KokkosFFT_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,8 @@ namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_map_axes(const ViewType& view, axis_type<DIM> _axes) {
static_assert(ViewType::rank() >= DIM,
"get_map_axes: Rank of View must be larger thane or "
"equal to the Rank of FFT axes.");
static_assert(DIM > 0,
"get_map_axes: Rank of FFT axes must be larger than "
"or equal to 1.");

constexpr int rank = static_cast<int>(ViewType::rank());
using array_layout_type = typename ViewType::array_layout;
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::are_valid_axes(view, _axes),
"get_map_axes: input axes are not valid for the view");

// Convert the input axes to be in the range of [0, rank-1]
std::vector<int> axes;
Expand All @@ -31,16 +24,14 @@ auto get_map_axes(const ViewType& view, axis_type<DIM> _axes) {
axes.push_back(axis);
}

// Assert if the elements are overlapped
assert(!KokkosFFT::Impl::has_duplicate_values(axes));

// how indices are map
// For 5D View and axes are (2,3), map would be (0, 1, 4, 2, 3)
constexpr int rank = static_cast<int>(ViewType::rank());
std::vector<int> map, map_inv;
map.reserve(rank);
map_inv.reserve(rank);

if (std::is_same_v<array_layout_type, Kokkos::LayoutRight>) {
if (std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>) {
// Stack axes not specified by axes (0, 1, 4)
for (int i = 0; i < rank; i++) {
if (!is_found(axes, i)) {
Expand Down Expand Up @@ -396,22 +387,21 @@ template <typename ExecutionSpace, typename InViewType, typename OutViewType,
std::size_t DIM = 1>
void transpose(const ExecutionSpace& exec_space, InViewType& in,
OutViewType& out, axis_type<DIM> map) {
static_assert(Kokkos::is_view<InViewType>::value,
"transpose: InViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<InViewType>::value,
"transpose: OutViewType is not a Kokkos::View.");

static_assert(InViewType::rank() == OutViewType::rank(),
"transpose: InViewType and OutViewType must have "
"the same rank.");
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"transpose: InViewType and OutViewType must have the same base floating "
"point "
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");

static_assert(InViewType::rank() == DIM,
"transpose: Rank of View must be equal to Rank of "
"transpose axes.");

if (!KokkosFFT::Impl::is_transpose_needed(map)) {
throw std::runtime_error("transpose: transpose not necessary");
}
KOKKOSFFT_EXPECTS(KokkosFFT::Impl::is_transpose_needed(map),
"transpose: transpose not necessary");

// in order not to call transpose_impl for 1D case
if constexpr (DIM > 1) {
Expand Down
Loading

0 comments on commit 0dfbae8

Please sign in to comment.