Skip to content

Commit

Permalink
QD8-f32-qc4w SSE/AVX microkernels use while loop counting down
Browse files Browse the repository at this point in the history
- Preparation for unrolled loop for QC4W

PiperOrigin-RevId: 574548736
  • Loading branch information
fbarchard authored and xnnpack-bot committed Oct 18, 2023
1 parent f9bd912 commit b5b3a4f
Show file tree
Hide file tree
Showing 202 changed files with 789 additions and 807 deletions.
60 changes: 30 additions & 30 deletions src/amalgam/gen/avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -6402,8 +6402,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__avx_ld128(
__m128i vacc0x3 = _mm_blend_epi16(vinit0, vzero, 0x3F);
w = (const int32_t*) w + 4;

size_t k = 0;
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(int8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
a0 += 8;
Expand All @@ -6430,7 +6430,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_1x4c8__avx_ld128(
vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));

w = (const int8_t*) w + 16;
k += 8 * sizeof(int8_t);
k -= 8 * sizeof(int8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -6554,8 +6554,8 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__avx_ld128(
__m128i vacc3x3 = _mm_blend_epi16(vinit3, vzero, 0x3F);
w = (const int32_t*) w + 4;

size_t k = 0;
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(int8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
a0 += 8;
Expand Down Expand Up @@ -6603,7 +6603,7 @@ void xnn_qd8_f32_qc4w_gemm_minmax_ukernel_4x4c8__avx_ld128(
vacc3x3 = _mm_add_epi32(vacc3x3, _mm_madd_epi16(vxa3, vxb3));

w = (const int8_t*) w + 16;
k += 8 * sizeof(int8_t);
k -= 8 * sizeof(int8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -6744,8 +6744,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__avx_ld128(
__m128i vacc0x3 = _mm_blend_epi16(vinit0, vzero, 0x3F);
w = (const int32_t*) w + 4;

size_t k = 0;
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(int8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
a0 += 8;
Expand All @@ -6766,7 +6766,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x4c8__avx_ld128(
vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));

w = (const int8_t*) w + 32;
k += 8 * sizeof(int8_t);
k -= 8 * sizeof(int8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -6864,8 +6864,8 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__avx_ld128(
__m128i vacc1x3 = _mm_blend_epi16(vinit1, vzero, 0x3F);
w = (const int32_t*) w + 4;

size_t k = 0;
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(int8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
a0 += 8;
Expand Down Expand Up @@ -6893,7 +6893,7 @@ void xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x4c8__avx_ld128(
vacc1x3 = _mm_add_epi32(vacc1x3, _mm_madd_epi16(vxa1, vxb3));

w = (const int8_t*) w + 32;
k += 8 * sizeof(int8_t);
k -= 8 * sizeof(int8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -8832,8 +8832,8 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x4c8__avx_ld128(
__m128i vacc0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
w = (const int32_t*) w + 4;

size_t k = 0;
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(int8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
a0 += 8;
Expand All @@ -8854,7 +8854,7 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_1x4c8__avx_ld128(
vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));

w = (const int8_t*) w + 32;
k += 8 * sizeof(int8_t);
k -= 8 * sizeof(int8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -8945,8 +8945,8 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x4c8__avx_ld128(
__m128i vacc1x3 = vacc0x3;
w = (const int32_t*) w + 4;

size_t k = 0;
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(int8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
a0 += 8;
Expand Down Expand Up @@ -8974,7 +8974,7 @@ void xnn_qs8_gemm_minmax_fp32_ukernel_2x4c8__avx_ld128(
vacc1x3 = _mm_add_epi32(vacc1x3, _mm_madd_epi16(vxa1, vxb3));

w = (const int8_t*) w + 32;
k += 8 * sizeof(int8_t);
k -= 8 * sizeof(int8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -10945,8 +10945,8 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__avx_ld128(
__m128i vacc0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
w = (const int32_t*) w + 4;

size_t k = 0;
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(int8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
a0 += 8;
Expand All @@ -10967,7 +10967,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_1x4c8__avx_ld128(
vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));

w = (const int8_t*) w + 32;
k += 8 * sizeof(int8_t);
k -= 8 * sizeof(int8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -11059,8 +11059,8 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__avx_ld128(
__m128i vacc1x3 = vacc0x3;
w = (const int32_t*) w + 4;

size_t k = 0;
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(int8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepi8_epi16(va0);
a0 += 8;
Expand Down Expand Up @@ -11088,7 +11088,7 @@ void xnn_qs8_qc8w_gemm_minmax_fp32_ukernel_2x4c8__avx_ld128(
vacc1x3 = _mm_add_epi32(vacc1x3, _mm_madd_epi16(vxa1, vxb3));

w = (const int8_t*) w + 32;
k += 8 * sizeof(int8_t);
k -= 8 * sizeof(int8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -13706,10 +13706,10 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__avx_ld128(
__m128i vacc0x3 = _mm_cvtsi32_si128(((const int*) w)[3]);
w = (const int32_t*) w + 4;

size_t k = 0;
const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->fp32_sse2.kernel_zero_point);
const __m128i vzero = _mm_setzero_si128();
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(uint8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepu8_epi16(va0);
a0 += 8;
Expand All @@ -13728,7 +13728,7 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_1x4c8__avx_ld128(
vacc0x3 = _mm_add_epi32(vacc0x3, _mm_madd_epi16(vxa0, vxb3));

w = (const uint8_t*) w + 32;
k += 8 * sizeof(uint8_t);
k -= 8 * sizeof(uint8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down Expand Up @@ -13818,10 +13818,10 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__avx_ld128(
__m128i vacc1x3 = vacc0x3;
w = (const int32_t*) w + 4;

size_t k = 0;
const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->fp32_sse2.kernel_zero_point);
const __m128i vzero = _mm_setzero_si128();
while (k < kc) {
size_t k = kc;
while (k >= 8 * sizeof(uint8_t)) {
const __m128i va0 = _mm_loadl_epi64((const __m128i*) a0);
const __m128i vxa0 = _mm_cvtepu8_epi16(va0);
a0 += 8;
Expand All @@ -13847,7 +13847,7 @@ void xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__avx_ld128(
vacc1x3 = _mm_add_epi32(vacc1x3, _mm_madd_epi16(vxa1, vxb3));

w = (const uint8_t*) w + 32;
k += 8 * sizeof(uint8_t);
k -= 8 * sizeof(uint8_t);
}

const __m128i vacc0x01 = _mm_hadd_epi32(vacc0x0, vacc0x1);
Expand Down
Loading

0 comments on commit b5b3a4f

Please sign in to comment.