From af19c28e2354ba1d405c21b6e6f197e29f76fd61 Mon Sep 17 00:00:00 2001 From: syzygy1 <3028851+syzygy1@users.noreply.github.com> Date: Mon, 14 Sep 2020 01:22:28 +0200 Subject: [PATCH] Cfish 12 --- src/Makefile | 103 ++++---- src/evaluate.h | 2 + src/misc.c | 2 +- src/misc.h | 12 + src/nnue.c | 658 ++++++++++++++++++++++++++++++++++++------------- 5 files changed, 558 insertions(+), 219 deletions(-) diff --git a/src/Makefile b/src/Makefile index 42861333..b1f085f9 100644 --- a/src/Makefile +++ b/src/Makefile @@ -102,6 +102,8 @@ vnni = no neon = no ARCH = auto native = no +embed = no +STRIP = strip ### 2.2 Architecture specific @@ -386,6 +388,10 @@ ifeq ($(COMP),gcc) LDFLAGS += -m$(bits) endif + ifeq ($(arch),$(filter $(arch),armv7)) + LDFLAGS += -latomic + endif + ifneq ($(KERNEL),Darwin) LDFLAGS += -Wl,--no-as-needed endif @@ -393,9 +399,6 @@ ifeq ($(COMP),gcc) ifneq ($(KERNEL),$(filter $(KERNEL),Linux Darwin Haiku)) CFLAGS += -Wno-pedantic-ms-format endif - - gccversion = $(shell $(CC) --version) - gccisclang = $(findstring clang,$(gccversion)) endif ifeq ($(COMP),mingw) @@ -424,15 +427,21 @@ endif ifeq ($(COMP),icc) comp=icc CC=icc - CFLAGS += -diag-disable 1476,10120 -Wcheck -Wabi -Wdeprecated -strict-ansi + CFLAGS += -diag-disable 344,711,2259,2330 -Wcheck -Wabi -Wdeprecated -strict-ansi endif ifeq ($(COMP),clang) comp=clang CC=clang CFLAGS += -pedantic -Wextra -Wshadow -# -Wno-missing-braces -Wno-missing-field-initializers -Wno-unknown-attributes - ifeq ($(ARCH),$(filter $(ARCH),armv7 armv8)) + + ifneq ($(KERNEL),Darwin) + ifneq ($(KERNEL),OpenBSD) + LDFLAGS += -latomic + endif + endif + + ifeq ($(arch),$(filter $(arch),armv7 armv8)) ifeq ($(OS),Android) CFLAGS += -m$(bits) LDFLAGS += -m$(bits) @@ -441,11 +450,6 @@ ifeq ($(COMP),clang) CFLAGS += -m$(bits) LDFLAGS += -m$(bits) endif - - ifeq ($(KERNEL),Darwin) - CFLAGS += - DEPENDFLAGS += - endif endif ifneq ($(COMP),mingw) @@ -454,17 +458,6 @@ ifeq ($(KERNEL),$(filter $(KERNEL),Linux Darwin Haiku)) endif endif -ifeq ($(comp),icc) - profile_make = icc-profile-make - profile_use = icc-profile-use -else ifeq ($(comp),clang) - profile_make = clang-profile-make - profile_use = clang-profile-use -else - profile_make = gcc-profile-make - profile_use = gcc-profile-use -endif - ifeq ($(KERNEL),Darwin) CFLAGS += -arch $(arch) -mmacosx-version-min=10.14 LDFLAGS += -arch $(arch) -mmacosx-version-min=10.14 @@ -475,18 +468,17 @@ endif # binutils. Currently we don't know how to make PGO builds with the NDK yet. ifeq ($(COMP),ndk) CFLAGS += -fPIE + comp=clang ifeq ($(arch),armv7) - comp=armv7a-linux-androideabi16-clang CC=armv7a-linux-androideabi16-clang - CCFLAGS += -mthumb -march=armv7-a -mfloat-abi=softfp -mfpu=neon + CFLAGS += -mthumb -march=armv7-a -mfloat-abi=softfp -mfpu=neon STRIP=arm-linux-androideabi-strip endif ifeq ($(arch),armv8) - comp=aarch64-linux-android21-clang CC=aarch64-linux-android21-clang STRIP=aarch64-linux-android-strip endif - LDFLAGS += -pie -lm -latomic + LDFLAGS += -static -latomic -z muldefs endif ### Allow overwriting CC from command line @@ -494,13 +486,36 @@ ifdef COMPCC CC=$(COMPCC) endif +ifeq ($(comp),icc) + profile_make = icc-profile-make + profile_use = icc-profile-use +else ifeq ($(comp),clang) + profile_make = clang-profile-make + profile_use = clang-profile-use +else + profile_make = gcc-profile-make + profile_use = gcc-profile-use +endif + +### Sometimes gcc is really clang +ifeq ($(COMP),gcc) + gccversion = $(shell $(CC) --version) + gccisclang = $(findstring clang,$(gccversion)) + ifneq ($(gccisclang),) + profile_make = clang-profile-make + profile_use = clang-profile-use + endif +endif + ### On mingw use Windows threads, otherwise POSIX ifneq ($(comp),mingw) # On Android Bionic's C library comes with its own pthread implementation bundled in ifneq ($(OS),Android) # Haiku has pthreads in its libroot, so only link it in on other platforms ifneq ($(KERNEL),Haiku) - LDFLAGS += -lpthread + ifneq ($(COMP),ndk) + LDFLAGS += -lpthread + endif endif endif endif @@ -544,19 +559,14 @@ ifeq ($(bits),64) CFLAGS += -DIS_64BIT endif -### 3.5 prefetch -ifeq ($(prefetch),yes) - ifeq ($(sse),yes) - CFLAGS += -msse - DEPENDFLAGS += -msse - endif -else +### 3.5 prefetch and sse +ifeq ($(prefetch),no) CFLAGS += -DNO_PREFETCH endif ### 3.6 popcnt ifeq ($(popcnt),yes) - ifeq ($(arch),$(filter $(arch),ppc64 armv8)) + ifeq ($(arch),$(filter $(arch),ppc64 armv7 armv8 arm64)) CFLAGS += -DUSE_POPCNT else ifeq ($(comp),icc) @@ -618,6 +628,10 @@ ifeq ($(sse2),yes) endif endif +ifeq ($(sse),yes) + CFLAGS += -msse -DUSE_SSE +endif + ifeq ($(mmx),yes) CFLAGS += -DUSE_MMX ifeq ($(comp),$(filter $(comp),gcc clang mingw)) @@ -660,6 +674,9 @@ endif ifeq ($(nnue),yes) CFLAGS += -DNNUE OBJS += nnue.o + ifeq ($(embed),yes) + CFLAGS += -DNNUE_EMBEDDED + endif endif ### 3.9 Link Time Optimization @@ -668,10 +685,7 @@ endif ifeq ($(lto),yes) ifeq ($(optimize),yes) ifeq ($(debug),no) - ifeq ($(COMP),ndk) - CFLAGS += -fltho=thin - LDFLAGS += $(CFLAGS) - else ifeq ($(comp),clang) + ifeq ($(comp),clang) CFLAGS += -flto=thin ifneq ($(findstring MINGW,$(KERNEL)),) CFLAGS += -fuse-ld=lld @@ -711,7 +725,7 @@ endif ### breaks Android 4.0 and earlier. ifeq ($(arch),armv7) CFLAGS += -fPIE - LDFLAGS += -fPIE -pie + LDFLAGS += -fPIE endif @@ -813,7 +827,7 @@ profile-build: net config-sanity objclean profileclean pgo: profile-build strip: - strip $(EXE) + $(STRIP) $(EXE) install: -mkdir -p -m 755 $(BINDIR) @@ -824,7 +838,7 @@ clean: objclean profileclean @rm -f .depend core net: - $(eval nnuenet := $(shell grep EvalFile ucioption.c | grep OPT_TYPE_STRING | sed 's/.*\(nn-[a-z0-9]\{12\}.nnue\).*/\1/')) + $(eval nnuenet := $(shell grep DefaultEvalFile evaluate.h | sed 's/.*\(nn-[a-z0-9]\{12\}.nnue\).*/\1/')) @echo "Default net: $(nnuenet)" $(eval nnuedownloadurl := https://tests.stockfishchess.org/api/nn/$(nnuenet)) $(eval curl_or_wget := $(shell if hash curl 2>/dev/null; then echo "curl -skL"; elif hash wget 2>/dev/null; then echo "wget -qO-"; fi)) @@ -874,6 +888,7 @@ config-sanity: @echo "vnni: '$(vnni)'" @echo "neon: '$(neon)'" @echo "native: '$(native)'" + @echo "embed: '$(embed)'" @echo "" @echo "Flags:" @echo "CC: $(CC)" @@ -902,7 +917,9 @@ config-sanity: @test "$(vnni)" = "yes" || test "$(vnni)" = "no" @test "$(neon)" = "yes" || test "$(neon)" = "no" @test "$(native)" = "yes" || test "$(native)" = "no" - @test "$(comp)" = "gcc" || test "$(comp)" = "icc" || test "$(comp)" = "mingw" || test "$(comp)" = "clang" + @test "$(embed)" = "yes" || test "$(embed)" = "no" + @test "$(comp)" = "gcc" || test "$(comp)" = "icc" || test "$(comp)" = "mingw" || test "$(comp)" = "clang" \ + || test "$(comp)" = "armv7a-linux-androideabi16-clang" || test "$(comp)" = "aarch64-linux-android21-clang" $(EXE): $(OBJS) $(CC) -o $@ $(OBJS) $(LDFLAGS) diff --git a/src/evaluate.h b/src/evaluate.h index 2e5a97a4..aebb7966 100644 --- a/src/evaluate.h +++ b/src/evaluate.h @@ -3,6 +3,8 @@ #include "types.h" +#define DefaultEvalFile "nn-308d71810dff.nnue" + enum { Tempo = 28 }; #ifdef NNUE diff --git a/src/misc.c b/src/misc.c index 10a56026..ad469b64 100644 --- a/src/misc.c +++ b/src/misc.c @@ -34,7 +34,7 @@ // Version number. If Version is left empty, then compile date in the format // DD-MM-YY and show in engine_info. -char Version[] = ""; +char Version[] = "12"; #ifndef _WIN32 pthread_mutex_t ioMutex = PTHREAD_MUTEX_INITIALIZER; diff --git a/src/misc.h b/src/misc.h index 0eb41de8..5fe7c326 100644 --- a/src/misc.h +++ b/src/misc.h @@ -167,4 +167,16 @@ INLINE uint16_t read_le_u16(void *p) return from_le_u16(*(uint16_t *)p); } +INLINE uint32_t readu_le_u32(const void *p) +{ + const uint8_t *q = p; + return q[0] | (q[1] << 8) | (q[2] << 16) | (q[3] << 24); +} + +INLINE uint16_t readu_le_u16(const void *p) +{ + const uint8_t *q = p; + return q[0] | (q[1] << 8); +} + #endif diff --git a/src/nnue.c b/src/nnue.c index 4a31c5e6..01c9d384 100644 --- a/src/nnue.c +++ b/src/nnue.c @@ -17,6 +17,9 @@ #elif defined(USE_SSE2) #include +#elif defined(USE_SSE) +#include + #elif defined(USE_MMX) #include @@ -30,6 +33,11 @@ #include "position.h" #include "uci.h" +#ifdef NNUE_EMBEDDED +#include "incbin.h" +INCBIN(Network, DefaultEvalFile); +#endif + // Old gcc on Windows is unable to provide a 32-byte aligned stack. // We need to hack around this when using AVX2 and AVX512. #if defined(__GNUC__ ) && (__GNUC__ < 9) && defined(_WIN32) \ @@ -69,7 +77,6 @@ enum { }; enum { - kMaxActiveDimensions = 30, kHalfDimensions = 256, FtInDims = 64 * PS_END, // 64 * 641 FtOutDims = kHalfDimensions * 2 @@ -80,17 +87,31 @@ enum { #undef USE_MMX #endif -// For certain architectures we transpose the weights matrix and make use -// of the sparseness of the vectors. Only SSE2 for now. -#if defined(USE_SSE2) // && !defined(USE_AVX2) +#if defined(USE_AVX512) +//#define TRANSPOSE + +#elif defined(USE_AVX2) #define TRANSPOSE -#define USE_MASK -#endif -#if !defined(USE_MMX) && !defined(USE_SSE2) && !defined(USE_NEON) +#elif defined(USE_SSE2) +#define TRANSPOSE + +#elif defined(USE_MMX) +//#define TRANSPOSE + +#elif defined(USE_NEON) +//#define TRANSPOSE + +#else /* fallback code */ #define TRANSPOSE #endif +#ifdef TRANSPOSE +#if defined(USE_SSE) +#define USE_MASK +#endif +#endif + static_assert(kHalfDimensions % 256 == 0, "kHalfDimensions should be a multiple of 256"); #ifdef USE_AVX512 @@ -118,11 +139,15 @@ typedef __m64 vec_t; #define vec_sub_16(a,b) _mm_sub_pi16(a,b) #elif USE_NEON +#define SIMD_WIDTH 128 typedef int8x8_t vec_t; // unused +#else +#define SIMD_WIDTH 8 // dummy + #endif -// NUM_REGS is used only in transform() +// NUM_REGS is used only in refresh/update_accumulator() #if defined(USE_AVX512) #define NUM_REGS 8 // only 8 are needed @@ -140,44 +165,41 @@ typedef int8x8_t vec_t; // unused #ifndef TRANSPOSE #if defined(USE_MMX) || (defined(USE_SSE2) && !defined(USE_SSSE3)) -typedef int16_t clipped_t; //SSE2 and MMX have no int8 multiply. +typedef int16_t clipped_t; // SSE2 and MMX have no int8 multiply. typedef int16_t weight_t; #else -typedef uint8_t clipped_t; +typedef int8_t clipped_t; typedef int8_t weight_t; #endif +typedef uint8_t mask_t; // dummy + #else /* TRANSPOSE */ -typedef uint8_t clipped_t; -#if defined(USE_MMX) || (defined(USE_SSE2) && !defined(USE_SSSE3)) +typedef int8_t clipped_t; +#if defined(USE_MMX) || (defined(USE_SSE2) && !defined(USE_AVX2)) typedef int16_t weight_t; #else typedef int8_t weight_t; #endif -#if defined(USE_AVX2) +#if defined(USE_AVX512) +typedef __mmask64 mask_t; +#elif defined(USE_AVX2) typedef uint32_t mask_t; -#else +#elif defined(USE_SSE2) typedef uint16_t mask_t; +#elif defined(USE_MMX) +typedef uint8_t mask_t; +#else +typedef uint8_t mask_t; // dummy #endif #endif -#define LOOP_4(f) f(0);f(1);f(2);f(3) -#define LOOP_8(f) LOOP_4(f); f(4);f(5);f(6);f(7) -#define LOOP_16(f) LOOP_8(f); f(8);f(9);f(10);f(11);f(12);f(13);f(14);f(15) - -static uint32_t read_uint32_t(FILE *F) -{ - uint32_t v; - fread(&v, 4, 1, F); - return from_le_u32(v); -} - typedef struct { size_t size; - unsigned values[kMaxActiveDimensions]; + unsigned values[30]; } IndexList; INLINE Square orient(Color c, Square s) @@ -255,7 +277,7 @@ static alignas(64) weight_t output_weights [1 * 32]; static alignas(64) int32_t hidden1_biases[32]; static alignas(64) int32_t hidden2_biases[32]; -static int32_t output_biases [1]; +static int32_t output_biases[1]; INLINE void affine_propagate(clipped_t *input, int32_t *output, unsigned inDims, unsigned outDims, int32_t *biases, weight_t *weights) @@ -263,34 +285,34 @@ INLINE void affine_propagate(clipped_t *input, int32_t *output, unsigned inDims, assert(inDims % 32 == 0); #if defined(USE_AVX512) - const unsigned numChunks = inDims / 64; + const unsigned numChunks = (inDims * 8) / SIMD_WIDTH; + __m512i *inVec = (__m512i *)input; #if !defined(USE_VNNI) const __m512i kOnes = _mm512_set1_epi16(1); #endif - __m512i *inVec = (__m512i *)input; #elif defined(USE_AVX2) - const unsigned numChunks = inDims / 32; + const unsigned numChunks = (inDims * 8) / SIMD_WIDTH; __m256i *inVec = (__m256i *)input; #if !defined(USE_VNNI) const __m256i kOnes = _mm256_set1_epi16(1); #endif -#elif defined(USE_SSSE3) - const unsigned numChunks = inDims / 32; - const __m128i kOnes = _mm_set1_epi16(1); +#elif defined(USE_SSSE3) && !defined(TRANSPOSE) + const unsigned numChunks = (inDims * 8) / SIMD_WIDTH; __m128i *inVec = (__m128i *)input; + const __m128i kOnes = _mm_set1_epi16(1); #elif defined(USE_SSE2) - const unsigned numChunks = inDims / 16; + const unsigned numChunks = (inDims * 16) / SIMD_WIDTH; __m128i *inVec = (__m128i *)input; #elif defined(USE_MMX) - const unsigned numChunks = inDims / 8; + const unsigned numChunks = (inDims * 16) / SIMD_WIDTH; __m64 *inVec = (__m64 *)input; #elif defined(USE_NEON) - const unsigned numChunks = inDims / 16; + const unsigned numChunks = (inDims * 8) / SIMD_WIDTH; int8x8_t *inVec = (int8x8_t *)input; #endif @@ -342,10 +364,10 @@ INLINE void affine_propagate(clipped_t *input, int32_t *output, unsigned inDims, sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB)); output[i] = _mm_cvtsi128_si32(sum128) + biases[i]; -#elif defined(USE_SSSE3) +#elif defined(USE_SSSE3) && !defined(TRANSPOSE) __m128i sum = _mm_setzero_si128(); __m128i *row = (__m128i *)&weights[offset]; - for (unsigned j = 0; j < numChunks; j++) { + for (unsigned j = 0; j < numChunks / 2; j++) { __m128i product0 = _mm_maddubs_epi16(inVec[2 * j], row[2 * j]); product0 = _mm_madd_epi16(product0, kOnes); sum = _mm_add_epi32(sum, product0); @@ -360,7 +382,7 @@ INLINE void affine_propagate(clipped_t *input, int32_t *output, unsigned inDims, #elif defined(USE_SSE2) __m128i sum = _mm_setzero_si128(), sum1 = sum; __m128i *row = (__m128i *)&weights[offset]; - for (unsigned j = 0; j < numChunks; j++) { + for (unsigned j = 0; j < numChunks / 2; j++) { __m128i product0 = _mm_madd_epi16(inVec[2 * j], row[2 * j]); sum = _mm_add_epi32(sum, product0); __m128i product1 = _mm_madd_epi16(inVec[2 * j + 1], row[2 * j + 1]); @@ -375,8 +397,8 @@ INLINE void affine_propagate(clipped_t *input, int32_t *output, unsigned inDims, // adding 1 or 4 numbers per loop is slower, 2 seems optimal __m64 s0 = _mm_setzero_si64(), s1 = s0; __m64 *row = (__m64 *)&weights[offset]; - for (unsigned j = 0; j < numChunks; j++) { - s0 = _mm_add_pi32(s0, _mm_madd_pi16(row[2 * j + 0], inVec[2 * j + 0])); + for (unsigned j = 0; j < numChunks / 2; j++) { + s0 = _mm_add_pi32(s0, _mm_madd_pi16(row[2 * j], inVec[2 * j])); s1 = _mm_add_pi32(s1, _mm_madd_pi16(row[2 * j + 1], inVec[2 * j + 1])); } __m64 sum = _mm_add_pi32(s0, s1); @@ -387,13 +409,14 @@ INLINE void affine_propagate(clipped_t *input, int32_t *output, unsigned inDims, int32x4_t sum = {biases[i]}; int8x8_t *row = (int8x8_t *)&weights[offset]; for (unsigned j = 0; j < numChunks; j++) { - int16x8_t product = vmull_s8(inVec[j * 2], row[j * 2]); - product = vmlal_s8(product, inVec[j * 2 + 1], row[j * 2 + 1]); + int16x8_t product = vmull_s8(inVec[2 * j], row[2 * j]); + product = vmlal_s8(product, inVec[2 * j + 1], row[2 * j + 1]); sum = vpadalq_s16(sum, product); } output[i] = sum[0] + sum[1] + sum[2] + sum[3]; #else + (void)numChunks; int32_t sum = biases[i]; for (unsigned j = 0; j < inDims; j++) sum += weights[offset + j] * input[j]; @@ -407,9 +430,20 @@ INLINE void affine_propagate(clipped_t *input, int32_t *output, unsigned inDims, INLINE void clip_propagate(int32_t *input, clipped_t *output, unsigned numDims) { - assert(numDims % 32 == 0); + assert(numDims == 32); -#if defined(USE_AVX2) +#if defined(USE_AVX512) + (void)numDims; + const __m512i kZero = _mm512_setzero_si512(); + const __m512i kOffsets = _mm512_set_epi32(0,0,0,0,0,0,0,0,13,9,5,1,12,8,4,0); + __m512i *in = (__m512i *)input; + __m256i *out = (__m256i *)output; + __m512i words = _mm512_srai_epi16(_mm512_packs_epi32(in[0], in[1]), SHIFT); + out[0] = _mm256_max_epi8(_mm512_castsi512_si256( + _mm512_permutexvar_epi32(kOffsets, _mm512_packs_epi16(words, kZero))), + _mm256_setzero_si256()); + +#elif defined(USE_AVX2) const unsigned numChunks = numDims / 32; const __m256i kZero = _mm256_setzero_si256(); const __m256i kOffsets = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); @@ -498,82 +532,90 @@ INLINE void clip_propagate(int32_t *input, clipped_t *output, unsigned numDims) static_assert(FtOutDims % 64 == 0, "FtOutDims not a multiple of 64"); INLINE bool next_idx(unsigned *idx, unsigned *offset, uint64_t *v, - uint64_t *mask, unsigned inDims) + mask_t *mask, unsigned inDims) { while (*v == 0) { *offset += 64; if (*offset >= inDims) return false; - *v = mask[*offset / 64]; + memcpy(v, (char *)mask + (*offset / 8), 8); } *idx = *offset + __builtin_ctzll(*v); *v &= *v - 1; return true; } -#ifdef USE_AVX2 -INLINE void affine_txfm(uint8_t *input, void *output, unsigned inDims, +#if defined(USE_AVX2) +INLINE void affine_txfm(int8_t *input, void *output, unsigned inDims, unsigned outDims, const int32_t *biases, weight_t *weights, - uint64_t *inMask, mask_t *outMask, - const bool pack8_and_calc_mask) + mask_t *inMask, mask_t *outMask, const bool pack8_and_calc_mask) { assert(outDims == 32); (void)outDims; const __m256i kZero = _mm256_setzero_si256(); -#define TMP(j) __m256i out_##j = ((__m256i *)biases)[j]; - LOOP_4(TMP); -#undef TMP + __m256i out_0 = ((__m256i *)biases)[0]; + __m256i out_1 = ((__m256i *)biases)[1]; + __m256i out_2 = ((__m256i *)biases)[2]; + __m256i out_3 = ((__m256i *)biases)[3]; __m256i first, second; - uint64_t v = inMask[0]; + uint64_t v; unsigned idx; + memcpy(&v, inMask, 8); for (unsigned offset = 0; offset < inDims;) { if (!next_idx(&idx, &offset, &v, inMask, inDims)) break; first = ((__m256i *)weights)[idx]; - uint16_t factor = input[idx]; + uint16_t factor = (uint8_t)input[idx]; if (next_idx(&idx, &offset, &v, inMask, inDims)) { second = ((__m256i *)weights)[idx]; factor |= input[idx] << 8; } else { second = kZero; } - __m256i mul = _mm256_set1_epi16(factor), prod; + __m256i mul = _mm256_set1_epi16(factor), prod, signs; prod = _mm256_maddubs_epi16(mul, _mm256_unpacklo_epi8(first, second)); - out_0 = _mm256_add_epi32(out_0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(prod))); - out_1 = _mm256_add_epi32(out_1, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(_mm256_permute4x64_epi64(prod, 0xE)))); + signs = _mm256_cmpgt_epi16(kZero, prod); + out_0 = _mm256_add_epi32(out_0, _mm256_unpacklo_epi16(prod, signs)); + out_1 = _mm256_add_epi32(out_1, _mm256_unpackhi_epi16(prod, signs)); prod = _mm256_maddubs_epi16(mul, _mm256_unpackhi_epi8(first, second)); - out_2 = _mm256_add_epi32(out_2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(prod))); - out_3 = _mm256_add_epi32(out_3, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(_mm256_permute4x64_epi64(prod, 0xE)))); + signs = _mm256_cmpgt_epi16(kZero, prod); + out_2 = _mm256_add_epi32(out_2, _mm256_unpacklo_epi16(prod, signs)); + out_3 = _mm256_add_epi32(out_3, _mm256_unpackhi_epi16(prod, signs)); } - __m256i out_in16_0 = _mm256_srai_epi16(_mm256_packs_epi32(out_0, out_1), SHIFT); - __m256i out_in16_1 = _mm256_srai_epi16(_mm256_packs_epi32(out_2, out_3), SHIFT); + __m256i out16_0 = _mm256_srai_epi16(_mm256_packs_epi32(out_0, out_1), SHIFT); + __m256i out16_1 = _mm256_srai_epi16(_mm256_packs_epi32(out_2, out_3), SHIFT); __m256i *outVec = (__m256i *)output; if (pack8_and_calc_mask) { - outVec[0] = _mm256_packs_epi16(out_in16_0, out_in16_1); + outVec[0] = _mm256_packs_epi16(out16_0, out16_1); outMask[0] = _mm256_movemask_epi8(_mm256_cmpgt_epi8(outVec[0], kZero)); } else { - outVec[0] = _mm256_max_epi8(_mm256_packs_epi16(out_in16_0, out_in16_1), kZero); + outVec[0] = _mm256_max_epi8(_mm256_packs_epi16(out16_0, out16_1), kZero); } } -#elif USE_SSSE3 -INLINE void affine_txfm(uint8_t *input, void *output, unsigned inDims, +#elif AVOID_USE_SSSE3 +INLINE void affine_txfm(int8_t *input, void *output, unsigned inDims, unsigned outDims, const int32_t *biases, weight_t *weights, - uint64_t *inMask, mask_t *outMask, - const bool pack8_and_calc_mask) + mask_t *inMask, mask_t *outMask, const bool pack8_and_calc_mask) { assert(outDims == 32); const __m128i kZeros[2] = { 0 }; -#define TMP(j) __m128i out_##j = ((__m128i *)biases)[j]; - LOOP_8(TMP); -#undef TMP + __m128i out_0 = ((__m128i *)biases)[0]; + __m128i out_1 = ((__m128i *)biases)[1]; + __m128i out_2 = ((__m128i *)biases)[2]; + __m128i out_3 = ((__m128i *)biases)[3]; + __m128i out_4 = ((__m128i *)biases)[4]; + __m128i out_5 = ((__m128i *)biases)[5]; + __m128i out_6 = ((__m128i *)biases)[6]; + __m128i out_7 = ((__m128i *)biases)[7]; const __m128i *first, *second; - uint64_t v = inMask[0]; + uint64_t v; unsigned idx; + memcpy(&v, inMask, 8); for (unsigned offset = 0; offset < inDims;) { if (!next_idx(&idx, &offset, &v, inMask, inDims)) break; @@ -616,53 +658,57 @@ INLINE void affine_txfm(uint8_t *input, void *output, unsigned inDims, #endif } - __m128i out_in16_0 = _mm_srai_epi16(_mm_packs_epi32(out_0, out_1), SHIFT); - __m128i out_in16_1 = _mm_srai_epi16(_mm_packs_epi32(out_2, out_3), SHIFT); - __m128i out_in16_2 = _mm_srai_epi16(_mm_packs_epi32(out_4, out_5), SHIFT); - __m128i out_in16_3 = _mm_srai_epi16(_mm_packs_epi32(out_6, out_7), SHIFT); + __m128i out16_0 = _mm_srai_epi16(_mm_packs_epi32(out_0, out_1), SHIFT); + __m128i out16_1 = _mm_srai_epi16(_mm_packs_epi32(out_2, out_3), SHIFT); + __m128i out16_2 = _mm_srai_epi16(_mm_packs_epi32(out_4, out_5), SHIFT); + __m128i out16_3 = _mm_srai_epi16(_mm_packs_epi32(out_6, out_7), SHIFT); __m128i *outVec = (__m128i *)output; if (pack8_and_calc_mask) { - outVec[0] = _mm_packs_epi16(out_in16_0, out_in16_1); + outVec[0] = _mm_packs_epi16(out16_0, out16_1); outMask[0] = _mm_movemask_epi8(_mm_cmpgt_epi8(outVec[0], kZeros[0])); - outVec[1] = _mm_packs_epi16(out_in16_2, out_in16_3); + outVec[1] = _mm_packs_epi16(out16_2, out16_3); outMask[1] = _mm_movemask_epi8(_mm_cmpgt_epi8(outVec[1], kZeros[0])); } else { #if defined(USE_SSE41) - outVec[0] = _mm_max_epi8(_mm_packs_epi16(out_in16_0, out_in16_1), kZeros[0]); - outVec[1] = _mm_max_epi8(_mm_packs_epi16(out_in16_2, out_in16_3), kZeros[0]); + outVec[0] = _mm_max_epi8(_mm_packs_epi16(out16_0, out16_1), kZeros[0]); + outVec[1] = _mm_max_epi8(_mm_packs_epi16(out16_2, out16_3), kZeros[0]); #else const __m128i k0x80s = _mm_set1_epi8(-128); - outVec[0] = _mm_subs_epi8(_mm_adds_epi8(_mm_packs_epi16(out_in16_0, out_in16_1), k0x80s), k0x80s); - outVec[1] = _mm_subs_epi8(_mm_adds_epi8(_mm_packs_epi16(out_in16_2, out_in16_3), k0x80s), k0x80s); + outVec[0] = _mm_subs_epi8(_mm_adds_epi8(_mm_packs_epi16(out16_0, out16_1), k0x80s), k0x80s); + outVec[1] = _mm_subs_epi8(_mm_adds_epi8(_mm_packs_epi16(out16_2, out16_3), k0x80s), k0x80s); #endif } } #elif defined(USE_SSE2) INLINE void affine_txfm(clipped_t *input, void *output, unsigned inDims, unsigned outDims, const int32_t *biases, weight_t *weights, - uint64_t *inMask, mask_t *outMask, - const bool pack8_and_calc_mask) + mask_t *inMask, mask_t *outMask, const bool pack8_and_calc_mask) { assert(outDims == 32); const __m128i kZeros[4] = { 0 }; - __m128i *inVec = (__m128i *)input; -#define TMP(j) __m128i out_##j = ((__m128i *)biases)[j] - LOOP_8(TMP); -#undef TMP + __m128i out_0 = ((__m128i *)biases)[0]; + __m128i out_1 = ((__m128i *)biases)[1]; + __m128i out_2 = ((__m128i *)biases)[2]; + __m128i out_3 = ((__m128i *)biases)[3]; + __m128i out_4 = ((__m128i *)biases)[4]; + __m128i out_5 = ((__m128i *)biases)[5]; + __m128i out_6 = ((__m128i *)biases)[6]; + __m128i out_7 = ((__m128i *)biases)[7]; const __m128i *first, *second; - uint64_t v = inMask[0]; + uint64_t v; unsigned idx; + memcpy(&v, inMask, 8); for (unsigned offset = 0; offset < inDims;) { if (!next_idx(&idx, &offset, &v, inMask, inDims)) break; first = (__m128i *)&weights[outDims * idx]; - uint32_t factor = ((uint8_t *)inVec)[idx]; + uint32_t factor = input[idx]; if (next_idx(&idx, &offset, &v, inMask, inDims)) { second = (__m128i *)&weights[outDims * idx]; - factor |= ((uint8_t *)inVec)[idx] << 16; + factor |= input[idx] << 16; } else { second = kZeros; } @@ -677,30 +723,215 @@ INLINE void affine_txfm(clipped_t *input, void *output, unsigned inDims, out_7 = _mm_add_epi32(out_7, _mm_madd_epi16(mul, _mm_unpackhi_epi16(first[3],second[3]))); } - __m128i out_in16_0 = _mm_srai_epi16(_mm_packs_epi32(out_0, out_1), SHIFT); - __m128i out_in16_1 = _mm_srai_epi16(_mm_packs_epi32(out_2, out_3), SHIFT); - __m128i out_in16_2 = _mm_srai_epi16(_mm_packs_epi32(out_4, out_5), SHIFT); - __m128i out_in16_3 = _mm_srai_epi16(_mm_packs_epi32(out_6, out_7), SHIFT); + __m128i out16_0 = _mm_srai_epi16(_mm_packs_epi32(out_0, out_1), SHIFT); + __m128i out16_1 = _mm_srai_epi16(_mm_packs_epi32(out_2, out_3), SHIFT); + __m128i out16_2 = _mm_srai_epi16(_mm_packs_epi32(out_4, out_5), SHIFT); + __m128i out16_3 = _mm_srai_epi16(_mm_packs_epi32(out_6, out_7), SHIFT); __m128i *outVec = (__m128i *)output; if (pack8_and_calc_mask) { - outVec[0] = _mm_packs_epi16(out_in16_0, out_in16_1); + outVec[0] = _mm_packs_epi16(out16_0, out16_1); outMask[0] = _mm_movemask_epi8(_mm_cmpgt_epi8(outVec[0], kZeros[0])); - outVec[1] = _mm_packs_epi16(out_in16_2, out_in16_3); + outVec[1] = _mm_packs_epi16(out16_2, out16_3); outMask[1] = _mm_movemask_epi8(_mm_cmpgt_epi8(outVec[1], kZeros[0])); } else { +#if defined(USE_SSE41) + const __m128i kx07f = _mm_set1_epi16(127); + outVec[0] = _mm_min_epi16(_mm_max_epi16(out16_0, kZeros[0]), kx07f); + outVec[1] = _mm_min_epi16(_mm_max_epi16(out16_1, kZeros[0]), kx07f); + outVec[2] = _mm_min_epi16(_mm_max_epi16(out16_2, kZeros[0]), kx07f); + outVec[3] = _mm_min_epi16(_mm_max_epi16(out16_3, kZeros[0]), kx07f); +#else const __m128i k0x7f80 = _mm_set1_epi16(0x7f80); const __m128i k0x0080 = _mm_set1_epi16(0x0080); const __m128i k0x8000 = _mm_set1_epi16(-0x8000); -#define TMP(j) outVec[j] = _mm_subs_epu16(_mm_add_epi16(_mm_adds_epi16(out_in16_##j, k0x7f80), k0x0080), k0x8000); - LOOP_4(TMP); -#undef TMP + outVec[0] = _mm_subs_epu16(_mm_add_epi16(_mm_adds_epi16(out16_0, k0x7f80), k0x0080), k0x8000); + outVec[1] = _mm_subs_epu16(_mm_add_epi16(_mm_adds_epi16(out16_1, k0x7f80), k0x0080), k0x8000); + outVec[2] = _mm_subs_epu16(_mm_add_epi16(_mm_adds_epi16(out16_2, k0x7f80), k0x0080), k0x8000); + outVec[3] = _mm_subs_epu16(_mm_add_epi16(_mm_adds_epi16(out16_3, k0x7f80), k0x0080), k0x8000); +#endif + } +} +#elif defined(USE_MMX) && defined(USE_SSE) +INLINE void affine_txfm(clipped_t *input, void *output, unsigned inDims, + unsigned outDims, const int32_t *biases, weight_t *weights, + mask_t *inMask, mask_t *outMask, const bool pack8_and_calc_mask) +{ + assert(outDims == 32); + + const __m64 kZeros[8] = { 0 }; + __m64 out_0 = ((__m64 *)biases)[0]; + __m64 out_1 = ((__m64 *)biases)[1]; + __m64 out_2 = ((__m64 *)biases)[2]; + __m64 out_3 = ((__m64 *)biases)[3]; + __m64 out_4 = ((__m64 *)biases)[4]; + __m64 out_5 = ((__m64 *)biases)[5]; + __m64 out_6 = ((__m64 *)biases)[6]; + __m64 out_7 = ((__m64 *)biases)[7]; + __m64 out_8 = ((__m64 *)biases)[8]; + __m64 out_9 = ((__m64 *)biases)[9]; + __m64 out_10 = ((__m64 *)biases)[10]; + __m64 out_11 = ((__m64 *)biases)[11]; + __m64 out_12 = ((__m64 *)biases)[12]; + __m64 out_13 = ((__m64 *)biases)[13]; + __m64 out_14 = ((__m64 *)biases)[14]; + __m64 out_15 = ((__m64 *)biases)[15]; + const __m64 *first, *second; + uint64_t v; + unsigned idx; + + memcpy(&v, inMask, 8); + for (unsigned offset = 0; offset < inDims;) { + if (!next_idx(&idx, &offset, &v, inMask, inDims)) + break; + first = (__m64 *)&weights[outDims * idx]; + uint32_t factor = input[idx]; + if (next_idx(&idx, &offset, &v, inMask, inDims)) { + second = (__m64 *)&weights[outDims * idx]; + factor |= input[idx] << 16; + } else { + second = kZeros; + } + __m64 mul = _mm_set1_pi32(factor); + out_0 = _mm_add_pi32(out_0, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[0],second[0]))); + out_1 = _mm_add_pi32(out_1, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[0],second[0]))); + out_2 = _mm_add_pi32(out_2, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[1],second[1]))); + out_3 = _mm_add_pi32(out_3, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[1],second[1]))); + out_4 = _mm_add_pi32(out_4, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[2],second[2]))); + out_5 = _mm_add_pi32(out_5, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[2],second[2]))); + out_6 = _mm_add_pi32(out_6, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[3],second[3]))); + out_7 = _mm_add_pi32(out_7, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[3],second[3]))); + out_8 = _mm_add_pi32(out_8, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[4],second[4]))); + out_9 = _mm_add_pi32(out_9, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[4],second[4]))); + out_10 = _mm_add_pi32(out_10, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[5],second[5]))); + out_11 = _mm_add_pi32(out_11, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[5],second[5]))); + out_12 = _mm_add_pi32(out_12, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[6],second[6]))); + out_13 = _mm_add_pi32(out_13, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[6],second[6]))); + out_14 = _mm_add_pi32(out_14, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[7],second[7]))); + out_15 = _mm_add_pi32(out_15, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[7],second[7]))); + } + + __m64 out16_0 = _mm_srai_pi16(_mm_packs_pi32(out_0, out_1), SHIFT); + __m64 out16_1 = _mm_srai_pi16(_mm_packs_pi32(out_2, out_3), SHIFT); + __m64 out16_2 = _mm_srai_pi16(_mm_packs_pi32(out_4, out_5), SHIFT); + __m64 out16_3 = _mm_srai_pi16(_mm_packs_pi32(out_6, out_7), SHIFT); + __m64 out16_4 = _mm_srai_pi16(_mm_packs_pi32(out_8, out_9), SHIFT); + __m64 out16_5 = _mm_srai_pi16(_mm_packs_pi32(out_10, out_11), SHIFT); + __m64 out16_6 = _mm_srai_pi16(_mm_packs_pi32(out_12, out_13), SHIFT); + __m64 out16_7 = _mm_srai_pi16(_mm_packs_pi32(out_14, out_15), SHIFT); + + __m64 *outVec = (__m64 *)output; + if (pack8_and_calc_mask) { + outVec[0] = _mm_packs_pi16(out16_0, out16_1); + outMask[0] = _mm_movemask_pi8(_mm_cmpgt_pi8(outVec[0], kZeros[0])); + outVec[1] = _mm_packs_pi16(out16_2, out16_3); + outMask[1] = _mm_movemask_pi8(_mm_cmpgt_pi8(outVec[1], kZeros[0])); + outVec[2] = _mm_packs_pi16(out16_4, out16_5); + outMask[2] = _mm_movemask_pi8(_mm_cmpgt_pi8(outVec[2], kZeros[0])); + outVec[3] = _mm_packs_pi16(out16_6, out16_7); + outMask[3] = _mm_movemask_pi8(_mm_cmpgt_pi8(outVec[3], kZeros[0])); + } else { + const __m64 kx07f = _mm_set1_pi16(127); + outVec[0] = _mm_min_pi16(_mm_max_pi16(out16_0, kZeros[0]), kx07f); + outVec[1] = _mm_min_pi16(_mm_max_pi16(out16_1, kZeros[0]), kx07f); + outVec[2] = _mm_min_pi16(_mm_max_pi16(out16_2, kZeros[0]), kx07f); + outVec[3] = _mm_min_pi16(_mm_max_pi16(out16_3, kZeros[0]), kx07f); + outVec[4] = _mm_min_pi16(_mm_max_pi16(out16_4, kZeros[0]), kx07f); + outVec[5] = _mm_min_pi16(_mm_max_pi16(out16_5, kZeros[0]), kx07f); + outVec[6] = _mm_min_pi16(_mm_max_pi16(out16_6, kZeros[0]), kx07f); + outVec[7] = _mm_min_pi16(_mm_max_pi16(out16_7, kZeros[0]), kx07f); + } +} +#elif defined(USE_MMX) +INLINE void affine_txfm(clipped_t *input, void *output, unsigned inDims, + unsigned outDims, const int32_t *biases, weight_t *weights, + mask_t *inMask, mask_t *outMask, const bool pack8_and_calc_mask) +{ + assert(outDims == 32); + + (void)inMask; (void)outMask; + const __m64 kZeros[8] = { 0 }; + __m64 out_0 = ((__m64 *)biases)[0]; + __m64 out_1 = ((__m64 *)biases)[1]; + __m64 out_2 = ((__m64 *)biases)[2]; + __m64 out_3 = ((__m64 *)biases)[3]; + __m64 out_4 = ((__m64 *)biases)[4]; + __m64 out_5 = ((__m64 *)biases)[5]; + __m64 out_6 = ((__m64 *)biases)[6]; + __m64 out_7 = ((__m64 *)biases)[7]; + __m64 out_8 = ((__m64 *)biases)[8]; + __m64 out_9 = ((__m64 *)biases)[9]; + __m64 out_10 = ((__m64 *)biases)[10]; + __m64 out_11 = ((__m64 *)biases)[11]; + __m64 out_12 = ((__m64 *)biases)[12]; + __m64 out_13 = ((__m64 *)biases)[13]; + __m64 out_14 = ((__m64 *)biases)[14]; + __m64 out_15 = ((__m64 *)biases)[15]; + const __m64 *first, *second; + + for (unsigned idx = 0; idx < inDims; idx++) { + if (input[idx] <= 0) + continue; + uint32_t factor = input[idx]; + first = (__m64 *)&weights[outDims * idx]; + while (++idx < inDims && input[idx] <= 0); + if (idx < inDims) { + second = (__m64 *)&weights[outDims * idx]; + factor |= input[idx] << 16; + } else + second = kZeros; + __m64 mul = _mm_set1_pi32(factor); + out_0 = _mm_add_pi32(out_0, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[0],second[0]))); + out_1 = _mm_add_pi32(out_1, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[0],second[0]))); + out_2 = _mm_add_pi32(out_2, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[1],second[1]))); + out_3 = _mm_add_pi32(out_3, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[1],second[1]))); + out_4 = _mm_add_pi32(out_4, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[2],second[2]))); + out_5 = _mm_add_pi32(out_5, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[2],second[2]))); + out_6 = _mm_add_pi32(out_6, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[3],second[3]))); + out_7 = _mm_add_pi32(out_7, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[3],second[3]))); + out_8 = _mm_add_pi32(out_8, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[4],second[4]))); + out_9 = _mm_add_pi32(out_9, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[4],second[4]))); + out_10 = _mm_add_pi32(out_10, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[5],second[5]))); + out_11 = _mm_add_pi32(out_11, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[5],second[5]))); + out_12 = _mm_add_pi32(out_12, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[6],second[6]))); + out_13 = _mm_add_pi32(out_13, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[6],second[6]))); + out_14 = _mm_add_pi32(out_14, _mm_madd_pi16(mul, _mm_unpacklo_pi16(first[7],second[7]))); + out_15 = _mm_add_pi32(out_15, _mm_madd_pi16(mul, _mm_unpackhi_pi16(first[7],second[7]))); + } + + __m64 out16_0 = _mm_srai_pi16(_mm_packs_pi32(out_0, out_1), SHIFT); + __m64 out16_1 = _mm_srai_pi16(_mm_packs_pi32(out_2, out_3), SHIFT); + __m64 out16_2 = _mm_srai_pi16(_mm_packs_pi32(out_4, out_5), SHIFT); + __m64 out16_3 = _mm_srai_pi16(_mm_packs_pi32(out_6, out_7), SHIFT); + __m64 out16_4 = _mm_srai_pi16(_mm_packs_pi32(out_8, out_9), SHIFT); + __m64 out16_5 = _mm_srai_pi16(_mm_packs_pi32(out_10, out_11), SHIFT); + __m64 out16_6 = _mm_srai_pi16(_mm_packs_pi32(out_12, out_13), SHIFT); + __m64 out16_7 = _mm_srai_pi16(_mm_packs_pi32(out_14, out_15), SHIFT); + + __m64 *outVec = (__m64 *)output; + if (pack8_and_calc_mask) { + outVec[0] = _mm_packs_pi16(out16_0, out16_1); + outVec[1] = _mm_packs_pi16(out16_2, out16_3); + outVec[2] = _mm_packs_pi16(out16_4, out16_5); + outVec[3] = _mm_packs_pi16(out16_6, out16_7); + } else { + const __m64 k0x7f80 = _mm_set1_pi16(0x7f80); + const __m64 k0x0080 = _mm_set1_pi16(0x0080); + const __m64 k0x8000 = _mm_set1_pi16(-0x8000); + outVec[0] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(out16_0, k0x7f80), k0x0080), k0x8000); + outVec[1] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(out16_1, k0x7f80), k0x0080), k0x8000); + outVec[2] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(out16_2, k0x7f80), k0x0080), k0x8000); + outVec[3] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(out16_3, k0x7f80), k0x0080), k0x8000); + outVec[4] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(out16_4, k0x7f80), k0x0080), k0x8000); + outVec[5] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(out16_5, k0x7f80), k0x0080), k0x8000); + outVec[6] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(out16_6, k0x7f80), k0x0080), k0x8000); + outVec[7] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(out16_7, k0x7f80), k0x0080), k0x8000); } } #else /* generic fallback */ INLINE void affine_txfm(clipped_t *input, void *output, unsigned inDims, unsigned outDims, int32_t *biases, weight_t *weights, - uint64_t *inMask, mask_t *outMask, + mask_t *inMask, mask_t *outMask, const bool pack8_and_calc_mask) { (void)inMask; (void)outMask; (void)pack8_and_calc_mask; @@ -904,26 +1135,43 @@ INLINE void transform(const Position *pos, clipped_t *output, int16_t (*accumulation)[2][256] = &pos->st->accumulator.accumulation; (void)outMask; // avoid compiler warning -#if defined(USE_AVX2) - const unsigned numChunks = kHalfDimensions / 32; + // Number of vectors to read + const unsigned numChunks = (16 * kHalfDimensions) / SIMD_WIDTH; +#if defined(USE_AVX512) + const __m512i kZero = _mm512_setzero_si512(); + const __m512i kOffsets = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + +#elif defined(USE_AVX2) const __m256i kZero = _mm256_setzero_si256(); #elif defined(USE_SSE2) - const unsigned numChunks = kHalfDimensions / 16; #if defined(USE_SSE41) || defined(TRANSPOSE) const __m128i kZero = _mm_setzero_si128(); +#else +#if !defined(USE_SSSE3) + const __m128i k0x7f80 = _mm_set1_epi16(0x7f80); + const __m128i k0x0080 = _mm_set1_epi16(0x0080); + const __m128i k0x8000 = _mm_set1_epi16(-0x8000); #else const __m128i k0x80s = _mm_set1_epi8(-128); #endif +#endif #elif defined(USE_MMX) - const unsigned numChunks = kHalfDimensions / 4; +#ifndef TRANSPOSE +#ifdef USE_SSE + const __m64 k0x7f = _mm_set1_pi16(127); + const __m64 kZero = _mm_setzero_si64(); +#else const __m64 k0x7f80 = _mm_set1_pi16(0x7f80); const __m64 k0x0080 = _mm_set1_pi16(0x0080); const __m64 k0x8000 = _mm_set1_pi16(-0x8000); +#endif +#elif USE_SSE + const __m64 kZero = _mm_setzero_si64(); +#endif #elif defined(USE_NEON) - const unsigned numChunks = kHalfDimensions / 8; const int8x8_t kZero = {0}; #endif @@ -932,44 +1180,80 @@ INLINE void transform(const Position *pos, clipped_t *output, for (unsigned p = 0; p < 2; p++) { const unsigned offset = kHalfDimensions * p; -#if defined(USE_AVX2) +#if defined(USE_AVX512) + __m512i *out = (__m512i *)&output[offset]; + for (unsigned i = 0; i < numChunks / 2; i++) { + __m512i sum0 = ((__m512i *)(*accumulation)[perspectives[p]])[i * 2 + 0]; + __m512i sum1 = ((__m512i *)(*accumulation)[perspectives[p]])[i * 2 + 1]; + __m512i packed = _mm512_packs_epi16(sum0, sum1); +#ifndef TRANSPOSE + out[i] = _mm512_permutexvar_epi64(kOffsets, _mm512_max_epi8(packed, kZero)); +#else + out[i] = _mm512_permutexvar_epi64(kOffsets, packed); + *outMask++ = _mm512_cmpgt_epi8_mask(out[i], kZero); +#endif + } + +#elif defined(USE_AVX2) __m256i *out = (__m256i *)&output[offset]; - for (unsigned i = 0; i < numChunks; i++) { + for (unsigned i = 0; i < numChunks / 2; i++) { __m256i sum0 = ((__m256i *)(*accumulation)[perspectives[p]])[i * 2 + 0]; __m256i sum1 = ((__m256i *)(*accumulation)[perspectives[p]])[i * 2 + 1]; + __m256i packed = _mm256_packs_epi16(sum0, sum1); #ifndef TRANSPOSE - out[i] = _mm256_permute4x64_epi64(_mm256_max_epi8( - _mm256_packs_epi16(sum0, sum1), kZero), 0xd8); + out[i] = _mm256_permute4x64_epi64(_mm256_max_epi8(packed, kZero), 0xd8); #else - out[i] = _mm256_permute4x64_epi64(_mm256_packs_epi16(sum0, sum1), 0xd8); + out[i] = _mm256_permute4x64_epi64(packed, 0xd8); *outMask++ = _mm256_movemask_epi8(_mm256_cmpgt_epi8(out[i], kZero)); #endif } #elif defined(USE_SSE2) __m128i *out = (__m128i *)&output[offset]; - for (unsigned i = 0; i < numChunks; i++) { - __m128i sum0 = ((__m128i *)(*accumulation)[perspectives[p]])[i * 2 + 0]; +#if defined(TRANSPOSE) || defined(USE_SSSE3) + for (unsigned i = 0; i < numChunks / 2; i++) { + __m128i sum0 = ((__m128i *)(*accumulation)[perspectives[p]])[i * 2]; __m128i sum1 = ((__m128i *)(*accumulation)[perspectives[p]])[i * 2 + 1]; + __m128i packed = _mm_packs_epi16(sum0, sum1); #ifndef TRANSPOSE - __m128i packedbytes = _mm_packs_epi16(sum0, sum1); #if defined(USE_SSE41) - out[i] = _mm_max_epi8(packedbytes, kZero); + out[i] = _mm_max_epi8(packed, kZero); #else - out[i] = _mm_subs_epi8(_mm_adds_epi8(packedbytes, k0x80s), k0x80s); + out[i] = _mm_subs_epi8(_mm_adds_epi8(packed, k0x80s), k0x80s); #endif #else - out[i] = _mm_packs_epi16(sum0, sum1); - *outMask++ = _mm_movemask_epi8(_mm_cmpgt_epi8(out[i], kZero)); + out[i] = packed; + *outMask++ = _mm_movemask_epi8(_mm_cmpgt_epi8(packed, kZero)); #endif } +#else /* USE_SSE2 && !TRANSPOSE */ + for (unsigned i = 0; i < numChunks; i++) { + __m128i sum = ((__m128i *)(*accumulation)[perspectives[p]])[i]; + out[i] = _mm_subs_epu16(_mm_add_epi16(_mm_adds_epi16(sum, k0x7f80), k0x0080), k0x8000); + } +#endif #elif defined(USE_MMX) __m64 *out = (__m64 *)&output[offset]; +#ifdef TRANSPOSE + for (unsigned i = 0; i < numChunks / 2; i++) { + __m64 sum0 = ((__m64 *)(*accumulation)[perspectives[p]])[i * 2]; + __m64 sum1 = ((__m64 *)(*accumulation)[perspectives[p]])[i * 2 + 1]; + out[i] = _mm_packs_pi16(sum0, sum1); +#ifdef USE_MASK + *outMask++ = _mm_movemask_pi8(_mm_cmpgt_pi8(out[i], kZero)); +#endif + } +#else /* !TRANSPOSE */ for (unsigned i = 0; i < numChunks; i++) { __m64 sum = ((__m64 *)(*accumulation)[perspectives[p]])[i]; +#ifdef USE_SSE + out[i] = _mm_min_pi16(_mm_max_pi16(sum, kZero), k0x7f); +#else out[i] = _mm_subs_pu16(_mm_add_pi16(_mm_adds_pi16(sum, k0x7f80), k0x0080), k0x8000); +#endif } +#endif #elif defined(USE_NEON) int8x8_t *out = (int8x8_t *)&output[offset]; @@ -979,6 +1263,7 @@ INLINE void transform(const Position *pos, clipped_t *output, } #else + (void)numChunks; for (unsigned i = 0; i < kHalfDimensions; i++) { int16_t sum = (*accumulation)[perspectives[p]][i]; output[offset + i] = clamp(sum, 0, 127); @@ -998,7 +1283,7 @@ struct NetData { clipped_t hidden2_clipped[32]; #else clipped_t hidden1_out[32]; -#if defined(USE_SSE2) && !defined(USE_SSSE3) +#if (defined(USE_SSE2) || defined(USE_MMX)) && !defined(USE_AVX2) int16_t hidden2_out[32]; #else int8_t hidden2_out[32]; @@ -1038,19 +1323,13 @@ Value nnue_evaluate(const Position *pos) #else - // Use memcpy() from mask_t to uint64_t to prevent aliasing problems. - // The compiler will optimize away the actual memcpy() operation. - uint64_t input_mask2[FtOutDims / 64]; - memcpy(input_mask2, input_mask, FtOutDims / 8); affine_txfm(B(input), B(hidden1_out), FtOutDims, 32, - hidden1_biases, hidden1_weights, input_mask2, hidden1_mask, true); + hidden1_biases, hidden1_weights, input_mask, hidden1_mask, true); - uint64_t hidden1_mask2[1]; - memcpy(hidden1_mask2, hidden1_mask, 8); affine_txfm(B(hidden1_out), B(hidden2_out), 32, 32, - hidden2_biases, hidden2_weights, hidden1_mask2, NULL, false); + hidden2_biases, hidden2_weights, hidden1_mask, NULL, false); - affine_propagate((uint8_t *)B(hidden2_out), &out_value, 32, 1, output_biases, + affine_propagate((int8_t *)B(hidden2_out), &out_value, 32, 1, output_biases, output_weights); #endif @@ -1062,34 +1341,24 @@ Value nnue_evaluate(const Position *pos) return out_value / FV_SCALE; } -bool read_weights(weight_t *output_buf, unsigned width, unsigned height, - FILE *F) +const char *read_weights(weight_t *w, unsigned width, unsigned height, + const char *d) { - int8_t v; - - for (unsigned i = 0; i < height; i++) { + for (unsigned i = 0; i < height; i++) for (unsigned j = 0; j < width; j++) { - fread(&v, 1, 1, F); #ifndef TRANSPOSE - output_buf[i * width + j] = v; + w[i * width + j] = *d++; #else - output_buf[j * height + i] = v; + w[j * height + i] = *d++; #endif } - } - return true; + return d; } #if defined(TRANSPOSE) && defined(USE_AVX2) -void permute_weights_and_biases(int8_t *weights, int32_t *biases, - unsigned numDims) +void permute_biases(int32_t *biases) { - __m256i *w = (__m256i *)weights; - __m256i permutation = _mm256_set_epi32(7, 3, 5, 1, 6, 2, 4, 0); - for (unsigned i = 0; i < numDims; i++) - w[i] = _mm256_permutevar8x32_epi32(w[i], permutation); - __m128i *b = (__m128i *)biases; __m128i tmp[8]; tmp[0] = b[0]; @@ -1104,44 +1373,78 @@ void permute_weights_and_biases(int8_t *weights, int32_t *biases, } #endif -bool load_eval_file(const char *evalFile) +static const size_t TransformerStart = 3 * 4 + 177; +static const size_t NetworkStart = + TransformerStart + 4 + 2 * 256 + 2 * 256 * 64 * 641; + +bool verify_net(const void *evalData, size_t size) { - FILE *F = fopen(evalFile, "rb"); + if (size != 21022697) return false; - if (!F) return false; + const char *d = evalData; + if (readu_le_u32(d) != NnueVersion) return false; + if (readu_le_u32(d + 4) != 0x3e5aa6eeU) return false; + if (readu_le_u32(d + 8) != 177) return false; + if (readu_le_u32(d + TransformerStart) != 0x5d69d7b8) return false; + if (readu_le_u32(d + NetworkStart) != 0x63337156) return false; - // Read network header - uint32_t version = read_uint32_t(F); - uint32_t hash = read_uint32_t(F); - uint32_t len = read_uint32_t(F); - for (unsigned i = 0; i < len; i++) - fgetc(F); - if (version != NnueVersion) return false; - if (hash != 0x3e5aa6eeu) return false; + return true; +} + +void init_weights(const void *evalData) +{ + const char *d = (const char *)evalData + TransformerStart + 4; - // Read feature transformer - hash = read_uint32_t(F); - if (hash != 0x5d69d7b8) return false; - fread(ft_biases, sizeof(int16_t), kHalfDimensions, F); - fread(ft_weights, sizeof(int16_t), kHalfDimensions * FtInDims, F); + // Read transformer + for (unsigned i = 0; i < kHalfDimensions; i++, d += 2) + ft_biases[i] = readu_le_u16(d); + for (unsigned i = 0; i < kHalfDimensions * FtInDims; i++, d += 2) + ft_weights[i] = readu_le_u16(d); // Read network - hash = read_uint32_t(F); - if (hash != 0x63337156) return false; - fread(hidden1_biases, sizeof(int32_t), 32, F); - read_weights(hidden1_weights, 512, 32, F); - fread(hidden2_biases, sizeof(int32_t), 32, F); - read_weights(hidden2_weights, 32 , 32 , F); - fread(output_biases, sizeof(int32_t), 1 , F); - read_weights(output_weights, 32, 1 , F); + d += 4; + for (unsigned i = 0; i < 32; i++, d += 4) + hidden1_biases[i] = readu_le_u32(d); + d = read_weights(hidden1_weights, 512, 32, d); + for (unsigned i = 0; i < 32; i++, d += 4) + hidden2_biases[i] = readu_le_u32(d); + d = read_weights(hidden2_weights, 32, 32, d); + for (unsigned i = 0; i < 1; i++, d += 4) + output_biases[i] = readu_le_u32(d); + read_weights(output_weights, 32, 1, d); #if defined(TRANSPOSE) && defined(USE_AVX2) - permute_weights_and_biases(hidden1_weights, hidden1_biases, 512); - permute_weights_and_biases(hidden2_weights, hidden2_biases, 32); + permute_biases(hidden1_biases); + permute_biases(hidden2_biases); +#endif +} + +static bool load_eval_file(const char *evalFile) +{ + const void *evalData; + map_t mapping; + size_t size; + +#ifdef NNUE_EMBEDDED + if (strcmp(evalFile, DefaultEvalFile) == 0) { + evalData = gNetworkData; + mapping = 0; + size = gNetworkSize; + } else #endif + { + FD fd = open_file(evalFile); + if (fd == FD_ERR) return false; + evalData = map_file(fd, &mapping); + size = file_size(fd); + close_file(fd); + } - return true; -// return feof(F); + bool success = verify_net(evalData, size); + if (success) + init_weights(evalData); + if (mapping) unmap_file((void *)evalData, mapping); + return success; } static char *loadedFile = NULL; @@ -1165,9 +1468,14 @@ void nnue_init(void) } printf("info string ERROR: The network file %s was not loaded successfully.\n" +#ifdef NNUE_EMBEDDED + , evalFile +#else "info string ERROR: The default net can be downloaded from:\n" "info string ERROR: https://tests.stockfishchess.org/api/nn/%s\n", - evalFile, option_default_string_value(OPT_EVAL_FILE)); + evalFile, option_default_string_value(OPT_EVAL_FILE) +#endif + ); exit(EXIT_FAILURE); }