Skip to content

Commit

Permalink
Fix bug in AVX512 f16-gemm
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708328572
  • Loading branch information
dsharletg authored and xnnpack-bot committed Dec 20, 2024
1 parent 08185b7 commit 7bdbcd6
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 33 deletions.
4 changes: 2 additions & 2 deletions src/f16-gemm/avx512fp16-broadcast.c.in
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${NR}__avx512fp16_broadcast(
$for M in RANGE_MR:
_mm512_storeu_ph(c${M}, vacc${M}x0);
$for N in range(32, NR, 32):
_mm512_storeu_ph((uint16_t*) c${M} + ${N//32}, vacc${M}x${N//32});
_mm512_storeu_ph((uint16_t*) c${M} + ${N}, vacc${M}x${N//32});
a${M} = (const uint16_t*) ((uintptr_t) a${M} - kc);
c${M} = (uint16_t*) ((uintptr_t) c${M} + cn_stride);

Expand All @@ -113,7 +113,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${NR}__avx512fp16_broadcast(
$for M in RANGE_MR:
_mm512_storeu_ph(c${M}, vacc${M}x0);
$for N in range(32, 1 << LOG2N, 32):
_mm512_storeu_ph((uint16_t*) c${M} + ${N//32}, vacc${M}x${N//32});
_mm512_storeu_ph((uint16_t*) c${M} + ${N}, vacc${M}x${N//32});

$for M in RANGE_MR:
$for N in range(0, NR - (1 << LOG2N), 32):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void xnn_f16_gemm_minmax_ukernel_1x64__avx512fp16_broadcast(

if XNN_LIKELY(nc >= 64) {
_mm512_storeu_ph(c0, vacc0x0);
_mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1);
_mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1);
a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);

Expand Down
8 changes: 4 additions & 4 deletions src/f16-gemm/gen/f16-gemm-4x64-minmax-avx512fp16-broadcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,19 @@ void xnn_f16_gemm_minmax_ukernel_4x64__avx512fp16_broadcast(

if XNN_LIKELY(nc >= 64) {
_mm512_storeu_ph(c0, vacc0x0);
_mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1);
_mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1);
a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
_mm512_storeu_ph(c1, vacc1x0);
_mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1);
_mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1);
a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
_mm512_storeu_ph(c2, vacc2x0);
_mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1);
_mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1);
a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
_mm512_storeu_ph(c3, vacc3x0);
_mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1);
_mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1);
a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);

Expand Down
10 changes: 5 additions & 5 deletions src/f16-gemm/gen/f16-gemm-5x64-minmax-avx512fp16-broadcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,23 @@ void xnn_f16_gemm_minmax_ukernel_5x64__avx512fp16_broadcast(

if XNN_LIKELY(nc >= 64) {
_mm512_storeu_ph(c0, vacc0x0);
_mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1);
_mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1);
a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
_mm512_storeu_ph(c1, vacc1x0);
_mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1);
_mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1);
a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
_mm512_storeu_ph(c2, vacc2x0);
_mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1);
_mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1);
a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
_mm512_storeu_ph(c3, vacc3x0);
_mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1);
_mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1);
a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
_mm512_storeu_ph(c4, vacc4x0);
_mm512_storeu_ph((uint16_t*) c4 + 1, vacc4x1);
_mm512_storeu_ph((uint16_t*) c4 + 32, vacc4x1);
a4 = (const uint16_t*) ((uintptr_t) a4 - kc);
c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride);

Expand Down
12 changes: 6 additions & 6 deletions src/f16-gemm/gen/f16-gemm-6x64-minmax-avx512fp16-broadcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -149,27 +149,27 @@ void xnn_f16_gemm_minmax_ukernel_6x64__avx512fp16_broadcast(

if XNN_LIKELY(nc >= 64) {
_mm512_storeu_ph(c0, vacc0x0);
_mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1);
_mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1);
a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
_mm512_storeu_ph(c1, vacc1x0);
_mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1);
_mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1);
a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
_mm512_storeu_ph(c2, vacc2x0);
_mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1);
_mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1);
a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
_mm512_storeu_ph(c3, vacc3x0);
_mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1);
_mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1);
a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
_mm512_storeu_ph(c4, vacc4x0);
_mm512_storeu_ph((uint16_t*) c4 + 1, vacc4x1);
_mm512_storeu_ph((uint16_t*) c4 + 32, vacc4x1);
a4 = (const uint16_t*) ((uintptr_t) a4 - kc);
c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride);
_mm512_storeu_ph(c5, vacc5x0);
_mm512_storeu_ph((uint16_t*) c5 + 1, vacc5x1);
_mm512_storeu_ph((uint16_t*) c5 + 32, vacc5x1);
a5 = (const uint16_t*) ((uintptr_t) a5 - kc);
c5 = (uint16_t*) ((uintptr_t) c5 + cn_stride);

