Skip to content

Commit

Permalink
Changed pivot code back to previous logic for performance reasons
Browse files Browse the repository at this point in the history
  • Loading branch information
sterrettm2 committed Sep 7, 2023
1 parent d9a8723 commit 09fce7a
Showing 1 changed file with 126 additions and 3 deletions.
129 changes: 126 additions & 3 deletions src/avx512-common-qsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ static inline int64_t partition_avx512(type_t1 *keys,
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
X86_SIMD_SORT_INLINE type_t get_pivot_scalar(type_t *arr,
const int64_t left,
const int64_t right)
{
Expand All @@ -703,9 +703,132 @@ X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,

auto vec = vtype::loadu(samples);
vec = vtype::sort_vec(vec);
vtype::storeu(samples, vec);
return ((type_t *)&vec)[numSamples / 2];
}

template <typename vtype, typename reg_t>
X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm);

template <typename vtype, typename reg_t>
X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm);

template <typename vtype, typename reg_t>
X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm);

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr,
const int64_t left,
const int64_t right)
{
// median of 32
int64_t size = (right - left) / 32;
type_t vec_arr[32] = {arr[left],
arr[left + size],
arr[left + 2 * size],
arr[left + 3 * size],
arr[left + 4 * size],
arr[left + 5 * size],
arr[left + 6 * size],
arr[left + 7 * size],
arr[left + 8 * size],
arr[left + 9 * size],
arr[left + 10 * size],
arr[left + 11 * size],
arr[left + 12 * size],
arr[left + 13 * size],
arr[left + 14 * size],
arr[left + 15 * size],
arr[left + 16 * size],
arr[left + 17 * size],
arr[left + 18 * size],
arr[left + 19 * size],
arr[left + 20 * size],
arr[left + 21 * size],
arr[left + 22 * size],
arr[left + 23 * size],
arr[left + 24 * size],
arr[left + 25 * size],
arr[left + 26 * size],
arr[left + 27 * size],
arr[left + 28 * size],
arr[left + 29 * size],
arr[left + 30 * size],
arr[left + 31 * size]};
typename vtype::reg_t rand_vec = vtype::loadu(vec_arr);
typename vtype::reg_t sort = sort_zmm_16bit<vtype>(rand_vec);
return ((type_t *)&sort)[16];
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr,
const int64_t left,
const int64_t right)
{
// median of 16
int64_t size = (right - left) / 16;
using zmm_t = typename vtype::reg_t;
using ymm_t = typename vtype::halfreg_t;
__m512i rand_index1 = _mm512_set_epi64(left + size,
left + 2 * size,
left + 3 * size,
left + 4 * size,
left + 5 * size,
left + 6 * size,
left + 7 * size,
left + 8 * size);
__m512i rand_index2 = _mm512_set_epi64(left + 9 * size,
left + 10 * size,
left + 11 * size,
left + 12 * size,
left + 13 * size,
left + 14 * size,
left + 15 * size,
left + 16 * size);
ymm_t rand_vec1
= vtype::template i64gather<sizeof(type_t)>(rand_index1, arr);
ymm_t rand_vec2
= vtype::template i64gather<sizeof(type_t)>(rand_index2, arr);
zmm_t rand_vec = vtype::merge(rand_vec1, rand_vec2);
zmm_t sort = sort_zmm_32bit<vtype>(rand_vec);
// pivot will never be a nan, since there are no nan's!
return ((type_t *)&sort)[8];
}

template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr,
const int64_t left,
const int64_t right)
{
// median of 8
int64_t size = (right - left) / 8;
using zmm_t = typename vtype::reg_t;
__m512i rand_index = _mm512_set_epi64(left + size,
left + 2 * size,
left + 3 * size,
left + 4 * size,
left + 5 * size,
left + 6 * size,
left + 7 * size,
left + 8 * size);
zmm_t rand_vec = vtype::template i64gather<sizeof(type_t)>(rand_index, arr);
// pivot will never be a nan, since there are no nan's!
zmm_t sort = sort_zmm_64bit<vtype>(rand_vec);
return ((type_t *)&sort)[4];
}

return samples[numSamples / 2];
template <typename vtype, typename type_t>
X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr,
const int64_t left,
const int64_t right)
{
if constexpr (vtype::numlanes == 8)
return get_pivot_64bit<vtype>(arr, left, right);
else if constexpr (vtype::numlanes == 16)
return get_pivot_32bit<vtype>(arr, left, right);
else if constexpr (vtype::numlanes == 32)
return get_pivot_16bit<vtype>(arr, left, right);
else
return get_pivot_scalar<vtype>(arr, left, right);
}

template <typename vtype, int64_t maxN>
Expand Down

0 comments on commit 09fce7a

Please sign in to comment.