diff --git a/cmake/gen/avxvnni_microkernels.cmake b/cmake/gen/avxvnni_microkernels.cmake index 2b6d5189046..2b3b4324144 100644 --- a/cmake/gen/avxvnni_microkernels.cmake +++ b/cmake/gen/avxvnni_microkernels.cmake @@ -132,6 +132,8 @@ SET(NON_PROD_AVXVNNI_MICROKERNEL_SRCS src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c + src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni-prfm.c + src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avxvnni.c diff --git a/gen/avxvnni_microkernels.bzl b/gen/avxvnni_microkernels.bzl index b0ff2945c40..6017023ecd3 100644 --- a/gen/avxvnni_microkernels.bzl +++ b/gen/avxvnni_microkernels.bzl @@ -129,6 +129,8 @@ NON_PROD_AVXVNNI_MICROKERNEL_SRCS = [ "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-7x8c8-minmax-avxvnni.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni-prfm.c", "src/qd8-f32-qc8w-igemm/gen/qd8-f32-qc8w-igemm-8x8c8-minmax-avxvnni.c", + "src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni-prfm.c", + "src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni-prfm.c", "src/qs8-packw/gen/qs8-packw-x16c8-gemm-goi-avxvnni.c", "src/qs8-qc8w-gemm/gen/qs8-qc8w-gemm-1x8c8-minmax-fp32-avxvnni.c", diff --git a/scripts/generate-x8-packw.sh b/scripts/generate-x8-packw.sh index 2b943bb867d..81f6da717b0 100755 --- a/scripts/generate-x8-packw.sh +++ b/scripts/generate-x8-packw.sh @@ -33,6 +33,8 @@ tools/xngen src/x8-packw/kr-scalar.c.in -D NR=16 -D KR=8 -D TYPE=int8_t -D IZP=1 ### AVXVNNI micro-kernels ### C8 packing +tools/xngen src/x8-packw/kr-c4-avxvnni.c.in -D NR=8 -D KR=4 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni.c & +tools/xngen src/x8-packw/kr-c4-avxvnni.c.in -D NR=8 -D KR=4 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni-prfm.c & tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=0 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni.c & tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=0 -D AVX=2 -D PREFETCH=1 -o src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-avxvnni-prfm.c & tools/xngen src/x8-packw/kr-avxvnni.c.in -D NR=8 -D KR=8 -D DATATYPE=QS8 -D TYPE=int8_t -D IZP=128 -D AVX=2 -D PREFETCH=0 -o src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-avxvnni.c & diff --git a/src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni-prfm.c b/src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni-prfm.c new file mode 100644 index 00000000000..d943badb4f0 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni-prfm.c @@ -0,0 +1,410 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-c4-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +#include "xnnpack/prefetch.h" + + +XNN_INLINE static uint32_t safe_load_u32(const void* src, size_t k) { + uint32_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < k; ++i) { + value |= (uint32_t) bytes[i] << (i * 8); + } + return value; +} + +void xnn_qs8_packw_gemm_goi_ukernel_x8c4__avxvnni_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 4); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (; n >= 8; n -= 8) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + xnn_prefetch_to_l1((const int8_t*) w0 + 0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w0 + 128); + xnn_prefetch_to_l1((const int8_t*) w0 + 192); + xnn_prefetch_to_l1((const int8_t*) w1 + 0); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w1 + 128); + xnn_prefetch_to_l1((const int8_t*) w1 + 192); + xnn_prefetch_to_l1((const int8_t*) w2 + 0); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w2 + 128); + xnn_prefetch_to_l1((const int8_t*) w2 + 192); + xnn_prefetch_to_l1((const int8_t*) w3 + 0); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w3 + 128); + xnn_prefetch_to_l1((const int8_t*) w3 + 192); + xnn_prefetch_to_l1((const int8_t*) w4 + 0); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w4 + 128); + xnn_prefetch_to_l1((const int8_t*) w4 + 192); + xnn_prefetch_to_l1((const int8_t*) w5 + 0); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w5 + 128); + xnn_prefetch_to_l1((const int8_t*) w5 + 192); + xnn_prefetch_to_l1((const int8_t*) w6 + 0); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w6 + 128); + xnn_prefetch_to_l1((const int8_t*) w6 + 192); + xnn_prefetch_to_l1((const int8_t*) w7 + 0); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + xnn_prefetch_to_l1((const int8_t*) w7 + 128); + xnn_prefetch_to_l1((const int8_t*) w7 + 192); + + __m256i vacc0 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_01234567 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_01234567 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_01234567 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_01234567 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_01234567 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_01234567 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_01234567 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_01234567 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_0145 = _mm256_unpacklo_epi32(v0_01234567, v1_01234567); + const __m256i v01_2367 = _mm256_unpackhi_epi32(v0_01234567, v1_01234567); + const __m256i v23_0145 = _mm256_unpacklo_epi32(v2_01234567, v3_01234567); + const __m256i v23_2367 = _mm256_unpackhi_epi32(v2_01234567, v3_01234567); + const __m256i v45_0145 = _mm256_unpacklo_epi32(v4_01234567, v5_01234567); + const __m256i v45_2367 = _mm256_unpackhi_epi32(v4_01234567, v5_01234567); + const __m256i v67_0145 = _mm256_unpacklo_epi32(v6_01234567, v7_01234567); + const __m256i v67_2367 = _mm256_unpackhi_epi32(v6_01234567, v7_01234567); + + const __m256i v02_02 = _mm256_unpacklo_epi64(v01_0145, v23_0145); + const __m256i v02_13 = _mm256_unpackhi_epi64(v01_0145, v23_0145); + const __m256i v13_02 = _mm256_unpacklo_epi64(v01_2367, v23_2367); + const __m256i v13_13 = _mm256_unpackhi_epi64(v01_2367, v23_2367); + const __m256i v46_02 = _mm256_unpacklo_epi64(v45_0145, v67_0145); + const __m256i v46_13 = _mm256_unpackhi_epi64(v45_0145, v67_0145); + const __m256i v57_02 = _mm256_unpacklo_epi64(v45_2367, v67_2367); + const __m256i v57_13 = _mm256_unpackhi_epi64(v45_2367, v67_2367); + + const __m256i v04_0 = _mm256_permute2f128_si256(v02_02, v46_02, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v04_1 = _mm256_permute2f128_si256(v02_02, v46_02, _MM_SHUFFLE(0, 3, 0, 1)); + const __m256i v15_0 = _mm256_permute2f128_si256(v02_13, v46_13, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v15_1 = _mm256_permute2f128_si256(v02_13, v46_13, _MM_SHUFFLE(0, 3, 0, 1)); + const __m256i v26_0 = _mm256_permute2f128_si256(v13_02, v57_02, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v26_1 = _mm256_permute2f128_si256(v13_02, v57_02, _MM_SHUFFLE(0, 3, 0, 1)); + const __m256i v37_0 = _mm256_permute2f128_si256(v13_13, v57_13, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v37_1 = _mm256_permute2f128_si256(v13_13, v57_13, _MM_SHUFFLE(0, 3, 0, 1)); + + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v04_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v15_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v26_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v37_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v04_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v15_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v26_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v37_1); + + _mm256_storeu_si256((__m256i *)&out[0], v04_0); + _mm256_storeu_si256((__m256i *)&out[32], v15_0); + _mm256_storeu_si256((__m256i *)&out[64], v26_0); + _mm256_storeu_si256((__m256i *)&out[96], v37_0); + _mm256_storeu_si256((__m256i *)&out[128], v04_1); + _mm256_storeu_si256((__m256i *)&out[160], v15_1); + _mm256_storeu_si256((__m256i *)&out[192], v26_1); + _mm256_storeu_si256((__m256i *)&out[224], v37_1); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x4 + for (; k >= 4; k -= 4) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + out += 32; + } + + // KC remainder of 1..3 + if (k != 0) { + assert(k >= 1 && k <= 3); + + const __m256i vmask = _mm256_set1_epi32((1u << (k * sizeof(int8_t) * 8)) - 1); + + __m256i v0 = _mm256_set1_epi32((int32_t) safe_load_u32(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w1, k)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w2, k)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w3, k)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w4, k)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w5, k)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w6, k)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w7, k)), 0x80); + v0 = _mm256_and_si256(v0, vmask); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + + out += 32; + } + + __m256i vksum0 = _mm256_mullo_epi32(vacc0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + out += nb * sizeof(int32_t); + } + out += (8 - n) * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + xnn_prefetch_to_l1((const int8_t*) w0); + xnn_prefetch_to_l1((const int8_t*) w0 + 64); + xnn_prefetch_to_l1((const int8_t*) w1); + xnn_prefetch_to_l1((const int8_t*) w1 + 64); + xnn_prefetch_to_l1((const int8_t*) w2); + xnn_prefetch_to_l1((const int8_t*) w2 + 64); + xnn_prefetch_to_l1((const int8_t*) w3); + xnn_prefetch_to_l1((const int8_t*) w3 + 64); + xnn_prefetch_to_l1((const int8_t*) w4); + xnn_prefetch_to_l1((const int8_t*) w4 + 64); + xnn_prefetch_to_l1((const int8_t*) w5); + xnn_prefetch_to_l1((const int8_t*) w5 + 64); + xnn_prefetch_to_l1((const int8_t*) w6); + xnn_prefetch_to_l1((const int8_t*) w6 + 64); + xnn_prefetch_to_l1((const int8_t*) w7); + xnn_prefetch_to_l1((const int8_t*) w7 + 64); + + __m256i vacc0 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x4 + size_t k = kc; + for (; k >= 4; k -= 4) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + xnn_prefetch_to_l1((const int8_t*) w0 + 448); + xnn_prefetch_to_l1((const int8_t*) w1 + 448); + xnn_prefetch_to_l1((const int8_t*) w2 + 448); + xnn_prefetch_to_l1((const int8_t*) w3 + 448); + xnn_prefetch_to_l1((const int8_t*) w4 + 448); + xnn_prefetch_to_l1((const int8_t*) w5 + 448); + xnn_prefetch_to_l1((const int8_t*) w6 + 448); + xnn_prefetch_to_l1((const int8_t*) w7 + 448); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + out += 32; + } + + // KC remainder of 1..3 + if (k != 0) { + assert(k >= 1 && k <= 3); + + const __m256i vmask = _mm256_set1_epi32((1u << (k * sizeof(int8_t) * 8)) - 1); + + __m256i v0 = _mm256_set1_epi32((int32_t) safe_load_u32(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w1, k)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w2, k)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w3, k)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w4, k)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w5, k)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w6, k)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w7, k)), 0x80); + v0 = _mm256_and_si256(v0, vmask); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + + out += 32; + } + + __m256i vksum0 = _mm256_mullo_epi32(vacc0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni.c b/src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni.c new file mode 100644 index 00000000000..9cc5207a757 --- /dev/null +++ b/src/qs8-packw/gen/qs8-packw-x8c4-gemm-goi-avxvnni.c @@ -0,0 +1,337 @@ +// Auto-generated file. Do not edit! +// Template: src/x8-packw/kr-c4-avxvnni.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" + + +XNN_INLINE static uint32_t safe_load_u32(const void* src, size_t k) { + uint32_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < k; ++i) { + value |= (uint32_t) bytes[i] << (i * 8); + } + return value; +} + +void xnn_qs8_packw_gemm_goi_ukernel_x8c4__avxvnni( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const int8_t* weights, + const int32_t* bias, + const void* scale, + int8_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); + assert(kr == 4); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + int8_t* out = (int8_t*) packed_weights; + const int32_t* b = (const int32_t*) bias; + + const __m256i vone = _mm256_set1_epi8(1); + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + 0): 0); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of 8 + const int8_t* w0 = (const int8_t*) weights; + size_t n = nc; + for (; n >= 8; n -= 8) { + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + const __m256i vb0 = _mm256_loadu_si256((const __m256i*) (b + 0)); + _mm256_storeu_si256((__m256i*) (out + 0), vb0); + b += 8; + } else { + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + } + out += 8 * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + const int8_t* w2 = w1 + kc; + const int8_t* w3 = w2 + kc; + const int8_t* w4 = w3 + kc; + const int8_t* w5 = w4 + kc; + const int8_t* w6 = w5 + kc; + const int8_t* w7 = w6 + kc; + + __m256i vacc0 = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of 8x32 + for (; k >= 32; k -= 32) { + const __m256i v0_01234567 = _mm256_loadu_si256((const __m256i*) w0); + const __m256i v1_01234567 = _mm256_loadu_si256((const __m256i*) w1); + const __m256i v2_01234567 = _mm256_loadu_si256((const __m256i*) w2); + const __m256i v3_01234567 = _mm256_loadu_si256((const __m256i*) w3); + const __m256i v4_01234567 = _mm256_loadu_si256((const __m256i*) w4); + const __m256i v5_01234567 = _mm256_loadu_si256((const __m256i*) w5); + const __m256i v6_01234567 = _mm256_loadu_si256((const __m256i*) w6); + const __m256i v7_01234567 = _mm256_loadu_si256((const __m256i*) w7); + + const __m256i v01_0145 = _mm256_unpacklo_epi32(v0_01234567, v1_01234567); + const __m256i v01_2367 = _mm256_unpackhi_epi32(v0_01234567, v1_01234567); + const __m256i v23_0145 = _mm256_unpacklo_epi32(v2_01234567, v3_01234567); + const __m256i v23_2367 = _mm256_unpackhi_epi32(v2_01234567, v3_01234567); + const __m256i v45_0145 = _mm256_unpacklo_epi32(v4_01234567, v5_01234567); + const __m256i v45_2367 = _mm256_unpackhi_epi32(v4_01234567, v5_01234567); + const __m256i v67_0145 = _mm256_unpacklo_epi32(v6_01234567, v7_01234567); + const __m256i v67_2367 = _mm256_unpackhi_epi32(v6_01234567, v7_01234567); + + const __m256i v02_02 = _mm256_unpacklo_epi64(v01_0145, v23_0145); + const __m256i v02_13 = _mm256_unpackhi_epi64(v01_0145, v23_0145); + const __m256i v13_02 = _mm256_unpacklo_epi64(v01_2367, v23_2367); + const __m256i v13_13 = _mm256_unpackhi_epi64(v01_2367, v23_2367); + const __m256i v46_02 = _mm256_unpacklo_epi64(v45_0145, v67_0145); + const __m256i v46_13 = _mm256_unpackhi_epi64(v45_0145, v67_0145); + const __m256i v57_02 = _mm256_unpacklo_epi64(v45_2367, v67_2367); + const __m256i v57_13 = _mm256_unpackhi_epi64(v45_2367, v67_2367); + + const __m256i v04_0 = _mm256_permute2f128_si256(v02_02, v46_02, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v04_1 = _mm256_permute2f128_si256(v02_02, v46_02, _MM_SHUFFLE(0, 3, 0, 1)); + const __m256i v15_0 = _mm256_permute2f128_si256(v02_13, v46_13, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v15_1 = _mm256_permute2f128_si256(v02_13, v46_13, _MM_SHUFFLE(0, 3, 0, 1)); + const __m256i v26_0 = _mm256_permute2f128_si256(v13_02, v57_02, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v26_1 = _mm256_permute2f128_si256(v13_02, v57_02, _MM_SHUFFLE(0, 3, 0, 1)); + const __m256i v37_0 = _mm256_permute2f128_si256(v13_13, v57_13, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v37_1 = _mm256_permute2f128_si256(v13_13, v57_13, _MM_SHUFFLE(0, 3, 0, 1)); + + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v04_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v15_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v26_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v37_0); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v04_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v15_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v26_1); + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v37_1); + + _mm256_storeu_si256((__m256i *)&out[0], v04_0); + _mm256_storeu_si256((__m256i *)&out[32], v15_0); + _mm256_storeu_si256((__m256i *)&out[64], v26_0); + _mm256_storeu_si256((__m256i *)&out[96], v37_0); + _mm256_storeu_si256((__m256i *)&out[128], v04_1); + _mm256_storeu_si256((__m256i *)&out[160], v15_1); + _mm256_storeu_si256((__m256i *)&out[192], v26_1); + _mm256_storeu_si256((__m256i *)&out[224], v37_1); + + w0 += 32; + w1 += 32; + w2 += 32; + w3 += 32; + w4 += 32; + w5 += 32; + w6 += 32; + w7 += 32; + out += 256; + } + + // KC main loop multiple of 8x4 + for (; k >= 4; k -= 4) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + out += 32; + } + + // KC remainder of 1..3 + if (k != 0) { + assert(k >= 1 && k <= 3); + + const __m256i vmask = _mm256_set1_epi32((1u << (k * sizeof(int8_t) * 8)) - 1); + + __m256i v0 = _mm256_set1_epi32((int32_t) safe_load_u32(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w1, k)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w2, k)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w3, k)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w4, k)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w5, k)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w6, k)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w7, k)), 0x80); + v0 = _mm256_and_si256(v0, vmask); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + + out += 32; + } + + __m256i vksum0 = _mm256_mullo_epi32(vacc0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + w0 = w7; + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= 7); + + int32_t* packed_b = (int32_t*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((int32_t*) out) = *b++; + out += sizeof(int32_t); + } while (--nb != 0); + } else { + size_t nb = n; + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + out += nb * sizeof(int32_t); + } + out += (8 - n) * sizeof(int32_t); + + const int8_t* w1 = w0 + kc; + if XNN_UNPREDICTABLE(n < 2) { + w1 = w0; + } + const int8_t* w2 = w1 + kc; + if XNN_UNPREDICTABLE(n <= 2) { + w2 = w1; + } + const int8_t* w3 = w2 + kc; + if XNN_UNPREDICTABLE(n < 4) { + w3 = w2; + } + const int8_t* w4 = w3 + kc; + if XNN_UNPREDICTABLE(n <= 4) { + w4 = w3; + } + const int8_t* w5 = w4 + kc; + if XNN_UNPREDICTABLE(n < 6) { + w5 = w4; + } + const int8_t* w6 = w5 + kc; + if XNN_UNPREDICTABLE(n <= 6) { + w6 = w5; + } + const int8_t* w7 = w6 + kc; + if XNN_UNPREDICTABLE(n < 8) { + w7 = w6; + } + + __m256i vacc0 = _mm256_setzero_si256(); + + // KC main loop multiple of 8x4 + size_t k = kc; + for (; k >= 4; k -= 4) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + + w0 += 4; + w1 += 4; + w2 += 4; + w3 += 4; + w4 += 4; + w5 += 4; + w6 += 4; + w7 += 4; + out += 32; + } + + // KC remainder of 1..3 + if (k != 0) { + assert(k >= 1 && k <= 3); + + const __m256i vmask = _mm256_set1_epi32((1u << (k * sizeof(int8_t) * 8)) - 1); + + __m256i v0 = _mm256_set1_epi32((int32_t) safe_load_u32(w0, k)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w1, k)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w2, k)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w3, k)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w4, k)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w5, k)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w6, k)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w7, k)), 0x80); + v0 = _mm256_and_si256(v0, vmask); + + w0 += k; + w1 += k; + w2 += k; + w3 += k; + w4 += k; + w5 += k; + w6 += k; + w7 += k; + + vacc0 = _mm256_dpbusd_avx_epi32(vacc0, vone, v0); + + _mm256_storeu_si256((__m256i *)&out[0], v0); + + out += 32; + } + + __m256i vksum0 = _mm256_mullo_epi32(vacc0, vzeropoint); + __m256i vpack0 = _mm256_loadu_si256((const __m256i*) (packed_b + 0)); + vpack0 = _mm256_sub_epi32(vpack0, vksum0); + _mm256_storeu_si256((__m256i *) (packed_b + 0), vpack0); + out = (int8_t*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/qs8-packw/qs8-packw.h b/src/qs8-packw/qs8-packw.h index 24fc85abf05..c1c96d30bc3 100644 --- a/src/qs8-packw/qs8-packw.h +++ b/src/qs8-packw/qs8-packw.h @@ -16,6 +16,8 @@ XNN_QS8_UKERNEL(0, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__scalar, 8, 8, 1, XNN_QS8_UKERNEL(0, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x16c8__scalar, 16, 8, 1, 8, 1, 128) #if XNN_ENABLE_AVXVNNI && (XNN_ARCH_X86_64 || XNN_ARCH_X86) +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x8c4__avxvnni, 8, 4, 1, 4, 1, 0) +XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x8c4__avxvnni_prfm, 8, 4, 1, 4, 1, 0) XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni, 8, 8, 1, 8, 1, 0) XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_packw_gemm_goi_ukernel_x8c8__avxvnni_prfm, 8, 8, 1, 8, 1, 0) XNN_QS8_UKERNEL(xnn_arch_x86_avxvnni, xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__avxvnni, 8, 8, 1, 8, 1, 128) diff --git a/src/x8-packw/kr-c4-avxvnni.c.in b/src/x8-packw/kr-c4-avxvnni.c.in new file mode 100644 index 00000000000..0f7f8665ecc --- /dev/null +++ b/src/x8-packw/kr-c4-avxvnni.c.in @@ -0,0 +1,321 @@ +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +$assert NR == 8 +$assert KR == 4 +$assert DATATYPE in ["QS8", "X8"] +$assert TYPE in ["int8_t"] +$assert IZP in [0, 128] + +#include +#include +#include + +#include + +#include "xnnpack/packw.h" +#include "xnnpack/unaligned.h" +$if PREFETCH: + #include "xnnpack/prefetch.h" + + +$BTYPE = {"QS8": "int32_t", "X8": "uint32_t"}[DATATYPE] +$WTYPE = "int8_t" +$if DATATYPE in ["QS8"]: + $_MM256_DPBUSD_EPI32 = "_mm256_dpbusd_avx_epi32" if AVX == 2 else "_mm256_dpbusd_epi32" + $ISA = "avxvnni" if AVX == 2 else "avx256vnni" +$else: + $ISA = "avx2" if AVX == 2 else "avx256skx" +XNN_INLINE static uint32_t safe_load_u32(const void* src, size_t k) { + uint32_t value = 0; + const uint8_t* bytes = (const uint8_t*)src; + for (size_t i = 0; i < k; ++i) { + value |= (uint32_t) bytes[i] << (i * 8); + } + return value; +} + +void xnn_${DATATYPE.lower()}${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_prfm" if PREFETCH else ""}( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + const ${WTYPE}* weights, + const ${BTYPE}* bias, + const void* scale, + ${WTYPE}* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == ${NR}); + assert(kr == ${KR}); + assert(sr == 1); + assert(weights != NULL); + assert(packed_weights != NULL); + + ${TYPE}* out = (${TYPE}*) packed_weights; + const ${BTYPE}* b = (const ${BTYPE}*) bias; + + $if DATATYPE in ["QS8"]: + const __m256i vone = _mm256_set1_epi8(1); + const uint32_t izp = (uint32_t) (params ? (((const struct xnn_qs8_packw_params*) params)->input_zero_point + ${IZP}): ${IZP}); + __m256i vzeropoint = _mm256_set1_epi32((int32_t) izp); + + do { + // NC main loop multiple of ${NR} + const ${TYPE}* w0 = (const ${TYPE}*) weights; + size_t n = nc; + for (; n >= ${NR}; n -= ${NR}) { + $if DATATYPE in ["QS8"]: + ${BTYPE}* packed_b = (${BTYPE}*) out; + if XNN_LIKELY(b != NULL) { + $for N in range(0, NR, 8): + const __m256i vb${N} = _mm256_loadu_si256((const __m256i*) (b + ${N})); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), vb${N}); + b += ${NR}; + } else { + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i*) (out + ${N*4}), _mm256_setzero_si256()); + } + out += ${NR} * sizeof(${BTYPE}); + + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + $if PREFETCH: + $for N in range(0, NR): + $for OFFSET in range(0, 256, 64): + xnn_prefetch_to_l1((const int8_t*) w${N} + ${OFFSET}); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vacc${N} = _mm256_setzero_si256(); + + size_t k = kc; + // KC main loop multiple of ${NR}x${8 * KR} + for (; k >= ${8 * KR}; k -= ${8 * KR}) { + $for N in range(NR): + const __m256i v${N}_01234567 = _mm256_loadu_si256((const __m256i*) w${N}); + + $for N in range(0, NR, 2): + const __m256i v${N}${N+1}_0145 = _mm256_unpacklo_epi32(v${N}_01234567, v${N+1}_01234567); + const __m256i v${N}${N+1}_2367 = _mm256_unpackhi_epi32(v${N}_01234567, v${N+1}_01234567); + + $for N in range(0, NR, 4): + const __m256i v${N}${N+2}_02 = _mm256_unpacklo_epi64(v${N}${N+1}_0145, v${N+2}${N+3}_0145); + const __m256i v${N}${N+2}_13 = _mm256_unpackhi_epi64(v${N}${N+1}_0145, v${N+2}${N+3}_0145); + const __m256i v${N+1}${N+3}_02 = _mm256_unpacklo_epi64(v${N}${N+1}_2367, v${N+2}${N+3}_2367); + const __m256i v${N+1}${N+3}_13 = _mm256_unpackhi_epi64(v${N}${N+1}_2367, v${N+2}${N+3}_2367); + + $for N in range(0, NR // 4): + $for I in range(0, 2): + $C = N*2+I + const __m256i v${C}${C+4}_0 = _mm256_permute2f128_si256(v${N}${N+2}_${I}${I+2}, v${N+4}${N+6}_${I}${I+2}, _MM_SHUFFLE(0, 2, 0, 0)); + const __m256i v${C}${C+4}_1 = _mm256_permute2f128_si256(v${N}${N+2}_${I}${I+2}, v${N+4}${N+6}_${I}${I+2}, _MM_SHUFFLE(0, 3, 0, 1)); + + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + $for I in range(0, 2): + $for J in range(0, 4): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${J}${J+4}_${I}); + + $for I in range(0, 2): + $for N in range(0, KR): + _mm256_storeu_si256((__m256i *)&out[${(I*KR + N)*8*KR}], v${N}${N+4}_${I}); + + $for N in range(NR): + w${N} += ${8 * KR}; + out += ${8*NR*KR}; + } + + // KC main loop multiple of ${NR}x${KR} + for (; k >= ${KR}; k -= ${KR}) { + __m256i v0 = _mm256_set1_epi32((int32_t) unaligned_load_u32(w0)); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w1)), 0x02); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w2)), 0x04); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w3)), 0x08); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w4)), 0x10); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w5)), 0x20); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w6)), 0x40); + v0 = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) unaligned_load_u32(w7)), 0x80); + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + assert(k >= 1 && k <= ${KR-1}); + + const __m256i vmask = _mm256_set1_epi32((1u << (k * sizeof(int8_t) * 8)) - 1); + + $for N in range(0, NR, 8): + __m256i v${N} = _mm256_set1_epi32((int32_t) safe_load_u32(w${N}, k)); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+1}, k)), 0x02); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+2}, k)), 0x04); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+3}, k)), 0x08); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+4}, k)), 0x10); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+5}, k)), 0x20); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+6}, k)), 0x40); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+7}, k)), 0x80); + v${N} = _mm256_and_si256(v${N}, vmask); + + $for N in range(NR): + w${N} += k; + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + out += ${NR*KR}; + } + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vksum${N} = _mm256_mullo_epi32(vacc${N}, vzeropoint); + $for N in range(0, NR, 8): + __m256i vpack${N} = _mm256_loadu_si256((const __m256i*) (packed_b + ${N})); + $for N in range(0, NR, 8): + vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + w0 = w${NR-1}; + } + + // NC remainder (1..${NR-1}) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1 && n <= ${NR-1}); + + $if DATATYPE in ["QS8"]: + ${BTYPE}* packed_b = (${BTYPE}*) out; + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *((${BTYPE}*) out) = *b++; + out += sizeof(${BTYPE}); + } while (--nb != 0); + } else { + size_t nb = n; + _mm256_storeu_si256((__m256i*) (out + 0), _mm256_setzero_si256()); + out += nb * sizeof(${BTYPE}); + } + out += (${NR} - n) * sizeof(${BTYPE}); + + $for N in range(1, NR): + const ${TYPE}* w${N} = w${N-1} + kc; + $if N % 2 == 0: + if XNN_UNPREDICTABLE(n <= ${N}) { + w${N} = w${N-1}; + } + $else: + if XNN_UNPREDICTABLE(n < ${N+1}) { + w${N} = w${N-1}; + } + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N}); + xnn_prefetch_to_l1((const int8_t*) w${N} + 64); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vacc${N} = _mm256_setzero_si256(); + + // KC main loop multiple of ${NR}x${KR} + size_t k = kc; + for (; k >= ${KR}; k -= ${KR}) { + $for N in range(0, NR, 8): + __m256i v${N} = _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N})); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+1})), 0x02); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+2})), 0x04); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+3})), 0x08); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+4})), 0x10); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+5})), 0x20); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+6})), 0x40); + v${N} = _mm256_blend_epi32(v${N}, _mm256_set1_epi32((int32_t) unaligned_load_u32(w${N+7})), 0x80); + $if PREFETCH: + $for N in range(0, NR): + xnn_prefetch_to_l1((const int8_t*) w${N} + 448); + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + $for N in range(NR): + w${N} += ${KR}; + out += ${NR*KR}; + } + + // KC remainder of 1..${KR-1} + if (k != 0) { + assert(k >= 1 && k <= ${KR-1}); + + const __m256i vmask = _mm256_set1_epi32((1u << (k * sizeof(int8_t) * 8)) - 1); + + $for N in range(0, NR, 8): + __m256i v${N} = _mm256_set1_epi32((int32_t) safe_load_u32(w${N}, k)); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+1}, k)), 0x02); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+2}, k)), 0x04); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+3}, k)), 0x08); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+4}, k)), 0x10); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+5}, k)), 0x20); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+6}, k)), 0x40); + v${N} = _mm256_blend_epi32(v0, _mm256_set1_epi32((int32_t) safe_load_u32(w${N+7}, k)), 0x80); + v${N} = _mm256_and_si256(v${N}, vmask); + + $for N in range(NR): + w${N} += k; + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + vacc${N} = ${_MM256_DPBUSD_EPI32}(vacc${N}, vone, v${N}); + + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *)&out[${N * KR}], v${N}); + + out += ${NR*KR}; + } + + $if DATATYPE in ["QS8"]: + $for N in range(0, NR, 8): + __m256i vksum${N} = _mm256_mullo_epi32(vacc${N}, vzeropoint); + $for N in range(0, NR, 8): + __m256i vpack${N} = _mm256_loadu_si256((const __m256i*) (packed_b + ${N})); + $for N in range(0, NR, 8): + vpack${N} = _mm256_sub_epi32(vpack${N}, vksum${N}); + $for N in range(0, NR, 8): + _mm256_storeu_si256((__m256i *) (packed_b + ${N}), vpack${N}); + out = (${TYPE}*) ((uintptr_t) out + extra_bytes); + } + + weights += nc * kc; + } while (--g != 0); +}