From aecc95c0cabc6604642e7bc4a8c9e5cb5233ebc4 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 11 Dec 2024 18:55:21 +0100 Subject: [PATCH] Fix AVX2 implementation of iq4_nl_r4 (#137) Co-authored-by: Iwan Kawrakow --- ggml/src/iqk/iqk_mul_mat.cpp | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 4316373a00ad9..f0e9d61d47d5f 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2405,7 +2405,8 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data Q8 q8(info); auto m4 = _mm256_set1_epi8(0xf); auto m1 = _mm256_set1_epi16(1); - auto values = load_iq4nl_values_256(); + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); int nb = n / QK4_NL; GGML_ASSERT(nb%4 == 0); __m256 acc[nrc_y] = {}; @@ -2416,32 +2417,29 @@ static void mul_mat_iq4_nl_r4_q8_1(int n, const void * vx, size_t bx, const Data for (int k = 0; k < 4; ++k) { auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4[4*ib4+k].d)); auto scales = _mm256_set_m128(scales128, scales128); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-64.f)); auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+0); auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[4*ib4+k].qs+1); auto q1 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); auto q2 = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); auto q3 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); auto q4 = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); + auto s1 = _mm256_sign_epi8(q1, q1); + auto s2 = _mm256_sign_epi8(q2, q2); + auto s3 = _mm256_sign_epi8(q3, q3); + auto s4 = _mm256_sign_epi8(q4, q4); + for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k); - auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q1, _mm256_shuffle_epi32(y, 0x00))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q2, _mm256_shuffle_epi32(y, 0x55)))); - auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(q3, _mm256_shuffle_epi32(y, 0xaa))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(q4, _mm256_shuffle_epi32(y, 0xff)))); + auto sumi1 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), q1))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), q2)))); + auto sumi2 = _mm256_add_epi32(_mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), q3))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), q4)))); auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[iy]); - //acc[2*iy+0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), acc[2*iy+0]); - //acc[2*iy+1] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k+4])), acc[2*iy+1]); } } } for (int iy = 0; iy < nrc_y; ++iy) { - //auto sum256 = _mm256_add_ps(acc[2*iy+0], acc[2*iy+1]); - //acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_ps(); - //auto sum = _mm_add_ps(_mm256_castps256_ps128(sum256), _mm256_extractf128_ps(sum256, 1)); - //info.store(ix, iy, sum); auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); info.store(ix, iy, sum); acc[iy] = _mm256_setzero_ps();