diff --git a/src/avx512-common-qsort.h b/src/avx512-common-qsort.h index a5353445..9991f74a 100644 --- a/src/avx512-common-qsort.h +++ b/src/avx512-common-qsort.h @@ -688,7 +688,7 @@ static inline int64_t partition_avx512(type_t1 *keys, } template -X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, +X86_SIMD_SORT_INLINE type_t get_pivot_scalar(type_t *arr, const int64_t left, const int64_t right) { @@ -703,9 +703,132 @@ X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, auto vec = vtype::loadu(samples); vec = vtype::sort_vec(vec); - vtype::storeu(samples, vec); + return ((type_t *)&vec)[numSamples / 2]; +} + +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_16bit(reg_t zmm); + +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_32bit(reg_t zmm); + +template +X86_SIMD_SORT_INLINE reg_t sort_zmm_64bit(reg_t zmm); + +template +X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 32 + int64_t size = (right - left) / 32; + type_t vec_arr[32] = {arr[left], + arr[left + size], + arr[left + 2 * size], + arr[left + 3 * size], + arr[left + 4 * size], + arr[left + 5 * size], + arr[left + 6 * size], + arr[left + 7 * size], + arr[left + 8 * size], + arr[left + 9 * size], + arr[left + 10 * size], + arr[left + 11 * size], + arr[left + 12 * size], + arr[left + 13 * size], + arr[left + 14 * size], + arr[left + 15 * size], + arr[left + 16 * size], + arr[left + 17 * size], + arr[left + 18 * size], + arr[left + 19 * size], + arr[left + 20 * size], + arr[left + 21 * size], + arr[left + 22 * size], + arr[left + 23 * size], + arr[left + 24 * size], + arr[left + 25 * size], + arr[left + 26 * size], + arr[left + 27 * size], + arr[left + 28 * size], + arr[left + 29 * size], + arr[left + 30 * size], + arr[left + 31 * size]}; + typename vtype::reg_t rand_vec = vtype::loadu(vec_arr); + typename vtype::reg_t sort = sort_zmm_16bit(rand_vec); + return ((type_t *)&sort)[16]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 16 + int64_t size = (right - left) / 16; + using zmm_t = typename vtype::reg_t; + using ymm_t = typename vtype::halfreg_t; + __m512i rand_index1 = _mm512_set_epi64(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + __m512i rand_index2 = _mm512_set_epi64(left + 9 * size, + left + 10 * size, + left + 11 * size, + left + 12 * size, + left + 13 * size, + left + 14 * size, + left + 15 * size, + left + 16 * size); + ymm_t rand_vec1 + = vtype::template i64gather(rand_index1, arr); + ymm_t rand_vec2 + = vtype::template i64gather(rand_index2, arr); + zmm_t rand_vec = vtype::merge(rand_vec1, rand_vec2); + zmm_t sort = sort_zmm_32bit(rand_vec); + // pivot will never be a nan, since there are no nan's! + return ((type_t *)&sort)[8]; +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 8 + int64_t size = (right - left) / 8; + using zmm_t = typename vtype::reg_t; + __m512i rand_index = _mm512_set_epi64(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + zmm_t rand_vec = vtype::template i64gather(rand_index, arr); + // pivot will never be a nan, since there are no nan's! + zmm_t sort = sort_zmm_64bit(rand_vec); + return ((type_t *)&sort)[4]; +} - return samples[numSamples / 2]; +template +X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, + const int64_t left, + const int64_t right) +{ + if constexpr (vtype::numlanes == 8) + return get_pivot_64bit(arr, left, right); + else if constexpr (vtype::numlanes == 16) + return get_pivot_32bit(arr, left, right); + else if constexpr (vtype::numlanes == 32) + return get_pivot_16bit(arr, left, right); + else + return get_pivot_scalar(arr, left, right); } template