Skip to content

Commit

Permalink
fix c8 packw overread and unexpected load rescheduling for revec
Browse files Browse the repository at this point in the history
  • Loading branch information
yolanda15 committed Dec 3, 2024
1 parent 51a0103 commit 1a9d2fb
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 102 deletions.
84 changes: 42 additions & 42 deletions src/qs8-packw/gen/qs8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

#include "xnnpack/packw.h"

XNN_INLINE static v128_t safe_v128_load64_splat(const void* address, size_t n) {
assert(n >= 1 && n <= sizeof(uint64_t));
const uint8_t* bytes = (const uint8_t*) address;
uint64_t value = (uint64_t) bytes[0];
for (size_t i = 1; i < n; ++i) {
value |= (uint64_t) bytes[i] << (i * 8);
}

return wasm_u64x2_splat(value);
}

void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
size_t g,
Expand Down Expand Up @@ -164,28 +174,26 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
out += 64;
}

// Load ealier to avoid unexpected rescheduling.
v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

// KC remainder 1..KR-1
if (k != 0) {
assert(k >= 1 && k <= 7);

const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8);

const v128_t v0 = wasm_v128_load64_splat(w0);
const v128_t v1 = wasm_v128_load64_splat(w1);
v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3);
v01 = wasm_v128_and(v01, vmask);
const v128_t v2 = wasm_v128_load64_splat(w2);
const v128_t v3 = wasm_v128_load64_splat(w3);
v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3);
v23 = wasm_v128_and(v23, vmask);
const v128_t v4 = wasm_v128_load64_splat(w4);
const v128_t v5 = wasm_v128_load64_splat(w5);
v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3);
v45 = wasm_v128_and(v45, vmask);
const v128_t v6 = wasm_v128_load64_splat(w6);
const v128_t v7 = wasm_v128_load64_splat(w7);
v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3);
v67 = wasm_v128_and(v67, vmask);
const v128_t v0 = safe_v128_load64_splat(w0, k);
const v128_t v1 = safe_v128_load64_splat(w1, k);
const v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3);
const v128_t v2 = safe_v128_load64_splat(w2, k);
const v128_t v3 = safe_v128_load64_splat(w3, k);
const v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3);
const v128_t v4 = safe_v128_load64_splat(w4, k);
const v128_t v5 = safe_v128_load64_splat(w5, k);
const v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3);
const v128_t v6 = safe_v128_load64_splat(w6, k);
const v128_t v7 = safe_v128_load64_splat(w7, k);
const v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3);

vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01);
vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23);
Expand Down Expand Up @@ -214,9 +222,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint);
vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint);

v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123));
wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567));

Expand Down Expand Up @@ -315,28 +320,26 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
out += 64;
}

// Load ealier to avoid unexpected rescheduling.
v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

// KC remainder of 1..7
if (k != 0) {
assert(k >= 1 && k <= 7);

const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8);

const v128_t v0 = wasm_v128_load64_splat(w0);
const v128_t v1 = wasm_v128_load64_splat(w1);
v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3);
v01 = wasm_v128_and(v01, vmask);
const v128_t v2 = wasm_v128_load64_splat(w2);
const v128_t v3 = wasm_v128_load64_splat(w3);
v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3);
v23 = wasm_v128_and(v23, vmask);
const v128_t v4 = wasm_v128_load64_splat(w4);
const v128_t v5 = wasm_v128_load64_splat(w5);
v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3);
v45 = wasm_v128_and(v45, vmask);
const v128_t v6 = wasm_v128_load64_splat(w6);
const v128_t v7 = wasm_v128_load64_splat(w7);
v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3);
v67 = wasm_v128_and(v67, vmask);
const v128_t v0 = safe_v128_load64_splat(w0, k);
const v128_t v1 = safe_v128_load64_splat(w1, k);
const v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3);
const v128_t v2 = safe_v128_load64_splat(w2, k);
const v128_t v3 = safe_v128_load64_splat(w3, k);
const v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3);
const v128_t v4 = safe_v128_load64_splat(w4, k);
const v128_t v5 = safe_v128_load64_splat(w5, k);
const v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3);
const v128_t v6 = safe_v128_load64_splat(w6, k);
const v128_t v7 = safe_v128_load64_splat(w7, k);
const v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3);

vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01);
vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23);
Expand All @@ -357,9 +360,6 @@ void xnn_qs8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint);
vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint);

v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123));
wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567));

Expand Down
84 changes: 42 additions & 42 deletions src/qs8-qu8-packw/gen/qs8-qu8-packw-x8c8-gemm-goi-wasmrelaxedsimd.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@

#include "xnnpack/packw.h"

XNN_INLINE static v128_t safe_v128_load64_splat(const void* address, size_t n) {
assert(n >= 1 && n <= sizeof(uint64_t));
const uint8_t* bytes = (const uint8_t*) address;
uint64_t value = (uint64_t) bytes[0];
for (size_t i = 1; i < n; ++i) {
value |= (uint64_t) bytes[i] << (i * 8);
}

return wasm_u64x2_splat(value);
}

void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
size_t g,
Expand Down Expand Up @@ -164,28 +174,26 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
out += 64;
}

// Load ealier to avoid unexpected rescheduling.
v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

// KC remainder 1..KR-1
if (k != 0) {
assert(k >= 1 && k <= 7);

const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8);

const v128_t v0 = wasm_v128_load64_splat(w0);
const v128_t v1 = wasm_v128_load64_splat(w1);
v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3);
v01 = wasm_v128_and(v01, vmask);
const v128_t v2 = wasm_v128_load64_splat(w2);
const v128_t v3 = wasm_v128_load64_splat(w3);
v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3);
v23 = wasm_v128_and(v23, vmask);
const v128_t v4 = wasm_v128_load64_splat(w4);
const v128_t v5 = wasm_v128_load64_splat(w5);
v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3);
v45 = wasm_v128_and(v45, vmask);
const v128_t v6 = wasm_v128_load64_splat(w6);
const v128_t v7 = wasm_v128_load64_splat(w7);
v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3);
v67 = wasm_v128_and(v67, vmask);
const v128_t v0 = safe_v128_load64_splat(w0, k);
const v128_t v1 = safe_v128_load64_splat(w1, k);
const v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3);
const v128_t v2 = safe_v128_load64_splat(w2, k);
const v128_t v3 = safe_v128_load64_splat(w3, k);
const v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3);
const v128_t v4 = safe_v128_load64_splat(w4, k);
const v128_t v5 = safe_v128_load64_splat(w5, k);
const v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3);
const v128_t v6 = safe_v128_load64_splat(w6, k);
const v128_t v7 = safe_v128_load64_splat(w7, k);
const v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3);

vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01);
vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23);
Expand Down Expand Up @@ -214,9 +222,6 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint);
vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint);

v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123));
wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567));

Expand Down Expand Up @@ -315,28 +320,26 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
out += 64;
}

// Load ealier to avoid unexpected rescheduling.
v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

// KC remainder of 1..7
if (k != 0) {
assert(k >= 1 && k <= 7);

const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (8 - k) * sizeof(int8_t) * 8);

const v128_t v0 = wasm_v128_load64_splat(w0);
const v128_t v1 = wasm_v128_load64_splat(w1);
v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3);
v01 = wasm_v128_and(v01, vmask);
const v128_t v2 = wasm_v128_load64_splat(w2);
const v128_t v3 = wasm_v128_load64_splat(w3);
v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3);
v23 = wasm_v128_and(v23, vmask);
const v128_t v4 = wasm_v128_load64_splat(w4);
const v128_t v5 = wasm_v128_load64_splat(w5);
v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3);
v45 = wasm_v128_and(v45, vmask);
const v128_t v6 = wasm_v128_load64_splat(w6);
const v128_t v7 = wasm_v128_load64_splat(w7);
v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3);
v67 = wasm_v128_and(v67, vmask);
const v128_t v0 = safe_v128_load64_splat(w0, k);
const v128_t v1 = safe_v128_load64_splat(w1, k);
const v128_t v01 = wasm_i64x2_shuffle(v0, v1, 0, 3);
const v128_t v2 = safe_v128_load64_splat(w2, k);
const v128_t v3 = safe_v128_load64_splat(w3, k);
const v128_t v23 = wasm_i64x2_shuffle(v2, v3, 0, 3);
const v128_t v4 = safe_v128_load64_splat(w4, k);
const v128_t v5 = safe_v128_load64_splat(w5, k);
const v128_t v45 = wasm_i64x2_shuffle(v4, v5, 0, 3);
const v128_t v6 = safe_v128_load64_splat(w6, k);
const v128_t v7 = safe_v128_load64_splat(w7, k);
const v128_t v67 = wasm_i64x2_shuffle(v6, v7, 0, 3);

