Skip to content

Commit

Permalink
Merge pull request #178 from sterrettm2/kvsort-nantests
Browse files Browse the repository at this point in the history
Fix kvsort/kvselect nan behavior and added tests for mixed nan/inf arrays
  • Loading branch information
r-devulap authored Feb 20, 2025
2 parents 9427923 + 89de98d commit b27f82b
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 44 deletions.
104 changes: 85 additions & 19 deletions src/xss-common-keyvaluesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,63 @@
#include <omp.h>
#endif

/*
* Sort all the NAN's to end of the array and return the index of the last elem
* in the array which is not a nan
*/
template <typename T1, typename T2, typename vtype>
X86_SIMD_SORT_INLINE arrsize_t move_nans_to_end_of_array(T1 *keys,
T2 *vals,
arrsize_t size)
{
using reg_t = typename vtype::reg_t;

arrsize_t jj = size - 1;
arrsize_t ii = 0;
arrsize_t count = 0;

while (ii + vtype::numlanes < jj) {
reg_t in = vtype::loadu(keys + ii);
auto nanmask = vtype::convert_mask_to_int(
vtype::template fpclass<0x01 | 0x80>(in));

// Check if there are any nans in this vector, and process them if so
if (nanmask != 0x00) {
for (size_t offset = 0; offset < vtype::numlanes; offset++) {
if (is_a_nan(keys[ii])) {
std::swap(keys[ii], keys[jj]);
std::swap(vals[ii], vals[jj]);
jj -= 1;
count++;
}
else {
ii += 1;
}
}
}
else {
ii += vtype::numlanes;
}
}

// Handle the remainders once we have less than 1 vector worth
while (ii < jj) {
if (is_a_nan(keys[ii])) {
std::swap(keys[ii], keys[jj]);
std::swap(vals[ii], vals[jj]);
jj -= 1;
count++;
}
else {
ii += 1;
}
}

/* Haven't checked for nan when ii == jj */
if (is_a_nan(keys[ii])) { count++; }
return size - count - 1;
}

