Skip to content

Commit

Permalink
Add additional comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Srihari-mcw committed May 23, 2024
1 parent d0af2a1 commit b9a5d91
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
2 changes: 2 additions & 0 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -4876,6 +4876,7 @@ void ggml_vec_dot_q4_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const
__m128bh xd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, x_delta)));
__m128bh yd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, y_delta)));

// Computes product of delta values from four corresponding blocks
__m256 d = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, xd, yd));
d = _mm256_permute2f128_ps(d ,d, 0);

Expand Down Expand Up @@ -6407,6 +6408,7 @@ void ggml_vec_dot_q8_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const
__m128bh xd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, x_delta)));
__m128bh yd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, y_delta)));

// Computes product of delta values from four corresponding blocks
__m256 d = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, xd, yd));
d = _mm256_permute2f128_ps(d ,d, 0);

Expand Down
5 changes: 4 additions & 1 deletion sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ class tinyBLAS_Q0_B16_AVX {
}

#if defined(__AVX512BF16__)
// Templated functions for gemm of dimesnions 4xN
template <int RN>
NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / 4;
Expand All @@ -1006,6 +1007,7 @@ class tinyBLAS_Q0_B16_AVX {
__m256i avec3 = load(A + lda * (ii + 3) + l);
for (int64_t j = 0; j < RN; ++j) {
__m128bh db = m128bh(_mm_set1_epi16(B[ldb * (jj + j) + l].d));
// Computation of product of delta values for four blocks
__m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db));
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
Expand Down Expand Up @@ -1057,7 +1059,8 @@ class tinyBLAS_Q0_B16_AVX {
__m256i bvec3 = load(B + ldb * (jj + 3) + l);
for (int64_t i = 0; i < RM; ++i) {
__m128bh da = m128bh(_mm_set1_epi16((A[lda * (ii + i) + l].d)));
__m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db));
// Computation of product of delta values for four blocks
__m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db));
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
Expand Down

0 comments on commit b9a5d91

Please sign in to comment.