vacc01 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v01, vone, vacc01);
vacc23 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v23, vone, vacc23);
Expand All @@ -357,9 +360,6 @@ void xnn_qs8_to_qu8_packw_gemm_goi_ukernel_x8c8__wasmrelaxedsimd(
vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint);
vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint);

v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123));
wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567));

Expand Down
42 changes: 24 additions & 18 deletions src/x8-packw/kr-wasmdot.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ $assert IZP in [0, 128]

#include "xnnpack/packw.h"

XNN_INLINE static v128_t safe_v128_load64_splat(const void* address, size_t n) {
assert(n >= 1 && n <= sizeof(uint64_t));
const uint8_t* bytes = (const uint8_t*) address;
uint64_t value = (uint64_t) bytes[0];
for (size_t i = 1; i < n; ++i) {
value |= (uint64_t) bytes[i] << (i * 8);
}

return wasm_u64x2_splat(value);
}

$ABC = "012345678"
$BTYPE = {"int8_t": "uint32_t"}[TYPE]
Expand Down Expand Up @@ -115,17 +125,18 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K
out += ${NR*KR};
}

// Load ealier to avoid unexpected rescheduling.
v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

// KC remainder 1..KR-1
if (k != 0) {
assert(k >= 1 && k <= ${KR-1});

const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (${KR} - k) * sizeof(${WTYPE}) * 8);

$for N in range(0, NR, 2):
const v128_t v${N} = wasm_v128_load64_splat(w${N});
const v128_t v${N+1} = wasm_v128_load64_splat(w${N+1});
v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3);
v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, vmask);
const v128_t v${N} = safe_v128_load64_splat(w${N}, k);
const v128_t v${N+1} = safe_v128_load64_splat(w${N+1}, k);
const v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3);

$for N in range(0, NR, 2):
vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]});
Expand All @@ -144,9 +155,6 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K
vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint);
vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint);

v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123));
wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567));

Expand Down Expand Up @@ -207,17 +215,18 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K
out += ${NR*KR};
}

// Load ealier to avoid unexpected rescheduling.
v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

// KC remainder of 1..${KR-1}
if (k != 0) {
assert(k >= 1 && k <= ${KR-1});

const v128_t vmask = wasm_u64x2_shr(wasm_i32x4_splat(-1), (${KR} - k) * sizeof(${WTYPE}) * 8);

$for N in range(0, NR, 2):
const v128_t v${N} = wasm_v128_load64_splat(w${N});
const v128_t v${N+1} = wasm_v128_load64_splat(w${N+1});
v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3);
v${ABC[N:N+2]} = wasm_v128_and(v${ABC[N:N+2]}, vmask);
const v128_t v${N} = safe_v128_load64_splat(w${N}, k);
const v128_t v${N+1} = safe_v128_load64_splat(w${N+1}, k);
const v128_t v${ABC[N:N+2]} = wasm_i64x2_shuffle(v${N}, v${N+1}, 0, 3);

$for N in range(0, NR, 2):
vacc${ABC[N:N+2]} = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(v${ABC[N:N+2]}, vone, vacc${ABC[N:N+2]});
Expand All @@ -234,9 +243,6 @@ void xnn_qs8${"_to_qu8" if IZP == 128 else ""}_packw_gemm_goi_ukernel_x${NR}c${K
vksum0123 = wasm_i32x4_mul(vksum0123, vzeropoint);
vksum4567 = wasm_i32x4_mul(vksum4567, vzeropoint);

v128_t vpack0123 = wasm_v128_load(packed_b);
v128_t vpack4567 = wasm_v128_load(packed_b + 4);

wasm_v128_store(packed_b, wasm_i32x4_sub(vpack0123, vksum0123));
wasm_v128_store(packed_b + 4, wasm_i32x4_sub(vpack4567, vksum4567));

Expand Down

0 comments on commit 1a9d2fb

Please sign in to comment.