From f9860436f438fe4edede620291dd5d8f6b35d153 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 4 Aug 2023 11:40:20 -0700 Subject: [PATCH 01/19] Core AVX2 code logic --- lib/meson.build | 10 + lib/x86simdsort-avx2.cpp | 28 + lib/x86simdsort.cpp | 6 +- src/avx2-32bit-common.h | 663 ++++++++++++++++++ src/avx2-32bit-qsort.hpp | 13 + src/avx2-emu-funcs.hpp | 231 ++++++ src/avx512-16bit-common.h | 3 +- src/avx512-32bit-qsort.hpp | 3 +- ...x512-common-qsort.h => xss-common-qsort.h} | 121 +++- src/xss-network-qsort.hpp | 1 + src/xss-pivot-selection.hpp | 52 +- 11 files changed, 1112 insertions(+), 19 deletions(-) create mode 100644 lib/x86simdsort-avx2.cpp create mode 100644 src/avx2-32bit-common.h create mode 100644 src/avx2-32bit-qsort.hpp create mode 100644 src/avx2-emu-funcs.hpp rename src/{avx512-common-qsort.h => xss-common-qsort.h} (83%) diff --git a/lib/meson.build b/lib/meson.build index fc544701..0bcd5da6 100644 --- a/lib/meson.build +++ b/lib/meson.build @@ -1,5 +1,15 @@ libtargets = [] +if cpp.has_argument('-march=haswell') + libtargets += static_library('libavx', + files( + 'x86simdsort-avx2.cpp', + ), + include_directories : [src], + cpp_args : ['-march=haswell', flags_hide_symbols], + ) +endif + if cpp.has_argument('-march=skylake-avx512') libtargets += static_library('libskx', files( diff --git a/lib/x86simdsort-avx2.cpp b/lib/x86simdsort-avx2.cpp new file mode 100644 index 00000000..825b4069 --- /dev/null +++ b/lib/x86simdsort-avx2.cpp @@ -0,0 +1,28 @@ +// AVX2 specific routines: +#include "avx2-32bit-qsort.hpp" +#include "x86simdsort-internal.h" + +#define DEFINE_ALL_METHODS(type) \ + template <> \ + void qsort(type *arr, size_t arrsize) \ + { \ + avx2_qsort(arr, arrsize); \ + } \ + template <> \ + void qselect(type *arr, size_t k, size_t arrsize, bool hasnan) \ + { \ + avx2_qselect(arr, k, arrsize, hasnan); \ + } \ + template <> \ + void partial_qsort(type *arr, size_t k, size_t arrsize, bool hasnan) \ + { \ + avx2_partial_qsort(arr, k, arrsize, hasnan); \ + } + +namespace xss { +namespace avx2 { + DEFINE_ALL_METHODS(uint32_t) + DEFINE_ALL_METHODS(int32_t) + DEFINE_ALL_METHODS(float) +} // namespace avx512 +} // namespace xss diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 6879657a..57cac471 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -140,9 +140,9 @@ DISPATCH(argselect, _Float16, "none") DISPATCH(func, uint64_t, ISA_64BIT) \ DISPATCH(func, double, ISA_64BIT) -DISPATCH_ALL(qsort, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) -DISPATCH_ALL(qselect, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) -DISPATCH_ALL(partial_qsort, ("avx512_icl"), ("avx512_skx"), ("avx512_skx")) +DISPATCH_ALL(qsort, ("avx512_icl"), ("avx512_skx","avx2"), ("avx512_skx")) +DISPATCH_ALL(qselect, ("avx512_icl"), ("avx512_skx","avx2"), ("avx512_skx")) +DISPATCH_ALL(partial_qsort, ("avx512_icl"), ("avx512_skx","avx2"), ("avx512_skx")) DISPATCH_ALL(argsort, "none", "avx512_skx", "avx512_skx") DISPATCH_ALL(argselect, "none", "avx512_skx", "avx512_skx") diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h new file mode 100644 index 00000000..110de03e --- /dev/null +++ b/src/avx2-32bit-common.h @@ -0,0 +1,663 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * Matthew Sterrett + * ****************************************************************/ + +#ifndef AVX2_32BIT_COMMON +#define AVX2_32BIT_COMMON +#include "avx2-emu-funcs.hpp" +#include "xss-common-qsort.h" + +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ + +// ymm 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 + +namespace xss { +namespace avx2 { + +// Assumes ymm is bitonic and performs a recursive half cleaner +template +X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_32bit(reg_t ymm) +{ + + const typename vtype::opmask_t oxAA = _mm256_set_epi32( + 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); + const typename vtype::opmask_t oxCC = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); + const typename vtype::opmask_t oxF0 = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); + + // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 + ymm = cmp_merge( + ymm, + vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_4), ymm), + oxF0); + // 2) half_cleaner[4] + ymm = cmp_merge( + ymm, + vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_3), ymm), + oxCC); + // 3) half_cleaner[1] + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + return ymm; +} + +/* + * Assumes ymm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) +{ + const typename vtype::opmask_t oxAA = _mm256_set_epi32( + 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); + const typename vtype::opmask_t oxCC = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); + const typename vtype::opmask_t oxF0 = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); + + const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge( + ymm, vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); + ymm = cmp_merge( + ymm, vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + return ymm; +} + +struct avx2_32bit_swizzle_ops; + +template <> +struct ymm_vector { + using type_t = int32_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = avx2_mask_helper32; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static reg_t zmm_max() + { + return _mm256_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _mm256_xor_si256(x, y); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t le(reg_t x, reg_t y) + { + return ~_mm256_cmpgt_epi32(x, y); + } + static opmask_t ge(reg_t x, reg_t y) + { + opmask_t equal = eq(x, y); + opmask_t greater = _mm256_cmpgt_epi32(x, y); + return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(equal), + _mm256_castsi256_ps(greater))); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_cmpeq_epi32(x, y); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_epi32((int const *)base, index, scale); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_epi32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static int32_t double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskload_epi32((const int *)mem, mask); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_epi32((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), + _mm256_castsi256_ps(y), + _mm256_castsi256_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_epi32((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_epi32(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + //return avx2_emu_permutexvar_epi32(idx, ymm); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + return _mm256_extract_epi32(v, index); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_ymm_32bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v){ + return v; + } + static __m256i cast_to(reg_t v){ + return v; + } +}; +template <> +struct ymm_vector { + using type_t = uint32_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = avx2_mask_helper32; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static reg_t zmm_max() + { + return _mm256_set1_epi32(type_max()); + } + + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_epi32((int const *)base, index, scale); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t ge(reg_t x, reg_t y) + { + reg_t maxi = max(x, y); + return eq(maxi, x); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_cmpeq_epi32(x, y); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static int32_t double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_epi32((const int *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), + _mm256_castsi256_ps(y), + _mm256_castsi256_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_epi32((int *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_epu32(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + return _mm256_extract_epi32(v, index); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_ymm_32bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v){ + return v; + } + static __m256i cast_to(reg_t v){ + return v; + } +}; +template <> +struct ymm_vector { + using type_t = float; + using reg_t = __m256; + using ymmi_t = __m256i; + using opmask_t = avx2_mask_helper32; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static reg_t zmm_max() + { + return _mm256_set1_ps(type_max()); + } + + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskload_ps((const float *)mem, mask); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t ge(reg_t x, reg_t y) + { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_GE_OQ)); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ)); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_ps( + src, base, index, _mm256_castsi256_ps(mask), scale); + ; + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_ps((float *)base, index, scale); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_ps((float const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static int32_t double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_ps((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_blendv_ps(x, y, _mm256_castsi256_ps(mask)); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_ps((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_ps(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_ps(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_ps(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + int32_t x = _mm256_extract_epi32(_mm256_castps_si256(v), index); + float y; + std::memcpy(&y, &x, sizeof(y)); + return y; + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_ps(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_castsi256_ps( + _mm256_shuffle_epi32(_mm256_castps_si256(ymm), mask)); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_ps((float *)mem, x); + } + static reg_t bitonic_merge(reg_t x) + { + return bitonic_merge_ymm_32bit>(x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v){ + return _mm256_castsi256_ps(v); + } + static __m256i cast_to(reg_t v){ + return _mm256_castps_si256(v); + } +}; + +inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize) +{ + int64_t nan_count = 0; + __mmask8 loadmask = 0xFF; + while (arrsize > 0) { + if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } + __m256 in_ymm = ymm_vector::maskz_loadu(loadmask, arr); + __m256i nanmask = _mm256_castps_si256( + _mm256_cmp_ps(in_ymm, in_ymm, _CMP_NEQ_UQ)); + nan_count += _popcnt32(avx2_mask_helper32(nanmask)); + ymm_vector::mask_storeu(arr, nanmask, YMM_MAX_FLOAT); + arr += 8; + arrsize -= 8; + } + return nan_count; +} + +X86_SIMD_SORT_INLINE void +replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) +{ + for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + arr[ii] = std::nan("1"); + nan_count -= 1; + } +} + +template +X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, + const int64_t left, + const int64_t right) +{ + // median of 8 + int64_t size = (right - left) / 8; + using reg_t = typename vtype::reg_t; + __m256i rand_index = _mm256_set_epi32(left + size, + left + 2 * size, + left + 3 * size, + left + 4 * size, + left + 5 * size, + left + 6 * size, + left + 7 * size, + left + 8 * size); + reg_t rand_vec = vtype::template i64gather(rand_index, arr); + // pivot will never be a nan, since there are no nan's! + reg_t sort = sort_ymm_32bit(rand_vec); + return ((type_t *)&sort)[4]; +} + +struct avx2_32bit_swizzle_ops{ + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){ + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2){ + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b10110001); + v = _mm256_castps_si256(vf); + }else if constexpr (scale == 4){ + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b01001110); + v = _mm256_castps_si256(vf); + }else if constexpr (scale == 8){ + v = _mm256_permute2x128_si256(v, v, 0b00000001); + }else{ + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n(typename vtype::reg_t reg){ + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2){ + return swap_n(reg); + }else if constexpr (scale == 4){ + constexpr uint64_t mask = 0b00011011; + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, mask); + v = _mm256_castps_si256(vf); + }else if constexpr (scale == 8){ + return vtype::reverse(reg); + }else{ + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n(typename vtype::reg_t reg, typename vtype::reg_t other){ + __m256i v1 = vtype::cast_to(reg); + __m256i v2 = vtype::cast_to(other); + + if constexpr (scale == 2){ + v1 = _mm256_blend_epi32(v1, v2, 0b01010101); + }else if constexpr (scale == 4){ + v1 = _mm256_blend_epi32(v1, v2, 0b00110011); + }else if constexpr (scale == 8){ + v1 = _mm256_blend_epi32(v1, v2, 0b00001111); + }else{ + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; + +} // namespace avx2 +} // namespace xss +#endif diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp new file mode 100644 index 00000000..c0590d94 --- /dev/null +++ b/src/avx2-32bit-qsort.hpp @@ -0,0 +1,13 @@ +/******************************************************************* + * Copyright (C) 2022 Intel Corporation + * SPDX-License-Identifier: BSD-3-Clause + * Authors: Raghuveer Devulapalli + * ****************************************************************/ + +#ifndef AVX2_QSORT_32BIT +#define AVX2_QSORT_32BIT + +#include "avx2-32bit-common.h" +#include "xss-network-qsort.hpp" + +#endif // AVX2_QSORT_32BIT diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp new file mode 100644 index 00000000..228603fe --- /dev/null +++ b/src/avx2-emu-funcs.hpp @@ -0,0 +1,231 @@ +#ifndef AVX2_EMU_FUNCS +#define AVX2_EMU_FUNCS + +#include +#include +#include "xss-common-qsort.h" + +namespace xss { +namespace avx2 { + +constexpr auto avx2_mask_helper_lut32 = [] { + std::array, 256> lut {}; + for (int64_t i = 0; i <= 0xFF; i++) { + std::array entry {}; + for (int j = 0; j < 8; j++) { + if (((i >> j) & 1) == 1) + entry[j] = 0xFFFFFFFF; + else + entry[j] = 0; + } + lut[i] = entry; + } + return lut; +}(); + +constexpr auto avx2_compressstore_lut32_gen = [] { + std::array, 256>, 2> lutPair {}; + auto &permLut = lutPair[0]; + auto &leftLut = lutPair[1]; + for (int64_t i = 0; i <= 0xFF; i++) { + std::array indices {}; + std::array leftEntry = {0, 0, 0, 0, 0, 0, 0, 0}; + int right = 7; + int left = 0; + for (int j = 0; j < 8; j++) { + bool ge = (i >> j) & 1; + if (ge) { + indices[right] = j; + right--; + } + else { + indices[left] = j; + leftEntry[left] = 0xFFFFFFFF; + left++; + } + } + permLut[i] = indices; + leftLut[i] = leftEntry; + } + return lutPair; +}(); +constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; +constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; + +struct avx2_mask_helper32 { + __m256i mask; + + avx2_mask_helper32() = default; + avx2_mask_helper32(int m) + { + mask = converter(m); + } + avx2_mask_helper32(__m256i m) + { + mask = m; + } + operator __m256i() + { + return mask; + } + operator int32_t() + { + return converter(mask); + } + __m256i operator=(int m) + { + mask = converter(m); + return mask; + } + +private: + __m256i converter(int m) + { + return _mm256_loadu_si256( + (const __m256i *)avx2_mask_helper_lut32[m].data()); + } + + int32_t converter(__m256i m) + { + return _mm256_movemask_ps(_mm256_castsi256_ps(m)); + } +}; +static __m256i operator~(const avx2_mask_helper32 x) +{ + return ~x.mask; +} + +// Emulators for intrinsics missing from AVX2 compared to AVX512 +template +T avx2_emu_reduce_max32(typename ymm_vector::reg_t x) +{ + using vtype = ymm_vector; + typename vtype::reg_t inter1 = vtype::max( + x, vtype::template shuffle(x)); + typename vtype::reg_t inter2 = vtype::permutevar( + inter1, _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0)); + typename vtype::reg_t inter3 = vtype::max( + inter2, vtype::template shuffle(inter2)); + T can1 = vtype::template extract<0>(inter3); + T can2 = vtype::template extract<2>(inter3); + return std::max(can1, can2); +} + +template +T avx2_emu_reduce_min32(typename ymm_vector::reg_t x) +{ + using vtype = ymm_vector; + typename vtype::reg_t inter1 = vtype::min( + x, vtype::template shuffle(x)); + typename vtype::reg_t inter2 = vtype::permutevar( + inter1, _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0)); + typename vtype::reg_t inter3 = vtype::min( + inter2, vtype::template shuffle(inter2)); + T can1 = vtype::template extract<0>(inter3); + T can2 = vtype::template extract<2>(inter3); + return std::min(can1, can2); +} + +template +typename ymm_vector::opmask_t +avx2_emu_fpclassify64(typename ymm_vector::reg_t x, int mask) +{ + using vtype = ymm_vector; + T store[vtype::numlanes]; + vtype::storeu(&store[0], x); + int64_t res[vtype::numlanes]; + + for (int i = 0; i < vtype::numlanes; i++) { + bool flagged = scalar_emu_fpclassify(store[i]); + res[i] = 0xFFFFFFFFFFFFFFFF; + } + return vtype::loadu(res); +} + +template +void avx2_emu_mask_compressstoreu(void *base_addr, + typename ymm_vector::opmask_t k, + typename ymm_vector::reg_t reg) +{ + using vtype = ymm_vector; + T *storage = (T *)base_addr; + int32_t mask[vtype::numlanes]; + T data[vtype::numlanes]; + + _mm256_storeu_si256((__m256i *)&mask[0], k); + vtype::storeu(&data[0], reg); + +#pragma GCC unroll 8 + for (int i = 0; i < vtype::numlanes; i++) { + if (mask[i]) { + *storage = data[i]; + storage++; + } + } +} + +template +int32_t avx2_double_compressstore32(void *left_addr, + void *right_addr, + typename ymm_vector::opmask_t k, + typename ymm_vector::reg_t reg) +{ + using vtype = ymm_vector; + + T *leftStore = (T *)left_addr; + T *rightStore = (T *)right_addr; + + int32_t shortMask = avx2_mask_helper32(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::mask_storeu(leftStore, left, temp); + vtype::mask_storeu(rightStore - vtype::numlanes, ~left, temp); + + return _mm_popcnt_u32(shortMask); +} + +template +int64_t avx2_emu_popcnt(__m256i reg) +{ + using vtype = ymm_vector; + + int32_t data[vtype::numlanes]; + _mm256_storeu_si256((__m256i *)&data[0], reg); + + int64_t pop = 0; + for (int i = 0; i < vtype::numlanes; i++) { + pop += _popcnt32(data[i]); + } + return pop; +} + +template +typename ymm_vector::reg_t avx2_emu_max(typename ymm_vector::reg_t x, + typename ymm_vector::reg_t y) +{ + using vtype = ymm_vector; + typename vtype::opmask_t nlt = vtype::ge(x, y); + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y), + _mm256_castsi256_pd(x), + _mm256_castsi256_pd(nlt))); +} + +template +typename ymm_vector::reg_t avx2_emu_min(typename ymm_vector::reg_t x, + typename ymm_vector::reg_t y) +{ + using vtype = ymm_vector; + typename vtype::opmask_t nlt = vtype::ge(x, y); + return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), + _mm256_castsi256_pd(y), + _mm256_castsi256_pd(nlt))); +} +} // namespace avx2 +} // namespace x86_simd_sort + +#endif \ No newline at end of file diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 288f85d0..41116b33 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -7,7 +7,8 @@ #ifndef AVX512_16BIT_COMMON #define AVX512_16BIT_COMMON -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" +#include "xss-network-qsort.hpp" /* * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index dc56e370..4d244d6e 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -8,7 +8,8 @@ #ifndef AVX512_QSORT_32BIT #define AVX512_QSORT_32BIT -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" +#include "xss-network-qsort.hpp" /* * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic diff --git a/src/avx512-common-qsort.h b/src/xss-common-qsort.h similarity index 83% rename from src/avx512-common-qsort.h rename to src/xss-common-qsort.h index 3a489b7c..e0eadbcf 100644 --- a/src/avx512-common-qsort.h +++ b/src/xss-common-qsort.h @@ -38,6 +38,20 @@ #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" +namespace xss{ +namespace avx2{ +template +struct ymm_vector; + +inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize); +} +} + +// key-value sort routines +template +void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize); + +>>>>>>> d6479e7 (Core AVX2 code logic):src/xss-common-qsort.h template bool is_a_nan(T elem) { @@ -167,7 +181,7 @@ X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask) * number of elements that are greater than or equal to the pivot. */ template -X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, +X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx512(type_t *l_store, type_t *r_store, const reg_t curr_vec, const reg_t pivot_vec, @@ -186,6 +200,50 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, return amount_ge_pivot; } +/* + * Parition one YMM register based on the pivot and returns the + * number of elements that are greater than or equal to the pivot. + */ +template +X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx2(type_t *l_store, + type_t *r_store, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t &smallest_vec, + reg_t &biggest_vec) +{ + /* which elements are larger than or equal to the pivot */ + typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); + + int32_t amount_ge_pivot = vtype::double_compressstore( + arr + left, arr + left + unpartitioned, ge_mask, curr_vec); + + left += (vtype::numlanes - amount_ge_pivot); + unpartitioned -= vtype::numlanes; + + smallest_vec = vtype::min(curr_vec, smallest_vec); + biggest_vec = vtype::max(curr_vec, biggest_vec); +} + +// Generic function dispatches to AVX2 or AVX512 code +template +X86_SIMD_SORT_INLINE void partition_vec(type_t *arr, + arrsize_t &left, + arrsize_t &unpartitioned, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t &smallest_vec, + reg_t &biggest_vec) +{ + if constexpr (sizeof(reg_t) == 64){ + partition_vec_avx512(arr, left, unpartitioned, curr_vec, pivot_vec, smallest_vec, biggest_vec); + }else if constexpr (sizeof(reg_t) == 32){ + partition_vec_avx2(arr, left, unpartitioned, curr_vec, pivot_vec, smallest_vec, biggest_vec); + }else{ + static_assert(sizeof(reg_t) == -1, "should not reach here"); + } +} + /* * Parition an array based on the pivot and returns the index of the * first element that is greater than or equal to the pivot. @@ -467,8 +525,7 @@ template void sort_n(typename vtype::type_t *arr, int N); template -static void -qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) +static void qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -556,6 +613,28 @@ X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize) } } +template +void avx2_qsort(T *arr, int64_t arrsize) +{ + using vtype = xss::avx2::ymm_vector; + if (arrsize > 1) { + /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ + if constexpr (std::is_floating_point_v) { + int64_t nan_count + = xss::avx2::replace_nan_with_inf(arr, arrsize); + qsort_( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + replace_inf_with_nan(arr, arrsize, nan_count); + } + else { + qsort_( + arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + } + } +} + +void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); + template X86_SIMD_SORT_INLINE void avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) @@ -574,6 +653,27 @@ avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) } } +template +void avx2_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +{ + int64_t indx_last_elem = arrsize - 1; + /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ + if constexpr (std::is_floating_point_v) { + if (UNLIKELY(hasnan)) { + indx_last_elem = move_nans_to_end_of_array(arr, arrsize); + } + } + if (indx_last_elem >= k) { + qselect_, T>( + arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); + } +} + +void avx512_qselect_fp16(uint16_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan = false); + template X86_SIMD_SORT_INLINE void avx512_partial_qsort(T *arr, arrsize_t k, @@ -584,4 +684,19 @@ X86_SIMD_SORT_INLINE void avx512_partial_qsort(T *arr, avx512_qsort(arr, k - 1); } +template +inline void avx2_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +{ + avx2_qselect(arr, k - 1, arrsize, hasnan); + avx2_qsort(arr, k - 1); +} +inline void avx512_partial_qsort_fp16(uint16_t *arr, + int64_t k, + int64_t arrsize, + bool hasnan = false) +{ + avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); + avx512_qsort_fp16(arr, k - 1); +} + #endif // AVX512_QSORT_COMMON diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index a768a580..1a2e313d 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -2,6 +2,7 @@ #define XSS_NETWORK_QSORT #include "xss-optimal-networks.hpp" +#include "xss-common-qsort.h" template X86_SIMD_SORT_FINLINE void bitonic_sort_n_vec(reg_t *regs) diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp index 15fe36a2..c9f7bdb3 100644 --- a/src/xss-pivot-selection.hpp +++ b/src/xss-pivot-selection.hpp @@ -2,7 +2,7 @@ template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); template -X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, +X86_SIMD_SORT_INLINE type_t get_pivot_avx512_16bit(type_t *arr, const arrsize_t left, const arrsize_t right) { @@ -46,7 +46,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_16bit(type_t *arr, } template -X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, +X86_SIMD_SORT_INLINE type_t get_pivot_avx512_32bit(type_t *arr, const arrsize_t left, const arrsize_t right) { @@ -76,7 +76,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, } template -X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, +X86_SIMD_SORT_INLINE type_t get_pivot_avx512_64bit(type_t *arr, const arrsize_t left, const arrsize_t right) { @@ -96,19 +96,49 @@ X86_SIMD_SORT_INLINE type_t get_pivot_64bit(type_t *arr, return ((type_t *)&sort)[4]; } +template +X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N); + +template +X86_SIMD_SORT_INLINE type_t get_pivot_scalar(type_t *arr, + const arrsize_t left, + const arrsize_t right) +{ + type_t samples[vtype::numlanes]; + + arrsize_t delta = (right - left) / vtype::numlanes; + + for (int i = 0; i < vtype::numlanes; i++){ + samples[i] = arr[left + i * delta]; + } + + sort_n(samples, vtype::numlanes); + + return samples[vtype::numlanes / 2]; +} + template X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, const arrsize_t left, const arrsize_t right) { - if constexpr (vtype::numlanes == 8) - return get_pivot_64bit(arr, left, right); - else if constexpr (vtype::numlanes == 16) - return get_pivot_32bit(arr, left, right); - else if constexpr (vtype::numlanes == 32) - return get_pivot_16bit(arr, left, right); - else - return arr[right]; + using reg_t = typename vtype::reg_t; + if constexpr (sizeof(reg_t) == 64){ // AVX512 + if constexpr (vtype::numlanes == 8) + return get_pivot_avx512_64bit(arr, left, right); + else if constexpr (vtype::numlanes == 16) + return get_pivot_avx512_32bit(arr, left, right); + else if constexpr (vtype::numlanes == 32) + return get_pivot_avx512_16bit(arr, left, right); + else + static_assert(vtype::numlanes == -1, "should not reach here"); + }else if constexpr (sizeof(reg_t) == 32) { // AVX2 + if constexpr (vtype::numlanes == 8){ + return get_pivot_scalar(arr, left, right); + } + }else{ + static_assert(sizeof(reg_t) == -1, "should not reach here"); + } } template From 68938d49fe4c77477aa13fabe5847abf1fd8cfce Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 12 Oct 2023 11:11:14 -0700 Subject: [PATCH 02/19] Bug fix in runtime dispatch and variadic args --- lib/x86simdsort.cpp | 51 ++++++++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/lib/x86simdsort.cpp b/lib/x86simdsort.cpp index 57cac471..e5803d06 100644 --- a/lib/x86simdsort.cpp +++ b/lib/x86simdsort.cpp @@ -7,7 +7,7 @@ static int check_cpu_feature_support(std::string_view cpufeature) { - const char* disable_avx512 = std::getenv("XSS_DISABLE_AVX512"); + const char *disable_avx512 = std::getenv("XSS_DISABLE_AVX512"); if ((cpufeature == "avx512_spr") && (!disable_avx512)) #ifdef __FLT16_MAX__ @@ -100,20 +100,20 @@ dispatch_requested(std::string_view cpurequested, } /* runtime dispatch mechanism */ -#define DISPATCH(func, TYPE, ...) \ +#define DISPATCH(func, TYPE, ISA) \ DECLARE_INTERNAL_##func(TYPE) static __attribute__((constructor)) void \ CAT(CAT(resolve_, func), TYPE)(void) \ { \ CAT(CAT(internal_, func), TYPE) = &xss::scalar::func; \ __builtin_cpu_init(); \ - std::string_view preferred_cpu = find_preferred_cpu({__VA_ARGS__}); \ - if constexpr (dispatch_requested("avx512", {__VA_ARGS__})) { \ + std::string_view preferred_cpu = find_preferred_cpu(ISA); \ + if constexpr (dispatch_requested("avx512", ISA)) { \ if (preferred_cpu.find("avx512") != std::string_view::npos) { \ CAT(CAT(internal_, func), TYPE) = &xss::avx512::func; \ return; \ } \ } \ - else if constexpr (dispatch_requested("avx2", {__VA_ARGS__})) { \ + if constexpr (dispatch_requested("avx2", ISA)) { \ if (preferred_cpu.find("avx2") != std::string_view::npos) { \ CAT(CAT(internal_, func), TYPE) = &xss::avx2::func; \ return; \ @@ -121,13 +121,19 @@ dispatch_requested(std::string_view cpurequested, } \ } +#define ISA_LIST(...) \ + std::initializer_list \ + { \ + __VA_ARGS__ \ + } + namespace x86simdsort { #ifdef __FLT16_MAX__ -DISPATCH(qsort, _Float16, "avx512_spr") -DISPATCH(qselect, _Float16, "avx512_spr") -DISPATCH(partial_qsort, _Float16, "avx512_spr") -DISPATCH(argsort, _Float16, "none") -DISPATCH(argselect, _Float16, "none") +DISPATCH(qsort, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(qselect, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(partial_qsort, _Float16, ISA_LIST("avx512_spr")) +DISPATCH(argsort, _Float16, ISA_LIST("none")) +DISPATCH(argselect, _Float16, ISA_LIST("none")) #endif #define DISPATCH_ALL(func, ISA_16BIT, ISA_32BIT, ISA_64BIT) \ @@ -140,10 +146,25 @@ DISPATCH(argselect, _Float16, "none") DISPATCH(func, uint64_t, ISA_64BIT) \ DISPATCH(func, double, ISA_64BIT) -DISPATCH_ALL(qsort, ("avx512_icl"), ("avx512_skx","avx2"), ("avx512_skx")) -DISPATCH_ALL(qselect, ("avx512_icl"), ("avx512_skx","avx2"), ("avx512_skx")) -DISPATCH_ALL(partial_qsort, ("avx512_icl"), ("avx512_skx","avx2"), ("avx512_skx")) -DISPATCH_ALL(argsort, "none", "avx512_skx", "avx512_skx") -DISPATCH_ALL(argselect, "none", "avx512_skx", "avx512_skx") +DISPATCH_ALL(qsort, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx"))) +DISPATCH_ALL(qselect, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx"))) +DISPATCH_ALL(partial_qsort, + (ISA_LIST("avx512_icl")), + (ISA_LIST("avx512_skx", "avx2")), + (ISA_LIST("avx512_skx"))) +DISPATCH_ALL(argsort, + (ISA_LIST("none")), + (ISA_LIST("avx512_skx")), + (ISA_LIST("avx512_skx"))) +DISPATCH_ALL(argselect, + (ISA_LIST("none")), + (ISA_LIST("avx512_skx")), + (ISA_LIST("avx512_skx"))) } // namespace x86simdsort From 73b9ba502d5eb8b0ba99d40d5cdda337df992ec8 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Mon, 16 Oct 2023 12:24:19 -0700 Subject: [PATCH 03/19] Bugfixes in AVX2 code --- src/avx2-emu-funcs.hpp | 30 +++++++++++++----------------- src/xss-common-qsort.h | 26 ++++++++++++-------------- 2 files changed, 25 insertions(+), 31 deletions(-) diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 228603fe..ae2862f5 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -100,14 +100,12 @@ template T avx2_emu_reduce_max32(typename ymm_vector::reg_t x) { using vtype = ymm_vector; - typename vtype::reg_t inter1 = vtype::max( - x, vtype::template shuffle(x)); - typename vtype::reg_t inter2 = vtype::permutevar( - inter1, _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0)); - typename vtype::reg_t inter3 = vtype::max( - inter2, vtype::template shuffle(inter2)); - T can1 = vtype::template extract<0>(inter3); - T can2 = vtype::template extract<2>(inter3); + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::max(x, vtype::template shuffle(x)); + reg_t inter2 = vtype::max(inter1, vtype::template shuffle(inter1)); + T can1 = vtype::template extract<0>(inter2); + T can2 = vtype::template extract<4>(inter2); return std::max(can1, can2); } @@ -115,14 +113,12 @@ template T avx2_emu_reduce_min32(typename ymm_vector::reg_t x) { using vtype = ymm_vector; - typename vtype::reg_t inter1 = vtype::min( - x, vtype::template shuffle(x)); - typename vtype::reg_t inter2 = vtype::permutevar( - inter1, _mm256_set_epi32(3, 2, 1, 0, 3, 2, 1, 0)); - typename vtype::reg_t inter3 = vtype::min( - inter2, vtype::template shuffle(inter2)); - T can1 = vtype::template extract<0>(inter3); - T can2 = vtype::template extract<2>(inter3); + using reg_t = typename vtype::reg_t; + + reg_t inter1 = vtype::min(x, vtype::template shuffle(x)); + reg_t inter2 = vtype::min(inter1, vtype::template shuffle(inter1)); + T can1 = vtype::template extract<0>(inter2); + T can2 = vtype::template extract<4>(inter2); return std::min(can1, can2); } @@ -184,7 +180,7 @@ int32_t avx2_double_compressstore32(void *left_addr, typename vtype::reg_t temp = vtype::permutevar(reg, perm); vtype::mask_storeu(leftStore, left, temp); - vtype::mask_storeu(rightStore - vtype::numlanes, ~left, temp); + vtype::mask_storeu(rightStore, ~left, temp); return _mm_popcnt_u32(shortMask); } diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index e0eadbcf..9137e8ed 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -51,7 +51,6 @@ inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize); template void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize); ->>>>>>> d6479e7 (Core AVX2 code logic):src/xss-common-qsort.h template bool is_a_nan(T elem) { @@ -216,31 +215,30 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx2(type_t *l_store, typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); int32_t amount_ge_pivot = vtype::double_compressstore( - arr + left, arr + left + unpartitioned, ge_mask, curr_vec); - - left += (vtype::numlanes - amount_ge_pivot); - unpartitioned -= vtype::numlanes; + l_store, r_store, ge_mask, curr_vec); smallest_vec = vtype::min(curr_vec, smallest_vec); biggest_vec = vtype::max(curr_vec, biggest_vec); + + return amount_ge_pivot; } // Generic function dispatches to AVX2 or AVX512 code template -X86_SIMD_SORT_INLINE void partition_vec(type_t *arr, - arrsize_t &left, - arrsize_t &unpartitioned, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t &smallest_vec, - reg_t &biggest_vec) +X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, + type_t *r_store, + const reg_t curr_vec, + const reg_t pivot_vec, + reg_t &smallest_vec, + reg_t &biggest_vec) { if constexpr (sizeof(reg_t) == 64){ - partition_vec_avx512(arr, left, unpartitioned, curr_vec, pivot_vec, smallest_vec, biggest_vec); + return partition_vec_avx512(l_store, r_store, curr_vec, pivot_vec, smallest_vec, biggest_vec); }else if constexpr (sizeof(reg_t) == 32){ - partition_vec_avx2(arr, left, unpartitioned, curr_vec, pivot_vec, smallest_vec, biggest_vec); + return partition_vec_avx2(l_store, r_store, curr_vec, pivot_vec, smallest_vec, biggest_vec); }else{ static_assert(sizeof(reg_t) == -1, "should not reach here"); + return 0; } } From 0f3c1384cb65647ca5e7a58d0c97d6acc95687e2 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 18 Oct 2023 10:52:55 -0700 Subject: [PATCH 04/19] Removed some unused code, changed how single compresstore works for AVX2 --- src/avx2-emu-funcs.hpp | 57 +++++++++--------------------------------- 1 file changed, 12 insertions(+), 45 deletions(-) diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index ae2862f5..6f29fcae 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -122,42 +122,24 @@ T avx2_emu_reduce_min32(typename ymm_vector::reg_t x) return std::min(can1, can2); } -template -typename ymm_vector::opmask_t -avx2_emu_fpclassify64(typename ymm_vector::reg_t x, int mask) -{ - using vtype = ymm_vector; - T store[vtype::numlanes]; - vtype::storeu(&store[0], x); - int64_t res[vtype::numlanes]; - - for (int i = 0; i < vtype::numlanes; i++) { - bool flagged = scalar_emu_fpclassify(store[i]); - res[i] = 0xFFFFFFFFFFFFFFFF; - } - return vtype::loadu(res); -} - template void avx2_emu_mask_compressstoreu(void *base_addr, typename ymm_vector::opmask_t k, typename ymm_vector::reg_t reg) { using vtype = ymm_vector; - T *storage = (T *)base_addr; - int32_t mask[vtype::numlanes]; - T data[vtype::numlanes]; - - _mm256_storeu_si256((__m256i *)&mask[0], k); - vtype::storeu(&data[0], reg); - -#pragma GCC unroll 8 - for (int i = 0; i < vtype::numlanes; i++) { - if (mask[i]) { - *storage = data[i]; - storage++; - } - } + + T *leftStore = (T *)base_addr; + + int32_t shortMask = avx2_mask_helper32(k); + const __m256i &perm = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); + const __m256i &left = _mm256_loadu_si256( + (const __m256i *)avx2_compressstore_lut32_left[shortMask].data()); + + typename vtype::reg_t temp = vtype::permutevar(reg, perm); + + vtype::mask_storeu(leftStore, left, temp); } template @@ -185,21 +167,6 @@ int32_t avx2_double_compressstore32(void *left_addr, return _mm_popcnt_u32(shortMask); } -template -int64_t avx2_emu_popcnt(__m256i reg) -{ - using vtype = ymm_vector; - - int32_t data[vtype::numlanes]; - _mm256_storeu_si256((__m256i *)&data[0], reg); - - int64_t pop = 0; - for (int i = 0; i < vtype::numlanes; i++) { - pop += _popcnt32(data[i]); - } - return pop; -} - template typename ymm_vector::reg_t avx2_emu_max(typename ymm_vector::reg_t x, typename ymm_vector::reg_t y) From 55a6b513c150faa913d491432e7abd324caed848 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Wed, 18 Oct 2023 14:58:41 -0700 Subject: [PATCH 05/19] Some small bug fixes --- src/avx2-32bit-common.h | 2 +- src/xss-common-qsort.h | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h index 110de03e..010425aa 100644 --- a/src/avx2-32bit-common.h +++ b/src/avx2-32bit-common.h @@ -558,7 +558,7 @@ inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize) __m256 in_ymm = ymm_vector::maskz_loadu(loadmask, arr); __m256i nanmask = _mm256_castps_si256( _mm256_cmp_ps(in_ymm, in_ymm, _CMP_NEQ_UQ)); - nan_count += _popcnt32(avx2_mask_helper32(nanmask)); + nan_count += _mm_popcnt_u32(avx2_mask_helper32(nanmask)); ymm_vector::mask_storeu(arr, nanmask, YMM_MAX_FLOAT); arr += 8; arrsize -= 8; diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 9137e8ed..316f936d 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -661,6 +661,7 @@ void avx2_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) indx_last_elem = move_nans_to_end_of_array(arr, arrsize); } } + UNUSED(hasnan); if (indx_last_elem >= k) { qselect_, T>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); From b02fe0a485e54364d28a3eac61323b65f2d89650 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Thu, 19 Oct 2023 10:37:50 -0700 Subject: [PATCH 06/19] Removes unnecessary forward includes --- src/xss-common-qsort.h | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 316f936d..74ea8948 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -631,8 +631,6 @@ void avx2_qsort(T *arr, int64_t arrsize) } } -void avx512_qsort_fp16(uint16_t *arr, int64_t arrsize); - template X86_SIMD_SORT_INLINE void avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) @@ -668,11 +666,6 @@ void avx2_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) } } -void avx512_qselect_fp16(uint16_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan = false); - template X86_SIMD_SORT_INLINE void avx512_partial_qsort(T *arr, arrsize_t k, @@ -689,13 +682,5 @@ inline void avx2_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = avx2_qselect(arr, k - 1, arrsize, hasnan); avx2_qsort(arr, k - 1); } -inline void avx512_partial_qsort_fp16(uint16_t *arr, - int64_t k, - int64_t arrsize, - bool hasnan = false) -{ - avx512_qselect_fp16(arr, k - 1, arrsize, hasnan); - avx512_qsort_fp16(arr, k - 1); -} #endif // AVX512_QSORT_COMMON From dec22786e70c98d591446b5126e13f3e1072e0df Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Thu, 19 Oct 2023 10:42:19 -0700 Subject: [PATCH 07/19] Fixed types in a few places --- src/avx2-32bit-common.h | 30 ++++-------------------------- src/xss-common-qsort.h | 12 ++++++------ 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h index 010425aa..d2e4d8b8 100644 --- a/src/avx2-32bit-common.h +++ b/src/avx2-32bit-common.h @@ -549,9 +549,9 @@ struct ymm_vector { } }; -inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize) +inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize) { - int64_t nan_count = 0; + arrsize_t nan_count = 0; __mmask8 loadmask = 0xFF; while (arrsize > 0) { if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } @@ -567,36 +567,14 @@ inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize) } X86_SIMD_SORT_INLINE void -replace_inf_with_nan(float *arr, int64_t arrsize, int64_t nan_count) +replace_inf_with_nan(float *arr, arrsize_t arrsize, arrsize_t nan_count) { - for (int64_t ii = arrsize - 1; nan_count > 0; --ii) { + for (arrsize_t ii = arrsize - 1; nan_count > 0; --ii) { arr[ii] = std::nan("1"); nan_count -= 1; } } -template -X86_SIMD_SORT_INLINE type_t get_pivot_32bit(type_t *arr, - const int64_t left, - const int64_t right) -{ - // median of 8 - int64_t size = (right - left) / 8; - using reg_t = typename vtype::reg_t; - __m256i rand_index = _mm256_set_epi32(left + size, - left + 2 * size, - left + 3 * size, - left + 4 * size, - left + 5 * size, - left + 6 * size, - left + 7 * size, - left + 8 * size); - reg_t rand_vec = vtype::template i64gather(rand_index, arr); - // pivot will never be a nan, since there are no nan's! - reg_t sort = sort_ymm_32bit(rand_vec); - return ((type_t *)&sort)[4]; -} - struct avx2_32bit_swizzle_ops{ template X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){ diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 74ea8948..709d4ca4 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -43,7 +43,7 @@ namespace avx2{ template struct ymm_vector; -inline int64_t replace_nan_with_inf(float *arr, int64_t arrsize); +inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize); } } @@ -612,13 +612,13 @@ X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize) } template -void avx2_qsort(T *arr, int64_t arrsize) +void avx2_qsort(T *arr, arrsize_t arrsize) { using vtype = xss::avx2::ymm_vector; if (arrsize > 1) { /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { - int64_t nan_count + arrsize_t nan_count = xss::avx2::replace_nan_with_inf(arr, arrsize); qsort_( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); @@ -650,9 +650,9 @@ avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) } template -void avx2_qselect(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +void avx2_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) { - int64_t indx_last_elem = arrsize - 1; + arrsize_t indx_last_elem = arrsize - 1; /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { if (UNLIKELY(hasnan)) { @@ -677,7 +677,7 @@ X86_SIMD_SORT_INLINE void avx512_partial_qsort(T *arr, } template -inline void avx2_partial_qsort(T *arr, int64_t k, int64_t arrsize, bool hasnan = false) +inline void avx2_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) { avx2_qselect(arr, k - 1, arrsize, hasnan); avx2_qsort(arr, k - 1); From 0e41ebcdf9f978380398118cd602eb645c7c3052 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 20 Oct 2023 09:45:59 -0700 Subject: [PATCH 08/19] Make pivot selection simpler --- src/xss-pivot-selection.hpp | 145 +++--------------------------------- 1 file changed, 12 insertions(+), 133 deletions(-) diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp index c9f7bdb3..29394321 100644 --- a/src/xss-pivot-selection.hpp +++ b/src/xss-pivot-selection.hpp @@ -2,152 +2,31 @@ template X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); template -X86_SIMD_SORT_INLINE type_t get_pivot_avx512_16bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - // median of 32 - arrsize_t size = (right - left) / 32; - type_t vec_arr[32] = {arr[left], - arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size], - arr[left + 17 * size], - arr[left + 18 * size], - arr[left + 19 * size], - arr[left + 20 * size], - arr[left + 21 * size], - arr[left + 22 * size], - arr[left + 23 * size], - arr[left + 24 * size], - arr[left + 25 * size], - arr[left + 26 * size], - arr[left + 27 * size], - arr[left + 28 * size], - arr[left + 29 * size], - arr[left + 30 * size], - arr[left + 31 * size]}; - typename vtype::reg_t rand_vec = vtype::loadu(vec_arr); - typename vtype::reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[16]; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_avx512_32bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - // median of 16 - arrsize_t size = (right - left) / 16; - using reg_t = typename vtype::reg_t; - type_t vec_arr[16] = {arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size], - arr[left + 9 * size], - arr[left + 10 * size], - arr[left + 11 * size], - arr[left + 12 * size], - arr[left + 13 * size], - arr[left + 14 * size], - arr[left + 15 * size], - arr[left + 16 * size]}; - reg_t rand_vec = vtype::loadu(vec_arr); - reg_t sort = vtype::sort_vec(rand_vec); - // pivot will never be a nan, since there are no nan's! - return ((type_t *)&sort)[8]; -} - -template -X86_SIMD_SORT_INLINE type_t get_pivot_avx512_64bit(type_t *arr, - const arrsize_t left, - const arrsize_t right) +X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, + const arrsize_t left, + const arrsize_t right) { - // median of 8 - arrsize_t size = (right - left) / 8; using reg_t = typename vtype::reg_t; - reg_t rand_vec = vtype::set(arr[left + size], - arr[left + 2 * size], - arr[left + 3 * size], - arr[left + 4 * size], - arr[left + 5 * size], - arr[left + 6 * size], - arr[left + 7 * size], - arr[left + 8 * size]); - // pivot will never be a nan, since there are no nan's! - reg_t sort = vtype::sort_vec(rand_vec); - return ((type_t *)&sort)[4]; -} - -template -X86_SIMD_SORT_INLINE void sort_n(typename vtype::type_t *arr, int N); - -template -X86_SIMD_SORT_INLINE type_t get_pivot_scalar(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ type_t samples[vtype::numlanes]; - arrsize_t delta = (right - left) / vtype::numlanes; - - for (int i = 0; i < vtype::numlanes; i++){ + for (int i = 0; i < vtype::numlanes; i++) { samples[i] = arr[left + i * delta]; } - - sort_n(samples, vtype::numlanes); - - return samples[vtype::numlanes / 2]; -} + reg_t rand_vec = vtype::loadu(samples); + reg_t sort = vtype::sort_vec(rand_vec); -template -X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, - const arrsize_t left, - const arrsize_t right) -{ - using reg_t = typename vtype::reg_t; - if constexpr (sizeof(reg_t) == 64){ // AVX512 - if constexpr (vtype::numlanes == 8) - return get_pivot_avx512_64bit(arr, left, right); - else if constexpr (vtype::numlanes == 16) - return get_pivot_avx512_32bit(arr, left, right); - else if constexpr (vtype::numlanes == 32) - return get_pivot_avx512_16bit(arr, left, right); - else - static_assert(vtype::numlanes == -1, "should not reach here"); - }else if constexpr (sizeof(reg_t) == 32) { // AVX2 - if constexpr (vtype::numlanes == 8){ - return get_pivot_scalar(arr, left, right); - } - }else{ - static_assert(sizeof(reg_t) == -1, "should not reach here"); - } + return ((type_t *)&sort)[vtype::numlanes / 2]; } template X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr, - arrsize_t left, - arrsize_t right) + const arrsize_t left, + const arrsize_t right) { - if (right - left <= 1024) { return get_pivot(arr, left, right); } + if (right - left <= 1024) { + return get_pivot(arr, left, right); + } using reg_t = typename vtype::reg_t; constexpr int numVecs = 5; From 15d602523e79360f5339cd7a51620c81524c1f81 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 20 Oct 2023 10:53:26 -0700 Subject: [PATCH 09/19] Fixes/changes many small things --- src/avx2-32bit-common.h | 97 ++++++++------------------------------ src/avx2-emu-funcs.hpp | 37 +++++++-------- src/avx512-16bit-common.h | 1 - src/avx512-32bit-qsort.hpp | 1 - src/xss-common-qsort.h | 18 ++----- 5 files changed, 39 insertions(+), 115 deletions(-) diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h index d2e4d8b8..ac6d17dd 100644 --- a/src/avx2-32bit-common.h +++ b/src/avx2-32bit-common.h @@ -22,37 +22,6 @@ #define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 #define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 -namespace xss { -namespace avx2 { - -// Assumes ymm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_32bit(reg_t ymm) -{ - - const typename vtype::opmask_t oxAA = _mm256_set_epi32( - 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); - const typename vtype::opmask_t oxCC = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); - const typename vtype::opmask_t oxF0 = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); - - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - ymm = cmp_merge( - ymm, - vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_4), ymm), - oxF0); - // 2) half_cleaner[4] - ymm = cmp_merge( - ymm, - vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_3), ymm), - oxCC); - // 3) half_cleaner[1] - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - return ymm; -} - /* * Assumes ymm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg @@ -85,7 +54,7 @@ X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) struct avx2_32bit_swizzle_ops; template <> -struct ymm_vector { +struct avx2_vector { using type_t = int32_t; using reg_t = __m256i; using ymmi_t = __m256i; @@ -231,13 +200,9 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_ymm_32bit>(x); - } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_ymm_32bit>(x); } static reg_t cast_from(__m256i v){ return v; @@ -247,7 +212,7 @@ struct ymm_vector { } }; template <> -struct ymm_vector { +struct avx2_vector { using type_t = uint32_t; using reg_t = __m256i; using ymmi_t = __m256i; @@ -378,13 +343,9 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_ymm_32bit>(x); - } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_ymm_32bit>(x); } static reg_t cast_from(__m256i v){ return v; @@ -394,7 +355,7 @@ struct ymm_vector { } }; template <> -struct ymm_vector { +struct avx2_vector { using type_t = float; using reg_t = __m256; using ymmi_t = __m256i; @@ -440,6 +401,19 @@ struct ymm_vector { { return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ)); } + static opmask_t get_partial_loadmask(int size) + { + return (0x0001 << size) - 0x0001; + } + template + static opmask_t fpclass(reg_t x) + { + if constexpr (type == (0x01 | 0x80)){ + return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q)); + }else{ + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } template static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) @@ -533,13 +507,9 @@ struct ymm_vector { { _mm256_storeu_ps((float *)mem, x); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_ymm_32bit>(x); - } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_ymm_32bit>(x); } static reg_t cast_from(__m256i v){ return _mm256_castsi256_ps(v); @@ -549,32 +519,6 @@ struct ymm_vector { } }; -inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize) -{ - arrsize_t nan_count = 0; - __mmask8 loadmask = 0xFF; - while (arrsize > 0) { - if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } - __m256 in_ymm = ymm_vector::maskz_loadu(loadmask, arr); - __m256i nanmask = _mm256_castps_si256( - _mm256_cmp_ps(in_ymm, in_ymm, _CMP_NEQ_UQ)); - nan_count += _mm_popcnt_u32(avx2_mask_helper32(nanmask)); - ymm_vector::mask_storeu(arr, nanmask, YMM_MAX_FLOAT); - arr += 8; - arrsize -= 8; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(float *arr, arrsize_t arrsize, arrsize_t nan_count) -{ - for (arrsize_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nan("1"); - nan_count -= 1; - } -} - struct avx2_32bit_swizzle_ops{ template X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){ @@ -635,7 +579,4 @@ struct avx2_32bit_swizzle_ops{ return vtype::cast_from(v1); } }; - -} // namespace avx2 -} // namespace xss #endif diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 6f29fcae..f7c0dfb9 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -5,9 +5,6 @@ #include #include "xss-common-qsort.h" -namespace xss { -namespace avx2 { - constexpr auto avx2_mask_helper_lut32 = [] { std::array, 256> lut {}; for (int64_t i = 0; i <= 0xFF; i++) { @@ -97,9 +94,9 @@ static __m256i operator~(const avx2_mask_helper32 x) // Emulators for intrinsics missing from AVX2 compared to AVX512 template -T avx2_emu_reduce_max32(typename ymm_vector::reg_t x) +T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) { - using vtype = ymm_vector; + using vtype = avx2_vector; using reg_t = typename vtype::reg_t; reg_t inter1 = vtype::max(x, vtype::template shuffle(x)); @@ -110,9 +107,9 @@ T avx2_emu_reduce_max32(typename ymm_vector::reg_t x) } template -T avx2_emu_reduce_min32(typename ymm_vector::reg_t x) +T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) { - using vtype = ymm_vector; + using vtype = avx2_vector; using reg_t = typename vtype::reg_t; reg_t inter1 = vtype::min(x, vtype::template shuffle(x)); @@ -124,10 +121,10 @@ T avx2_emu_reduce_min32(typename ymm_vector::reg_t x) template void avx2_emu_mask_compressstoreu(void *base_addr, - typename ymm_vector::opmask_t k, - typename ymm_vector::reg_t reg) + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) { - using vtype = ymm_vector; + using vtype = avx2_vector; T *leftStore = (T *)base_addr; @@ -145,10 +142,10 @@ void avx2_emu_mask_compressstoreu(void *base_addr, template int32_t avx2_double_compressstore32(void *left_addr, void *right_addr, - typename ymm_vector::opmask_t k, - typename ymm_vector::reg_t reg) + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) { - using vtype = ymm_vector; + using vtype = avx2_vector; T *leftStore = (T *)left_addr; T *rightStore = (T *)right_addr; @@ -168,10 +165,10 @@ int32_t avx2_double_compressstore32(void *left_addr, } template -typename ymm_vector::reg_t avx2_emu_max(typename ymm_vector::reg_t x, - typename ymm_vector::reg_t y) +typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) { - using vtype = ymm_vector; + using vtype = avx2_vector; typename vtype::opmask_t nlt = vtype::ge(x, y); return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y), _mm256_castsi256_pd(x), @@ -179,16 +176,14 @@ typename ymm_vector::reg_t avx2_emu_max(typename ymm_vector::reg_t x, } template -typename ymm_vector::reg_t avx2_emu_min(typename ymm_vector::reg_t x, - typename ymm_vector::reg_t y) +typename avx2_vector::reg_t avx2_emu_min(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) { - using vtype = ymm_vector; + using vtype = avx2_vector; typename vtype::opmask_t nlt = vtype::ge(x, y); return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), _mm256_castsi256_pd(y), _mm256_castsi256_pd(nlt))); } -} // namespace avx2 -} // namespace x86_simd_sort #endif \ No newline at end of file diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 41116b33..28c1c1fe 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -8,7 +8,6 @@ #define AVX512_16BIT_COMMON #include "xss-common-qsort.h" -#include "xss-network-qsort.hpp" /* * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 4d244d6e..e96d4a9a 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -9,7 +9,6 @@ #define AVX512_QSORT_32BIT #include "xss-common-qsort.h" -#include "xss-network-qsort.hpp" /* * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 709d4ca4..6c440078 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -38,18 +38,8 @@ #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" -namespace xss{ -namespace avx2{ template -struct ymm_vector; - -inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize); -} -} - -// key-value sort routines -template -void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize); +struct avx2_vector; template bool is_a_nan(T elem) @@ -614,12 +604,12 @@ X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize) template void avx2_qsort(T *arr, arrsize_t arrsize) { - using vtype = xss::avx2::ymm_vector; + using vtype = avx2_vector; if (arrsize > 1) { /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { arrsize_t nan_count - = xss::avx2::replace_nan_with_inf(arr, arrsize); + = replace_nan_with_inf(arr, arrsize); qsort_( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); @@ -661,7 +651,7 @@ void avx2_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) } UNUSED(hasnan); if (indx_last_elem >= k) { - qselect_, T>( + qselect_, T>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } } From 4fa2255f7f7156a1165d8fbe9d35ed22aae296cc Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 20 Oct 2023 13:43:10 -0700 Subject: [PATCH 10/19] Fix rebase merge bugs --- src/avx512-64bit-argsort.hpp | 2 +- src/avx512-64bit-keyvaluesort.hpp | 2 +- src/avx512-64bit-qsort.hpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index 2d2c33f5..f706ded7 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -7,7 +7,7 @@ #ifndef AVX512_ARGSORT_64BIT #define AVX512_ARGSORT_64BIT -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" #include "avx512-64bit-common.h" #include "xss-network-keyvaluesort.hpp" #include diff --git a/src/avx512-64bit-keyvaluesort.hpp b/src/avx512-64bit-keyvaluesort.hpp index c24e575a..b1ec0cd2 100644 --- a/src/avx512-64bit-keyvaluesort.hpp +++ b/src/avx512-64bit-keyvaluesort.hpp @@ -8,7 +8,7 @@ #ifndef AVX512_QSORT_64BIT_KV #define AVX512_QSORT_64BIT_KV -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" #include "avx512-64bit-common.h" #include "xss-network-keyvaluesort.hpp" diff --git a/src/avx512-64bit-qsort.hpp b/src/avx512-64bit-qsort.hpp index afafe250..4dcaeafa 100644 --- a/src/avx512-64bit-qsort.hpp +++ b/src/avx512-64bit-qsort.hpp @@ -7,7 +7,7 @@ #ifndef AVX512_QSORT_64BIT #define AVX512_QSORT_64BIT -#include "avx512-common-qsort.h" +#include "xss-common-qsort.h" #include "avx512-64bit-common.h" #endif // AVX512_QSORT_64BIT From 3a3baae6719f6a6ff706e86b63d30a9dc32e3ba2 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 20 Oct 2023 13:45:42 -0700 Subject: [PATCH 11/19] Move avx2_vector declaration to common includes --- src/xss-common-includes.h | 4 ++++ src/xss-common-qsort.h | 3 --- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index f1465977..9fff8d09 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -72,3 +72,7 @@ struct zmm_vector; template struct ymm_vector; + +template +struct avx2_vector; + diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 6c440078..9fadc7b0 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -38,9 +38,6 @@ #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" -template -struct avx2_vector; - template bool is_a_nan(T elem) { From 2f397e40fbd91d4843f032e27b54d199dd7e1455 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 20 Oct 2023 13:47:51 -0700 Subject: [PATCH 12/19] Add avx2 to examples --- examples/Makefile | 5 ++++- examples/avx2-32bit-qsort.cpp | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 examples/avx2-32bit-qsort.cpp diff --git a/examples/Makefile b/examples/Makefile index a66db80c..483c39df 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -1,6 +1,6 @@ CXX ?= g++-12 CFLAGS = -I../src -std=c++17 -O3 -EXE = argsort kvsort qsortfp16 qsort16 qsort32 qsort64 +EXE = qsort32avx2 argsort kvsort qsortfp16 qsort16 qsort32 qsort64 default: all all : $(EXE) @@ -14,6 +14,9 @@ qsort16: avx512-16bit-qsort.cpp qsort32: avx512-32bit-qsort.cpp $(CXX) -o qsort32 -march=skylake-avx512 $(CFLAGS) avx512-32bit-qsort.cpp +qsort32avx2: avx2-32bit-qsort.cpp + $(CXX) -o qsort32avx2 -march=haswell $(CFLAGS) avx2-32bit-qsort.cpp + qsort64: avx512-64bit-qsort.cpp $(CXX) -o qsort64 -march=skylake-avx512 $(CFLAGS) avx512-64bit-qsort.cpp diff --git a/examples/avx2-32bit-qsort.cpp b/examples/avx2-32bit-qsort.cpp new file mode 100644 index 00000000..eddedca8 --- /dev/null +++ b/examples/avx2-32bit-qsort.cpp @@ -0,0 +1,10 @@ +#include "avx2-32bit-qsort.hpp" + +int main() { + const int size = 1000; + float arr[size]; + avx2_qsort(arr, size); + avx2_qselect(arr, 10, size); + avx2_partial_qsort(arr, 10, size); + return 0; +} From 3561db30feccecc1e45c643869c1d5de2d4c3c11 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 20 Oct 2023 13:53:29 -0700 Subject: [PATCH 13/19] Changed partition code --- src/avx2-32bit-common.h | 45 ++++++++++++------------- src/avx2-emu-funcs.hpp | 2 +- src/avx512-16bit-qsort.hpp | 21 ++++++++++++ src/avx512-32bit-qsort.hpp | 21 ++++++++++++ src/avx512-64bit-common.h | 21 ++++++++++++ src/avx512fp16-16bit-qsort.hpp | 7 ++++ src/xss-common-qsort.h | 60 ++++++++-------------------------- 7 files changed, 106 insertions(+), 71 deletions(-) diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h index ac6d17dd..dd82d095 100644 --- a/src/avx2-32bit-common.h +++ b/src/avx2-32bit-common.h @@ -129,14 +129,6 @@ struct avx2_vector { { return avx2_emu_mask_compressstoreu(mem, mask, x); } - static int32_t double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32( - left_addr, right_addr, k, reg); - } static reg_t maskz_loadu(opmask_t mask, void const *mem) { return _mm256_maskload_epi32((const int *)mem, mask); @@ -210,6 +202,13 @@ struct avx2_vector { static __m256i cast_to(reg_t v){ return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32(left_addr, right_addr, k, reg); + } }; template <> struct avx2_vector { @@ -277,14 +276,6 @@ struct avx2_vector { { return avx2_emu_mask_compressstoreu(mem, mask, x); } - static int32_t double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32( - left_addr, right_addr, k, reg); - } static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { reg_t dst = _mm256_maskload_epi32((const int *)mem, mask); @@ -353,6 +344,13 @@ struct avx2_vector { static __m256i cast_to(reg_t v){ return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32(left_addr, right_addr, k, reg); + } }; template <> struct avx2_vector { @@ -439,14 +437,6 @@ struct avx2_vector { { return avx2_emu_mask_compressstoreu(mem, mask, x); } - static int32_t double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32( - left_addr, right_addr, k, reg); - } static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) { reg_t dst = _mm256_maskload_ps((type_t *)mem, mask); @@ -517,6 +507,13 @@ struct avx2_vector { static __m256i cast_to(reg_t v){ return _mm256_castps_si256(v); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32(left_addr, right_addr, k, reg); + } }; struct avx2_32bit_swizzle_ops{ diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index f7c0dfb9..ab8ea567 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -140,7 +140,7 @@ void avx2_emu_mask_compressstoreu(void *base_addr, } template -int32_t avx2_double_compressstore32(void *left_addr, +int avx2_double_compressstore32(void *left_addr, void *right_addr, typename avx2_vector::opmask_t k, typename avx2_vector::reg_t reg) diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index fdfba924..99a65a46 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -177,6 +177,13 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; template <> @@ -301,6 +308,13 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -422,6 +436,13 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; template <> diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index e96d4a9a..546dcd0b 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -154,6 +154,13 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -281,6 +288,13 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -422,6 +436,13 @@ struct zmm_vector { { return _mm512_castps_si512(v); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; /* diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index 9c58b494..c029bd9b 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -660,6 +660,13 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -818,6 +825,13 @@ struct zmm_vector { { return v; } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; template <> struct zmm_vector { @@ -982,6 +996,13 @@ struct zmm_vector { { return _mm512_castpd_si512(v); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; /* diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 94e508f0..a2352d0a 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -145,6 +145,13 @@ struct zmm_vector<_Float16> { { return _mm512_castph_si512(v); } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx512_double_compressstore>(left_addr, right_addr, k, reg); + } }; template <> diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 9fadc7b0..d71fb543 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -162,47 +162,34 @@ X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask) reg_t max = vtype::max(in2, in1); return vtype::mask_mov(min, mask, max); // 0 -> min, 1 -> max } -/* - * Parition one ZMM register based on the pivot and returns the - * number of elements that are greater than or equal to the pivot. - */ + template -X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx512(type_t *l_store, - type_t *r_store, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t &smallest_vec, - reg_t &biggest_vec) +int avx512_double_compressstore(type_t *left_addr, + type_t *right_addr, + typename vtype::opmask_t k, + reg_t reg) { - typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); - int amount_ge_pivot = _mm_popcnt_u32((int)ge_mask); + int amount_ge_pivot = _mm_popcnt_u32((int)k); - vtype::mask_compressstoreu(l_store, vtype::knot_opmask(ge_mask), curr_vec); + vtype::mask_compressstoreu(left_addr, vtype::knot_opmask(k), reg); vtype::mask_compressstoreu( - r_store + vtype::numlanes - amount_ge_pivot, ge_mask, curr_vec); - - smallest_vec = vtype::min(curr_vec, smallest_vec); - biggest_vec = vtype::max(curr_vec, biggest_vec); - + right_addr + vtype::numlanes - amount_ge_pivot, k, reg); + return amount_ge_pivot; } -/* - * Parition one YMM register based on the pivot and returns the - * number of elements that are greater than or equal to the pivot. - */ + +// Generic function dispatches to AVX2 or AVX512 code template -X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx2(type_t *l_store, +X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, type_t *r_store, const reg_t curr_vec, const reg_t pivot_vec, reg_t &smallest_vec, reg_t &biggest_vec) { - /* which elements are larger than or equal to the pivot */ typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); - - int32_t amount_ge_pivot = vtype::double_compressstore( - l_store, r_store, ge_mask, curr_vec); + + int amount_ge_pivot = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec); smallest_vec = vtype::min(curr_vec, smallest_vec); biggest_vec = vtype::max(curr_vec, biggest_vec); @@ -210,25 +197,6 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec_avx2(type_t *l_store, return amount_ge_pivot; } -// Generic function dispatches to AVX2 or AVX512 code -template -X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, - type_t *r_store, - const reg_t curr_vec, - const reg_t pivot_vec, - reg_t &smallest_vec, - reg_t &biggest_vec) -{ - if constexpr (sizeof(reg_t) == 64){ - return partition_vec_avx512(l_store, r_store, curr_vec, pivot_vec, smallest_vec, biggest_vec); - }else if constexpr (sizeof(reg_t) == 32){ - return partition_vec_avx2(l_store, r_store, curr_vec, pivot_vec, smallest_vec, biggest_vec); - }else{ - static_assert(sizeof(reg_t) == -1, "should not reach here"); - return 0; - } -} - /* * Parition an array based on the pivot and returns the index of the * first element that is greater than or equal to the pivot. From 9ad4432e6d441fb2118bb66519e27999eac7f570 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 23 Oct 2023 13:01:27 -0700 Subject: [PATCH 14/19] Get rid of the avx2_mask_helper --- src/avx2-32bit-common.h | 35 +++++++++++++------- src/avx2-emu-funcs.hpp | 59 +++++++++------------------------- src/avx512-16bit-qsort.hpp | 16 +++++++++ src/avx512-32bit-qsort.hpp | 16 +++++++-- src/avx512-64bit-common.h | 28 +++++++++++++--- src/avx512fp16-16bit-qsort.hpp | 8 +++-- src/xss-common-qsort.h | 6 ++-- src/xss-network-qsort.hpp | 8 ++--- 8 files changed, 106 insertions(+), 70 deletions(-) diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h index dd82d095..701aca0e 100644 --- a/src/avx2-32bit-common.h +++ b/src/avx2-32bit-common.h @@ -15,7 +15,7 @@ * sorting network (see * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) */ - + // ymm 7, 6, 5, 4, 3, 2, 1, 0 #define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 #define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 @@ -58,11 +58,11 @@ struct avx2_vector { using type_t = int32_t; using reg_t = __m256i; using ymmi_t = __m256i; - using opmask_t = avx2_mask_helper32; + using opmask_t = __m256i; static const uint8_t numlanes = 8; static constexpr int network_sort_threshold = 256; static constexpr int partition_unroll_factor = 4; - + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() @@ -77,7 +77,11 @@ struct avx2_vector { { return _mm256_set1_epi32(type_max()); } // TODO: this should broadcast bits as is? - + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { @@ -215,11 +219,11 @@ struct avx2_vector { using type_t = uint32_t; using reg_t = __m256i; using ymmi_t = __m256i; - using opmask_t = avx2_mask_helper32; + using opmask_t = __m256i; static const uint8_t numlanes = 8; static constexpr int network_sort_threshold = 256; static constexpr int partition_unroll_factor = 4; - + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() @@ -234,7 +238,11 @@ struct avx2_vector { { return _mm256_set1_epi32(type_max()); } - + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } static ymmi_t seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) { @@ -357,11 +365,11 @@ struct avx2_vector { using type_t = float; using reg_t = __m256; using ymmi_t = __m256i; - using opmask_t = avx2_mask_helper32; + using opmask_t = __m256i; static const uint8_t numlanes = 8; static constexpr int network_sort_threshold = 256; static constexpr int partition_unroll_factor = 4; - + using swizzle_ops = avx2_32bit_swizzle_ops; static type_t type_max() @@ -399,9 +407,14 @@ struct avx2_vector { { return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ)); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static int32_t convert_mask_to_int(opmask_t mask) { - return (0x0001 << size) - 0x0001; + return convert_avx2_mask_to_int(mask); } template static opmask_t fpclass(reg_t x) diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index ab8ea567..33f6fd63 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -46,50 +46,21 @@ constexpr auto avx2_compressstore_lut32_gen = [] { } return lutPair; }(); + constexpr auto avx2_compressstore_lut32_perm = avx2_compressstore_lut32_gen[0]; constexpr auto avx2_compressstore_lut32_left = avx2_compressstore_lut32_gen[1]; -struct avx2_mask_helper32 { - __m256i mask; - - avx2_mask_helper32() = default; - avx2_mask_helper32(int m) - { - mask = converter(m); - } - avx2_mask_helper32(__m256i m) - { - mask = m; - } - operator __m256i() - { - return mask; - } - operator int32_t() - { - return converter(mask); - } - __m256i operator=(int m) - { - mask = converter(m); - return mask; - } - -private: - __m256i converter(int m) - { - return _mm256_loadu_si256( - (const __m256i *)avx2_mask_helper_lut32[m].data()); - } +X86_SIMD_SORT_INLINE +__m256i convert_int_to_avx2_mask(int32_t m) +{ + return _mm256_loadu_si256( + (const __m256i *)avx2_mask_helper_lut32[m].data()); +} - int32_t converter(__m256i m) - { - return _mm256_movemask_ps(_mm256_castsi256_ps(m)); - } -}; -static __m256i operator~(const avx2_mask_helper32 x) +X86_SIMD_SORT_INLINE +int32_t convert_avx2_mask_to_int(__m256i m) { - return ~x.mask; + return _mm256_movemask_ps(_mm256_castsi256_ps(m)); } // Emulators for intrinsics missing from AVX2 compared to AVX512 @@ -98,7 +69,7 @@ T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) { using vtype = avx2_vector; using reg_t = typename vtype::reg_t; - + reg_t inter1 = vtype::max(x, vtype::template shuffle(x)); reg_t inter2 = vtype::max(inter1, vtype::template shuffle(inter1)); T can1 = vtype::template extract<0>(inter2); @@ -111,7 +82,7 @@ T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) { using vtype = avx2_vector; using reg_t = typename vtype::reg_t; - + reg_t inter1 = vtype::min(x, vtype::template shuffle(x)); reg_t inter2 = vtype::min(inter1, vtype::template shuffle(inter1)); T can1 = vtype::template extract<0>(inter2); @@ -128,7 +99,7 @@ void avx2_emu_mask_compressstoreu(void *base_addr, T *leftStore = (T *)base_addr; - int32_t shortMask = avx2_mask_helper32(k); + int32_t shortMask = convert_avx2_mask_to_int(k); const __m256i &perm = _mm256_loadu_si256( (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); const __m256i &left = _mm256_loadu_si256( @@ -150,7 +121,7 @@ int avx2_double_compressstore32(void *left_addr, T *leftStore = (T *)left_addr; T *rightStore = (T *)right_addr; - int32_t shortMask = avx2_mask_helper32(k); + int32_t shortMask = convert_avx2_mask_to_int(k); const __m256i &perm = _mm256_loadu_si256( (const __m256i *)avx2_compressstore_lut32_perm[shortMask].data()); const __m256i &left = _mm256_loadu_si256( @@ -186,4 +157,4 @@ typename avx2_vector::reg_t avx2_emu_min(typename avx2_vector::reg_t x, _mm256_castsi256_pd(nlt))); } -#endif \ No newline at end of file +#endif diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 99a65a46..2392ba6f 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -80,6 +80,14 @@ struct zmm_vector { exp_eq, mant_x, mant_y, _MM_CMPINT_NLT); return _kxor_mask32(mask_ge, neg); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return mask; + } static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); @@ -227,6 +235,10 @@ struct zmm_vector { { return _mm512_cmp_epi16_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); @@ -357,6 +369,10 @@ struct zmm_vector { { return _mm512_cmp_epu16_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 546dcd0b..a018dc4d 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -65,6 +65,10 @@ struct zmm_vector { { return _mm512_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } template static halfreg_t i64gather(__m512i index, void const *base) { @@ -209,6 +213,10 @@ struct zmm_vector { { return _mm512_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static reg_t loadu(void const *mem) { return _mm512_loadu_si512(mem); @@ -333,9 +341,13 @@ struct zmm_vector { { return _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) { - return (0x0001 << size) - 0x0001; + return mask; } template static opmask_t fpclass(reg_t x) diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index c029bd9b..de548dd9 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -81,9 +81,9 @@ struct ymm_vector { { return _mm256_cmp_ps_mask(x, y, _CMP_EQ_OQ); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) { - return (0x01 << size) - 0x01; + return ((0x1ull << num_to_read) - 0x1ull); } template static opmask_t fpclass(reg_t x) @@ -244,6 +244,10 @@ struct ymm_vector { { return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmp_epu32_mask(x, y, _MM_CMPINT_EQ); @@ -396,6 +400,10 @@ struct ymm_vector { { return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static opmask_t eq(reg_t x, reg_t y) { return _mm256_cmp_epi32_mask(x, y, _MM_CMPINT_EQ); @@ -557,6 +565,10 @@ struct zmm_vector { { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static opmask_t eq(reg_t x, reg_t y) { return _mm512_cmp_epi64_mask(x, y, _MM_CMPINT_EQ); @@ -745,6 +757,10 @@ struct zmm_vector { { return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_NLT); } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } static opmask_t eq(reg_t x, reg_t y) { return _mm512_cmp_epu64_mask(x, y, _MM_CMPINT_EQ); @@ -894,9 +910,13 @@ struct zmm_vector { { return _mm512_cmp_pd_mask(x, y, _CMP_EQ_OQ); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) { - return (0x01 << size) - 0x01; + return mask; } template static opmask_t fpclass(reg_t x) diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index a2352d0a..28b8e193 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -54,9 +54,13 @@ struct zmm_vector<_Float16> { { return _mm512_cmp_ph_mask(x, y, _CMP_GE_OQ); } - static opmask_t get_partial_loadmask(int size) + static opmask_t get_partial_loadmask(uint64_t num_to_read) { - return (0x00000001 << size) - 0x00000001; + return ((0x1ull << num_to_read) - 0x1ull); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return mask; } template static opmask_t fpclass(reg_t x) diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index d71fb543..b62d1c44 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -65,7 +65,7 @@ X86_SIMD_SORT_INLINE arrsize_t replace_nan_with_inf(T *arr, arrsize_t size) in = vtype::loadu(arr + ii); } opmask_t nanmask = vtype::template fpclass<0x01 | 0x80>(in); - nan_count += _mm_popcnt_u32((int32_t)nanmask); + nan_count += _mm_popcnt_u32(vtype::convert_mask_to_int(nanmask)); vtype::mask_storeu(arr + ii, nanmask, vtype::zmm_max()); } return nan_count; @@ -174,7 +174,7 @@ int avx512_double_compressstore(type_t *left_addr, vtype::mask_compressstoreu(left_addr, vtype::knot_opmask(k), reg); vtype::mask_compressstoreu( right_addr + vtype::numlanes - amount_ge_pivot, k, reg); - + return amount_ge_pivot; } @@ -188,7 +188,7 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, reg_t &biggest_vec) { typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); - + int amount_ge_pivot = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec); smallest_vec = vtype::min(curr_vec, smallest_vec); diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index 1a2e313d..56a1aca1 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -157,10 +157,10 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int N) typename vtype::opmask_t ioMasks[numVecs - numVecs / 2]; X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { - int64_t num_to_read - = std::min((int64_t)std::max(0, N - i * vtype::numlanes), - (int64_t)vtype::numlanes); - ioMasks[j] = ((0x1ull << num_to_read) - 0x1ull); + uint64_t num_to_read + = std::min((uint64_t)std::max(0, N - i * vtype::numlanes), + (uint64_t)vtype::numlanes); + ioMasks[j] = vtype::get_partial_loadmask(num_to_read); } // Unmasked part of the load From f8f611fbe4db842b9331aba50e8f593ad8b33592 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 23 Oct 2023 13:14:52 -0700 Subject: [PATCH 15/19] format files --- _clang-format | 2 +- src/avx2-32bit-common.h | 86 ++++++++++++++++++++++------------ src/avx2-emu-funcs.hpp | 22 +++++---- src/avx512-16bit-qsort.hpp | 9 ++-- src/avx512-32bit-qsort.hpp | 9 ++-- src/avx512-64bit-argsort.hpp | 1 - src/avx512-64bit-common.h | 9 ++-- src/avx512fp16-16bit-qsort.hpp | 3 +- src/xss-common-includes.h | 1 - src/xss-common-qsort.h | 29 ++++++------ src/xss-pivot-selection.hpp | 8 ++-- 11 files changed, 108 insertions(+), 71 deletions(-) diff --git a/_clang-format b/_clang-format index 98760584..4ef31af4 100644 --- a/_clang-format +++ b/_clang-format @@ -74,7 +74,7 @@ PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 60 PointerAlignment: Right ReflowComments: false -SortIncludes: true +SortIncludes: false SortUsingDeclarations: true SpaceAfterCStyleCast: false SpaceAfterTemplateKeyword: true diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h index 701aca0e..14705fb1 100644 --- a/src/avx2-32bit-common.h +++ b/src/avx2-32bit-common.h @@ -40,12 +40,16 @@ X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) ymm = cmp_merge( ymm, vtype::template shuffle(ymm), oxAA); ymm = cmp_merge( - ymm, vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), oxCC); + ymm, + vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), + oxCC); ymm = cmp_merge( ymm, vtype::template shuffle(ymm), oxAA); ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); ymm = cmp_merge( - ymm, vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), oxCC); + ymm, + vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), + oxCC); ymm = cmp_merge( ymm, vtype::template shuffle(ymm), oxAA); return ymm; @@ -200,10 +204,12 @@ struct avx2_vector { { return sort_ymm_32bit>(x); } - static reg_t cast_from(__m256i v){ + static reg_t cast_from(__m256i v) + { return v; } - static __m256i cast_to(reg_t v){ + static __m256i cast_to(reg_t v) + { return v; } static int double_compressstore(type_t *left_addr, @@ -211,7 +217,8 @@ struct avx2_vector { opmask_t k, reg_t reg) { - return avx2_double_compressstore32(left_addr, right_addr, k, reg); + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); } }; template <> @@ -346,10 +353,12 @@ struct avx2_vector { { return sort_ymm_32bit>(x); } - static reg_t cast_from(__m256i v){ + static reg_t cast_from(__m256i v) + { return v; } - static __m256i cast_to(reg_t v){ + static __m256i cast_to(reg_t v) + { return v; } static int double_compressstore(type_t *left_addr, @@ -357,7 +366,8 @@ struct avx2_vector { opmask_t k, reg_t reg) { - return avx2_double_compressstore32(left_addr, right_addr, k, reg); + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); } }; template <> @@ -419,9 +429,10 @@ struct avx2_vector { template static opmask_t fpclass(reg_t x) { - if constexpr (type == (0x01 | 0x80)){ + if constexpr (type == (0x01 | 0x80)) { return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q)); - }else{ + } + else { static_assert(type == (0x01 | 0x80), "should not reach here"); } } @@ -514,10 +525,12 @@ struct avx2_vector { { return sort_ymm_32bit>(x); } - static reg_t cast_from(__m256i v){ + static reg_t cast_from(__m256i v) + { return _mm256_castsi256_ps(v); } - static __m256i cast_to(reg_t v){ + static __m256i cast_to(reg_t v) + { return _mm256_castps_si256(v); } static int double_compressstore(type_t *left_addr, @@ -525,26 +538,31 @@ struct avx2_vector { opmask_t k, reg_t reg) { - return avx2_double_compressstore32(left_addr, right_addr, k, reg); + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); } }; -struct avx2_32bit_swizzle_ops{ +struct avx2_32bit_swizzle_ops { template - X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){ + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { __m256i v = vtype::cast_to(reg); - if constexpr (scale == 2){ + if constexpr (scale == 2) { __m256 vf = _mm256_castsi256_ps(v); vf = _mm256_permute_ps(vf, 0b10110001); v = _mm256_castps_si256(vf); - }else if constexpr (scale == 4){ + } + else if constexpr (scale == 4) { __m256 vf = _mm256_castsi256_ps(v); vf = _mm256_permute_ps(vf, 0b01001110); v = _mm256_castps_si256(vf); - }else if constexpr (scale == 8){ + } + else if constexpr (scale == 8) { v = _mm256_permute2x128_si256(v, v, 0b00000001); - }else{ + } + else { static_assert(scale == -1, "should not be reached"); } @@ -552,19 +570,22 @@ struct avx2_32bit_swizzle_ops{ } template - X86_SIMD_SORT_INLINE typename vtype::reg_t reverse_n(typename vtype::reg_t reg){ + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { __m256i v = vtype::cast_to(reg); - if constexpr (scale == 2){ - return swap_n(reg); - }else if constexpr (scale == 4){ + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { constexpr uint64_t mask = 0b00011011; __m256 vf = _mm256_castsi256_ps(v); vf = _mm256_permute_ps(vf, mask); v = _mm256_castps_si256(vf); - }else if constexpr (scale == 8){ + } + else if constexpr (scale == 8) { return vtype::reverse(reg); - }else{ + } + else { static_assert(scale == -1, "should not be reached"); } @@ -572,17 +593,22 @@ struct avx2_32bit_swizzle_ops{ } template - X86_SIMD_SORT_INLINE typename vtype::reg_t merge_n(typename vtype::reg_t reg, typename vtype::reg_t other){ + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { __m256i v1 = vtype::cast_to(reg); __m256i v2 = vtype::cast_to(other); - if constexpr (scale == 2){ + if constexpr (scale == 2) { v1 = _mm256_blend_epi32(v1, v2, 0b01010101); - }else if constexpr (scale == 4){ + } + else if constexpr (scale == 4) { v1 = _mm256_blend_epi32(v1, v2, 0b00110011); - }else if constexpr (scale == 8){ + } + else if constexpr (scale == 8) { v1 = _mm256_blend_epi32(v1, v2, 0b00001111); - }else{ + } + else { static_assert(scale == -1, "should not be reached"); } diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 33f6fd63..43eed316 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -70,8 +70,10 @@ T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) using vtype = avx2_vector; using reg_t = typename vtype::reg_t; - reg_t inter1 = vtype::max(x, vtype::template shuffle(x)); - reg_t inter2 = vtype::max(inter1, vtype::template shuffle(inter1)); + reg_t inter1 = vtype::max( + x, vtype::template shuffle(x)); + reg_t inter2 = vtype::max( + inter1, vtype::template shuffle(inter1)); T can1 = vtype::template extract<0>(inter2); T can2 = vtype::template extract<4>(inter2); return std::max(can1, can2); @@ -83,8 +85,10 @@ T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) using vtype = avx2_vector; using reg_t = typename vtype::reg_t; - reg_t inter1 = vtype::min(x, vtype::template shuffle(x)); - reg_t inter2 = vtype::min(inter1, vtype::template shuffle(inter1)); + reg_t inter1 = vtype::min( + x, vtype::template shuffle(x)); + reg_t inter2 = vtype::min( + inter1, vtype::template shuffle(inter1)); T can1 = vtype::template extract<0>(inter2); T can2 = vtype::template extract<4>(inter2); return std::min(can1, can2); @@ -112,9 +116,9 @@ void avx2_emu_mask_compressstoreu(void *base_addr, template int avx2_double_compressstore32(void *left_addr, - void *right_addr, - typename avx2_vector::opmask_t k, - typename avx2_vector::reg_t reg) + void *right_addr, + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) { using vtype = avx2_vector; @@ -137,7 +141,7 @@ int avx2_double_compressstore32(void *left_addr, template typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, - typename avx2_vector::reg_t y) + typename avx2_vector::reg_t y) { using vtype = avx2_vector; typename vtype::opmask_t nlt = vtype::ge(x, y); @@ -148,7 +152,7 @@ typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, template typename avx2_vector::reg_t avx2_emu_min(typename avx2_vector::reg_t x, - typename avx2_vector::reg_t y) + typename avx2_vector::reg_t y) { using vtype = avx2_vector; typename vtype::opmask_t nlt = vtype::ge(x, y); diff --git a/src/avx512-16bit-qsort.hpp b/src/avx512-16bit-qsort.hpp index 2392ba6f..1278201e 100644 --- a/src/avx512-16bit-qsort.hpp +++ b/src/avx512-16bit-qsort.hpp @@ -190,7 +190,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; @@ -325,7 +326,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; template <> @@ -457,7 +459,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index a018dc4d..2d101b88 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -163,7 +163,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; template <> @@ -301,7 +302,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; template <> @@ -453,7 +455,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; diff --git a/src/avx512-64bit-argsort.hpp b/src/avx512-64bit-argsort.hpp index f706ded7..e5d0db0d 100644 --- a/src/avx512-64bit-argsort.hpp +++ b/src/avx512-64bit-argsort.hpp @@ -65,7 +65,6 @@ std_argsort(T *arr, arrsize_t *arg, arrsize_t left, arrsize_t right) }); } - /* Workaround for NumPy failed build on macOS x86_64: implicit instantiation of * undefined template 'zmm_vector'*/ #ifdef __APPLE__ diff --git a/src/avx512-64bit-common.h b/src/avx512-64bit-common.h index de548dd9..e7f9f44c 100644 --- a/src/avx512-64bit-common.h +++ b/src/avx512-64bit-common.h @@ -677,7 +677,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; template <> @@ -846,7 +847,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; template <> @@ -1021,7 +1023,8 @@ struct zmm_vector { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 28b8e193..081a9939 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -154,7 +154,8 @@ struct zmm_vector<_Float16> { opmask_t k, reg_t reg) { - return avx512_double_compressstore>(left_addr, right_addr, k, reg); + return avx512_double_compressstore>( + left_addr, right_addr, k, reg); } }; diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index 9fff8d09..66cadc12 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -75,4 +75,3 @@ struct ymm_vector; template struct avx2_vector; - diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index b62d1c44..b74a7693 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -165,9 +165,9 @@ X86_SIMD_SORT_INLINE reg_t cmp_merge(reg_t in1, reg_t in2, opmask_t mask) template int avx512_double_compressstore(type_t *left_addr, - type_t *right_addr, - typename vtype::opmask_t k, - reg_t reg) + type_t *right_addr, + typename vtype::opmask_t k, + reg_t reg) { int amount_ge_pivot = _mm_popcnt_u32((int)k); @@ -179,7 +179,9 @@ int avx512_double_compressstore(type_t *left_addr, } // Generic function dispatches to AVX2 or AVX512 code -template +template X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, type_t *r_store, const reg_t curr_vec, @@ -189,7 +191,8 @@ X86_SIMD_SORT_INLINE arrsize_t partition_vec(type_t *l_store, { typename vtype::opmask_t ge_mask = vtype::ge(curr_vec, pivot_vec); - int amount_ge_pivot = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec); + int amount_ge_pivot + = vtype::double_compressstore(l_store, r_store, ge_mask, curr_vec); smallest_vec = vtype::min(curr_vec, smallest_vec); biggest_vec = vtype::max(curr_vec, biggest_vec); @@ -473,12 +476,12 @@ X86_SIMD_SORT_INLINE arrsize_t partition_avx512_unrolled(type_t *arr, return l_store; } - template void sort_n(typename vtype::type_t *arr, int N); template -static void qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) +static void +qsort_(type_t *arr, arrsize_t left, arrsize_t right, arrsize_t max_iters) { /* * Resort to std::sort if quicksort isnt making any progress @@ -573,15 +576,12 @@ void avx2_qsort(T *arr, arrsize_t arrsize) if (arrsize > 1) { /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { - arrsize_t nan_count - = replace_nan_with_inf(arr, arrsize); - qsort_( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + arrsize_t nan_count = replace_nan_with_inf(arr, arrsize); + qsort_(arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); } else { - qsort_( - arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + qsort_(arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); } } } @@ -632,7 +632,8 @@ X86_SIMD_SORT_INLINE void avx512_partial_qsort(T *arr, } template -inline void avx2_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +inline void +avx2_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) { avx2_qselect(arr, k - 1, arrsize, hasnan); avx2_qsort(arr, k - 1); diff --git a/src/xss-pivot-selection.hpp b/src/xss-pivot-selection.hpp index 29394321..2a28b348 100644 --- a/src/xss-pivot-selection.hpp +++ b/src/xss-pivot-selection.hpp @@ -3,8 +3,8 @@ X86_SIMD_SORT_INLINE void COEX(mm_t &a, mm_t &b); template X86_SIMD_SORT_INLINE type_t get_pivot(type_t *arr, - const arrsize_t left, - const arrsize_t right) + const arrsize_t left, + const arrsize_t right) { using reg_t = typename vtype::reg_t; type_t samples[vtype::numlanes]; @@ -24,9 +24,7 @@ X86_SIMD_SORT_INLINE type_t get_pivot_blocks(type_t *arr, const arrsize_t right) { - if (right - left <= 1024) { - return get_pivot(arr, left, right); - } + if (right - left <= 1024) { return get_pivot(arr, left, right); } using reg_t = typename vtype::reg_t; constexpr int numVecs = 5; From d032bc40b136c5107976801b44d54f601312d72b Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 23 Oct 2023 13:39:29 -0700 Subject: [PATCH 16/19] Wrapper function for qsort, qselect and partialsort --- src/xss-common-qsort.h | 101 ++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 63 deletions(-) diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index b74a7693..0b76add6 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -8,8 +8,8 @@ * Tang Xi * ****************************************************************/ -#ifndef AVX512_QSORT_COMMON -#define AVX512_QSORT_COMMON +#ifndef XSS_COMMON_QSORT +#define XSS_COMMON_QSORT /* * Quicksort using AVX-512. The ideas and code are based on these two research @@ -549,49 +549,28 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, qselect_(arr, pos, pivot_index, right, max_iters - 1); } -// Regular quicksort routines: -template -X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize) -{ - if (arrsize > 1) { - /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ - if constexpr (std::is_floating_point_v) { - arrsize_t nan_count - = replace_nan_with_inf>(arr, arrsize); - qsort_, T>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - replace_inf_with_nan(arr, arrsize, nan_count); - } - else { - qsort_, T>( - arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); - } - } -} - -template -void avx2_qsort(T *arr, arrsize_t arrsize) +// Quicksort routines: +template +X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize) { - using vtype = avx2_vector; if (arrsize > 1) { - /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { arrsize_t nan_count = replace_nan_with_inf(arr, arrsize); - qsort_(arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + qsort_(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); } else { - qsort_(arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); + qsort_(arr, 0, arrsize - 1, 2 * (arrsize_t)log2(arrsize)); } } } -template +// Quick select methods +template X86_SIMD_SORT_INLINE void -avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { arrsize_t indx_last_elem = arrsize - 1; - /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { if (UNLIKELY(hasnan)) { indx_last_elem = move_nans_to_end_of_array(arr, arrsize); @@ -599,44 +578,40 @@ avx512_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) } UNUSED(hasnan); if (indx_last_elem >= k) { - qselect_, T>( + qselect_( arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); } } -template -void avx2_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) +// Partial sort methods: +template +X86_SIMD_SORT_INLINE void +xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { - arrsize_t indx_last_elem = arrsize - 1; - /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ - if constexpr (std::is_floating_point_v) { - if (UNLIKELY(hasnan)) { - indx_last_elem = move_nans_to_end_of_array(arr, arrsize); - } - } - UNUSED(hasnan); - if (indx_last_elem >= k) { - qselect_, T>( - arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); - } + xss_qselect(arr, k - 1, arrsize, hasnan); + xss_qsort(arr, k - 1); } -template -X86_SIMD_SORT_INLINE void avx512_partial_qsort(T *arr, - arrsize_t k, - arrsize_t arrsize, - bool hasnan = false) -{ - avx512_qselect(arr, k - 1, arrsize, hasnan); - avx512_qsort(arr, k - 1); -} +#define DEFINE_METHODS(ISA, VTYPE) \ + template \ + X86_SIMD_SORT_INLINE void ISA##_qsort(T *arr, arrsize_t size) \ + { \ + xss_qsort(arr, size); \ + } \ + template \ + X86_SIMD_SORT_INLINE void ISA##_qselect( \ + T *arr, arrsize_t k, arrsize_t size, bool hasnan = false) \ + { \ + xss_qselect(arr, k, size, hasnan); \ + } \ + template \ + X86_SIMD_SORT_INLINE void ISA##_partial_qsort( \ + T *arr, arrsize_t k, arrsize_t size, bool hasnan = false) \ + { \ + xss_partial_qsort(arr, k, size, hasnan); \ + } -template -inline void -avx2_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) -{ - avx2_qselect(arr, k - 1, arrsize, hasnan); - avx2_qsort(arr, k - 1); -} +DEFINE_METHODS(avx512, zmm_vector) +DEFINE_METHODS(avx2, avx2_vector) -#endif // AVX512_QSORT_COMMON +#endif // XSS_COMMON_QSORT From 191ef61f46274d104dca6eebcee8ccde6fb05ffd Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 23 Oct 2023 13:42:11 -0700 Subject: [PATCH 17/19] Header include directives --- src/xss-common-includes.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index 66cadc12..23c9f964 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -1,3 +1,5 @@ +#ifndef XSS_COMMON_INCLUDES +#define XSS_COMMON_INCLUDES #include #include #include @@ -75,3 +77,5 @@ struct ymm_vector; template struct avx2_vector; + +#endif // XSS_COMMON_INCLUDES From 3dab097d3b1ce1f8342ae714271ac903fe9d0ca5 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 23 Oct 2023 13:50:58 -0700 Subject: [PATCH 18/19] Remove need for avx2-32bit-common.h --- src/avx2-32bit-common.h | 618 --------------------------------------- src/avx2-32bit-qsort.hpp | 610 +++++++++++++++++++++++++++++++++++++- 2 files changed, 608 insertions(+), 620 deletions(-) delete mode 100644 src/avx2-32bit-common.h diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h deleted file mode 100644 index 14705fb1..00000000 --- a/src/avx2-32bit-common.h +++ /dev/null @@ -1,618 +0,0 @@ -/******************************************************************* - * Copyright (C) 2022 Intel Corporation - * SPDX-License-Identifier: BSD-3-Clause - * Authors: Raghuveer Devulapalli - * Matthew Sterrett - * ****************************************************************/ - -#ifndef AVX2_32BIT_COMMON -#define AVX2_32BIT_COMMON -#include "avx2-emu-funcs.hpp" -#include "xss-common-qsort.h" - -/* - * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic - * sorting network (see - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) - */ - -// ymm 7, 6, 5, 4, 3, 2, 1, 0 -#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 -#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 -#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 -#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 - -/* - * Assumes ymm is random and performs a full sorting network defined in - * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg - */ -template -X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) -{ - const typename vtype::opmask_t oxAA = _mm256_set_epi32( - 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); - const typename vtype::opmask_t oxCC = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); - const typename vtype::opmask_t oxF0 = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); - - const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2); - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - ymm = cmp_merge( - ymm, - vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), - oxCC); - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); - ymm = cmp_merge( - ymm, - vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), - oxCC); - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - return ymm; -} - -struct avx2_32bit_swizzle_ops; - -template <> -struct avx2_vector { - using type_t = int32_t; - using reg_t = __m256i; - using ymmi_t = __m256i; - using opmask_t = __m256i; - static const uint8_t numlanes = 8; - static constexpr int network_sort_threshold = 256; - static constexpr int partition_unroll_factor = 4; - - using swizzle_ops = avx2_32bit_swizzle_ops; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_INT32; - } - static type_t type_min() - { - return X86_SIMD_SORT_MIN_INT32; - } - static reg_t zmm_max() - { - return _mm256_set1_epi32(type_max()); - } // TODO: this should broadcast bits as is? - static opmask_t get_partial_loadmask(uint64_t num_to_read) - { - auto mask = ((0x1ull << num_to_read) - 0x1ull); - return convert_int_to_avx2_mask(mask); - } - static ymmi_t - seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) - { - return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); - } - static opmask_t kxor_opmask(opmask_t x, opmask_t y) - { - return _mm256_xor_si256(x, y); - } - static opmask_t knot_opmask(opmask_t x) - { - return ~x; - } - static opmask_t le(reg_t x, reg_t y) - { - return ~_mm256_cmpgt_epi32(x, y); - } - static opmask_t ge(reg_t x, reg_t y) - { - opmask_t equal = eq(x, y); - opmask_t greater = _mm256_cmpgt_epi32(x, y); - return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(equal), - _mm256_castsi256_ps(greater))); - } - static opmask_t eq(reg_t x, reg_t y) - { - return _mm256_cmpeq_epi32(x, y); - } - template - static reg_t - mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) - { - return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); - } - template - static reg_t i64gather(__m256i index, void const *base) - { - return _mm256_i32gather_epi32((int const *)base, index, scale); - } - static reg_t loadu(void const *mem) - { - return _mm256_loadu_si256((reg_t const *)mem); - } - static reg_t max(reg_t x, reg_t y) - { - return _mm256_max_epi32(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) - { - return avx2_emu_mask_compressstoreu(mem, mask, x); - } - static reg_t maskz_loadu(opmask_t mask, void const *mem) - { - return _mm256_maskload_epi32((const int *)mem, mask); - } - static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) - { - reg_t dst = _mm256_maskload_epi32((type_t *)mem, mask); - return mask_mov(x, mask, dst); - } - static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) - { - return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), - _mm256_castsi256_ps(y), - _mm256_castsi256_ps(mask))); - } - static void mask_storeu(void *mem, opmask_t mask, reg_t x) - { - return _mm256_maskstore_epi32((type_t *)mem, mask, x); - } - static reg_t min(reg_t x, reg_t y) - { - return _mm256_min_epi32(x, y); - } - static reg_t permutexvar(__m256i idx, reg_t ymm) - { - return _mm256_permutevar8x32_epi32(ymm, idx); - //return avx2_emu_permutexvar_epi32(idx, ymm); - } - static reg_t permutevar(reg_t ymm, __m256i idx) - { - return _mm256_permutevar8x32_epi32(ymm, idx); - } - static reg_t reverse(reg_t ymm) - { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); - return permutexvar(rev_index, ymm); - } - template - static type_t extract(reg_t v) - { - return _mm256_extract_epi32(v, index); - } - static type_t reducemax(reg_t v) - { - return avx2_emu_reduce_max32(v); - } - static type_t reducemin(reg_t v) - { - return avx2_emu_reduce_min32(v); - } - static reg_t set1(type_t v) - { - return _mm256_set1_epi32(v); - } - template - static reg_t shuffle(reg_t ymm) - { - return _mm256_shuffle_epi32(ymm, mask); - } - static void storeu(void *mem, reg_t x) - { - _mm256_storeu_si256((__m256i *)mem, x); - } - static reg_t sort_vec(reg_t x) - { - return sort_ymm_32bit>(x); - } - static reg_t cast_from(__m256i v) - { - return v; - } - static __m256i cast_to(reg_t v) - { - return v; - } - static int double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32( - left_addr, right_addr, k, reg); - } -}; -template <> -struct avx2_vector { - using type_t = uint32_t; - using reg_t = __m256i; - using ymmi_t = __m256i; - using opmask_t = __m256i; - static const uint8_t numlanes = 8; - static constexpr int network_sort_threshold = 256; - static constexpr int partition_unroll_factor = 4; - - using swizzle_ops = avx2_32bit_swizzle_ops; - - static type_t type_max() - { - return X86_SIMD_SORT_MAX_UINT32; - } - static type_t type_min() - { - return 0; - } - static reg_t zmm_max() - { - return _mm256_set1_epi32(type_max()); - } - static opmask_t get_partial_loadmask(uint64_t num_to_read) - { - auto mask = ((0x1ull << num_to_read) - 0x1ull); - return convert_int_to_avx2_mask(mask); - } - static ymmi_t - seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) - { - return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); - } - template - static reg_t - mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) - { - return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); - } - template - static reg_t i64gather(__m256i index, void const *base) - { - return _mm256_i32gather_epi32((int const *)base, index, scale); - } - static opmask_t knot_opmask(opmask_t x) - { - return ~x; - } - static opmask_t ge(reg_t x, reg_t y) - { - reg_t maxi = max(x, y); - return eq(maxi, x); - } - static opmask_t eq(reg_t x, reg_t y) - { - return _mm256_cmpeq_epi32(x, y); - } - static reg_t loadu(void const *mem) - { - return _mm256_loadu_si256((reg_t const *)mem); - } - static reg_t max(reg_t x, reg_t y) - { - return _mm256_max_epu32(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) - { - return avx2_emu_mask_compressstoreu(mem, mask, x); - } - static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) - { - reg_t dst = _mm256_maskload_epi32((const int *)mem, mask); - return mask_mov(x, mask, dst); - } - static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) - { - return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), - _mm256_castsi256_ps(y), - _mm256_castsi256_ps(mask))); - } - static void mask_storeu(void *mem, opmask_t mask, reg_t x) - { - return _mm256_maskstore_epi32((int *)mem, mask, x); - } - static reg_t min(reg_t x, reg_t y) - { - return _mm256_min_epu32(x, y); - } - static reg_t permutexvar(__m256i idx, reg_t ymm) - { - return _mm256_permutevar8x32_epi32(ymm, idx); - } - static reg_t permutevar(reg_t ymm, __m256i idx) - { - return _mm256_permutevar8x32_epi32(ymm, idx); - } - static reg_t reverse(reg_t ymm) - { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); - return permutexvar(rev_index, ymm); - } - template - static type_t extract(reg_t v) - { - return _mm256_extract_epi32(v, index); - } - static type_t reducemax(reg_t v) - { - return avx2_emu_reduce_max32(v); - } - static type_t reducemin(reg_t v) - { - return avx2_emu_reduce_min32(v); - } - static reg_t set1(type_t v) - { - return _mm256_set1_epi32(v); - } - template - static reg_t shuffle(reg_t ymm) - { - return _mm256_shuffle_epi32(ymm, mask); - } - static void storeu(void *mem, reg_t x) - { - _mm256_storeu_si256((__m256i *)mem, x); - } - static reg_t sort_vec(reg_t x) - { - return sort_ymm_32bit>(x); - } - static reg_t cast_from(__m256i v) - { - return v; - } - static __m256i cast_to(reg_t v) - { - return v; - } - static int double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32( - left_addr, right_addr, k, reg); - } -}; -template <> -struct avx2_vector { - using type_t = float; - using reg_t = __m256; - using ymmi_t = __m256i; - using opmask_t = __m256i; - static const uint8_t numlanes = 8; - static constexpr int network_sort_threshold = 256; - static constexpr int partition_unroll_factor = 4; - - using swizzle_ops = avx2_32bit_swizzle_ops; - - static type_t type_max() - { - return X86_SIMD_SORT_INFINITYF; - } - static type_t type_min() - { - return -X86_SIMD_SORT_INFINITYF; - } - static reg_t zmm_max() - { - return _mm256_set1_ps(type_max()); - } - - static ymmi_t - seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) - { - return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); - } - - static reg_t maskz_loadu(opmask_t mask, void const *mem) - { - return _mm256_maskload_ps((const float *)mem, mask); - } - static opmask_t knot_opmask(opmask_t x) - { - return ~x; - } - static opmask_t ge(reg_t x, reg_t y) - { - return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_GE_OQ)); - } - static opmask_t eq(reg_t x, reg_t y) - { - return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ)); - } - static opmask_t get_partial_loadmask(uint64_t num_to_read) - { - auto mask = ((0x1ull << num_to_read) - 0x1ull); - return convert_int_to_avx2_mask(mask); - } - static int32_t convert_mask_to_int(opmask_t mask) - { - return convert_avx2_mask_to_int(mask); - } - template - static opmask_t fpclass(reg_t x) - { - if constexpr (type == (0x01 | 0x80)) { - return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q)); - } - else { - static_assert(type == (0x01 | 0x80), "should not reach here"); - } - } - template - static reg_t - mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) - { - return _mm256_mask_i32gather_ps( - src, base, index, _mm256_castsi256_ps(mask), scale); - ; - } - template - static reg_t i64gather(__m256i index, void const *base) - { - return _mm256_i32gather_ps((float *)base, index, scale); - } - static reg_t loadu(void const *mem) - { - return _mm256_loadu_ps((float const *)mem); - } - static reg_t max(reg_t x, reg_t y) - { - return _mm256_max_ps(x, y); - } - static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) - { - return avx2_emu_mask_compressstoreu(mem, mask, x); - } - static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) - { - reg_t dst = _mm256_maskload_ps((type_t *)mem, mask); - return mask_mov(x, mask, dst); - } - static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) - { - return _mm256_blendv_ps(x, y, _mm256_castsi256_ps(mask)); - } - static void mask_storeu(void *mem, opmask_t mask, reg_t x) - { - return _mm256_maskstore_ps((type_t *)mem, mask, x); - } - static reg_t min(reg_t x, reg_t y) - { - return _mm256_min_ps(x, y); - } - static reg_t permutexvar(__m256i idx, reg_t ymm) - { - return _mm256_permutevar8x32_ps(ymm, idx); - } - static reg_t permutevar(reg_t ymm, __m256i idx) - { - return _mm256_permutevar8x32_ps(ymm, idx); - } - static reg_t reverse(reg_t ymm) - { - const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); - return permutexvar(rev_index, ymm); - } - template - static type_t extract(reg_t v) - { - int32_t x = _mm256_extract_epi32(_mm256_castps_si256(v), index); - float y; - std::memcpy(&y, &x, sizeof(y)); - return y; - } - static type_t reducemax(reg_t v) - { - return avx2_emu_reduce_max32(v); - } - static type_t reducemin(reg_t v) - { - return avx2_emu_reduce_min32(v); - } - static reg_t set1(type_t v) - { - return _mm256_set1_ps(v); - } - template - static reg_t shuffle(reg_t ymm) - { - return _mm256_castsi256_ps( - _mm256_shuffle_epi32(_mm256_castps_si256(ymm), mask)); - } - static void storeu(void *mem, reg_t x) - { - _mm256_storeu_ps((float *)mem, x); - } - static reg_t sort_vec(reg_t x) - { - return sort_ymm_32bit>(x); - } - static reg_t cast_from(__m256i v) - { - return _mm256_castsi256_ps(v); - } - static __m256i cast_to(reg_t v) - { - return _mm256_castps_si256(v); - } - static int double_compressstore(type_t *left_addr, - type_t *right_addr, - opmask_t k, - reg_t reg) - { - return avx2_double_compressstore32( - left_addr, right_addr, k, reg); - } -}; - -struct avx2_32bit_swizzle_ops { - template - X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) - { - __m256i v = vtype::cast_to(reg); - - if constexpr (scale == 2) { - __m256 vf = _mm256_castsi256_ps(v); - vf = _mm256_permute_ps(vf, 0b10110001); - v = _mm256_castps_si256(vf); - } - else if constexpr (scale == 4) { - __m256 vf = _mm256_castsi256_ps(v); - vf = _mm256_permute_ps(vf, 0b01001110); - v = _mm256_castps_si256(vf); - } - else if constexpr (scale == 8) { - v = _mm256_permute2x128_si256(v, v, 0b00000001); - } - else { - static_assert(scale == -1, "should not be reached"); - } - - return vtype::cast_from(v); - } - - template - X86_SIMD_SORT_INLINE typename vtype::reg_t - reverse_n(typename vtype::reg_t reg) - { - __m256i v = vtype::cast_to(reg); - - if constexpr (scale == 2) { return swap_n(reg); } - else if constexpr (scale == 4) { - constexpr uint64_t mask = 0b00011011; - __m256 vf = _mm256_castsi256_ps(v); - vf = _mm256_permute_ps(vf, mask); - v = _mm256_castps_si256(vf); - } - else if constexpr (scale == 8) { - return vtype::reverse(reg); - } - else { - static_assert(scale == -1, "should not be reached"); - } - - return vtype::cast_from(v); - } - - template - X86_SIMD_SORT_INLINE typename vtype::reg_t - merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) - { - __m256i v1 = vtype::cast_to(reg); - __m256i v2 = vtype::cast_to(other); - - if constexpr (scale == 2) { - v1 = _mm256_blend_epi32(v1, v2, 0b01010101); - } - else if constexpr (scale == 4) { - v1 = _mm256_blend_epi32(v1, v2, 0b00110011); - } - else if constexpr (scale == 8) { - v1 = _mm256_blend_epi32(v1, v2, 0b00001111); - } - else { - static_assert(scale == -1, "should not be reached"); - } - - return vtype::cast_from(v1); - } -}; -#endif diff --git a/src/avx2-32bit-qsort.hpp b/src/avx2-32bit-qsort.hpp index c0590d94..5dd77a27 100644 --- a/src/avx2-32bit-qsort.hpp +++ b/src/avx2-32bit-qsort.hpp @@ -7,7 +7,613 @@ #ifndef AVX2_QSORT_32BIT #define AVX2_QSORT_32BIT -#include "avx2-32bit-common.h" -#include "xss-network-qsort.hpp" +#include "xss-common-qsort.h" +#include "avx2-emu-funcs.hpp" + +/* + * Constants used in sorting 8 elements in a ymm registers. Based on Bitonic + * sorting network (see + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg) + */ + +// ymm 7, 6, 5, 4, 3, 2, 1, 0 +#define NETWORK_32BIT_AVX2_1 4, 5, 6, 7, 0, 1, 2, 3 +#define NETWORK_32BIT_AVX2_2 0, 1, 2, 3, 4, 5, 6, 7 +#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 +#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 + +/* + * Assumes ymm is random and performs a full sorting network defined in + * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg + */ +template +X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) +{ + const typename vtype::opmask_t oxAA = _mm256_set_epi32( + 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); + const typename vtype::opmask_t oxCC = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); + const typename vtype::opmask_t oxF0 = _mm256_set_epi32( + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); + + const typename vtype::ymmi_t rev_index = vtype::seti(NETWORK_32BIT_AVX2_2); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge( + ymm, + vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_1), ymm), + oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + ymm = cmp_merge(ymm, vtype::permutexvar(rev_index, ymm), oxF0); + ymm = cmp_merge( + ymm, + vtype::permutexvar(vtype::seti(NETWORK_32BIT_AVX2_3), ymm), + oxCC); + ymm = cmp_merge( + ymm, vtype::template shuffle(ymm), oxAA); + return ymm; +} + +struct avx2_32bit_swizzle_ops; + +template <> +struct avx2_vector { + using type_t = int32_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_INT32; + } + static type_t type_min() + { + return X86_SIMD_SORT_MIN_INT32; + } + static reg_t zmm_max() + { + return _mm256_set1_epi32(type_max()); + } // TODO: this should broadcast bits as is? + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + static opmask_t kxor_opmask(opmask_t x, opmask_t y) + { + return _mm256_xor_si256(x, y); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t le(reg_t x, reg_t y) + { + return ~_mm256_cmpgt_epi32(x, y); + } + static opmask_t ge(reg_t x, reg_t y) + { + opmask_t equal = eq(x, y); + opmask_t greater = _mm256_cmpgt_epi32(x, y); + return _mm256_castps_si256(_mm256_or_ps(_mm256_castsi256_ps(equal), + _mm256_castsi256_ps(greater))); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_cmpeq_epi32(x, y); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_epi32((int const *)base, index, scale); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_epi32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskload_epi32((const int *)mem, mask); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_epi32((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), + _mm256_castsi256_ps(y), + _mm256_castsi256_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_epi32((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_epi32(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + //return avx2_emu_permutexvar_epi32(idx, ymm); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + return _mm256_extract_epi32(v, index); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_vector { + using type_t = uint32_t; + using reg_t = __m256i; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_MAX_UINT32; + } + static type_t type_min() + { + return 0; + } + static reg_t zmm_max() + { + return _mm256_set1_epi32(type_max()); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_epi32(src, base, index, mask, scale); + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_epi32((int const *)base, index, scale); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t ge(reg_t x, reg_t y) + { + reg_t maxi = max(x, y); + return eq(maxi, x); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_cmpeq_epi32(x, y); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_si256((reg_t const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_epu32(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_epi32((const int *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_castps_si256(_mm256_blendv_ps(_mm256_castsi256_ps(x), + _mm256_castsi256_ps(y), + _mm256_castsi256_ps(mask))); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_epi32((int *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_epu32(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_epi32(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + return _mm256_extract_epi32(v, index); + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_epi32(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_shuffle_epi32(ymm, mask); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_si256((__m256i *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v) + { + return v; + } + static __m256i cast_to(reg_t v) + { + return v; + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } +}; +template <> +struct avx2_vector { + using type_t = float; + using reg_t = __m256; + using ymmi_t = __m256i; + using opmask_t = __m256i; + static const uint8_t numlanes = 8; + static constexpr int network_sort_threshold = 256; + static constexpr int partition_unroll_factor = 4; + + using swizzle_ops = avx2_32bit_swizzle_ops; + + static type_t type_max() + { + return X86_SIMD_SORT_INFINITYF; + } + static type_t type_min() + { + return -X86_SIMD_SORT_INFINITYF; + } + static reg_t zmm_max() + { + return _mm256_set1_ps(type_max()); + } + + static ymmi_t + seti(int v1, int v2, int v3, int v4, int v5, int v6, int v7, int v8) + { + return _mm256_set_epi32(v1, v2, v3, v4, v5, v6, v7, v8); + } + + static reg_t maskz_loadu(opmask_t mask, void const *mem) + { + return _mm256_maskload_ps((const float *)mem, mask); + } + static opmask_t knot_opmask(opmask_t x) + { + return ~x; + } + static opmask_t ge(reg_t x, reg_t y) + { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_GE_OQ)); + } + static opmask_t eq(reg_t x, reg_t y) + { + return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ)); + } + static opmask_t get_partial_loadmask(uint64_t num_to_read) + { + auto mask = ((0x1ull << num_to_read) - 0x1ull); + return convert_int_to_avx2_mask(mask); + } + static int32_t convert_mask_to_int(opmask_t mask) + { + return convert_avx2_mask_to_int(mask); + } + template + static opmask_t fpclass(reg_t x) + { + if constexpr (type == (0x01 | 0x80)) { + return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q)); + } + else { + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } + template + static reg_t + mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) + { + return _mm256_mask_i32gather_ps( + src, base, index, _mm256_castsi256_ps(mask), scale); + ; + } + template + static reg_t i64gather(__m256i index, void const *base) + { + return _mm256_i32gather_ps((float *)base, index, scale); + } + static reg_t loadu(void const *mem) + { + return _mm256_loadu_ps((float const *)mem); + } + static reg_t max(reg_t x, reg_t y) + { + return _mm256_max_ps(x, y); + } + static void mask_compressstoreu(void *mem, opmask_t mask, reg_t x) + { + return avx2_emu_mask_compressstoreu(mem, mask, x); + } + static reg_t mask_loadu(reg_t x, opmask_t mask, void const *mem) + { + reg_t dst = _mm256_maskload_ps((type_t *)mem, mask); + return mask_mov(x, mask, dst); + } + static reg_t mask_mov(reg_t x, opmask_t mask, reg_t y) + { + return _mm256_blendv_ps(x, y, _mm256_castsi256_ps(mask)); + } + static void mask_storeu(void *mem, opmask_t mask, reg_t x) + { + return _mm256_maskstore_ps((type_t *)mem, mask, x); + } + static reg_t min(reg_t x, reg_t y) + { + return _mm256_min_ps(x, y); + } + static reg_t permutexvar(__m256i idx, reg_t ymm) + { + return _mm256_permutevar8x32_ps(ymm, idx); + } + static reg_t permutevar(reg_t ymm, __m256i idx) + { + return _mm256_permutevar8x32_ps(ymm, idx); + } + static reg_t reverse(reg_t ymm) + { + const __m256i rev_index = _mm256_set_epi32(NETWORK_32BIT_AVX2_2); + return permutexvar(rev_index, ymm); + } + template + static type_t extract(reg_t v) + { + int32_t x = _mm256_extract_epi32(_mm256_castps_si256(v), index); + float y; + std::memcpy(&y, &x, sizeof(y)); + return y; + } + static type_t reducemax(reg_t v) + { + return avx2_emu_reduce_max32(v); + } + static type_t reducemin(reg_t v) + { + return avx2_emu_reduce_min32(v); + } + static reg_t set1(type_t v) + { + return _mm256_set1_ps(v); + } + template + static reg_t shuffle(reg_t ymm) + { + return _mm256_castsi256_ps( + _mm256_shuffle_epi32(_mm256_castps_si256(ymm), mask)); + } + static void storeu(void *mem, reg_t x) + { + _mm256_storeu_ps((float *)mem, x); + } + static reg_t sort_vec(reg_t x) + { + return sort_ymm_32bit>(x); + } + static reg_t cast_from(__m256i v) + { + return _mm256_castsi256_ps(v); + } + static __m256i cast_to(reg_t v) + { + return _mm256_castps_si256(v); + } + static int double_compressstore(type_t *left_addr, + type_t *right_addr, + opmask_t k, + reg_t reg) + { + return avx2_double_compressstore32( + left_addr, right_addr, k, reg); + } +}; + +struct avx2_32bit_swizzle_ops { + template + X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b10110001); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 4) { + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, 0b01001110); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 8) { + v = _mm256_permute2x128_si256(v, v, 0b00000001); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + reverse_n(typename vtype::reg_t reg) + { + __m256i v = vtype::cast_to(reg); + + if constexpr (scale == 2) { return swap_n(reg); } + else if constexpr (scale == 4) { + constexpr uint64_t mask = 0b00011011; + __m256 vf = _mm256_castsi256_ps(v); + vf = _mm256_permute_ps(vf, mask); + v = _mm256_castps_si256(vf); + } + else if constexpr (scale == 8) { + return vtype::reverse(reg); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v); + } + + template + X86_SIMD_SORT_INLINE typename vtype::reg_t + merge_n(typename vtype::reg_t reg, typename vtype::reg_t other) + { + __m256i v1 = vtype::cast_to(reg); + __m256i v2 = vtype::cast_to(other); + + if constexpr (scale == 2) { + v1 = _mm256_blend_epi32(v1, v2, 0b01010101); + } + else if constexpr (scale == 4) { + v1 = _mm256_blend_epi32(v1, v2, 0b00110011); + } + else if constexpr (scale == 8) { + v1 = _mm256_blend_epi32(v1, v2, 0b00001111); + } + else { + static_assert(scale == -1, "should not be reached"); + } + + return vtype::cast_from(v1); + } +}; #endif // AVX2_QSORT_32BIT From 9810e05d725d7a196d2a4a0fad6df4f3f228f713 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 23 Oct 2023 14:22:27 -0700 Subject: [PATCH 19/19] Explicit instantiation of avx512_partial_qsort for _Float16 --- src/avx512fp16-16bit-qsort.hpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/avx512fp16-16bit-qsort.hpp b/src/avx512fp16-16bit-qsort.hpp index 081a9939..b01c367b 100644 --- a/src/avx512fp16-16bit-qsort.hpp +++ b/src/avx512fp16-16bit-qsort.hpp @@ -201,4 +201,13 @@ void avx512_qselect(_Float16 *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) arr, k, 0, indx_last_elem, 2 * (arrsize_t)log2(indx_last_elem)); } } +template <> +void avx512_partial_qsort(_Float16 *arr, + arrsize_t k, + arrsize_t arrsize, + bool hasnan) +{ + avx512_qselect(arr, k - 1, arrsize, hasnan); + avx512_qsort(arr, k - 1); +} #endif // AVX512FP16_QSORT_16BIT