From 7bdbcd67fbb6fabe180b2df43a65ff0cd96eb46e Mon Sep 17 00:00:00 2001 From: Dillon Sharlet Date: Fri, 20 Dec 2024 08:19:18 -0800 Subject: [PATCH] Fix bug in AVX512 f16-gemm PiperOrigin-RevId: 708328572 --- src/f16-gemm/avx512fp16-broadcast.c.in | 4 ++-- .../f16-gemm-1x64-minmax-avx512fp16-broadcast.c | 2 +- .../f16-gemm-4x64-minmax-avx512fp16-broadcast.c | 8 ++++---- .../f16-gemm-5x64-minmax-avx512fp16-broadcast.c | 10 +++++----- .../f16-gemm-6x64-minmax-avx512fp16-broadcast.c | 12 ++++++------ .../f16-gemm-7x64-minmax-avx512fp16-broadcast.c | 14 +++++++------- .../f16-gemm-8x64-minmax-avx512fp16-broadcast.c | 16 ++++++++-------- 7 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/f16-gemm/avx512fp16-broadcast.c.in b/src/f16-gemm/avx512fp16-broadcast.c.in index ca13573f66c..3f2e318510f 100644 --- a/src/f16-gemm/avx512fp16-broadcast.c.in +++ b/src/f16-gemm/avx512fp16-broadcast.c.in @@ -100,7 +100,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${NR}__avx512fp16_broadcast( $for M in RANGE_MR: _mm512_storeu_ph(c${M}, vacc${M}x0); $for N in range(32, NR, 32): - _mm512_storeu_ph((uint16_t*) c${M} + ${N//32}, vacc${M}x${N//32}); + _mm512_storeu_ph((uint16_t*) c${M} + ${N}, vacc${M}x${N//32}); a${M} = (const uint16_t*) ((uintptr_t) a${M} - kc); c${M} = (uint16_t*) ((uintptr_t) c${M} + cn_stride); @@ -113,7 +113,7 @@ void xnn_${DATATYPE_SPEC}_gemm_minmax_ukernel_${MR}x${NR}__avx512fp16_broadcast( $for M in RANGE_MR: _mm512_storeu_ph(c${M}, vacc${M}x0); $for N in range(32, 1 << LOG2N, 32): - _mm512_storeu_ph((uint16_t*) c${M} + ${N//32}, vacc${M}x${N//32}); + _mm512_storeu_ph((uint16_t*) c${M} + ${N}, vacc${M}x${N//32}); $for M in RANGE_MR: $for N in range(0, NR - (1 << LOG2N), 32): diff --git a/src/f16-gemm/gen/f16-gemm-1x64-minmax-avx512fp16-broadcast.c b/src/f16-gemm/gen/f16-gemm-1x64-minmax-avx512fp16-broadcast.c index 5c98ffe6a71..ec1fa91993b 100644 --- a/src/f16-gemm/gen/f16-gemm-1x64-minmax-avx512fp16-broadcast.c +++ b/src/f16-gemm/gen/f16-gemm-1x64-minmax-avx512fp16-broadcast.c @@ -69,7 +69,7 @@ void xnn_f16_gemm_minmax_ukernel_1x64__avx512fp16_broadcast( if XNN_LIKELY(nc >= 64) { _mm512_storeu_ph(c0, vacc0x0); - _mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1); + _mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1); a0 = (const uint16_t*) ((uintptr_t) a0 - kc); c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); diff --git a/src/f16-gemm/gen/f16-gemm-4x64-minmax-avx512fp16-broadcast.c b/src/f16-gemm/gen/f16-gemm-4x64-minmax-avx512fp16-broadcast.c index 7c9c50cdb48..ef91d3d93fa 100644 --- a/src/f16-gemm/gen/f16-gemm-4x64-minmax-avx512fp16-broadcast.c +++ b/src/f16-gemm/gen/f16-gemm-4x64-minmax-avx512fp16-broadcast.c @@ -117,19 +117,19 @@ void xnn_f16_gemm_minmax_ukernel_4x64__avx512fp16_broadcast( if XNN_LIKELY(nc >= 64) { _mm512_storeu_ph(c0, vacc0x0); - _mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1); + _mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1); a0 = (const uint16_t*) ((uintptr_t) a0 - kc); c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); _mm512_storeu_ph(c1, vacc1x0); - _mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1); + _mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1); a1 = (const uint16_t*) ((uintptr_t) a1 - kc); c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); _mm512_storeu_ph(c2, vacc2x0); - _mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1); + _mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1); a2 = (const uint16_t*) ((uintptr_t) a2 - kc); c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); _mm512_storeu_ph(c3, vacc3x0); - _mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1); + _mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1); a3 = (const uint16_t*) ((uintptr_t) a3 - kc); c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); diff --git a/src/f16-gemm/gen/f16-gemm-5x64-minmax-avx512fp16-broadcast.c b/src/f16-gemm/gen/f16-gemm-5x64-minmax-avx512fp16-broadcast.c index f4903024f23..c97402bdb19 100644 --- a/src/f16-gemm/gen/f16-gemm-5x64-minmax-avx512fp16-broadcast.c +++ b/src/f16-gemm/gen/f16-gemm-5x64-minmax-avx512fp16-broadcast.c @@ -133,23 +133,23 @@ void xnn_f16_gemm_minmax_ukernel_5x64__avx512fp16_broadcast( if XNN_LIKELY(nc >= 64) { _mm512_storeu_ph(c0, vacc0x0); - _mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1); + _mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1); a0 = (const uint16_t*) ((uintptr_t) a0 - kc); c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); _mm512_storeu_ph(c1, vacc1x0); - _mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1); + _mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1); a1 = (const uint16_t*) ((uintptr_t) a1 - kc); c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); _mm512_storeu_ph(c2, vacc2x0); - _mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1); + _mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1); a2 = (const uint16_t*) ((uintptr_t) a2 - kc); c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); _mm512_storeu_ph(c3, vacc3x0); - _mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1); + _mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1); a3 = (const uint16_t*) ((uintptr_t) a3 - kc); c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); _mm512_storeu_ph(c4, vacc4x0); - _mm512_storeu_ph((uint16_t*) c4 + 1, vacc4x1); + _mm512_storeu_ph((uint16_t*) c4 + 32, vacc4x1); a4 = (const uint16_t*) ((uintptr_t) a4 - kc); c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); diff --git a/src/f16-gemm/gen/f16-gemm-6x64-minmax-avx512fp16-broadcast.c b/src/f16-gemm/gen/f16-gemm-6x64-minmax-avx512fp16-broadcast.c index 261f9c78c13..bbd9172fec7 100644 --- a/src/f16-gemm/gen/f16-gemm-6x64-minmax-avx512fp16-broadcast.c +++ b/src/f16-gemm/gen/f16-gemm-6x64-minmax-avx512fp16-broadcast.c @@ -149,27 +149,27 @@ void xnn_f16_gemm_minmax_ukernel_6x64__avx512fp16_broadcast( if XNN_LIKELY(nc >= 64) { _mm512_storeu_ph(c0, vacc0x0); - _mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1); + _mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1); a0 = (const uint16_t*) ((uintptr_t) a0 - kc); c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); _mm512_storeu_ph(c1, vacc1x0); - _mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1); + _mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1); a1 = (const uint16_t*) ((uintptr_t) a1 - kc); c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); _mm512_storeu_ph(c2, vacc2x0); - _mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1); + _mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1); a2 = (const uint16_t*) ((uintptr_t) a2 - kc); c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); _mm512_storeu_ph(c3, vacc3x0); - _mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1); + _mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1); a3 = (const uint16_t*) ((uintptr_t) a3 - kc); c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); _mm512_storeu_ph(c4, vacc4x0); - _mm512_storeu_ph((uint16_t*) c4 + 1, vacc4x1); + _mm512_storeu_ph((uint16_t*) c4 + 32, vacc4x1); a4 = (const uint16_t*) ((uintptr_t) a4 - kc); c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); _mm512_storeu_ph(c5, vacc5x0); - _mm512_storeu_ph((uint16_t*) c5 + 1, vacc5x1); + _mm512_storeu_ph((uint16_t*) c5 + 32, vacc5x1); a5 = (const uint16_t*) ((uintptr_t) a5 - kc); c5 = (uint16_t*) ((uintptr_t) c5 + cn_stride); diff --git a/src/f16-gemm/gen/f16-gemm-7x64-minmax-avx512fp16-broadcast.c b/src/f16-gemm/gen/f16-gemm-7x64-minmax-avx512fp16-broadcast.c index c3e2195bc10..f5933a4afc7 100644 --- a/src/f16-gemm/gen/f16-gemm-7x64-minmax-avx512fp16-broadcast.c +++ b/src/f16-gemm/gen/f16-gemm-7x64-minmax-avx512fp16-broadcast.c @@ -165,31 +165,31 @@ void xnn_f16_gemm_minmax_ukernel_7x64__avx512fp16_broadcast( if XNN_LIKELY(nc >= 64) { _mm512_storeu_ph(c0, vacc0x0); - _mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1); + _mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1); a0 = (const uint16_t*) ((uintptr_t) a0 - kc); c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); _mm512_storeu_ph(c1, vacc1x0); - _mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1); + _mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1); a1 = (const uint16_t*) ((uintptr_t) a1 - kc); c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); _mm512_storeu_ph(c2, vacc2x0); - _mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1); + _mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1); a2 = (const uint16_t*) ((uintptr_t) a2 - kc); c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); _mm512_storeu_ph(c3, vacc3x0); - _mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1); + _mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1); a3 = (const uint16_t*) ((uintptr_t) a3 - kc); c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); _mm512_storeu_ph(c4, vacc4x0); - _mm512_storeu_ph((uint16_t*) c4 + 1, vacc4x1); + _mm512_storeu_ph((uint16_t*) c4 + 32, vacc4x1); a4 = (const uint16_t*) ((uintptr_t) a4 - kc); c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); _mm512_storeu_ph(c5, vacc5x0); - _mm512_storeu_ph((uint16_t*) c5 + 1, vacc5x1); + _mm512_storeu_ph((uint16_t*) c5 + 32, vacc5x1); a5 = (const uint16_t*) ((uintptr_t) a5 - kc); c5 = (uint16_t*) ((uintptr_t) c5 + cn_stride); _mm512_storeu_ph(c6, vacc6x0); - _mm512_storeu_ph((uint16_t*) c6 + 1, vacc6x1); + _mm512_storeu_ph((uint16_t*) c6 + 32, vacc6x1); a6 = (const uint16_t*) ((uintptr_t) a6 - kc); c6 = (uint16_t*) ((uintptr_t) c6 + cn_stride); diff --git a/src/f16-gemm/gen/f16-gemm-8x64-minmax-avx512fp16-broadcast.c b/src/f16-gemm/gen/f16-gemm-8x64-minmax-avx512fp16-broadcast.c index 57b4c8b7111..352a224bdb9 100644 --- a/src/f16-gemm/gen/f16-gemm-8x64-minmax-avx512fp16-broadcast.c +++ b/src/f16-gemm/gen/f16-gemm-8x64-minmax-avx512fp16-broadcast.c @@ -181,35 +181,35 @@ void xnn_f16_gemm_minmax_ukernel_8x64__avx512fp16_broadcast( if XNN_LIKELY(nc >= 64) { _mm512_storeu_ph(c0, vacc0x0); - _mm512_storeu_ph((uint16_t*) c0 + 1, vacc0x1); + _mm512_storeu_ph((uint16_t*) c0 + 32, vacc0x1); a0 = (const uint16_t*) ((uintptr_t) a0 - kc); c0 = (uint16_t*) ((uintptr_t) c0 + cn_stride); _mm512_storeu_ph(c1, vacc1x0); - _mm512_storeu_ph((uint16_t*) c1 + 1, vacc1x1); + _mm512_storeu_ph((uint16_t*) c1 + 32, vacc1x1); a1 = (const uint16_t*) ((uintptr_t) a1 - kc); c1 = (uint16_t*) ((uintptr_t) c1 + cn_stride); _mm512_storeu_ph(c2, vacc2x0); - _mm512_storeu_ph((uint16_t*) c2 + 1, vacc2x1); + _mm512_storeu_ph((uint16_t*) c2 + 32, vacc2x1); a2 = (const uint16_t*) ((uintptr_t) a2 - kc); c2 = (uint16_t*) ((uintptr_t) c2 + cn_stride); _mm512_storeu_ph(c3, vacc3x0); - _mm512_storeu_ph((uint16_t*) c3 + 1, vacc3x1); + _mm512_storeu_ph((uint16_t*) c3 + 32, vacc3x1); a3 = (const uint16_t*) ((uintptr_t) a3 - kc); c3 = (uint16_t*) ((uintptr_t) c3 + cn_stride); _mm512_storeu_ph(c4, vacc4x0); - _mm512_storeu_ph((uint16_t*) c4 + 1, vacc4x1); + _mm512_storeu_ph((uint16_t*) c4 + 32, vacc4x1); a4 = (const uint16_t*) ((uintptr_t) a4 - kc); c4 = (uint16_t*) ((uintptr_t) c4 + cn_stride); _mm512_storeu_ph(c5, vacc5x0); - _mm512_storeu_ph((uint16_t*) c5 + 1, vacc5x1); + _mm512_storeu_ph((uint16_t*) c5 + 32, vacc5x1); a5 = (const uint16_t*) ((uintptr_t) a5 - kc); c5 = (uint16_t*) ((uintptr_t) c5 + cn_stride); _mm512_storeu_ph(c6, vacc6x0); - _mm512_storeu_ph((uint16_t*) c6 + 1, vacc6x1); + _mm512_storeu_ph((uint16_t*) c6 + 32, vacc6x1); a6 = (const uint16_t*) ((uintptr_t) a6 - kc); c6 = (uint16_t*) ((uintptr_t) c6 + cn_stride); _mm512_storeu_ph(c7, vacc7x0); - _mm512_storeu_ph((uint16_t*) c7 + 1, vacc7x1); + _mm512_storeu_ph((uint16_t*) c7 + 32, vacc7x1); a7 = (const uint16_t*) ((uintptr_t) a7 - kc); c7 = (uint16_t*) ((uintptr_t) c7 + cn_stride);