Skip to content

Commit

Permalink
Fixes/changes many small things
Browse files Browse the repository at this point in the history
  • Loading branch information
sterrettm2 authored and r-devulap committed Oct 20, 2023
1 parent 0e41ebc commit 15d6025
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 115 deletions.
97 changes: 19 additions & 78 deletions src/avx2-32bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,6 @@
#define NETWORK_32BIT_AVX2_3 5, 4, 7, 6, 1, 0, 3, 2
#define NETWORK_32BIT_AVX2_4 3, 2, 1, 0, 7, 6, 5, 4

namespace xss {
namespace avx2 {

// Assumes ymm is bitonic and performs a recursive half cleaner
template <typename vtype, typename reg_t = typename vtype::reg_t>
X86_SIMD_SORT_INLINE reg_t bitonic_merge_ymm_32bit(reg_t ymm)
{

const typename vtype::opmask_t oxAA = _mm256_set_epi32(
0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0, 0xFFFFFFFF, 0);
const typename vtype::opmask_t oxCC = _mm256_set_epi32(
0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0);
const typename vtype::opmask_t oxF0 = _mm256_set_epi32(
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0, 0);

// 1) half_cleaner[8]: compare 0-4, 1-5, 2-6, 3-7
ymm = cmp_merge<vtype>(
ymm,
vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_4), ymm),
oxF0);
// 2) half_cleaner[4]
ymm = cmp_merge<vtype>(
ymm,
vtype::permutexvar(_mm256_set_epi32(NETWORK_32BIT_AVX2_3), ymm),
oxCC);
// 3) half_cleaner[1]
ymm = cmp_merge<vtype>(
ymm, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(ymm), oxAA);
return ymm;
}

