Skip to content

Commit

Permalink
Fixed types in a few places
Browse files Browse the repository at this point in the history
  • Loading branch information
sterrettm2 authored and r-devulap committed Oct 20, 2023
1 parent b02fe0a commit dec2278
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 32 deletions.
30 changes: 4 additions & 26 deletions src/avx2-32bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,9 @@ struct ymm_vector<float> {
}
};

inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize)
{
int64_t nan_count = 0;
arrsize_t nan_count = 0;
__mmask8 loadmask = 0xFF;
while (arrsize > 0) {
if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; }
Expand All @@ -567,36 +567,14 @@ inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize)
}

X86_SIMD_SORT_INLINE void
replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count)
replace_inf_with_nan(float *arr, arrsize_t arrsize, arrsize_t nan_count)
{
for (int64_t ii = arrsize - 1; nan_count > 0; --ii) {
for (arrsize_t ii = arrsize - 1; nan_count > 0; --ii) {
arr[ii] = std::nan("1");
nan_count -= 1;
}
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr,
const int64_t left,
const int64_t right)
{
// median of 8
int64_t size = (right - left) / 8;
using reg_t = typename vtype::reg_t;
__m256i rand_index = _mm256_set_epi32(left + size,
left + 2 * size,
left + 3 * size,
left + 4 * size,
left + 5 * size,
left + 6 * size,
left + 7 * size,
left + 8 * size);
reg_t rand_vec = vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
// pivot will never be a nan, since there are no nan's!
reg_t sort = sort_ymm_32bit<vtype>(rand_vec);
return ((type_t *)&sort)[4];
}

struct avx2_32bit_swizzle_ops{
template <typename vtype, int scale>
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){
Expand Down
12 changes: 6 additions & 6 deletions src/xss-common-qsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace avx2{
template <typename type>
struct ymm_vector;

inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize);
inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize);
}
}

Expand Down Expand Up @@ -612,13 +612,13 @@ X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize)
}

template <typename T>
void avx2_qsort(T *arr, int64_t arrsize)
void avx2_qsort(T *arr, arrsize_t arrsize)
{
using vtype = xss::avx2::ymm_vector<T>;
if (arrsize > 1) {
/* std::is_floating_point_v<_Float16> == False, unless c++-23*/
if constexpr (std::is_floating_point_v<T>) {
int64_t nan_count
arrsize_t nan_count
= xss::avx2::replace_nan_with_inf(arr, arrsize);
qsort_<vtype, T>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
Expand Down Expand Up @@ -650,9 +650,9 @@ avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false)
}

template <typename T>
void avx2_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
void avx2_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false)
{
int64_t indx_last_elem = arrsize - 1;
arrsize_t indx_last_elem = arrsize - 1;
/* std::is_floating_point_v<_Float16> == False, unless c++-23*/
if constexpr (std::is_floating_point_v<T>) {
if (UNLIKELY(hasnan)) {
Expand All @@ -677,7 +677,7 @@ X86_SIMD_SORT_INLINE void avx512_partial_qsort(T *arr,
}

template <typename T>
inline void avx2_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false)
inline void avx2_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false)
{
avx2_qselect<T>(arr, k - 1, arrsize, hasnan);
avx2_qsort<T>(arr, k - 1);
Expand Down

0 comments on commit dec2278

Please sign in to comment.