Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IQ4_NL sgemm + Q4_0 AVX optimization #9422

Merged
merged 4 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 33 additions & 36 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,12 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )

return _mm_packus_epi16( bytes1, bytes2);
}

static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
const __m128i ax = _mm_sign_epi8(x, x);
const __m128i sy = _mm_sign_epi8(y, x);
return _mm_maddubs_epi16(ax, sy);
}
#endif
#elif defined(__SSSE3__)
// horizontally add 4x4 floats
Expand Down Expand Up @@ -4206,37 +4212,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r

sumf = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

// Main loop
for (; ib < nb; ++ib) {
// Compute combined scale for the block
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );

const __m128i lowMask = _mm_set1_epi8(0xF);
const __m128i off = _mm_set1_epi8(8);

const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);

__m128i bx_0 = _mm_and_si128(lowMask, tmp);
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);

bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
const __m128i mone = _mm_set1_epi16(1);

// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);

// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
}

sumf = hsum_float_8(acc);
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined(__SSSE3__)
// set constants
const __m128i lowMask = _mm_set1_epi8(0xF);
Expand Down Expand Up @@ -11819,15 +11825,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
#endif
}


#if defined(__AVX__)
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
const __m128i ax = _mm_sign_epi8(x, x);
const __m128i sy = _mm_sign_epi8(y, x);
return _mm_maddubs_epi16(ax, sy);
}
#endif

#if defined(__AVX2__)
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);
Expand Down
38 changes: 38 additions & 0 deletions ggml/src/llamafile/sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
}
#endif // __AVX512F__

////////////////////////////////////////////////////////////////////////////////////////////////////
// CONSTANTS

#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION

Expand Down Expand Up @@ -933,6 +941,20 @@ class tinyBLAS_Q0_AVX {
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
}

inline __m256i load(const block_iq4_nl *b) {
return MM256_SET_M128I(load1(b), load0(b));
}

inline __m128i load0(const block_iq4_nl *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
}

inline __m128i load1(const block_iq4_nl *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
}

inline __m256 updot(__m256i u, __m256i s) {
__m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
Expand Down Expand Up @@ -1159,6 +1181,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
#endif
}

case GGML_TYPE_IQ4_NL: {
if (Btype != GGML_TYPE_Q8_0)
return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
k, (const block_iq4_nl *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n);
return true;
#else
return false;
#endif
}

default:
return false;
}
Expand Down
Loading