Skip to content

Commit

Permalink
qc4w avx pack fuse xor with load
Browse files Browse the repository at this point in the history
- Eliminate xor instruction in 4 bit sum by doing it as part of the load

PiperOrigin-RevId: 702891126
  • Loading branch information
fbarchard authored and xnnpack-bot committed Dec 5, 2024
1 parent bb695f2 commit 0953222
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 37 deletions.
24 changes: 8 additions & 16 deletions src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avx256vnni.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
size_t k = kc;
// KC main loop multiple of 8x32
for (; k >= 32; k -= 32) {
const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0);
const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1);
const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2);
const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3);
const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4);
const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5);
const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6);
const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7);
const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4
const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4
const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4
const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4
const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4
const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4
const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4
const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4

const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123);
const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123);
Expand All @@ -156,7 +156,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
__m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1));
__m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1));

v0_0 = _mm256_xor_si256(v0_0, vkernel_zero_point); // uint4 -> int4
const __m256i vt0_0 = _mm256_slli_epi32(v0_0, 4); // isolate lower int4
const __m256i vh0_0 = _mm256_and_si256(v0_0, vmask); // isolate upper int4
const __m256i vl0_0 = _mm256_and_si256(vt0_0, vmask);
Expand All @@ -170,7 +169,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
const __m256i v23x0_0 = _mm256_or_si256(v2x0_0, v3x0_0);
const __m256i vt010_0 = _mm256_srli_epi32(v01x0_0, 4); // first plane 0-7
v0_0 = _mm256_or_si256(vt010_0, v23x0_0); // + second plane 8-F
v0_1 = _mm256_xor_si256(v0_1, vkernel_zero_point); // uint4 -> int4
const __m256i vt0_1 = _mm256_slli_epi32(v0_1, 4); // isolate lower int4
const __m256i vh0_1 = _mm256_and_si256(v0_1, vmask); // isolate upper int4
const __m256i vl0_1 = _mm256_and_si256(vt0_1, vmask);
Expand All @@ -184,7 +182,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
const __m256i v23x0_1 = _mm256_or_si256(v2x0_1, v3x0_1);
const __m256i vt010_1 = _mm256_srli_epi32(v01x0_1, 4); // first plane 0-7
v0_1 = _mm256_or_si256(vt010_1, v23x0_1); // + second plane 8-F
v0_2 = _mm256_xor_si256(v0_2, vkernel_zero_point); // uint4 -> int4
const __m256i vt0_2 = _mm256_slli_epi32(v0_2, 4); // isolate lower int4
const __m256i vh0_2 = _mm256_and_si256(v0_2, vmask); // isolate upper int4
const __m256i vl0_2 = _mm256_and_si256(vt0_2, vmask);
Expand All @@ -198,7 +195,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
const __m256i v23x0_2 = _mm256_or_si256(v2x0_2, v3x0_2);
const __m256i vt010_2 = _mm256_srli_epi32(v01x0_2, 4); // first plane 0-7
v0_2 = _mm256_or_si256(vt010_2, v23x0_2); // + second plane 8-F
v0_3 = _mm256_xor_si256(v0_3, vkernel_zero_point); // uint4 -> int4
const __m256i vt0_3 = _mm256_slli_epi32(v0_3, 4); // isolate lower int4
const __m256i vh0_3 = _mm256_and_si256(v0_3, vmask); // isolate upper int4
const __m256i vl0_3 = _mm256_and_si256(vt0_3, vmask);
Expand All @@ -212,7 +208,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
const __m256i v23x0_3 = _mm256_or_si256(v2x0_3, v3x0_3);
const __m256i vt010_3 = _mm256_srli_epi32(v01x0_3, 4); // first plane 0-7
v0_3 = _mm256_or_si256(vt010_3, v23x0_3); // + second plane 8-F
v4_0 = _mm256_xor_si256(v4_0, vkernel_zero_point); // uint4 -> int4
const __m256i vt4_0 = _mm256_slli_epi32(v4_0, 4); // isolate lower int4
const __m256i vh4_0 = _mm256_and_si256(v4_0, vmask); // isolate upper int4
const __m256i vl4_0 = _mm256_and_si256(vt4_0, vmask);
Expand All @@ -226,7 +221,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
const __m256i v23x4_0 = _mm256_or_si256(v2x4_0, v3x4_0);
const __m256i vt014_0 = _mm256_srli_epi32(v01x4_0, 4); // first plane 0-7
v4_0 = _mm256_or_si256(vt014_0, v23x4_0); // + second plane 8-F
v4_1 = _mm256_xor_si256(v4_1, vkernel_zero_point); // uint4 -> int4
const __m256i vt4_1 = _mm256_slli_epi32(v4_1, 4); // isolate lower int4
const __m256i vh4_1 = _mm256_and_si256(v4_1, vmask); // isolate upper int4
const __m256i vl4_1 = _mm256_and_si256(vt4_1, vmask);
Expand All @@ -240,7 +234,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
const __m256i v23x4_1 = _mm256_or_si256(v2x4_1, v3x4_1);
const __m256i vt014_1 = _mm256_srli_epi32(v01x4_1, 4); // first plane 0-7
v4_1 = _mm256_or_si256(vt014_1, v23x4_1); // + second plane 8-F
v4_2 = _mm256_xor_si256(v4_2, vkernel_zero_point); // uint4 -> int4
const __m256i vt4_2 = _mm256_slli_epi32(v4_2, 4); // isolate lower int4
const __m256i vh4_2 = _mm256_and_si256(v4_2, vmask); // isolate upper int4
const __m256i vl4_2 = _mm256_and_si256(vt4_2, vmask);
Expand All @@ -254,7 +247,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avx256vnni(
const __m256i v23x4_2 = _mm256_or_si256(v2x4_2, v3x4_2);
const __m256i vt014_2 = _mm256_srli_epi32(v01x4_2, 4); // first plane 0-7
v4_2 = _mm256_or_si256(vt014_2, v23x4_2); // + second plane 8-F
v4_3 = _mm256_xor_si256(v4_3, vkernel_zero_point); // uint4 -> int4
const __m256i vt4_3 = _mm256_slli_epi32(v4_3, 4); // isolate lower int4
const __m256i vh4_3 = _mm256_and_si256(v4_3, vmask); // isolate upper int4
const __m256i vl4_3 = _mm256_and_si256(vt4_3, vmask);
Expand Down
24 changes: 8 additions & 16 deletions src/qs8-qc4w-packw/gen/qs8-qc4w-packw-x8c8-gemm-goi-avxvnni.c
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
size_t k = kc;
// KC main loop multiple of 8x32
for (; k >= 32; k -= 32) {
const __m256i v0_0123 = _mm256_loadu_si256((const __m256i*) w0);
const __m256i v1_0123 = _mm256_loadu_si256((const __m256i*) w1);
const __m256i v2_0123 = _mm256_loadu_si256((const __m256i*) w2);
const __m256i v3_0123 = _mm256_loadu_si256((const __m256i*) w3);
const __m256i v4_0123 = _mm256_loadu_si256((const __m256i*) w4);
const __m256i v5_0123 = _mm256_loadu_si256((const __m256i*) w5);
const __m256i v6_0123 = _mm256_loadu_si256((const __m256i*) w6);
const __m256i v7_0123 = _mm256_loadu_si256((const __m256i*) w7);
const __m256i v0_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w0), vkernel_zero_point); // uint4 -> int4
const __m256i v1_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w1), vkernel_zero_point); // uint4 -> int4
const __m256i v2_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w2), vkernel_zero_point); // uint4 -> int4
const __m256i v3_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w3), vkernel_zero_point); // uint4 -> int4
const __m256i v4_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w4), vkernel_zero_point); // uint4 -> int4
const __m256i v5_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w5), vkernel_zero_point); // uint4 -> int4
const __m256i v6_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w6), vkernel_zero_point); // uint4 -> int4
const __m256i v7_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w7), vkernel_zero_point); // uint4 -> int4

