From 15d602523e79360f5339cd7a51620c81524c1f81 Mon Sep 17 00:00:00 2001 From: Matthew Sterrett Date: Fri, 20 Oct 2023 10:53:26 -0700 Subject: [PATCH] Fixes/changes many small things --- src/avx2-32bit-common.h | 97 ++++++++------------------------------ src/avx2-emu-funcs.hpp | 37 +++++++-------- src/avx512-16bit-common.h | 1 - src/avx512-32bit-qsort.hpp | 1 - src/xss-common-qsort.h | 18 ++----- 5 files changed, 39 insertions(+), 115 deletions(-) diff --git a/src/avx2-32bit-common.h b/src/avx2-32bit-common.h index d2e4d8b8..ac6d17dd 100644 --- a/src/avx2-32bit-common.h +++ b/src/avx2-32bit-common.h @@ -22,37 +22,6 @@ #define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2 #define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4 -namespace xss { -namespace avx2 { - -// Assumes ymm is bitonic and performs a recursive half cleaner -template -X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_32bit(reg_t ymm) -{ - - const typename vtype::opmask_t oxAA = _mm256_set_epi32( - 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0); - const typename vtype::opmask_t oxCC = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0); - const typename vtype::opmask_t oxF0 = _mm256_set_epi32( - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0); - - // 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7 - ymm = cmp_merge( - ymm, - vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_4), ymm), - oxF0); - // 2) half_cleaner[4] - ymm = cmp_merge( - ymm, - vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_3), ymm), - oxCC); - // 3) half_cleaner[1] - ymm = cmp_merge( - ymm, vtype::template shuffle(ymm), oxAA); - return ymm; -} - /* * Assumes ymm is random and performs a full sorting network defined in * https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg @@ -85,7 +54,7 @@ X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm) struct avx2_32bit_swizzle_ops; template <> -struct ymm_vector { +struct avx2_vector { using type_t = int32_t; using reg_t = __m256i; using ymmi_t = __m256i; @@ -231,13 +200,9 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_ymm_32bit>(x); - } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_ymm_32bit>(x); } static reg_t cast_from(__m256i v){ return v; @@ -247,7 +212,7 @@ struct ymm_vector { } }; template <> -struct ymm_vector { +struct avx2_vector { using type_t = uint32_t; using reg_t = __m256i; using ymmi_t = __m256i; @@ -378,13 +343,9 @@ struct ymm_vector { { _mm256_storeu_si256((__m256i *)mem, x); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_ymm_32bit>(x); - } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_ymm_32bit>(x); } static reg_t cast_from(__m256i v){ return v; @@ -394,7 +355,7 @@ struct ymm_vector { } }; template <> -struct ymm_vector { +struct avx2_vector { using type_t = float; using reg_t = __m256; using ymmi_t = __m256i; @@ -440,6 +401,19 @@ struct ymm_vector { { return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ)); } + static opmask_t get_partial_loadmask(int size) + { + return (0x0001 << size) - 0x0001; + } + template + static opmask_t fpclass(reg_t x) + { + if constexpr (type == (0x01 | 0x80)){ + return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q)); + }else{ + static_assert(type == (0x01 | 0x80), "should not reach here"); + } + } template static reg_t mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base) @@ -533,13 +507,9 @@ struct ymm_vector { { _mm256_storeu_ps((float *)mem, x); } - static reg_t bitonic_merge(reg_t x) - { - return bitonic_merge_ymm_32bit>(x); - } static reg_t sort_vec(reg_t x) { - return sort_ymm_32bit>(x); + return sort_ymm_32bit>(x); } static reg_t cast_from(__m256i v){ return _mm256_castsi256_ps(v); @@ -549,32 +519,6 @@ struct ymm_vector { } }; -inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize) -{ - arrsize_t nan_count = 0; - __mmask8 loadmask = 0xFF; - while (arrsize > 0) { - if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; } - __m256 in_ymm = ymm_vector::maskz_loadu(loadmask, arr); - __m256i nanmask = _mm256_castps_si256( - _mm256_cmp_ps(in_ymm, in_ymm, _CMP_NEQ_UQ)); - nan_count += _mm_popcnt_u32(avx2_mask_helper32(nanmask)); - ymm_vector::mask_storeu(arr, nanmask, YMM_MAX_FLOAT); - arr += 8; - arrsize -= 8; - } - return nan_count; -} - -X86_SIMD_SORT_INLINE void -replace_inf_with_nan(float *arr, arrsize_t arrsize, arrsize_t nan_count) -{ - for (arrsize_t ii = arrsize - 1; nan_count > 0; --ii) { - arr[ii] = std::nan("1"); - nan_count -= 1; - } -} - struct avx2_32bit_swizzle_ops{ template X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){ @@ -635,7 +579,4 @@ struct avx2_32bit_swizzle_ops{ return vtype::cast_from(v1); } }; - -} // namespace avx2 -} // namespace xss #endif diff --git a/src/avx2-emu-funcs.hpp b/src/avx2-emu-funcs.hpp index 6f29fcae..f7c0dfb9 100644 --- a/src/avx2-emu-funcs.hpp +++ b/src/avx2-emu-funcs.hpp @@ -5,9 +5,6 @@ #include #include "xss-common-qsort.h" -namespace xss { -namespace avx2 { - constexpr auto avx2_mask_helper_lut32 = [] { std::array, 256> lut {}; for (int64_t i = 0; i <= 0xFF; i++) { @@ -97,9 +94,9 @@ static __m256i operator~(const avx2_mask_helper32 x) // Emulators for intrinsics missing from AVX2 compared to AVX512 template -T avx2_emu_reduce_max32(typename ymm_vector::reg_t x) +T avx2_emu_reduce_max32(typename avx2_vector::reg_t x) { - using vtype = ymm_vector; + using vtype = avx2_vector; using reg_t = typename vtype::reg_t; reg_t inter1 = vtype::max(x, vtype::template shuffle(x)); @@ -110,9 +107,9 @@ T avx2_emu_reduce_max32(typename ymm_vector::reg_t x) } template -T avx2_emu_reduce_min32(typename ymm_vector::reg_t x) +T avx2_emu_reduce_min32(typename avx2_vector::reg_t x) { - using vtype = ymm_vector; + using vtype = avx2_vector; using reg_t = typename vtype::reg_t; reg_t inter1 = vtype::min(x, vtype::template shuffle(x)); @@ -124,10 +121,10 @@ T avx2_emu_reduce_min32(typename ymm_vector::reg_t x) template void avx2_emu_mask_compressstoreu(void *base_addr, - typename ymm_vector::opmask_t k, - typename ymm_vector::reg_t reg) + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) { - using vtype = ymm_vector; + using vtype = avx2_vector; T *leftStore = (T *)base_addr; @@ -145,10 +142,10 @@ void avx2_emu_mask_compressstoreu(void *base_addr, template int32_t avx2_double_compressstore32(void *left_addr, void *right_addr, - typename ymm_vector::opmask_t k, - typename ymm_vector::reg_t reg) + typename avx2_vector::opmask_t k, + typename avx2_vector::reg_t reg) { - using vtype = ymm_vector; + using vtype = avx2_vector; T *leftStore = (T *)left_addr; T *rightStore = (T *)right_addr; @@ -168,10 +165,10 @@ int32_t avx2_double_compressstore32(void *left_addr, } template -typename ymm_vector::reg_t avx2_emu_max(typename ymm_vector::reg_t x, - typename ymm_vector::reg_t y) +typename avx2_vector::reg_t avx2_emu_max(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) { - using vtype = ymm_vector; + using vtype = avx2_vector; typename vtype::opmask_t nlt = vtype::ge(x, y); return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y), _mm256_castsi256_pd(x), @@ -179,16 +176,14 @@ typename ymm_vector::reg_t avx2_emu_max(typename ymm_vector::reg_t x, } template -typename ymm_vector::reg_t avx2_emu_min(typename ymm_vector::reg_t x, - typename ymm_vector::reg_t y) +typename avx2_vector::reg_t avx2_emu_min(typename avx2_vector::reg_t x, + typename avx2_vector::reg_t y) { - using vtype = ymm_vector; + using vtype = avx2_vector; typename vtype::opmask_t nlt = vtype::ge(x, y); return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x), _mm256_castsi256_pd(y), _mm256_castsi256_pd(nlt))); } -} // namespace avx2 -} // namespace x86_simd_sort #endif \ No newline at end of file diff --git a/src/avx512-16bit-common.h b/src/avx512-16bit-common.h index 41116b33..28c1c1fe 100644 --- a/src/avx512-16bit-common.h +++ b/src/avx512-16bit-common.h @@ -8,7 +8,6 @@ #define AVX512_16BIT_COMMON #include "xss-common-qsort.h" -#include "xss-network-qsort.hpp" /* * Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic diff --git a/src/avx512-32bit-qsort.hpp b/src/avx512-32bit-qsort.hpp index 4d244d6e..e96d4a9a 100644 --- a/src/avx512-32bit-qsort.hpp +++ b/src/avx512-32bit-qsort.hpp @@ -9,7 +9,6 @@ #define AVX512_QSORT_32BIT #include "xss-common-qsort.h" -#include "xss-network-qsort.hpp" /* * Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index 709d4ca4..6c440078 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -38,18 +38,8 @@ #include "xss-pivot-selection.hpp" #include "xss-network-qsort.hpp" -namespace xss{ -namespace avx2{ template -struct ymm_vector; - -inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize); -} -} - -// key-value sort routines -template -void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize); +struct avx2_vector; template bool is_a_nan(T elem) @@ -614,12 +604,12 @@ X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize) template void avx2_qsort(T *arr, arrsize_t arrsize) { - using vtype = xss::avx2::ymm_vector; + using vtype = avx2_vector; if (arrsize > 1) { /* std::is_floating_point_v<_Float16> == False, unless c++-23*/ if constexpr (std::is_floating_point_v) { arrsize_t nan_count - = xss::avx2::replace_nan_with_inf(arr, arrsize); + = replace_nan_with_inf(arr, arrsize); qsort_( arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize)); replace_inf_with_nan(arr, arrsize, nan_count); @@ -661,7 +651,7 @@ void avx2_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false) } UNUSED(hasnan); if (indx_last_elem >= k) { - qselect_, T>( + qselect_, T>( arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem)); } }