diff --git a/src/amalgam/gen/avx512vnni.c b/src/amalgam/gen/avx512vnni.c index e69de29bb2d..1423aa71f79 100644 --- a/src/amalgam/gen/avx512vnni.c +++ b/src/amalgam/gen/avx512vnni.c @@ -0,0 +1,537 @@ +// Copyright 2021 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include + +#include + +#include +#include +#include +#include + + +void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + float* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)], + const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + float* c0 = c; + + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); + const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); + const __m512i vsign_mask = _mm512_load_si512(params->avx512vnni.sign_mask); + do { + const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); + __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); + w = (const int32_t*) w + 16; + + size_t k = kc; + do { + __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + + a0 += 4; + + va0x0123 = _mm512_xor_epi32(va0x0123, vsign_mask); + + const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); + + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + + w = (const int8_t*) w + 64; + k -= 4 * sizeof(int8_t); + } while (k != 0); + + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + + vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); + + const __m512 vfilter_output_scale0123456789ABCDEF = _mm512_load_ps((const float*) w); + const __m512 vbias0123456789ABCDEF = _mm512_load_ps((const float*) w + 16); + w = (const float*) w + 32; + + vscaled0x0123456789ABCDEF = _mm512_fmadd_ps(vscaled0x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + + vscaled0x0123456789ABCDEF = _mm512_max_ps(vscaled0x0123456789ABCDEF, voutput_min); + + vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max); + + if(nc >= 16) { + _mm512_storeu_ps(c0, vscaled0x0123456789ABCDEF); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + + c0 = (float*) ((uintptr_t) c0 + cn_stride); + + nc -= 16; + } else { + // Prepare mask for valid 32-bit elements (depends on nc). + const __mmask16 vmask = _cvtu32_mask16((UINT32_C(1) << nc) - 1); + _mm512_mask_storeu_ps(c0, vmask, vscaled0x0123456789ABCDEF); + nc = 0; + } + } while (nc != 0); +} + +void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni( + size_t mr, + size_t nc, + size_t kc, + const int8_t* restrict a, + size_t a_stride, + const void* restrict w, + float* restrict c, + size_t cm_stride, + size_t cn_stride, + const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)], + const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + const int8_t* a0 = a; + float* c0 = c; + const int8_t* a1 = (const int8_t*) ((uintptr_t) a0 + a_stride); + float* c1 = (float*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + a1 = a0; + c1 = c0; + } + const int8_t* a2 = (const int8_t*) ((uintptr_t) a1 + a_stride); + float* c2 = (float*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + a2 = a1; + c2 = c1; + } + const int8_t* a3 = (const int8_t*) ((uintptr_t) a2 + a_stride); + float* c3 = (float*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + a3 = a2; + c3 = c2; + } + + const __m512i vinput_zero_point0 = _mm512_set1_epi32((int) quantization_params[0].zero_point + 128); + const __m512i vinput_zero_point1 = _mm512_set1_epi32((int) quantization_params[1].zero_point + 128); + const __m512i vinput_zero_point2 = _mm512_set1_epi32((int) quantization_params[2].zero_point + 128); + const __m512i vinput_zero_point3 = _mm512_set1_epi32((int) quantization_params[3].zero_point + 128); + const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); + const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); + const __m512i vsign_mask = _mm512_load_si512(params->avx512vnni.sign_mask); + do { + const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); + __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point0); + __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point1); + __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point2); + __m512i vacc3x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point3); + w = (const int32_t*) w + 16; + + size_t k = kc; + do { + __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); + __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); + __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); + __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + + va0x0123 = _mm512_xor_epi32(va0x0123, vsign_mask); + va1x0123 = _mm512_xor_epi32(va1x0123, vsign_mask); + va2x0123 = _mm512_xor_epi32(va2x0123, vsign_mask); + va3x0123 = _mm512_xor_epi32(va3x0123, vsign_mask); + + const __m512i vb0123456789ABCDEF = _mm512_load_si512(w); + + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEF); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEF); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEF); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEF); + + w = (const int8_t*) w + 64; + k -= 4 * sizeof(int8_t); + } while (k != 0); + + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + + vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, _mm512_set1_ps(quantization_params[0].inv_scale)); + vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, _mm512_set1_ps(quantization_params[1].inv_scale)); + vscaled2x0123456789ABCDEF = _mm512_mul_ps(vscaled2x0123456789ABCDEF, _mm512_set1_ps(quantization_params[2].inv_scale)); + vscaled3x0123456789ABCDEF = _mm512_mul_ps(vscaled3x0123456789ABCDEF, _mm512_set1_ps(quantization_params[3].inv_scale)); + + const __m512 vfilter_output_scale0123456789ABCDEF = _mm512_load_ps((const float*) w); + const __m512 vbias0123456789ABCDEF = _mm512_load_ps((const float*) w + 16); + w = (const float*) w + 32; + + vscaled0x0123456789ABCDEF = _mm512_fmadd_ps(vscaled0x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + vscaled1x0123456789ABCDEF = _mm512_fmadd_ps(vscaled1x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + vscaled2x0123456789ABCDEF = _mm512_fmadd_ps(vscaled2x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + vscaled3x0123456789ABCDEF = _mm512_fmadd_ps(vscaled3x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + + vscaled0x0123456789ABCDEF = _mm512_max_ps(vscaled0x0123456789ABCDEF, voutput_min); + vscaled1x0123456789ABCDEF = _mm512_max_ps(vscaled1x0123456789ABCDEF, voutput_min); + vscaled2x0123456789ABCDEF = _mm512_max_ps(vscaled2x0123456789ABCDEF, voutput_min); + vscaled3x0123456789ABCDEF = _mm512_max_ps(vscaled3x0123456789ABCDEF, voutput_min); + + vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max); + vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max); + vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max); + vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max); + + if(nc >= 16) { + _mm512_storeu_ps(c3, vscaled3x0123456789ABCDEF); + _mm512_storeu_ps(c2, vscaled2x0123456789ABCDEF); + _mm512_storeu_ps(c1, vscaled1x0123456789ABCDEF); + _mm512_storeu_ps(c0, vscaled0x0123456789ABCDEF); + + a0 = (const int8_t*) ((uintptr_t) a0 - kc); + a1 = (const int8_t*) ((uintptr_t) a1 - kc); + a2 = (const int8_t*) ((uintptr_t) a2 - kc); + a3 = (const int8_t*) ((uintptr_t) a3 - kc); + + c0 = (float*) ((uintptr_t) c0 + cn_stride); + c1 = (float*) ((uintptr_t) c1 + cn_stride); + c2 = (float*) ((uintptr_t) c2 + cn_stride); + c3 = (float*) ((uintptr_t) c3 + cn_stride); + + nc -= 16; + } else { + // Prepare mask for valid 32-bit elements (depends on nc). + const __mmask16 vmask = _cvtu32_mask16((UINT32_C(1) << nc) - 1); + _mm512_mask_storeu_ps(c3, vmask, vscaled3x0123456789ABCDEF); + _mm512_mask_storeu_ps(c2, vmask, vscaled2x0123456789ABCDEF); + _mm512_mask_storeu_ps(c1, vmask, vscaled1x0123456789ABCDEF); + _mm512_mask_storeu_ps(c0, vmask, vscaled0x0123456789ABCDEF); + nc = 0; + } + } while (nc != 0); +} + +void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + float* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)], + const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 1); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + float* c0 = c; + + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512 vinput_scale = _mm512_set1_ps(quantization_params->inv_scale); + const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); + const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); + const __m512i vsign_mask = _mm512_load_si512(params->avx512vnni.sign_mask); + do { + const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); + __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point); + w = (const int32_t*) w + 16; + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + a += 1; + + size_t k = kc; + while (k >= 8 * sizeof(int8_t)) { + __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; + + __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; + + va0x0123 = _mm512_xor_epi32(va0x0123, vsign_mask); + va0x4567 = _mm512_xor_epi32(va0x4567, vsign_mask); + + const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); + w = (const int8_t*) w + 64; + const __m512i vb0123456789ABCDEFx4567 = _mm512_load_si512(w); + w = (const int8_t*) w + 64; + + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x4567, vb0123456789ABCDEFx4567); + + k -= 8 * sizeof(int8_t); + } + if (k != 0) { + __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; + + va0x0123 = _mm512_xor_epi32(va0x0123, vsign_mask); + + const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); + w = (const int8_t*) w + 64; + + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); + } + p -= 1 * sizeof(void*); + } while (p != 0); + + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + + vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vinput_scale); + + const __m512 vfilter_output_scale0123456789ABCDEF = _mm512_load_ps((const float*) w); + const __m512 vbias0123456789ABCDEF = _mm512_load_ps((const float*) w + 16); + w = (const float*) w + 32; + + vscaled0x0123456789ABCDEF = _mm512_fmadd_ps(vscaled0x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + + vscaled0x0123456789ABCDEF = _mm512_max_ps(vscaled0x0123456789ABCDEF, voutput_min); + + vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max); + + if(nc >= 16) { + _mm512_storeu_ps(c0, vscaled0x0123456789ABCDEF); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + c0 = (float*) ((uintptr_t) c0 + cn_stride); + + nc -= 16; + } else { + // Prepare mask for valid 32-bit elements (depends on nc). + const __mmask16 vmask = _cvtu32_mask16((UINT32_C(1) << nc) - 1); + _mm512_mask_storeu_ps(c0, vmask, vscaled0x0123456789ABCDEF); + nc = 0; + } + } while (nc != 0); +} + + +void xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni( + size_t mr, + size_t nc, + size_t kc, + size_t ks, + const int8_t** restrict a, + const void* restrict w, + float* restrict c, + size_t cm_stride, + size_t cn_stride, + size_t a_offset, + const int8_t* zero, + const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)], + const struct xnn_qd8_quantization_params quantization_params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS +{ + assert(mr != 0); + assert(mr <= 4); + assert(nc != 0); + assert(kc != 0); + assert(kc % sizeof(int8_t) == 0); + assert(a != NULL); + assert(w != NULL); + assert(c != NULL); + + kc = round_up_po2(kc, 4 * sizeof(int8_t)); + float* c0 = c; + float* c1 = (float*) ((uintptr_t) c0 + cm_stride); + if XNN_UNPREDICTABLE(mr < 2) { + c1 = c0; + } + float* c2 = (float*) ((uintptr_t) c1 + cm_stride); + if XNN_UNPREDICTABLE(mr <= 2) { + c2 = c1; + } + float* c3 = (float*) ((uintptr_t) c2 + cm_stride); + if XNN_UNPREDICTABLE(mr != 4) { + c3 = c2; + } + + const __m512i vinput_zero_point = _mm512_set1_epi32((int) quantization_params->zero_point + 128); + const __m512 vinput_scale = _mm512_set1_ps(quantization_params->inv_scale); + const __m512 voutput_min = _mm512_set1_ps(params->avx512vnni.min); + const __m512 voutput_max = _mm512_set1_ps(params->avx512vnni.max); + const __m512i vsign_mask = _mm512_load_si512(params->avx512vnni.sign_mask); + do { + const __m512i vksum0123456789ABCDEF = _mm512_load_epi32(w); + __m512i vacc0x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point); + __m512i vacc1x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point); + __m512i vacc2x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point); + __m512i vacc3x0123456789ABCDEF = _mm512_mullo_epi32(vksum0123456789ABCDEF, vinput_zero_point); + w = (const int32_t*) w + 16; + + size_t p = ks; + do { + const int8_t* restrict a0 = a[0]; + if XNN_UNPREDICTABLE(a0 != zero) { + a0 = (const int8_t*) ((uintptr_t) a0 + a_offset); + } + const int8_t* restrict a1 = a[1]; + if XNN_UNPREDICTABLE(a1 != zero) { + a1 = (const int8_t*) ((uintptr_t) a1 + a_offset); + } + const int8_t* restrict a2 = a[2]; + if XNN_UNPREDICTABLE(a2 != zero) { + a2 = (const int8_t*) ((uintptr_t) a2 + a_offset); + } + const int8_t* restrict a3 = a[3]; + if XNN_UNPREDICTABLE(a3 != zero) { + a3 = (const int8_t*) ((uintptr_t) a3 + a_offset); + } + a += 4; + + size_t k = kc; + while (k >= 8 * sizeof(int8_t)) { + __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; + __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; + __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; + __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; + + __m512i va0x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; + __m512i va1x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; + __m512i va2x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; + __m512i va3x4567 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; + + va0x0123 = _mm512_xor_epi32(va0x0123, vsign_mask); + va0x4567 = _mm512_xor_epi32(va0x4567, vsign_mask); + va1x0123 = _mm512_xor_epi32(va1x0123, vsign_mask); + va1x4567 = _mm512_xor_epi32(va1x4567, vsign_mask); + va2x0123 = _mm512_xor_epi32(va2x0123, vsign_mask); + va2x4567 = _mm512_xor_epi32(va2x4567, vsign_mask); + va3x0123 = _mm512_xor_epi32(va3x0123, vsign_mask); + va3x4567 = _mm512_xor_epi32(va3x4567, vsign_mask); + + const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); + w = (const int8_t*) w + 64; + const __m512i vb0123456789ABCDEFx4567 = _mm512_load_si512(w); + w = (const int8_t*) w + 64; + + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEFx0123); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEFx0123); + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x4567, vb0123456789ABCDEFx4567); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x4567, vb0123456789ABCDEFx4567); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x4567, vb0123456789ABCDEFx4567); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x4567, vb0123456789ABCDEFx4567); + + k -= 8 * sizeof(int8_t); + } + if (k != 0) { + __m512i va0x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a0)); a0 += 4; + __m512i va1x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a1)); a1 += 4; + __m512i va2x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a2)); a2 += 4; + __m512i va3x0123 = _mm512_set1_epi32((int) unaligned_load_u32(a3)); a3 += 4; + + va0x0123 = _mm512_xor_epi32(va0x0123, vsign_mask); + va1x0123 = _mm512_xor_epi32(va1x0123, vsign_mask); + va2x0123 = _mm512_xor_epi32(va2x0123, vsign_mask); + va3x0123 = _mm512_xor_epi32(va3x0123, vsign_mask); + + const __m512i vb0123456789ABCDEFx0123 = _mm512_load_si512(w); + w = (const int8_t*) w + 64; + + vacc0x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc0x0123456789ABCDEF, va0x0123, vb0123456789ABCDEFx0123); + vacc1x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc1x0123456789ABCDEF, va1x0123, vb0123456789ABCDEFx0123); + vacc2x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc2x0123456789ABCDEF, va2x0123, vb0123456789ABCDEFx0123); + vacc3x0123456789ABCDEF = _mm512_dpbusd_epi32(vacc3x0123456789ABCDEF, va3x0123, vb0123456789ABCDEFx0123); + } + p -= 4 * sizeof(void*); + } while (p != 0); + + __m512 vscaled0x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc0x0123456789ABCDEF); + __m512 vscaled1x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc1x0123456789ABCDEF); + __m512 vscaled2x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc2x0123456789ABCDEF); + __m512 vscaled3x0123456789ABCDEF = _mm512_cvtepi32_ps(vacc3x0123456789ABCDEF); + + vscaled0x0123456789ABCDEF = _mm512_mul_ps(vscaled0x0123456789ABCDEF, vinput_scale); + vscaled1x0123456789ABCDEF = _mm512_mul_ps(vscaled1x0123456789ABCDEF, vinput_scale); + vscaled2x0123456789ABCDEF = _mm512_mul_ps(vscaled2x0123456789ABCDEF, vinput_scale); + vscaled3x0123456789ABCDEF = _mm512_mul_ps(vscaled3x0123456789ABCDEF, vinput_scale); + + const __m512 vfilter_output_scale0123456789ABCDEF = _mm512_load_ps((const float*) w); + const __m512 vbias0123456789ABCDEF = _mm512_load_ps((const float*) w + 16); + w = (const float*) w + 32; + + vscaled0x0123456789ABCDEF = _mm512_fmadd_ps(vscaled0x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + vscaled1x0123456789ABCDEF = _mm512_fmadd_ps(vscaled1x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + vscaled2x0123456789ABCDEF = _mm512_fmadd_ps(vscaled2x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + vscaled3x0123456789ABCDEF = _mm512_fmadd_ps(vscaled3x0123456789ABCDEF, vfilter_output_scale0123456789ABCDEF, vbias0123456789ABCDEF); + + vscaled0x0123456789ABCDEF = _mm512_max_ps(vscaled0x0123456789ABCDEF, voutput_min); + vscaled1x0123456789ABCDEF = _mm512_max_ps(vscaled1x0123456789ABCDEF, voutput_min); + vscaled2x0123456789ABCDEF = _mm512_max_ps(vscaled2x0123456789ABCDEF, voutput_min); + vscaled3x0123456789ABCDEF = _mm512_max_ps(vscaled3x0123456789ABCDEF, voutput_min); + + vscaled0x0123456789ABCDEF = _mm512_min_ps(vscaled0x0123456789ABCDEF, voutput_max); + vscaled1x0123456789ABCDEF = _mm512_min_ps(vscaled1x0123456789ABCDEF, voutput_max); + vscaled2x0123456789ABCDEF = _mm512_min_ps(vscaled2x0123456789ABCDEF, voutput_max); + vscaled3x0123456789ABCDEF = _mm512_min_ps(vscaled3x0123456789ABCDEF, voutput_max); + + if(nc >= 16) { + _mm512_storeu_ps(c3, vscaled3x0123456789ABCDEF); + _mm512_storeu_ps(c2, vscaled2x0123456789ABCDEF); + _mm512_storeu_ps(c1, vscaled1x0123456789ABCDEF); + _mm512_storeu_ps(c0, vscaled0x0123456789ABCDEF); + + a = (const int8_t**restrict) ((uintptr_t) a - ks); + + c0 = (float*) ((uintptr_t) c0 + cn_stride); + c1 = (float*) ((uintptr_t) c1 + cn_stride); + c2 = (float*) ((uintptr_t) c2 + cn_stride); + c3 = (float*) ((uintptr_t) c3 + cn_stride); + + nc -= 16; + } else { + // Prepare mask for valid 32-bit elements (depends on nc). + const __mmask16 vmask = _cvtu32_mask16((UINT32_C(1) << nc) - 1); + _mm512_mask_storeu_ps(c3, vmask, vscaled3x0123456789ABCDEF); + _mm512_mask_storeu_ps(c2, vmask, vscaled2x0123456789ABCDEF); + _mm512_mask_storeu_ps(c1, vmask, vscaled1x0123456789ABCDEF); + _mm512_mask_storeu_ps(c0, vmask, vscaled0x0123456789ABCDEF); + nc = 0; + } + } while (nc != 0); +} + diff --git a/src/configs/gemm-config.c b/src/configs/gemm-config.c index a0a078ca19a..c139ef00115 100644 --- a/src/configs/gemm-config.c +++ b/src/configs/gemm-config.c @@ -1968,7 +1968,18 @@ static void init_qd8_f32_qc8w_gemm_config(void) { #elif XNN_ARCH_X86 || XNN_ARCH_X86_64 const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config(); assert(hardware_config != NULL); - if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { + if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512vnni) { + qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__avx512vnni); + qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__avx512vnni); + qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c4__avx512vnni); + qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_1x16c4__avx512vnni); + qd8_f32_qc8w_gemm_config.init.f32 = xnn_init_f32_minmax_scalar_params; + qd8_f32_qc8w_gemm_config.pack_gemm_gio = (xnn_packw_gemm_gio_ukernel_fn) xnn_pack_qs8_gemm_gio_w; + qd8_f32_qc8w_gemm_config.pack_gemm_goi = (xnn_packw_gemm_goi_ukernel_fn) xnn_pack_qs8_gemm_goi_w; + qd8_f32_qc8w_gemm_config.mr = 4; + qd8_f32_qc8w_gemm_config.nr = 16; + qd8_f32_qc8w_gemm_config.log2_kr = 2; + } else if (!XNN_PLATFORM_MOBILE && hardware_config->use_x86_avx512skx) { qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c8__avx512skx); qd8_f32_qc8w_gemm_config.minmax.dqgemm[XNN_MR_TO_INDEX(1)] = xnn_init_hmp_dqgemm_ukernel((xnn_dqgemm_ukernel_fn) xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c8__avx512skx); qd8_f32_qc8w_gemm_config.minmax.dqigemm[XNN_MR_TO_INDEX(4)] = xnn_init_hmp_dqigemm_ukernel((xnn_dqigemm_ukernel_fn) xnn_qd8_f32_qc8w_igemm_minmax_ukernel_4x16c8__avx512skx);