/*
* Assumes ymm is random and performs a full sorting network defined in
* https://en.wikipedia.org/wiki/Bitonic_sorter#/media/File:BitonicSort.svg
Expand Down Expand Up @@ -85,7 +54,7 @@ X86_SIMD_SORT_INLINE reg_t sort_ymm_32bit(reg_t ymm)
struct avx2_32bit_swizzle_ops;

template <>
struct ymm_vector<int32_t> {
struct avx2_vector<int32_t> {
using type_t = int32_t;
using reg_t = __m256i;
using ymmi_t = __m256i;
Expand Down Expand Up @@ -231,13 +200,9 @@ struct ymm_vector<int32_t> {
{
_mm256_storeu_si256((__m256i *)mem, x);
}
static reg_t bitonic_merge(reg_t x)
{
return bitonic_merge_ymm_32bit<ymm_vector<type_t>>(x);
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit<ymm_vector<type_t>>(x);
return sort_ymm_32bit<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v){
return v;
Expand All @@ -247,7 +212,7 @@ struct ymm_vector<int32_t> {
}
};
template <>
struct ymm_vector<uint32_t> {
struct avx2_vector<uint32_t> {
using type_t = uint32_t;
using reg_t = __m256i;
using ymmi_t = __m256i;
Expand Down Expand Up @@ -378,13 +343,9 @@ struct ymm_vector<uint32_t> {
{
_mm256_storeu_si256((__m256i *)mem, x);
}
static reg_t bitonic_merge(reg_t x)
{
return bitonic_merge_ymm_32bit<ymm_vector<type_t>>(x);
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit<ymm_vector<type_t>>(x);
return sort_ymm_32bit<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v){
return v;
Expand All @@ -394,7 +355,7 @@ struct ymm_vector<uint32_t> {
}
};
template <>
struct ymm_vector<float> {
struct avx2_vector<float> {
using type_t = float;
using reg_t = __m256;
using ymmi_t = __m256i;
Expand Down Expand Up @@ -440,6 +401,19 @@ struct ymm_vector<float> {
{
return _mm256_castps_si256(_mm256_cmp_ps(x, y, _CMP_EQ_OQ));
}
static opmask_t get_partial_loadmask(int size)
{
return (0x0001 << size) - 0x0001;
}
template <int type>
static opmask_t fpclass(reg_t x)
{
if constexpr (type == (0x01 | 0x80)){
return _mm256_castps_si256(_mm256_cmp_ps(x, x, _CMP_UNORD_Q));
}else{
static_assert(type == (0x01 | 0x80), "should not reach here");
}
}
template <int scale>
static reg_t
mask_i64gather(reg_t src, opmask_t mask, __m256i index, void const *base)
Expand Down Expand Up @@ -533,13 +507,9 @@ struct ymm_vector<float> {
{
_mm256_storeu_ps((float *)mem, x);
}
static reg_t bitonic_merge(reg_t x)
{
return bitonic_merge_ymm_32bit<ymm_vector<type_t>>(x);
}
static reg_t sort_vec(reg_t x)
{
return sort_ymm_32bit<ymm_vector<type_t>>(x);
return sort_ymm_32bit<avx2_vector<type_t>>(x);
}
static reg_t cast_from(__m256i v){
return _mm256_castsi256_ps(v);
Expand All @@ -549,32 +519,6 @@ struct ymm_vector<float> {
}
};

inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize)
{
arrsize_t nan_count = 0;
__mmask8 loadmask = 0xFF;
while (arrsize > 0) {
if (arrsize < 8) { loadmask = (0x01 << arrsize) - 0x01; }
__m256 in_ymm = ymm_vector<float>::maskz_loadu(loadmask, arr);
__m256i nanmask = _mm256_castps_si256(
_mm256_cmp_ps(in_ymm, in_ymm, _CMP_NEQ_UQ));
nan_count += _mm_popcnt_u32(avx2_mask_helper32(nanmask));
ymm_vector<float>::mask_storeu(arr, nanmask, YMM_MAX_FLOAT);
arr += 8;
arrsize -= 8;
}
return nan_count;
}

X86_SIMD_SORT_INLINE void
replace_inf_with_nan(float *arr, arrsize_t arrsize, arrsize_t nan_count)
{
for (arrsize_t ii = arrsize - 1; nan_count > 0; --ii) {
arr[ii] = std::nan("1");
nan_count -= 1;
}
}

struct avx2_32bit_swizzle_ops{
template <typename vtype, int scale>
X86_SIMD_SORT_INLINE typename vtype::reg_t swap_n(typename vtype::reg_t reg){
Expand Down Expand Up @@ -635,7 +579,4 @@ struct avx2_32bit_swizzle_ops{
return vtype::cast_from(v1);
}
};

} // namespace avx2
} // namespace xss
#endif
37 changes: 16 additions & 21 deletions src/avx2-emu-funcs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
#include <utility>
#include "xss-common-qsort.h"

namespace xss {
namespace avx2 {

constexpr auto avx2_mask_helper_lut32 = [] {
std::array<std::array<int32_t, 8>, 256> lut {};
for (int64_t i = 0; i <= 0xFF; i++) {
Expand Down Expand Up @@ -97,9 +94,9 @@ static __m256i operator~(const avx2_mask_helper32 x)

// Emulators for intrinsics missing from AVX2 compared to AVX512
template <typename T>
T avx2_emu_reduce_max32(typename ymm_vector<T>::reg_t x)
T avx2_emu_reduce_max32(typename avx2_vector<T>::reg_t x)
{
using vtype = ymm_vector<T>;
using vtype = avx2_vector<T>;
using reg_t = typename vtype::reg_t;

reg_t inter1 = vtype::max(x, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(x));
Expand All @@ -110,9 +107,9 @@ T avx2_emu_reduce_max32(typename ymm_vector<T>::reg_t x)
}

template <typename T>
T avx2_emu_reduce_min32(typename ymm_vector<T>::reg_t x)
T avx2_emu_reduce_min32(typename avx2_vector<T>::reg_t x)
{
using vtype = ymm_vector<T>;
using vtype = avx2_vector<T>;
using reg_t = typename vtype::reg_t;

reg_t inter1 = vtype::min(x, vtype::template shuffle<SHUFFLE_MASK(2, 3, 0, 1)>(x));
Expand All @@ -124,10 +121,10 @@ T avx2_emu_reduce_min32(typename ymm_vector<T>::reg_t x)

template <typename T>
void avx2_emu_mask_compressstoreu(void *base_addr,
typename ymm_vector<T>::opmask_t k,
typename ymm_vector<T>::reg_t reg)
typename avx2_vector<T>::opmask_t k,
typename avx2_vector<T>::reg_t reg)
{
using vtype = ymm_vector<T>;
using vtype = avx2_vector<T>;

T *leftStore = (T *)base_addr;

Expand All @@ -145,10 +142,10 @@ void avx2_emu_mask_compressstoreu(void *base_addr,
template <typename T>
int32_t avx2_double_compressstore32(void *left_addr,
void *right_addr,
typename ymm_vector<T>::opmask_t k,
typename ymm_vector<T>::reg_t reg)
typename avx2_vector<T>::opmask_t k,
typename avx2_vector<T>::reg_t reg)
{
using vtype = ymm_vector<T>;
using vtype = avx2_vector<T>;

T *leftStore = (T *)left_addr;
T *rightStore = (T *)right_addr;
Expand All @@ -168,27 +165,25 @@ int32_t avx2_double_compressstore32(void *left_addr,
}

template <typename T>
typename ymm_vector<T>::reg_t avx2_emu_max(typename ymm_vector<T>::reg_t x,
typename ymm_vector<T>::reg_t y)
typename avx2_vector<T>::reg_t avx2_emu_max(typename avx2_vector<T>::reg_t x,
typename avx2_vector<T>::reg_t y)
{
using vtype = ymm_vector<T>;
using vtype = avx2_vector<T>;
typename vtype::opmask_t nlt = vtype::ge(x, y);
return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(y),
_mm256_castsi256_pd(x),
_mm256_castsi256_pd(nlt)));
}

template <typename T>
typename ymm_vector<T>::reg_t avx2_emu_min(typename ymm_vector<T>::reg_t x,
typename ymm_vector<T>::reg_t y)
typename avx2_vector<T>::reg_t avx2_emu_min(typename avx2_vector<T>::reg_t x,
typename avx2_vector<T>::reg_t y)
{
using vtype = ymm_vector<T>;
using vtype = avx2_vector<T>;
typename vtype::opmask_t nlt = vtype::ge(x, y);
return _mm256_castpd_si256(_mm256_blendv_pd(_mm256_castsi256_pd(x),
_mm256_castsi256_pd(y),
_mm256_castsi256_pd(nlt)));
}
} // namespace avx2
} // namespace x86_simd_sort

#endif
1 change: 0 additions & 1 deletion src/avx512-16bit-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#define AVX512_16BIT_COMMON

#include "xss-common-qsort.h"
#include "xss-network-qsort.hpp"

/*
* Constants used in sorting 32 elements in a ZMM registers. Based on Bitonic
Expand Down
1 change: 0 additions & 1 deletion src/avx512-32bit-qsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#define AVX512_QSORT_32BIT

#include "xss-common-qsort.h"
#include "xss-network-qsort.hpp"

/*
* Constants used in sorting 16 elements in a ZMM registers. Based on Bitonic
Expand Down
18 changes: 4 additions & 14 deletions src/xss-common-qsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,8 @@
#include "xss-pivot-selection.hpp"
#include "xss-network-qsort.hpp"

namespace xss{
namespace avx2{
template <typename type>
struct ymm_vector;

inline arrsize_t replace_nan_with_inf(float *arr, int64_t arrsize);
}
}

// key-value sort routines
template <typename T1, typename T2>
void avx512_qsort_kv(T1 *keys, T2 *indexes, int64_t arrsize);
struct avx2_vector;

template <typename T>
bool is_a_nan(T elem)
Expand Down Expand Up @@ -614,12 +604,12 @@ X86_SIMD_SORT_INLINE void avx512_qsort(T *arr, arrsize_t arrsize)
template <typename T>
void avx2_qsort(T *arr, arrsize_t arrsize)
{
using vtype = xss::avx2::ymm_vector<T>;
using vtype = avx2_vector<T>;
if (arrsize > 1) {
/* std::is_floating_point_v<_Float16> == False, unless c++-23*/
if constexpr (std::is_floating_point_v<T>) {
arrsize_t nan_count
= xss::avx2::replace_nan_with_inf(arr, arrsize);
= replace_nan_with_inf<vtype>(arr, arrsize);
qsort_<vtype, T>(
arr, 0, arrsize - 1, 2 * (int64_t)log2(arrsize));
replace_inf_with_nan(arr, arrsize, nan_count);
Expand Down Expand Up @@ -661,7 +651,7 @@ void avx2_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan = false)
}
UNUSED(hasnan);
if (indx_last_elem >= k) {
qselect_<xss::avx2::ymm_vector<T>, T>(
qselect_<avx2_vector<T>, T>(
arr, k, 0, indx_last_elem, 2 * (int64_t)log2(indx_last_elem));
}
}
Expand Down

0 comments on commit 15d6025

Please sign in to comment.