/*
* Parition one ZMM register based on the pivot and returns the index of the
* last element that is less than equal to the pivot.
Expand Down Expand Up @@ -529,6 +586,9 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
half_vector<T2>,
full_vector<T2>>::type;

// Exit early if no work would be done
if (arrsize <= 1) return;

#ifdef XSS_TEST_KEYVALUE_BASE_CASE
int maxiters = -1;
bool minarrsize = true;
Expand All @@ -538,11 +598,12 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
#endif // XSS_TEST_KEYVALUE_BASE_CASE

if (minarrsize) {
arrsize_t nan_count = 0;
arrsize_t index_last_elem = arrsize - 1;
if constexpr (xss::fp::is_floating_point_v<T1>) {
if (UNLIKELY(hasnan)) {
nan_count
= replace_nan_with_inf<full_vector<T1>>(keys, arrsize);
index_last_elem
= move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
keys, indexes, arrsize);
}
}
else {
Expand All @@ -565,24 +626,27 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
// Note that we do not use the if(...) clause built into OpenMP, because it causes a performance regression for small arrays
#pragma omp parallel num_threads(thread_count)
#pragma omp single
kvsort_<keytype, valtype>(
keys, indexes, 0, arrsize - 1, maxiters, task_threshold);
kvsort_<keytype, valtype>(keys,
indexes,
0,
index_last_elem,
maxiters,
task_threshold);
}
else {
kvsort_<keytype, valtype>(keys,
indexes,
0,
arrsize - 1,
index_last_elem,
maxiters,
std::numeric_limits<arrsize_t>::max());
}
#pragma omp taskwait
#else
kvsort_<keytype, valtype>(keys, indexes, 0, arrsize - 1, maxiters, 0);
kvsort_<keytype, valtype>(
keys, indexes, 0, index_last_elem, maxiters, 0);
#endif

replace_inf_with_nan(keys, arrsize, nan_count);

if (descending) {
std::reverse(keys, keys + arrsize);
std::reverse(indexes, indexes + arrsize);
Expand Down Expand Up @@ -614,6 +678,9 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
half_vector<T2>,
full_vector<T2>>::type;

// Exit early if no work would be done
if (arrsize <= 1) return;

#ifdef XSS_TEST_KEYVALUE_BASE_CASE
int maxiters = -1;
bool minarrsize = true;
Expand All @@ -625,20 +692,19 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
if (minarrsize) {
if (descending) { k = arrsize - 1 - k; }

if constexpr (std::is_floating_point_v<T1>) {
arrsize_t nan_count = 0;
arrsize_t index_last_elem = arrsize - 1;
if constexpr (xss::fp::is_floating_point_v<T1>) {
if (UNLIKELY(hasnan)) {
nan_count
= replace_nan_with_inf<full_vector<T1>>(keys, arrsize);
index_last_elem
= move_nans_to_end_of_array<T1, T2, full_vector<T1>>(
keys, indexes, arrsize);
}
kvselect_<keytype, valtype>(
keys, indexes, k, 0, arrsize - 1, maxiters);
replace_inf_with_nan(keys, arrsize, nan_count);
}
else {
UNUSED(hasnan);

UNUSED(hasnan);
if (index_last_elem >= k) {
kvselect_<keytype, valtype>(
keys, indexes, k, 0, arrsize - 1, maxiters);
keys, indexes, k, 0, index_last_elem, maxiters);
}

if (descending) {
Expand Down
43 changes: 28 additions & 15 deletions tests/test-keyvalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class simdkvsort : public ::testing::Test {
"smallrange",
"max_at_the_end",
"random_5d",
"rand_max"};
"rand_max",
"rand_with_nan",
"rand_with_max_and_nan"};
}
std::vector<std::string> arrtype;
std::vector<size_t> arrsize = std::vector<size_t>(1024);
Expand Down Expand Up @@ -123,27 +125,36 @@ bool is_kv_partialsorted(T1 *keys_comp,
}

// Now, we need to do some more work to handle keys exactly equal to the true kth
// There may be more values after the kth element with the same key,
// and thus we can find that the values of the kth elements do not match,
// even though the sort is correct.

// First, fully kvsort both arrays
xss::scalar::keyvalue_qsort<T1, T2>(keys_ref, vals_ref, size, true, false);
xss::scalar::keyvalue_qsort<T1, T2>(
keys_comp, vals_comp, size, true, false);

auto trueKth = keys_ref[k];
bool notFoundFirst = true;
auto trueKthKey = keys_ref[k];
bool foundFirstKthKey = false;
size_t i = 0;

// Search forwards until we find the block of keys that match the kth key,
// then find where it ends
for (; i < size; i++) {
if (notFoundFirst && cmp_eq(keys_ref[i], trueKth)) {
notFoundFirst = false;
if (!foundFirstKthKey && cmp_eq(keys_ref[i], trueKthKey)) {
foundFirstKthKey = true;
i_start = i;
}
else if (!notFoundFirst && !cmp_eq(keys_ref[i], trueKth)) {
else if (foundFirstKthKey && !cmp_eq(keys_ref[i], trueKthKey)) {
break;
}
}

if (notFoundFirst) return false;
// kth key is somehow missing? Since we got that value from keys_ref, should be impossible
if (!foundFirstKthKey) { return false; }

// Check that the values in the kth key block match, so they are equivalent
// up to permutation, which is allowed since the sort is not stable
if (!same_values(vals_ref + i_start, vals_comp + i_start, i - i_start)) {
return false;
}
Expand All @@ -156,7 +167,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_ascending)
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
std::vector<T1> key = get_array<T1>(type, size);
std::vector<T2> val = get_array<T2>(type, size);
Expand Down Expand Up @@ -187,7 +198,7 @@ TYPED_TEST_P(simdkvsort, test_kvsort_descending)
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
std::vector<T1> key = get_array<T1>(type, size);
std::vector<T2> val = get_array<T2>(type, size);
Expand Down Expand Up @@ -217,8 +228,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
{
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
auto cmp_eq = compare<T1, std::equal_to<T1>>();
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
size_t k = rand() % size;

Expand All @@ -237,7 +249,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_ascending)
xss::scalar::keyvalue_qsort(
key.data(), val.data(), k, hasnan, false);

ASSERT_EQ(key[k], key_bckp[k]);
ASSERT_EQ(cmp_eq(key[k], key_bckp[k]), true);

bool is_kv_partialsorted_
= is_kv_partialsorted<T1, T2>(key.data(),
Expand All @@ -260,8 +272,9 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
{
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
auto cmp_eq = compare<T1, std::equal_to<T1>>();
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
size_t k = rand() % size;

Expand All @@ -280,7 +293,7 @@ TYPED_TEST_P(simdkvsort, test_kvselect_descending)
xss::scalar::keyvalue_qsort(
key.data(), val.data(), k, hasnan, true);

ASSERT_EQ(key[k], key_bckp[k]);
ASSERT_EQ(cmp_eq(key[k], key_bckp[k]), true);

bool is_kv_partialsorted_
= is_kv_partialsorted<T1, T2>(key.data(),
Expand All @@ -304,7 +317,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_ascending)
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
size_t k = rand() % size;

Expand Down Expand Up @@ -341,7 +354,7 @@ TYPED_TEST_P(simdkvsort, test_kvpartial_sort_descending)
using T1 = typename std::tuple_element<0, decltype(TypeParam())>::type;
using T2 = typename std::tuple_element<1, decltype(TypeParam())>::type;
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
size_t k = rand() % size;

Expand Down
6 changes: 6 additions & 0 deletions tests/test-qsort-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
ASSERT_TRUE(false) << msg << ". arr size = " << size \
<< ", type = " << type << ", k = " << k;

inline bool is_nan_test(std::string type)
{
// Currently, determine whether the test uses nan just be checking if nan is in its name
return type.find("nan") != std::string::npos;
}

template <typename T>
void IS_SORTED(std::vector<T> sorted, std::vector<T> arr, std::string type)
{
Expand Down
21 changes: 11 additions & 10 deletions tests/test-qsort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class simdsort : public ::testing::Test {
"max_at_the_end",
"random_5d",
"rand_max",
"rand_with_nan"};
"rand_with_nan",
"rand_with_max_and_nan"};
}
std::vector<std::string> arrtype;
std::vector<size_t> arrsize = std::vector<size_t>(1024);
Expand All @@ -30,7 +31,7 @@ TYPED_TEST_SUITE_P(simdsort);
TYPED_TEST_P(simdsort, test_qsort_ascending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);

Expand All @@ -52,7 +53,7 @@ TYPED_TEST_P(simdsort, test_qsort_ascending)
TYPED_TEST_P(simdsort, test_qsort_descending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);

Expand All @@ -74,7 +75,7 @@ TYPED_TEST_P(simdsort, test_qsort_descending)
TYPED_TEST_P(simdsort, test_argsort_ascending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
std::vector<TypeParam> sortedarr = arr;
Expand All @@ -92,7 +93,7 @@ TYPED_TEST_P(simdsort, test_argsort_ascending)
TYPED_TEST_P(simdsort, test_argsort_descending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
std::vector<TypeParam> sortedarr = arr;
Expand All @@ -111,7 +112,7 @@ TYPED_TEST_P(simdsort, test_argsort_descending)
TYPED_TEST_P(simdsort, test_qselect_ascending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
size_t k = rand() % size;
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
Expand All @@ -135,7 +136,7 @@ TYPED_TEST_P(simdsort, test_qselect_ascending)
TYPED_TEST_P(simdsort, test_qselect_descending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
size_t k = rand() % size;
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
Expand All @@ -159,7 +160,7 @@ TYPED_TEST_P(simdsort, test_qselect_descending)
TYPED_TEST_P(simdsort, test_argselect)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
size_t k = rand() % size;
std::vector<TypeParam> arr = get_array<TypeParam>(type, size);
Expand All @@ -179,7 +180,7 @@ TYPED_TEST_P(simdsort, test_argselect)
TYPED_TEST_P(simdsort, test_partial_qsort_ascending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
size_t k = rand() % size;
std::vector<TypeParam> basearr = get_array<TypeParam>(type, size);
Expand All @@ -202,7 +203,7 @@ TYPED_TEST_P(simdsort, test_partial_qsort_ascending)
TYPED_TEST_P(simdsort, test_partial_qsort_descending)
{
for (auto type : this->arrtype) {
bool hasnan = (type == "rand_with_nan") ? true : false;
bool hasnan = is_nan_test(type);
for (auto size : this->arrsize) {
// k should be at least 1
size_t k = std::max((size_t)1, rand() % size);
Expand Down
Loading

0 comments on commit b27f82b

Please sign in to comment.