diff --git a/src/xss-network-qsort.hpp b/src/xss-network-qsort.hpp index b02c0b30..701ee774 100644 --- a/src/xss-network-qsort.hpp +++ b/src/xss-network-qsort.hpp @@ -1,16 +1,18 @@ #ifndef XSS_NETWORK_QSORT #define XSS_NETWORK_QSORT +#include "avx512-common-qsort.h" + template X86_SIMD_SORT_INLINE void bitonic_clean_n_vec(reg_t *regs) { -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int num = numVecs / 2; num >= 2; num /= 2) { -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int j = 0; j < numVecs; j += num) { -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < num / 2; i++) { COEX(regs[i + j], regs[i + j + num / 2]); } @@ -30,7 +32,7 @@ X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(reg_t *regs) } else if constexpr (numVecs > 2) { // Reverse upper half -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / 2; i++) { reg_t rev = vtype::reverse(regs[numVecs - i - 1]); reg_t maxV = vtype::max(regs[i], rev); @@ -44,7 +46,7 @@ X86_SIMD_SORT_INLINE void bitonic_merge_n_vec(reg_t *regs) bitonic_clean_n_vec(regs); // Now do bitonic_merge -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs; i++) { regs[i] = vtype::bitonic_merge(regs[i]); } @@ -59,7 +61,7 @@ X86_SIMD_SORT_INLINE void bitonic_fullmerge_n_vec(reg_t *regs) if constexpr (numPer > numVecs) return; else { -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / numPer; i++) { bitonic_merge_n_vec(regs + i * numPer); } @@ -79,7 +81,7 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int32_t N) // Generate masks for loading and storing typename vtype::opmask_t ioMasks[numVecs - numVecs / 2]; - #pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { int64_t num_to_read = std::min((int64_t)std::max(0, N - i * vtype::numlanes), @@ -88,19 +90,19 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int32_t N) } // Unmasked part of the load -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / 2; i++) { vecs[i] = vtype::loadu(arr + i * vtype::numlanes); } // Masked part of the load -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { vecs[i] = vtype::mask_loadu( vtype::zmm_max(), ioMasks[j], arr + i * vtype::numlanes); } // Sort each loaded vector -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs; i++) { vecs[i] = vtype::sort_vec(vecs[i]); } @@ -109,12 +111,12 @@ X86_SIMD_SORT_INLINE void sort_n_vec(typename vtype::type_t *arr, int32_t N) bitonic_fullmerge_n_vec(&vecs[0]); // Unmasked part of the store -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = 0; i < numVecs / 2; i++) { vtype::storeu(arr + i * vtype::numlanes, vecs[i]); } // Masked part of the store -#pragma GCC unroll 64 +X86_SIMD_SORT_UNROLL_LOOP(64) for (int i = numVecs / 2, j = 0; i < numVecs; i++, j++) { vtype::mask_storeu(arr + i * vtype::numlanes, ioMasks[j], vecs[i]); }