const __m256i v01_02 = _mm256_unpacklo_epi64(v0_0123, v1_0123);
const __m256i v01_13 = _mm256_unpackhi_epi64(v0_0123, v1_0123);
Expand All @@ -156,7 +156,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
__m256i v4_2 = _mm256_permute2f128_si256(v45_02, v67_02, _MM_SHUFFLE(0, 3, 0, 1));
__m256i v4_3 = _mm256_permute2f128_si256(v45_13, v67_13, _MM_SHUFFLE(0, 3, 0, 1));

v0_0 = _mm256_xor_si256(v0_0, vkernel_zero_point); // uint4 -> int4
const __m256i vt0_0 = _mm256_slli_epi32(v0_0, 4); // isolate lower int4
const __m256i vh0_0 = _mm256_and_si256(v0_0, vmask); // isolate upper int4
const __m256i vl0_0 = _mm256_and_si256(vt0_0, vmask);
Expand All @@ -170,7 +169,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
const __m256i v23x0_0 = _mm256_or_si256(v2x0_0, v3x0_0);
const __m256i vt010_0 = _mm256_srli_epi32(v01x0_0, 4); // first plane 0-7
v0_0 = _mm256_or_si256(vt010_0, v23x0_0); // + second plane 8-F
v0_1 = _mm256_xor_si256(v0_1, vkernel_zero_point); // uint4 -> int4
const __m256i vt0_1 = _mm256_slli_epi32(v0_1, 4); // isolate lower int4
const __m256i vh0_1 = _mm256_and_si256(v0_1, vmask); // isolate upper int4
const __m256i vl0_1 = _mm256_and_si256(vt0_1, vmask);
Expand All @@ -184,7 +182,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
const __m256i v23x0_1 = _mm256_or_si256(v2x0_1, v3x0_1);
const __m256i vt010_1 = _mm256_srli_epi32(v01x0_1, 4); // first plane 0-7
v0_1 = _mm256_or_si256(vt010_1, v23x0_1); // + second plane 8-F
v0_2 = _mm256_xor_si256(v0_2, vkernel_zero_point); // uint4 -> int4
const __m256i vt0_2 = _mm256_slli_epi32(v0_2, 4); // isolate lower int4
const __m256i vh0_2 = _mm256_and_si256(v0_2, vmask); // isolate upper int4
const __m256i vl0_2 = _mm256_and_si256(vt0_2, vmask);
Expand All @@ -198,7 +195,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
const __m256i v23x0_2 = _mm256_or_si256(v2x0_2, v3x0_2);
const __m256i vt010_2 = _mm256_srli_epi32(v01x0_2, 4); // first plane 0-7
v0_2 = _mm256_or_si256(vt010_2, v23x0_2); // + second plane 8-F
v0_3 = _mm256_xor_si256(v0_3, vkernel_zero_point); // uint4 -> int4
const __m256i vt0_3 = _mm256_slli_epi32(v0_3, 4); // isolate lower int4
const __m256i vh0_3 = _mm256_and_si256(v0_3, vmask); // isolate upper int4
const __m256i vl0_3 = _mm256_and_si256(vt0_3, vmask);
Expand All @@ -212,7 +208,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
const __m256i v23x0_3 = _mm256_or_si256(v2x0_3, v3x0_3);
const __m256i vt010_3 = _mm256_srli_epi32(v01x0_3, 4); // first plane 0-7
v0_3 = _mm256_or_si256(vt010_3, v23x0_3); // + second plane 8-F
v4_0 = _mm256_xor_si256(v4_0, vkernel_zero_point); // uint4 -> int4
const __m256i vt4_0 = _mm256_slli_epi32(v4_0, 4); // isolate lower int4
const __m256i vh4_0 = _mm256_and_si256(v4_0, vmask); // isolate upper int4
const __m256i vl4_0 = _mm256_and_si256(vt4_0, vmask);
Expand All @@ -226,7 +221,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
const __m256i v23x4_0 = _mm256_or_si256(v2x4_0, v3x4_0);
const __m256i vt014_0 = _mm256_srli_epi32(v01x4_0, 4); // first plane 0-7
v4_0 = _mm256_or_si256(vt014_0, v23x4_0); // + second plane 8-F
v4_1 = _mm256_xor_si256(v4_1, vkernel_zero_point); // uint4 -> int4
const __m256i vt4_1 = _mm256_slli_epi32(v4_1, 4); // isolate lower int4
const __m256i vh4_1 = _mm256_and_si256(v4_1, vmask); // isolate upper int4
const __m256i vl4_1 = _mm256_and_si256(vt4_1, vmask);
Expand All @@ -240,7 +234,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
const __m256i v23x4_1 = _mm256_or_si256(v2x4_1, v3x4_1);
const __m256i vt014_1 = _mm256_srli_epi32(v01x4_1, 4); // first plane 0-7
v4_1 = _mm256_or_si256(vt014_1, v23x4_1); // + second plane 8-F
v4_2 = _mm256_xor_si256(v4_2, vkernel_zero_point); // uint4 -> int4
const __m256i vt4_2 = _mm256_slli_epi32(v4_2, 4); // isolate lower int4
const __m256i vh4_2 = _mm256_and_si256(v4_2, vmask); // isolate upper int4
const __m256i vl4_2 = _mm256_and_si256(vt4_2, vmask);
Expand All @@ -254,7 +247,6 @@ void xnn_qs8_qc4w_packw_gemm_goi_ukernel_x8c8__avxvnni(
const __m256i v23x4_2 = _mm256_or_si256(v2x4_2, v3x4_2);
const __m256i vt014_2 = _mm256_srli_epi32(v01x4_2, 4); // first plane 0-7
v4_2 = _mm256_or_si256(vt014_2, v23x4_2); // + second plane 8-F
v4_3 = _mm256_xor_si256(v4_3, vkernel_zero_point); // uint4 -> int4
const __m256i vt4_3 = _mm256_slli_epi32(v4_3, 4); // isolate lower int4
const __m256i vh4_3 = _mm256_and_si256(v4_3, vmask); // isolate upper int4
const __m256i vl4_3 = _mm256_and_si256(vt4_3, vmask);
Expand Down
12 changes: 7 additions & 5 deletions src/x8-packw/kr-avxvnni.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ $if DATATYPE in ["QS8", "QS4"]:
$ISA = "avx2" if VARIANT == "MADD" else "avxvnni" if AVX == 2 else "avx256vnni"
$else:
$ISA = "avx2" if AVX == 2 else "avx256skx"
$NAMETYPE = "qs8_to_qu8" if IZP == 128 else {"QS8": "qs8", "QS4": "qs8_qc4w", "X8": "x8"}[DATATYPE]
$DATATYPE_SPEC = "qs8_to_qu8" if IZP == 128 else {"QS8": "qs8", "QS4": "qs8_qc4w", "X8": "x8"}[DATATYPE]

void xnn_${NAMETYPE}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_madd" if VARIANT == "MADD" else ""}${"_prfm" if PREFETCH else ""}(
void xnn_${DATATYPE_SPEC}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_madd" if VARIANT == "MADD" else ""}${"_prfm" if PREFETCH else ""}(
size_t g,
size_t nc,
size_t kc,
Expand All @@ -79,7 +79,7 @@ void xnn_${NAMETYPE}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_madd" if VAR
$if DATATYPE == "QS4":
// Use scalar pack if not an even block size
if (kc & 1) {
xnn_${NAMETYPE}_packw_gemm_goi_ukernel_x${NR}c${KR}__scalar(
xnn_${DATATYPE_SPEC}_packw_gemm_goi_ukernel_x${NR}c${KR}__scalar(
g, nc, kc, nr, kr, sr,
weights, bias, scale, packed_weights, extra_bytes, params);
return;
Expand Down Expand Up @@ -162,7 +162,10 @@ void xnn_${NAMETYPE}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_madd" if VAR
// KC main loop multiple of ${NR}x${4 * KR}
for (; k >= ${4 * KR}; k -= ${4 * KR}) {
$for N in range(NR):
const __m256i v${N}_0123 = _mm256_loadu_si256((const __m256i*) w${N});
$if DATATYPE in ["QS4"]:
const __m256i v${N}_0123 = _mm256_xor_si256(_mm256_loadu_si256((const __m256i*) w${N}), vkernel_zero_point); // uint4 -> int4
$else:
const __m256i v${N}_0123 = _mm256_loadu_si256((const __m256i*) w${N});

$for N in range(0, NR, 2):
const __m256i v${N}${N+1}_02 = _mm256_unpacklo_epi64(v${N}_0123, v${N+1}_0123);
Expand All @@ -183,7 +186,6 @@ void xnn_${NAMETYPE}_packw_gemm_goi_ukernel_x${NR}c${KR}__${ISA}${"_madd" if VAR
$elif DATATYPE in ["QS4"]:
$for N in range(0, NR, 4):
$for I in range(0, 4):
v${N}_${I} = _mm256_xor_si256(v${N}_${I}, vkernel_zero_point); // uint4 -> int4
const __m256i vt${N}_${I} = _mm256_slli_epi32(v${N}_${I}, 4); // isolate lower int4
const __m256i vh${N}_${I} = _mm256_and_si256(v${N}_${I}, vmask); // isolate upper int4
const __m256i vl${N}_${I} = _mm256_and_si256(vt${N}_${I}, vmask);
Expand Down

0 comments on commit 0953222

Please sign in to comment.