Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve check functions #125

Merged
merged 6 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions common/src/KokkosFFT_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ inline constexpr bool are_operatable_views_v =

// Other traits

template <typename ContainerType>
struct base_container_value;

template <template <typename, typename...> class ContainerType,
typename ValueType, typename... Args>
struct base_container_value<ContainerType<ValueType, Args...>> {
using value_type = ValueType;
};

// Specialization for std::array
template <typename ValueType, std::size_t N>
struct base_container_value<std::array<ValueType, N>> {
using value_type = ValueType;
};

// Specialization for Kokkos::Array
template <typename ValueType, std::size_t N>
struct base_container_value<Kokkos::Array<ValueType, N>> {
using value_type = ValueType;
};

/// \brief Helper to extract the base value type from a container
template <typename T>
using base_container_value_type = typename base_container_value<T>::value_type;

/// \brief Helper to define a managable View type from the original view type
template <typename T>
struct managable_view_type {
Expand Down
72 changes: 58 additions & 14 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <algorithm>
#include <numeric>
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_common_types.hpp"

#if defined(KOKKOS_ENABLE_CXX17)
#include <cstdlib>
Expand Down Expand Up @@ -85,43 +86,86 @@ auto convert_negative_shift(const ViewType& view, int _shift, int _axis) {
return std::tuple<int, int, int>({shift0, shift1, shift2});
}

template <typename T>
bool is_found(std::vector<T>& values, const T& key) {
return std::find(values.begin(), values.end(), key) != values.end();
template <typename ContainerType, typename ValueType>
bool is_found(ContainerType& values, const ValueType& value) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, ValueType>,
"Container value type must match ValueType");
return std::find(values.begin(), values.end(), value) != values.end();
}

template <typename T>
bool has_duplicate_values(const std::vector<T>& values) {
std::set<T> set_values(values.begin(), values.end());
template <typename ContainerType>
bool has_duplicate_values(const ContainerType& values) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
std::set<value_type> set_values(values.begin(), values.end());
return set_values.size() < values.size();
}

template <typename IntType, std::enable_if_t<std::is_integral_v<IntType>,
std::nullptr_t> = nullptr>
bool is_out_of_range_value_included(const std::vector<IntType>& values,
IntType max) {
template <
typename ContainerType, typename IntType,
std::enable_if_t<std::is_integral_v<IntType>, std::nullptr_t> = nullptr>
bool is_out_of_range_value_included(const ContainerType& values, IntType max) {
yasahi-hpc marked this conversation as resolved.
Show resolved Hide resolved
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, IntType>,
"Container value type must match IntType");
bool is_included = false;
for (auto value : values) {
is_included = value >= max;
}
return is_included;
}

template <
typename ViewType, template <typename, std::size_t> class ArrayType,
typename IntType, std::size_t DIM = 1,
std::enable_if_t<std::is_integral_v<IntType>, std::nullptr_t> = nullptr>
bool are_valid_axes(const ViewType& view, const ArrayType<IntType, DIM>& axes) {
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"are_valid_axes: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"are_valid_axes: View rank must be larger than or equal to the "
"Rank of FFT axes");

// Convert the input axes to be in the range of [0, rank-1]
yasahi-hpc marked this conversation as resolved.
Show resolved Hide resolved
// int type is choosen for consistency with the rest of the code
// the axes are defined with int type
std::array<int, DIM> non_negative_axes;

// In case axis is out of range, 'convert_negative_axis' will throw an
// runtime_error and we will return false. Without runtime_error, it is
// ensured that the 'non_negative_axes' are in the range of [0, rank-1]
try {
for (std::size_t i = 0; i < DIM; i++) {
int axis = KokkosFFT::Impl::convert_negative_axis(view, axes[i]);
non_negative_axes[i] = axis;
}
yasahi-hpc marked this conversation as resolved.
Show resolved Hide resolved
} catch (std::runtime_error& e) {
return false;
}

bool is_valid = !KokkosFFT::Impl::has_duplicate_values(non_negative_axes);
return is_valid;
}

template <std::size_t DIM = 1>
bool is_transpose_needed(std::array<int, DIM> map) {
std::array<int, DIM> contiguous_map;
std::iota(contiguous_map.begin(), contiguous_map.end(), 0);
return map != contiguous_map;
}

template <typename T>
std::size_t get_index(std::vector<T>& values, const T& key) {
auto it = find(values.begin(), values.end(), key);
template <typename ContainerType, typename ValueType>
std::size_t get_index(ContainerType& values, const ValueType& value) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, ValueType>,
"Container value type must match ValueType");
auto it = std::find(values.begin(), values.end(), value);
std::size_t index = 0;
if (it != values.end()) {
index = it - values.begin();
} else {
throw std::runtime_error("key is not included in values");
throw std::runtime_error("value is not included in values");
}

return index;
Expand Down
49 changes: 49 additions & 0 deletions common/unit_test/Test_Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// All the tests in this file are compile time tests, so we skip all the tests
// by GTEST_SKIP(). gtest is used for type parameterization.

// Int like types
using base_int_types = ::testing::Types<int, std::size_t>;

// Define the types to combine
using base_real_types = std::tuple<float, double, long double>;

Expand Down Expand Up @@ -39,6 +42,19 @@ using paired_view_types =
tuple_to_types_t<cartesian_product_t<base_real_types, base_layout_types,
base_real_types, base_layout_types>>;

template <typename T>
struct ContainerTypes : public ::testing::Test {
static constexpr std::size_t rank = 3;
using value_type = T;
using vector_type = std::vector<T>;
using std_array_type = std::array<T, rank>;
using Kokkos_array_type = Kokkos::Array<T, rank>;

virtual void SetUp() {
GTEST_SKIP() << "Skipping all tests for this fixture";
}
};

template <typename T>
struct RealAndComplexTypes : public ::testing::Test {
using real_type = T;
Expand Down Expand Up @@ -91,12 +107,45 @@ struct PairedViewTypes : public ::testing::Test {
}
};

TYPED_TEST_SUITE(ContainerTypes, base_int_types);
TYPED_TEST_SUITE(RealAndComplexTypes, real_types);
TYPED_TEST_SUITE(RealAndComplexViewTypes, view_types);
TYPED_TEST_SUITE(PairedValueTypes, paired_value_types);
TYPED_TEST_SUITE(PairedLayoutTypes, paired_layout_types);
TYPED_TEST_SUITE(PairedViewTypes, paired_view_types);

// Tests for base value type deduction
template <typename ValueType, typename ContainerType>
void test_get_container_value_type() {
using value_type_ContainerType =
KokkosFFT::Impl::base_container_value_type<ContainerType>;

// base value type of ContainerType is ValueType
static_assert(std::is_same_v<value_type_ContainerType, ValueType>,
"Value type not deduced correctly from ContainerType");
}

TYPED_TEST(ContainerTypes, get_value_type_from_vector) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::vector_type;

test_get_container_value_type<value_type, container_type>();
}

TYPED_TEST(ContainerTypes, get_value_type_from_std_array) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::std_array_type;

test_get_container_value_type<value_type, container_type>();
}

TYPED_TEST(ContainerTypes, get_value_type_from_kokkos_array) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::Kokkos_array_type;

test_get_container_value_type<value_type, container_type>();
}

// Tests for real type deduction
template <typename RealType, typename ComplexType>
void test_get_real_type() {
Expand Down
Loading
Loading