Expand Down
14 changes: 7 additions & 7 deletions src/f16-gemm/gen/f16-gemm-7x64-minmax-avx512fp16-broadcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -165,31 +165,31 @@ void xnn_f16_gemm_minmax_ukernel_7x64__avx512fp16_broadcast(

if XNN_LIKELY(nc >= 64) {
_mm512_storeu_ph(c0, vacc0x0);
_mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1);
_mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1);
a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
_mm512_storeu_ph(c1, vacc1x0);
_mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1);
_mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1);
a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
_mm512_storeu_ph(c2, vacc2x0);
_mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1);
_mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1);
a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
_mm512_storeu_ph(c3, vacc3x0);
_mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1);
_mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1);
a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
_mm512_storeu_ph(c4, vacc4x0);
_mm512_storeu_ph((uint16_t*) c4 + 1, vacc4x1);
_mm512_storeu_ph((uint16_t*) c4 + 32, vacc4x1);
a4 = (const uint16_t*) ((uintptr_t) a4 - kc);
c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride);
_mm512_storeu_ph(c5, vacc5x0);
_mm512_storeu_ph((uint16_t*) c5 + 1, vacc5x1);
_mm512_storeu_ph((uint16_t*) c5 + 32, vacc5x1);
a5 = (const uint16_t*) ((uintptr_t) a5 - kc);
c5 = (uint16_t*) ((uintptr_t) c5 + cn_stride);
_mm512_storeu_ph(c6, vacc6x0);
_mm512_storeu_ph((uint16_t*) c6 + 1, vacc6x1);
_mm512_storeu_ph((uint16_t*) c6 + 32, vacc6x1);
a6 = (const uint16_t*) ((uintptr_t) a6 - kc);
c6 = (uint16_t*) ((uintptr_t) c6 + cn_stride);

Expand Down
16 changes: 8 additions & 8 deletions src/f16-gemm/gen/f16-gemm-8x64-minmax-avx512fp16-broadcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -181,35 +181,35 @@ void xnn_f16_gemm_minmax_ukernel_8x64__avx512fp16_broadcast(

if XNN_LIKELY(nc >= 64) {
_mm512_storeu_ph(c0, vacc0x0);
_mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1);
_mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1);
a0 = (const uint16_t*) ((uintptr_t) a0 - kc);
c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride);
_mm512_storeu_ph(c1, vacc1x0);
_mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1);
_mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1);
a1 = (const uint16_t*) ((uintptr_t) a1 - kc);
c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride);
_mm512_storeu_ph(c2, vacc2x0);
_mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1);
_mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1);
a2 = (const uint16_t*) ((uintptr_t) a2 - kc);
c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride);
_mm512_storeu_ph(c3, vacc3x0);
_mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1);
_mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1);
a3 = (const uint16_t*) ((uintptr_t) a3 - kc);
c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride);
_mm512_storeu_ph(c4, vacc4x0);
_mm512_storeu_ph((uint16_t*) c4 + 1, vacc4x1);
_mm512_storeu_ph((uint16_t*) c4 + 32, vacc4x1);
a4 = (const uint16_t*) ((uintptr_t) a4 - kc);
c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride);
_mm512_storeu_ph(c5, vacc5x0);
_mm512_storeu_ph((uint16_t*) c5 + 1, vacc5x1);
_mm512_storeu_ph((uint16_t*) c5 + 32, vacc5x1);
a5 = (const uint16_t*) ((uintptr_t) a5 - kc);
c5 = (uint16_t*) ((uintptr_t) c5 + cn_stride);
_mm512_storeu_ph(c6, vacc6x0);
_mm512_storeu_ph((uint16_t*) c6 + 1, vacc6x1);
_mm512_storeu_ph((uint16_t*) c6 + 32, vacc6x1);
a6 = (const uint16_t*) ((uintptr_t) a6 - kc);
c6 = (uint16_t*) ((uintptr_t) c6 + cn_stride);
_mm512_storeu_ph(c7, vacc7x0);
_mm512_storeu_ph((uint16_t*) c7 + 1, vacc7x1);
_mm512_storeu_ph((uint16_t*) c7 + 32, vacc7x1);
a7 = (const uint16_t*) ((uintptr_t) a7 - kc);
c7 = (uint16_t*) ((uintptr_t) c7 + cn_stride);

Expand Down

0 comments on commit 7bdbcd6

Please sign in to comment.