diff --git a/_clang-format b/_clang-format index 98760584..4ef31af4 100644 --- a/_clang-format +++ b/_clang-format @@ -74,7 +74,7 @@ PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 60 PointerAlignment: Right ReflowComments: false -SortIncludes: true +SortIncludes: false SortUsingDeclarations: true SpaceAfterCStyleCast: false SpaceAfterTemplateKeyword: true diff --git a/examples/Makefile b/examples/Makefile index a66db80c..483c39df 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -1,6 +1,6 @@ CXX ?= g++-12 CFLAGS = -I../src -std=c++17 -O3 -EXE = argsort kvsort qsortfp16 qsort16 qsort32 qsort64 +EXE = qsort32avx2 argsort kvsort qsortfp16 qsort16 qsort32 qsort64 default: all all : $(EXE) @@ -14,6 +14,9 @@ qsort16: avx512-16bit-qsort.cpp qsort32: avx512-32bit-qsort.cpp $(CXX) -o qsort32 -march=skylake-avx512 $(CFLAGS) avx512-32bit-qsort.cpp +qsort32avx2: avx2-32bit-qsort.cpp + $(CXX) -o qsort32avx2 -march=haswell $(CFLAGS) avx2-32bit-qsort.cpp + qsort64: avx512-64bit-qsort.cpp $(CXX) -o qsort64 -march=skylake-avx512 $(CFLAGS) avx512-64bit-qsort.cpp diff --git a/examples/avx2-32bit-qsort.cpp b/examples/avx2-32bit-qsort.cpp new file mode 100644 index 00000000..eddedca8 --- /dev/null +++ b/examples/avx2-32bit-qsort.cpp @@ -0,0 +1,10 @@ +#include "avx2-32bit-qsort.hpp" + +int main() { + const int size = 1000; + float arr[size]; + avx2_qsort(arr, size); + avx2_qselect(arr, 10, size); + avx2_partial_qsort(arr, 10, size); + return 0; +} diff --git a/lib/meson.build b/lib/meson.build index fc544701..0bcd5da6 100644 --- a/lib/meson.build +++ b/lib/meson.build @@ -1,5 +1,15 @@ libtargets = [] +if cpp.has_argument('-march=haswell') + libtargets += static_library('libavx', + files( + 'x86simdsort-avx2.cpp', + ), + include_directories : [src], + cpp_args : ['-march=haswell', flags_hide_symbols], + ) +endif + if cpp.has_argument('-march=skylake-avx512') libtargets += static_library('libskx', files( diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp new file mode 100644 index 00000000..825b4069 --- /dev/null +++ b/lib/x86simdsort-avx2.cpp @@ -0,0 +1,28 @@ +// AVX2 specific routines: +#include "avx2-32bit-qsort.hpp" +#include "x86simdsort-internal.h" + +#define DEFINE_ALL_METHODS(type) \ + template <> \ + void qsort(type *arr, size_t arrsize) \ + { \ + avx2_qsort(arr, arrsize); \ + } \ + template <> \ + void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \ + { \ + avx2_qselect(arr, k, arrsize, hasnan); \ + } \ + template <> \ + void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ + { \ + avx2_partial_qsort(arr, k, arrsize, hasnan); \ + } + +namespace xss { +namespace avx2 { + DEFINE_ALL_METHODS(uint32_t) + DEFINE_ALL_METHODS(int32_t) + DEFINE_ALL_METHODS(float) +} // namespace avx512 +} // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 6879657a..e5803d06 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -7,7 +7,7 @@ static int check_cpu_feature_support(std::string_view cpufeature) { - const char* disable_avx512 = std::getenv("XSS_DISABLE_AVX512"); + const char *disable_avx512 = std::getenv("XSS_DISABLE_AVX512"); if ((cpufeature == "avx512_spr") && (!disable_avx512)) #ifdef __FLT16_MAX__ @@ -100,20 +100,20 @@ dispatch_requested(std::string_view cpurequested, } /* runtime dispatch mechanism */ -#define DISPATCH(func, TYPE, ...) \ +#define DISPATCH(func, TYPE, ISA) \ DECLARE_INTERNAL_##func(TYPE) static __attribute__((constructor)) void \ CAT(CAT(resolve_, func), TYPE)(void) \ { \ CAT(CAT(internal_, func), TYPE) = &xss::scalar::func; \ __builtin_cpu_init(); \ - std::string_view preferred_cpu = find_preferred_cpu({__VA_ARGS__}); \ - if constexpr (dispatch_requested("avx512", {__VA_ARGS__})) { \ + std::string_view preferred_cpu = find_preferred_cpu(ISA); \ + if constexpr (dispatch_requested("avx512", ISA)) { \ if (preferred_cpu.find("avx512") != std::string_view::npos) { \ CAT(CAT(internal_, func), TYPE) = &xss::avx512::func; \ return; \ } \ } \ - else if constexpr (dispatch_requested("avx2", {__VA_ARGS__})) { \ + if constexpr (dispatch_requested("avx2", ISA)) { \ if (preferred_cpu.find("avx2") != std::string_view::npos) { \ CAT(CAT(internal_, func), TYPE) = &xss::avx2::func; \ return; \ @@ -121,13 +121,19 @@ dispatch_requested(std::string_view cpurequested, } \ } +#define ISA_LIST(...) \ + std::initializer_list \ + { \ + __VA_ARGS__ \ + } + namespace x86simdsort { #ifdef __FLT16_MAX__ -DISPATCH(qsort, _Float16, "avx512_spr") -DISPATCH(qselect, _Float16, "avx512_spr") -DISPATCH(partial_qsort, _Float16, "avx512_spr") -DISPATCH(argsort, _Float16, "none") -DISPATCH(argselect, _Float16, "none") +DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(argsort, _Float16, ISA_LIST("none")) +DISPATCH(argselect, _Float16, ISA_LIST("none")) #endif #define DISPATCH_ALL(func, ISA_16BIT, ISA_32BIT, ISA_64BIT) \ @@ -140,10 +146,25 @@ DISPATCH(argselect, _Float16, "none") DISPATCH(func, uint64_t, ISA_64BIT) \ DISPATCH(func, double, ISA_64BIT) -DISPATCH_ALL(qsort, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) -DISPATCH_ALL(qselect, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) -DISPATCH_ALL(partial_qsort, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) -DISPATCH_ALL(argsort, "none", "avx512_skx", "avx512_skx") -DISPATCH_ALL(argselect, "none", "avx512_skx", "avx512_skx") +DISPATCH_ALL(qsort, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx"))) +DISPATCH_ALL(qselect, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx"))) +DISPATCH_ALL(partial_qsort, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx"))) +DISPATCH_ALL(argsort, + (ISA_LIST("none")), + (ISA_LIST("avx512_skx")), + (ISA_LIST("avx512_skx"))) +DISPATCH_ALL(argselect, + (ISA_LIST("none")), + (ISA_LIST("avx512_skx")), + (ISA_LIST("avx512_skx"))) } // namespace x86simdsort diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp new file mode 100644 index 00000000..5dd77a27 --- /dev/null +++ b/src/avx2-32bit-qsort.hpp @@ -0,0 +1,619 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX2_QSORT_32BIT +#define AVX2_QSORT_32BIT + +#include "xss-common-qsort.h" +#include "avx2-emu-funcs.hpp" + +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ + +// ymm 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 + +/* + * Assumes ymm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) +{ + const typename vtype::opmask_t oxAA = _mm256_set_epi32( + 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); + const typename vtype::opmask_t oxCC = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); + const typename vtype::opmask_t oxF0 = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); + + const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge( + ymm, + vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), + oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); + ymm = cmp_merge( + ymm, + vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), + oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + return ymm; +} + +struct avx2_32bit_swizzle_ops; + +template <> +struct avx2_vector { + using type_t = int32_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static reg_t zmm_max() + { + return _mm256_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _mm256_xor_si256(x, y); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t le(reg_t x, reg_t y) + { + return ~_mm256_cmpgt_epi32(x, y); + } + static opmask_t ge(reg_t x, reg_t y) + { + opmask_t equal = eq(x, y); + opmask_t greater = _mm256_cmpgt_epi32(x, y); + return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(equal), + _mm256_castsi256_ps(greater))); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_cmpeq_epi32(x, y); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_epi32((int const *)base, index, scale); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_epi32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskload_epi32((const int *)mem, mask); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_epi32((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), + _mm256_castsi256_ps(y), + _mm256_castsi256_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_epi32((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_epi32(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + //return avx2_emu_permutexvar_epi32(idx, ymm); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + return _mm256_extract_epi32(v, index); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_vector { + using type_t = uint32_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static reg_t zmm_max() + { + return _mm256_set1_epi32(type_max()); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_epi32((int const *)base, index, scale); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t ge(reg_t x, reg_t y) + { + reg_t maxi = max(x, y); + return eq(maxi, x); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_cmpeq_epi32(x, y); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_epi32((const int *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), + _mm256_castsi256_ps(y), + _mm256_castsi256_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_epi32((int *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_epu32(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + return _mm256_extract_epi32(v, index); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_vector { + using type_t = float; + using reg_t = __m256; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static reg_t zmm_max() + { + return _mm256_set1_ps(type_max()); + } + + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskload_ps((const float *)mem, mask); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t ge(reg_t x, reg_t y) + { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_GE_OQ)); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ)); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return convert_avx2_mask_to_int(mask); + } + template + static opmask_t fpclass(reg_t x) + { + if constexpr (type == (0x01 | 0x80)) { + return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q)); + } + else { + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_ps( + src, base, index, _mm256_castsi256_ps(mask), scale); + ; + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_ps((float *)base, index, scale); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_ps((float const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_ps((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_blendv_ps(x, y, _mm256_castsi256_ps(mask)); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_ps((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_ps(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_ps(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_ps(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + int32_t x = _mm256_extract_epi32(_mm256_castps_si256(v), index); + float y; + std::memcpy(&y, &x, sizeof(y)); + return y; + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_ps(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_castsi256_ps( + _mm256_shuffle_epi32(_mm256_castps_si256(ymm), mask)); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_ps((float *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v) + { + return _mm256_castsi256_ps(v); + } + static __m256i cast_to(reg_t v) + { + return _mm256_castps_si256(v); + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } +}; + +struct avx2_32bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b10110001); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 4) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b01001110); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 8) { + v = _mm256_permute2x128_si256(v, v, 0b00000001); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + constexpr uint64_t mask = 0b00011011; + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, mask); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 8) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m256i v1 = vtype::cast_to(reg); + __m256i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm256_blend_epi32(v1, v2, 0b01010101); + } + else if constexpr (scale == 4) { + v1 = _mm256_blend_epi32(v1, v2, 0b00110011); + } + else if constexpr (scale == 8) { + v1 = _mm256_blend_epi32(v1, v2, 0b00001111); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; + +#endif // AVX2_QSORT_32BIT diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp new file mode 100644 index 00000000..43eed316 --- /dev/null +++ b/src/avx2-emu-funcs.hpp @@ -0,0 +1,164 @@ +#ifndef AVX2_EMU_FUNCS +#define AVX2_EMU_FUNCS + +#include +#include +#include "xss-common-qsort.h" + +constexpr auto avx2_mask_helper_lut32 = [] { + std::array, 256> lut {}; + for (int64_t i = 0; i <= 0xFF; i++) { + std::array entry {}; + for (int j = 0; j < 8; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + +constexpr auto avx2_compressstore_lut32_gen = [] { + std::array, 256>, 2> lutPair {}; + auto &permLut = lutPair[0]; + auto &leftLut = lutPair[1]; + for (int64_t i = 0; i <= 0xFF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0, 0, 0, 0, 0}; + int right = 7; + int left = 0; + for (int j = 0; j < 8; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = j; + right--; + } + else { + indices[left] = j; + leftEntry[left] = 0xFFFFFFFF; + left++; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return lutPair; +}(); + +constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; +constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; + +X86_SIMD_SORT_INLINE +__m256i convert_int_to_avx2_mask(int32_t m) +{ + return _mm256_loadu_si256( + (const __m256i *)avx2_mask_helper_lut32[m].data()); +} + +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int(__m256i m) +{ + return _mm256_movemask_ps(_mm256_castsi256_ps(m)); +} + +// Emulators for intrinsics missing from AVX2 compared to AVX512 +template +T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) +{ + using vtype = avx2_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::max( + x, vtype::template shuffle(x)); + reg_t inter2 = vtype::max( + inter1, vtype::template shuffle(inter1)); + T can1 = vtype::template extract<0>(inter2); + T can2 = vtype::template extract<4>(inter2); + return std::max(can1, can2); +} + +template +T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) +{ + using vtype = avx2_vector; + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::min( + x, vtype::template shuffle(x)); + reg_t inter2 = vtype::min( + inter1, vtype::template shuffle(inter1)); + T can1 = vtype::template extract<0>(inter2); + T can2 = vtype::template extract<4>(inter2); + return std::min(can1, can2); +} + +template +void avx2_emu_mask_compressstoreu(void *base_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) +{ + using vtype = avx2_vector; + + T *leftStore = (T *)base_addr; + + int32_t shortMask = convert_avx2_mask_to_int(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::mask_storeu(leftStore, left, temp); +} + +template +int avx2_double_compressstore32(void *left_addr, + void *right_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) +{ + using vtype = avx2_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = convert_avx2_mask_to_int(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::mask_storeu(leftStore, left, temp); + vtype::mask_storeu(rightStore, ~left, temp); + + return _mm_popcnt_u32(shortMask); +} + +template +typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) +{ + using vtype = avx2_vector; + typename vtype::opmask_t nlt = vtype::ge(x, y); + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y), + _mm256_castsi256_pd(x), + _mm256_castsi256_pd(nlt))); +} + +template +typename avx2_vector::reg_t avx2_emu_min(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) +{ + using vtype = avx2_vector; + typename vtype::opmask_t nlt = vtype::ge(x, y); + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), + _mm256_castsi256_pd(y), + _mm256_castsi256_pd(nlt))); +} + +#endif diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 288f85d0..28c1c1fe 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -7,7 +7,7 @@ #ifndef AVX512_16BIT_COMMON #define AVX512_16BIT_COMMON -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" /* * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index fdfba924..1278201e 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -80,6 +80,14 @@ struct zmm_vector { exp_eq, mant_x, mant_y, _MM_CMPINT_NLT); return _kxor_mask32(mask_ge, neg); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return mask; + } static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); @@ -177,6 +185,14 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> @@ -220,6 +236,10 @@ struct zmm_vector { { return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); @@ -301,6 +321,14 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -343,6 +371,10 @@ struct zmm_vector { { return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); @@ -422,6 +454,14 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index dc56e370..2d101b88 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -8,7 +8,7 @@ #ifndef AVX512_QSORT_32BIT #define AVX512_QSORT_32BIT -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" /* * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic @@ -65,6 +65,10 @@ struct zmm_vector { { return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } template static halfreg_t i64gather(__m512i index, void const *base) { @@ -154,6 +158,14 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -202,6 +214,10 @@ struct zmm_vector { { return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); @@ -281,6 +297,14 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -319,9 +343,13 @@ struct zmm_vector { { return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) { - return (0x0001 << size) - 0x0001; + return mask; } template static opmask_t fpclass(reg_t x) @@ -422,6 +450,14 @@ struct zmm_vector { { return _mm512_castps_si512(v); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; /* diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 2d2c33f5..e5d0db0d 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -7,7 +7,7 @@ #ifndef AVX512_ARGSORT_64BIT #define AVX512_ARGSORT_64BIT -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" #include "avx512-64bit-common.h" #include "xss-network-keyvaluesort.hpp" #include @@ -65,7 +65,6 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) }); } - /* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of * undefined template 'zmm_vector'*/ #ifdef __APPLE__ diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 9c58b494..e7f9f44c 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -81,9 +81,9 @@ struct ymm_vector { { return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) { - return (0x01 << size) - 0x01; + return ((0x1ull << num_to_read) - 0x1ull); } template static opmask_t fpclass(reg_t x) @@ -244,6 +244,10 @@ struct ymm_vector { { return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_EQ); @@ -396,6 +400,10 @@ struct ymm_vector { { return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_EQ); @@ -557,6 +565,10 @@ struct zmm_vector { { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static opmask_t eq(reg_t x, reg_t y) { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); @@ -660,6 +672,14 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -738,6 +758,10 @@ struct zmm_vector { { return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static opmask_t eq(reg_t x, reg_t y) { return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); @@ -818,6 +842,14 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -880,9 +912,13 @@ struct zmm_vector { { return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) { - return (0x01 << size) - 0x01; + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return mask; } template static opmask_t fpclass(reg_t x) @@ -982,6 +1018,14 @@ struct zmm_vector { { return _mm512_castpd_si512(v); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; /* diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index c24e575a..b1ec0cd2 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -8,7 +8,7 @@ #ifndef AVX512_QSORT_64BIT_KV #define AVX512_QSORT_64BIT_KV -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" #include "avx512-64bit-common.h" #include "xss-network-keyvaluesort.hpp" diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index afafe250..4dcaeafa 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -7,7 +7,7 @@ #ifndef AVX512_QSORT_64BIT #define AVX512_QSORT_64BIT -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" #include "avx512-64bit-common.h" #endif // AVX512_QSORT_64BIT diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 94e508f0..b01c367b 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -54,9 +54,13 @@ struct zmm_vector<_Float16> { { return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) { - return (0x00000001 << size) - 0x00000001; + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return mask; } template static opmask_t fpclass(reg_t x) @@ -145,6 +149,14 @@ struct zmm_vector<_Float16> { { return _mm512_castph_si512(v); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); + } }; template <> @@ -189,4 +201,13 @@ void avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); } } +template <> +void avx512_partial_qsort(_Float16 *arr, + arrsize_t k, + arrsize_t arrsize, + bool hasnan) +{ + avx512_qselect(arr, k - 1, arrsize, hasnan); + avx512_qsort(arr, k - 1); +} #endif // AVX512FP16_QSORT_16BIT diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index f1465977..23c9f964 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -1,3 +1,5 @@ +#ifndef XSS_COMMON_INCLUDES +#define XSS_COMMON_INCLUDES #include #include #include @@ -72,3 +74,8 @@ struct zmm_vector; template struct ymm_vector; + +template +struct avx2_vector; + +#endif // XSS_COMMON_INCLUDES diff --git a/src/avx512-common-qsort.h b/src/xss-common-qsort.h similarity index 89% rename from src/avx512-common-qsort.h rename to src/xss-common-qsort.h index 3a489b7c..0b76add6 100644 --- a/src/avx512-common-qsort.h +++ b/src/xss-common-qsort.h @@ -8,8 +8,8 @@ * Tang Xi * ****************************************************************/ -#ifndef AVX512_QSORT_COMMON -#define AVX512_QSORT_COMMON +#ifndef XSS_COMMON_QSORT +#define XSS_COMMON_QSORT /* * Quicksort using AVX-512. The ideas and code are based on these two research @@ -65,7 +65,7 @@ X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size) in = vtype::loadu(arr + ii); } opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); - nan_count += _mm_popcnt_u32((int32_t)nanmask); + nan_count += _mm_popcnt_u32(vtype::convert_mask_to_int(nanmask)); vtype::mask_storeu(arr + ii, nanmask, vtype::zmm_max()); } return nan_count; @@ -162,11 +162,26 @@ X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask) reg_t max = vtype::max(in2, in1); return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max } -/* - * Parition one ZMM register based on the pivot and returns the - * number of elements that are greater than or equal to the pivot. - */ + template +int avx512_double_compressstore(type_t *left_addr, + type_t *right_addr, + typename vtype::opmask_t k, + reg_t reg) +{ + int amount_ge_pivot = _mm_popcnt_u32((int)k); + + vtype::mask_compressstoreu(left_addr, vtype::knot_opmask(k), reg); + vtype::mask_compressstoreu( + right_addr + vtype::numlanes - amount_ge_pivot, k, reg); + + return amount_ge_pivot; +} + +// Generic function dispatches to AVX2 or AVX512 code +template X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, type_t *r_store, const reg_t curr_vec, @@ -175,17 +190,16 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, reg_t &biggest_vec) { typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); - int amount_ge_pivot = _mm_popcnt_u32((int)ge_mask); - vtype::mask_compressstoreu(l_store, vtype::knot_opmask(ge_mask), curr_vec); - vtype::mask_compressstoreu( - r_store + vtype::numlanes - amount_ge_pivot, ge_mask, curr_vec); + int amount_ge_pivot + = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec); smallest_vec = vtype::min(curr_vec, smallest_vec); biggest_vec = vtype::max(curr_vec, biggest_vec); return amount_ge_pivot; } + /* * Parition an array based on the pivot and returns the index of the * first element that is greater than or equal to the pivot. @@ -462,7 +476,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, return l_store; } - template void sort_n(typename vtype::type_t *arr, int N); @@ -536,32 +549,28 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, qselect_(arr, pos, pivot_index, right, max_iters - 1); } -// Regular quicksort routines: -template -X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize) +// Quicksort routines: +template +X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize) { if (arrsize > 1) { - /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { - arrsize_t nan_count - = replace_nan_with_inf>(arr, arrsize); - qsort_, T>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + arrsize_t nan_count = replace_nan_with_inf(arr, arrsize); + qsort_(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); } else { - qsort_, T>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); + qsort_(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } } -template +// Quick select methods +template X86_SIMD_SORT_INLINE void -avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { arrsize_t indx_last_elem = arrsize - 1; - /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { if (UNLIKELY(hasnan)) { indx_last_elem = move_nans_to_end_of_array(arr, arrsize); @@ -569,19 +578,40 @@ avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) } UNUSED(hasnan); if (indx_last_elem >= k) { - qselect_, T>( + qselect_( arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); } } -template -X86_SIMD_SORT_INLINE void avx512_partial_qsort(T *arr, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) +// Partial sort methods: +template +X86_SIMD_SORT_INLINE void +xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - avx512_qselect(arr, k - 1, arrsize, hasnan); - avx512_qsort(arr, k - 1); + xss_qselect(arr, k - 1, arrsize, hasnan); + xss_qsort(arr, k - 1); } -#endif // AVX512_QSORT_COMMON +#define DEFINE_METHODS(ISA, VTYPE) \ + template \ + X86_SIMD_SORT_INLINE void ISA##_qsort(T *arr, arrsize_t size) \ + { \ + xss_qsort(arr, size); \ + } \ + template \ + X86_SIMD_SORT_INLINE void ISA##_qselect( \ + T *arr, arrsize_t k, arrsize_t size, bool hasnan = false) \ + { \ + xss_qselect(arr, k, size, hasnan); \ + } \ + template \ + X86_SIMD_SORT_INLINE void ISA##_partial_qsort( \ + T *arr, arrsize_t k, arrsize_t size, bool hasnan = false) \ + { \ + xss_partial_qsort(arr, k, size, hasnan); \ + } + +DEFINE_METHODS(avx512, zmm_vector) +DEFINE_METHODS(avx2, avx2_vector) + +#endif // XSS_COMMON_QSORT diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index a768a580..56a1aca1 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -2,6 +2,7 @@ #define XSS_NETWORK_QSORT #include "xss-optimal-networks.hpp" +#include "xss-common-qsort.h" template X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) @@ -156,10 +157,10 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) typename vtype::opmask_t ioMasks[numVecs - numVecs / 2]; X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - int64_t num_to_read - = std::min((int64_t)std::max(0, N - i * vtype::numlanes), - (int64_t)vtype::numlanes); - ioMasks[j] = ((0x1ull << num_to_read) - 0x1ull); + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * vtype::numlanes), + (uint64_t)vtype::numlanes); + ioMasks[j] = vtype::get_partial_loadmask(num_to_read); } // Unmasked part of the load diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp index 15fe36a2..2a28b348 100644 --- a/src/xss-pivot-selection.hpp +++ b/src/xss-pivot-selection.hpp @@ -1,120 +1,27 @@ template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); -template -X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - // median of 32 - arrsize_t size = (right - left) / 32; - type_t vec_arr[32] = {arr[left], - arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size], - arr[left + 17 * size], - arr[left + 18 * size], - arr[left + 19 * size], - arr[left + 20 * size], - arr[left + 21 * size], - arr[left + 22 * size], - arr[left + 23 * size], - arr[left + 24 * size], - arr[left + 25 * size], - arr[left + 26 * size], - arr[left + 27 * size], - arr[left + 28 * size], - arr[left + 29 * size], - arr[left + 30 * size], - arr[left + 31 * size]}; - typename vtype::reg_t rand_vec = vtype::loadu(vec_arr); - typename vtype::reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[16]; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - // median of 16 - arrsize_t size = (right - left) / 16; - using reg_t = typename vtype::reg_t; - type_t vec_arr[16] = {arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size]}; - reg_t rand_vec = vtype::loadu(vec_arr); - reg_t sort = vtype::sort_vec(rand_vec); - // pivot will never be a nan, since there are no nan's! - return ((type_t *)&sort)[8]; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - // median of 8 - arrsize_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[4]; -} - template X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, const arrsize_t left, const arrsize_t right) { - if constexpr (vtype::numlanes == 8) - return get_pivot_64bit(arr, left, right); - else if constexpr (vtype::numlanes == 16) - return get_pivot_32bit(arr, left, right); - else if constexpr (vtype::numlanes == 32) - return get_pivot_16bit(arr, left, right); - else - return arr[right]; + using reg_t = typename vtype::reg_t; + type_t samples[vtype::numlanes]; + arrsize_t delta = (right - left) / vtype::numlanes; + for (int i = 0; i < vtype::numlanes; i++) { + samples[i] = arr[left + i * delta]; + } + reg_t rand_vec = vtype::loadu(samples); + reg_t sort = vtype::sort_vec(rand_vec); + + return ((type_t *)&sort)[vtype::numlanes / 2]; } template X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr, - arrsize_t left, - arrsize_t right) + const arrsize_t left, + const arrsize_t right) { if (right - left <= 1024) { return get_pivot(arr, left, right); }