diff --git a/llama.cpp/ggml-common.h b/llama.cpp/ggml-common.h index 43c7978a09..b7a17ccf97 100644 --- a/llama.cpp/ggml-common.h +++ b/llama.cpp/ggml-common.h @@ -203,6 +203,18 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); +//[kawrakow] Need these two for performance on Arm +typedef struct { + ggml_half d[8]; + int8_t qs[4*QK8_1]; +} block_q8_1_x4; +static_assert(sizeof(block_q8_1_x4) == 4*sizeof(block_q8_1), "wrong q8_1_x4 block size/padding"); +typedef struct { + ggml_half d[4]; + int8_t qs[4*QK8_0]; +} block_q8_0_x4; +static_assert(sizeof(block_q8_0_x4) == 4*sizeof(block_q8_0), "wrong q8_0_x4 block size/padding"); + // // Super-block quantization structures // @@ -313,10 +325,11 @@ typedef struct { static_assert(sizeof(block_q6_K) == sizeof(ggml_half) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); // This is only used for intermediate quantization and dot products +// [kawrakow] Note: I have switched the order of bsums and qs. This results in some performance gain on Arm typedef struct { float d; // delta - int8_t qs[QK_K]; // quants int16_t bsums[QK_K/16]; // sum of quants in groups of 16 + int8_t qs[QK_K]; // quants } block_q8_K; static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); diff --git a/llama.cpp/ggml-quants.inc b/llama.cpp/ggml-quants.inc index e5a86a7358..4d8be61643 100644 --- a/llama.cpp/ggml-quants.inc +++ b/llama.cpp/ggml-quants.inc @@ -873,7 +873,11 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) block_q8_0 * restrict y = vy; #if defined(__ARM_NEON) + // [kawrakow] When running on Arm, we change how the data is layed out for performance reasons + block_q8_0_x4 * y4 = (block_q8_0_x4 *)vy; + int nb4 = 4*(nb/4); for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; float32x4_t srcv [8]; float32x4_t asrcv[8]; float32x4_t amaxv[8]; @@ -890,16 +894,29 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = GGML_FP32_TO_FP16(d); + // [kawrakow] When running on Arm, we change how the data is layed out for performance reasons + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } for (int j = 0; j < 8; j++) { const float32x4_t v = vmulq_n_f32(srcv[j], id); const int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + // [kawrakow] When running on Arm, we change how the data is layed out for performance reasons + if (i < nb4) { + y4[i4].qs[32*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[32*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[32*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[32*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } } } #elif defined(__wasm_simd128__) @@ -1192,7 +1209,11 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) block_q8_1 * restrict y = vy; #if defined(__ARM_NEON) + // [kawrakow] When running on Arm, we change how the data is layed out for performance reasons + block_q8_1_x4 * restrict y4 = vy; + int nb4 = 4*(nb/4); for (int i = 0; i < nb; i++) { + int i4 = i/4, ir = i%4; float32x4_t srcv [8]; float32x4_t asrcv[8]; float32x4_t amaxv[8]; @@ -1209,7 +1230,12 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = GGML_FP32_TO_FP16(d); + // [kawrakow] When running on Arm, we change how the data is layed out for performance reasons + if (i < nb4) { + y4[i4].d[ir] = GGML_FP32_TO_FP16(d); + } else { + y[i].d = GGML_FP32_TO_FP16(d); + } int32x4_t accv = vdupq_n_s32(0); @@ -1217,15 +1243,28 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) const float32x4_t v = vmulq_n_f32(srcv[j], id); const int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + // [kawrakow] When running on Arm, we change how the data is layed out for performance reasons + if (i < nb4) { + y4[i4].qs[QK8_1*ir + 4*j + 0] = vgetq_lane_s32(vi, 0); + y4[i4].qs[QK8_1*ir + 4*j + 1] = vgetq_lane_s32(vi, 1); + y4[i4].qs[QK8_1*ir + 4*j + 2] = vgetq_lane_s32(vi, 2); + y4[i4].qs[QK8_1*ir + 4*j + 3] = vgetq_lane_s32(vi, 3); + } else { + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } accv = vaddq_s32(accv, vi); } - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + // [kawrakow] When running on Arm, we change how the data is layed out for performance reasons + if (i < nb4) { + y4[i4].d[ir+4] = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } else { + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } } #elif defined(__wasm_simd128__) for (int i = 0; i < nb; i++) { diff --git a/llama.cpp/quantize/quantize.cpp b/llama.cpp/quantize/quantize.cpp index 8659a94247..fbefc07961 100644 --- a/llama.cpp/quantize/quantize.cpp +++ b/llama.cpp/quantize/quantize.cpp @@ -65,10 +65,12 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count"; static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) { - std::string ftype_str; + std::string ftype_str; ftype_str.reserve(ftype_str_in.size()); + bool is_number = true; for (auto ch : ftype_str_in) { ftype_str.push_back(std::toupper(ch)); + if (!std::isdigit(ftype_str.back())) is_number = false; } for (auto & it : QUANT_OPTIONS) { if (it.name == ftype_str) { @@ -77,6 +79,9 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp return true; } } + // On my system (OS Ventura 13.2.1) calling std::stoi with invalid input leads to a crash (Segmentation fault 11) + // Hence the check above and the early return + if (!is_number) return false; try { int ftype_int = std::stoi(ftype_str); for (auto & it : QUANT_OPTIONS) { diff --git a/llamafile/BUILD.mk b/llamafile/BUILD.mk index 95ee2ec1ca..30f4eab040 100644 --- a/llamafile/BUILD.mk +++ b/llamafile/BUILD.mk @@ -91,6 +91,7 @@ o/$(MODE)/llamafile: \ o/$(MODE)/llamafile/sgemm.o: private CXXFLAGS += -Os o/$(MODE)/llamafile/iqk_mul_mat_amd_avx2.o: private TARGET_ARCH += -Xx86_64-mtune=skylake -Xx86_64-mavx2 -Xx86_64-mfma -Xx86_64-mf16c o/$(MODE)/llamafile/iqk_mul_mat_amd_zen4.o: private TARGET_ARCH += -Xx86_64-mtune=skylake -Xx86_64-mavx2 -Xx86_64-mfma -Xx86_64-mf16c -Xx86_64-mavx512f -Xx86_64-mavx512vl -Xx86_64-mavx512vnni -Xx86_64-mavx512bw -Xx86_64-mavx512dq +o/$(MODE)/llamafile/iqk_mul_mat_arm82.o: private TARGET_ARCH += -Xaarch64-march=armv8.2-a+dotprod+fp16 o/$(MODE)/llamafile/tinyblas_cpu_sgemm_amd_avx.o: private TARGET_ARCH += -Xx86_64-mtune=sandybridge -Xx86_64-mf16c o/$(MODE)/llamafile/tinyblas_cpu_mixmul_amd_avx.o: private TARGET_ARCH += -Xx86_64-mtune=sandybridge -Xx86_64-mf16c o/$(MODE)/llamafile/tinyblas_cpu_sgemm_amd_fma.o: private TARGET_ARCH += -Xx86_64-mtune=bdver2 -Xx86_64-mf16c -Xx86_64-mfma diff --git a/llamafile/iqk_mul_mat.inc b/llamafile/iqk_mul_mat.inc index d41ed8b984..a126f2c9ab 100644 --- a/llamafile/iqk_mul_mat.inc +++ b/llamafile/iqk_mul_mat.inc @@ -15,7 +15,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifdef __x86_64__ +#include +#if defined __x86_64__ || defined __aarch64__ #include "llama.cpp/ggml-impl.h" #include "llama.cpp/ggml-quants.h" @@ -38,11 +39,6 @@ #include #include -#if defined HAVE_FANCY_SIMD - #undef HAVE_FANCY_SIMD -#endif -#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) - #define HAVE_FANCY_SIMD #endif namespace { @@ -82,6 +78,40 @@ struct DataInfo { } }; +typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); + +struct MulMat { + std::array funcs = {}; + //std::array funcs = {}; + inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { +#ifdef __aarch64__ + constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small) +#else + constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small) +#endif + int n_step = (nrc_y - info.cur_y)/funcs.size(); + if (n_step > 0) { + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + for (int iy = 0; iy < n_step; ++iy) { + funcs.back()(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); + this_info.cur_y += funcs.size(); + } + } + info.cur_y += funcs.size() * n_step; + } + int n_left = nrc_y - info.cur_y; + if (n_left > 0) { + funcs[n_left-1](n, vx, bx, info, nrc_x); + } + } + static bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int Ny); +private: + template static void set_functions(MulMat& m); +}; + inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { const uint16_t * scales = (const uint16_t *)scales8; const uint32_t a0 = scales[0] | (scales[1] << 16); @@ -93,19 +123,71 @@ inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { aux32[0] = a0 & 0x3f3f3f3f; } -static inline float hsum_float_4(__m128 x) { +} + +bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B, + float * C, long stride_C, int ith, int nth) { + + MulMat mm; + int row_size_q8; + if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) { + return false; + } + + auto row_size_qx = ggml_row_size((ggml_type)typeA, ne00); + + auto nrc_x = (Nx + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + + DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, (size_t)row_size_q8, 0, 1, nullptr, 0}; + + mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); + + return true; +} + +bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B, + float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { + const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; + assert(row_mapping != nullptr); + + MulMat mm; + int row_size_q8; + if (!MulMat::set_mul_mat(typeA, ne00, mm, row_size_q8, Ny)) { + return false; + } + int row_size_qx = ggml_row_size((ggml_type)typeA, ne00); + int nrc_x = (Nx + nth - 1)/nth; + int first_x = ith*nrc_x; + if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)}; + mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); + return true; +} + +#if defined __x86_64__ + +#if defined HAVE_FANCY_SIMD + #undef HAVE_FANCY_SIMD +#endif +#if defined(__AVX512F__) && defined(__AVX512VNNI__) && defined(__AVX512VL__) && defined(__AVX512BW__) && defined(__AVX512DQ__) + #define HAVE_FANCY_SIMD +#endif + +namespace { + +inline float hsum_float_4(__m128 x) { x = _mm_add_ps(x, _mm_movehl_ps(x, x)); x = _mm_add_ss(x, _mm_movehdup_ps(x)); return _mm_cvtss_f32(x); } -static inline float hsum_float_8(__m256 x) { +inline float hsum_float_8(__m256 x) { return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); } #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) -typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x); - template struct Q8 { @@ -748,7 +830,7 @@ struct DequantizerQ6K final : public BaseDequantizer { const __m256i mh = _mm256_set1_epi8(0x30); }; -static inline __m256i get_scale_shuffle_16(int i) { +inline __m256i get_scale_shuffle_16(int i) { static const uint8_t k_shuffle[128] = { 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, @@ -1212,29 +1294,7 @@ void mul_mat_q8_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info } } -struct MulMat { - std::array funcs = {}; - inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) { - constexpr int k_x_step = 64; // This works best on my Ryzen-7950X (but differences to other tile size are small) - int n_step = (nrc_y - info.cur_y)/funcs.size(); - if (n_step > 0) { - for (int ix = 0; ix < nrc_x; ix += k_x_step) { - auto this_info = info; - this_info.s += ix; - int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; - for (int iy = 0; iy < n_step; ++iy) { - funcs.back()(n, (const void *)((const char *)vx + ix*bx), bx, this_info, this_nrc_x); - this_info.cur_y += funcs.size(); - } - } - info.cur_y += funcs.size() * n_step; - } - int n_left = nrc_y - info.cur_y; - if (n_left > 0) { - funcs[n_left-1](n, vx, bx, info, nrc_x); - } - } - template static void set_functions(MulMat& m) { +template void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_0_q8_0_T; m.funcs[1] = mul_mat_qX_0_q8_0_T; @@ -1289,10 +1349,9 @@ struct MulMat { } #endif } - } -}; +} -bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8) { +bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int) { row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); @@ -1322,22 +1381,22 @@ bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8) { MulMat::set_functions(mm); break; case GGML_TYPE_Q4_0: - assert (ne00 % Q4K_0 == 0); + assert (ne00 % QK4_0 == 0); MulMat::set_functions(mm); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); break; case GGML_TYPE_Q4_1: - assert (ne00 % Q4K_1 == 0); + assert (ne00 % QK4_1 == 0); MulMat::set_functions(mm); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); break; case GGML_TYPE_Q5_0: - assert (ne00 % Q5K_0 == 0); + assert (ne00 % QK5_0 == 0); MulMat::set_functions(mm); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); break; case GGML_TYPE_Q5_1: - assert (ne00 % Q5K_1 == 0); + assert (ne00 % QK5_1 == 0); MulMat::set_functions(mm); row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); break; @@ -1351,49 +1410,1085 @@ bool set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8) { } // namespace -// -// ============================== Matrix multiplications -// -bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, const void * B, - float * C, long stride_C, int ith, int nth) { +#else // __aarch64__ - MulMat mm; - int row_size_q8; - if (!set_mul_mat(typeA, ne00, mm, row_size_q8)) { - return false; +namespace { + +template struct Q8 { + + constexpr static int nrc_y = nrc; + + Q8(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); } - auto row_size_qx = ggml_row_size((ggml_type)typeA, ne00); + inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); } + inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); } + inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); } + inline int16x8_t load_bsums8(int iy, int i) const { + auto q8s = vld1q_s16_x2(y[iy][i].bsums); + return vpaddq_s16(q8s.val[0], q8s.val[1]); + } + inline float scale(int iy, int i) const { return y[iy][i].d; } - auto nrc_x = (Nx + nth - 1)/nth; - auto first_x = ith*nrc_x; - if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + const block_q8 * y[nrc_y]; +}; - DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, (size_t)row_size_q8, 0, 1, nullptr, 0}; +template +inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, + const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) { + auto mzero = vdupq_n_s32(0); + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1 + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2 + auto p12 = vpaddq_s32(p1, p2); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1 + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2 + auto p34 = vpaddq_s32(p3, p4); + + auto pall = vpaddq_s32(p12, p34); + sumi = vmlaq_s32(sumi, scales.val[j], pall); +} - mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); +template +inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8, + const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) { + + auto mzero = vdupq_n_s32(0); + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1, + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4, + auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3 + sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5, + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]), + ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7, + auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7 + sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34); +} - return true; +template +inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums8(iy, i); + int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s)); + int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s)); + float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2)); + acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); + } +} +template +inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0])); + int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0])); + int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1])); + int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1])); + float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4))); + acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i))); + } } -bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, const void * B, - float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) { - const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping; - assert(row_mapping != nullptr); +struct Scales8 { + uint32_t utmp[4]; + const uint8_t * sc8 = (const uint8_t *)utmp; + template + inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) { + make_q4_scales(x.scales, utmp); + int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8)); + accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin)); + + uint8x8_t scales8 = vld1_u8(sc8); + uint16x8_t scales16 = vmovl_u8(scales8); + int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))), + vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))}; + return scales; + } +}; - MulMat mm; - int row_size_q8; - if (!set_mul_mat(typeA, ne00, mm, row_size_q8)) { - return false; +struct Q4bits { + const uint8x16_t m4b = vdupq_n_u8(0xf); + uint8x16x4_t b1, b2; + inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const { + b.val[0] = vandq_u8(val[0], m4b); + b.val[2] = vshrq_n_u8(val[0], 4); + b.val[1] = vandq_u8(val[1], m4b); + b.val[3] = vshrq_n_u8(val[1], 4); + } + inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const { + b.val[0] = vandq_u8(val[0], m4b); + b.val[1] = vshrq_n_u8(val[0], 4); + b.val[2] = vandq_u8(val[1], m4b); + b.val[3] = vshrq_n_u8(val[1], 4); + } + inline void prepare(const uint8_t * qs) { + auto q4bits = vld1q_u8_x2(qs); + prepare4(b1, q4bits.val); + q4bits = vld1q_u8_x2(qs+32); + prepare4(b2, q4bits.val); + } + inline void prepare_v2(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + prepare4(b1, q4bits.val+0); + prepare4(b2, q4bits.val+2); + } + inline void prepare64(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + b1.val[0] = vandq_u8(q4bits.val[0], m4b); + b1.val[1] = vandq_u8(q4bits.val[1], m4b); + b1.val[2] = vandq_u8(q4bits.val[2], m4b); + b1.val[3] = vandq_u8(q4bits.val[3], m4b); + b2.val[0] = vshrq_n_u8(q4bits.val[0], 4); + b2.val[1] = vshrq_n_u8(q4bits.val[1], 4); + b2.val[2] = vshrq_n_u8(q4bits.val[2], 4); + b2.val[3] = vshrq_n_u8(q4bits.val[3], 4); + } + inline void prepare16(const uint8_t * qs) { + auto q4bits = vld1q_u8_x2(qs); + prepare4_16(b1, q4bits.val); + q4bits = vld1q_u8_x2(qs+32); + prepare4_16(b2, q4bits.val); + } + inline void prepare16_v2(const uint8_t * qs) { + auto q4bits = vld1q_u8_x4(qs); + prepare4_16(b1, q4bits.val+0); + prepare4_16(b2, q4bits.val+2); } - int row_size_qx = ggml_row_size((ggml_type)typeA, ne00); - int nrc_x = (Nx + nth - 1)/nth; - int first_x = ith*nrc_x; - if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; - DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), (size_t)row_size_q8, 0, ne11, row_mapping, nb2/sizeof(float)}; - mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); - return true; +}; + +struct Q2bits { + const uint8x16_t m4b = vdupq_n_u8(0x03); + uint8x16x4_t b1, b2; + inline void prepare(const uint8_t * qs) { + auto q2bits = vld1q_u8_x2(qs); + b1.val[0] = vandq_u8(q2bits.val[0], m4b); + b1.val[1] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b1.val[2] = vandq_u8(q2bits.val[0], m4b); + b1.val[3] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b2.val[0] = vandq_u8(q2bits.val[0], m4b); + b2.val[1] = vandq_u8(q2bits.val[1], m4b); + + q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2); + q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2); + b2.val[2] = vandq_u8(q2bits.val[0], m4b); + b2.val[3] = vandq_u8(q2bits.val[1], m4b); + } +}; + +template +struct BaseDequantizer { + BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {} + inline void new_row(int ix) { x = (const block_q *)((const char *)vx + ix*bx); } + const void * vx; + const block_q * x; + const size_t bx; + const int nrc; +}; + +struct DequantizerQ4K final : public BaseDequantizer { + DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return s8.process_scales_mins(x[i], q8, i, acc); + } + inline void prepare(int i, int j) { + if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); + else bits.prepare(x[i].qs+64*j); + } + + Q4bits bits; + Scales8 s8; + + float d; +}; + +struct HighBit5 { + const uint8x16_t mhb = vdupq_n_u8(0x10); + uint8x16x2_t bits; + inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { + b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb)); + b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb)); + b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb)); + b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb)); + + b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); + b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); + b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); + b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); + + if (do_shift) { + bits.val[0] = vshrq_n_u8(bits.val[0], 4); + bits.val[1] = vshrq_n_u8(bits.val[1], 4); + } + } +}; + +struct HighBit3 { + const uint8x16_t mhb = vdupq_n_u8(0x04); + uint8x16x2_t bits; + inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) { + b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb)); + b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb)); + b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb)); + b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb)); + + b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb)); + b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb)); + b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb)); + b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb)); + + if (do_shift) { + bits.val[0] = vshrq_n_u8(bits.val[0], 4); + bits.val[1] = vshrq_n_u8(bits.val[1], 4); + } + } +}; + +struct DequantizerQ5K final : public BaseDequantizer { + DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + h.bits = vld1q_u8_x2(x[i].qh); + return s8.process_scales_mins(x[i], q8, i, acc); + } + inline void prepare(int i, int j) { + if (nrc == 1) bits.prepare_v2(x[i].qs+64*j); + else bits.prepare(x[i].qs+64*j); + h.apply(bits.b1, bits.b2, j == 0); + } + + Q4bits bits; + HighBit5 h; + Scales8 s8; + + uint8x16x2_t hbits; + + float d; +}; + +inline int32x4x4_t make_wider(const int16x8x2_t& scales16) { + int32x4x4_t scales = { + vmovl_s16(vget_low_s16 (scales16.val[0])), + vmovl_s16(vget_high_s16(scales16.val[0])), + vmovl_s16(vget_low_s16 (scales16.val[1])), + vmovl_s16(vget_high_s16(scales16.val[1])), + }; + return scales; +} + +template +inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) { + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(scales8)); + scales16.val[1] = vmovl_s8(vget_high_s8(scales8)); + accum_mins_16(scales16, q8, acc, i, c); + return make_wider(scales16); +} + +struct DequantizerQ6K final : public BaseDequantizer { + DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d); + } + inline void prepare(int i, int j) { + + auto hbits = vld1q_u8_x2(x[i].qh + 32*j); + + bits.prepare64(x[i].ql+64*j); + bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb)); + bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb)); + bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb)); + bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb)); + + bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb)); + bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb)); + bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb)); + bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb)); + + } + + Q4bits bits; + + const uint8x16_t mhb = vdupq_n_u8(0x30); + + float d; +}; + +struct DequantizerQ3K final : public BaseDequantizer { + DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + h.bits = vld1q_u8_x2(x[i].hmask); + const uint16_t * sc16 = (const uint16_t *)x[i].scales; + uint32_t aux0 = sc16[0] | (sc16[1] << 16); + uint32_t aux1 = sc16[2] | (sc16[3] << 16); + uint32_t aux2 = sc16[4] | (sc16[5] << 16); + aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030); + aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); + aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); + aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); + return process_scales_mins_16(vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)), q8, acc, i, -4.f*d); + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + h.apply(bits.b1, bits.b2, j == 0); + } + + uint32_t aux32[4]; + + Q2bits bits; + + const uint8x16_t mhb = vdupq_n_u8(0x04); + HighBit3 h; + + float d; +}; + +struct DequantizerQ2K final : public BaseDequantizer { + DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return true; } + + template + inline void process_scales(int i, const Q8& q8, float32x4_t * acc) { + d = GGML_FP16_TO_FP32(x[i].d); + auto scales_and_mins = vld1q_u8(x[i].scales); + auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4)); + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(mins8)); + scales16.val[1] = vmovl_s8(vget_high_s8(mins8)); + accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin)); + + scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf)); + } + + template + inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) { + process_scales(i, q8, acc); + int16x8x2_t scales16; + scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8))); + scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8))); + return make_wider(scales16); + } + + template + inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) { + auto m1 = vdupq_n_u8(1); + auto shuffle = vdupq_n_u8(8*j); + bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8b_1 = q8.load_quants(iy, i, 4*j+0); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]), + vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]); + + auto q8b_2 = q8.load_quants(iy, i, 4*j+1); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]), + vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]); + + auto q8b_3 = q8.load_quants(iy, i, 4*j+2); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]), + vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]); + + auto q8b_4 = q8.load_quants(iy, i, 4*j+3); + sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]), + vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]); + } + } + + inline void prepare(int i, int j) { + bits.prepare(x[i].qs+32*j); + } + + uint32_t aux32[4]; + + uint8x16_t scales8; + + Q2bits bits; + + float d; +}; + +struct DequantizerIQ4XS final : public BaseDequantizer { + + static int8x16_t load_values() { + static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + return vld1q_s8(iq4nl_values); + } + + DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); } + + template + inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) { + (void)q8; + (void)acc; + d = GGML_FP16_TO_FP32(x[i].d); + const uint16_t scales_h = x[i].scales_h; + const uint16_t * scales_l = (const uint16_t *)x[i].scales_l; + aux32[0] = scales_l[0] | (scales_l[1] << 16); + aux32[1] = aux32[0] >> 4; + // scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7 + uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf)); + uint16_t * aux16 = (uint16_t *)aux32; + aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2; + // sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7 + uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30)); + int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32)); + // shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7 + scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff)); + int16x8_t scales16 = vmovl_s8(scales8); + int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))}; + return scales; + } + inline void prepare(int i, int j) { + bits.prepare16(x[i].qs+64*j); + //if (nrc == 1) { + // bits.prepare16_v2(x[i].qs+64*j); + //} else { + // bits.prepare16(x[i].qs+64*j); + //} + for (int k = 0; k < 4; ++k) { + bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k])); + bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k])); + } + } + + Q4bits bits; + const int8x16_t values; + uint32_t aux32[2]; + + constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602}; + + float d; +}; + +template +static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + Dequantizer deq(vx, bx, nrc_y); + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb; ++i) { + + int32x4_t sumi[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0); + + if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) { + deq.process_scales(i, q8, acc); + deq.prepare(i, 0); + deq.compute(q8, i, 0, sumi); + deq.prepare(i, 1); + deq.compute(q8, i, 1, sumi); + } else { + if constexpr (Dequantizer::num_blocks() == 8) { + auto scales = deq.new_block(i, q8, acc); + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else if constexpr (Dequantizer::num_blocks() == 16) { + auto scales = deq.new_block(i, q8, acc); + deq.prepare(i, 0); + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]); + deq.prepare(i, 1); + for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]); + } + else { + GGML_ASSERT(false); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i))); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} + +// =========================================== Legacy quants + +template +inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) { + for (int k = 0; k < 4; ++k) aux[k] = x[k].d; + return vld1_f16((const float16_t *)aux); +} + +template +inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) { + if constexpr (std::is_same_v) { + for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; } + } else { + for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; } + } + return vld1q_f16((const float16_t *)aux); +} + +struct Q4LegacyBits { + template + inline void prepare(const Block * x) { + for (int i = 0; i < 4; ++i) { + auto q4bits = vld1q_u8(x[i].qs); + b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); + b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); + } + } + inline void prepare1(const uint8_t * qs, int8x16_t * q) const { + auto q4bits = vld1q_u8(qs); + q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); + q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); + } + inline void prepare1(const uint8_t * qs) { + prepare1(qs, b); + } + const uint8x16_t m4b = vdupq_n_u8(0xf); + int8x16_t b[8]; +}; + +// One would think this commented out version would do better than the one below +// because it offers more opportunities to execute instructions in parallel. +// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers +// cannot it just do the sequential version below on its own? +//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { +// const auto q8b_1 = vld1q_s8_x2(qs + 0); +// auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]); +// const auto q8b_2 = vld1q_s8_x2(qs + 32); +// auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]); +// auto p1234 = vpaddq_s32(p12, p34); +// const auto q8b_3 = vld1q_s8_x2(qs + 64); +// auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]); +// const auto q8b_4 = vld1q_s8_x2(qs + 96); +// auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]); +// return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); +//} + +inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { + auto q8b = vld1q_s8_x2(qs + 0); + auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 32); + auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]); + auto p1234 = vpaddq_s32(p12, p34); + q8b = vld1q_s8_x2(qs + 64); + auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 96); + auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]); + return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); +} + +template struct Q80 { + + constexpr static int nrc_y = nrc; + + Q80(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy); + } + + inline const int8_t * quant_data(int iy, int i) const { + const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; + return y4->qs; + } + + inline float16x4_t load_scales(int iy, int i) const { + const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; + return vld1_f16((const float16_t *)y4->d); + } + + template + inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const { + auto qx_scales = deq.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + sc16[iy] = vmul_f16(qx_scales, q8_scales); + } + } + + template + inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { + deq.prepare1(i); + float d = GGML_FP16_TO_FP32(deq.x[i].d); + for (int iy = 0; iy < nrc; ++iy) { + auto q8b = vld1q_s8_x2(y[iy][i].qs); + auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); + acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); + } + } + + const block_q8_0 * y[nrc_y]; +}; + +template struct Q81 { + + constexpr static int nrc_y = nrc; + + Q81(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy); + } + + inline const int8_t * quant_data(int iy, int i) const { + const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; + return y4->qs; + } + + inline float16x8_t load_scales(int iy, int i) const { + const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; + return vld1q_f16((const float16_t *)y4->d); + } + + template + inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const { + auto qx_scales = deq.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales)); + acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m)); + sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales)); + } + } + + template + inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { + deq.prepare1(i); + float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m); + for (int iy = 0; iy < nrc; ++iy) { + auto q8b = vld1q_s8_x2(y[iy][i].qs); + auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); + acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); + acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s))); + } + } + + const block_q8_1 * y[nrc_y]; +}; + +template +struct BaseLegacyDequantizer { + + BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {} + + inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); } + + Q4LegacyBits bits; + + const void * vx; + const block_q * x; + size_t bx; +}; + +struct DequantizerQ40 final : public BaseLegacyDequantizer { + + DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + q[0] = vaddq_s8(q[0], m8); + q[1] = vaddq_s8(q[1], m8); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + const int8x16_t m8 = vdupq_n_s8(-8); + //ggml_half aux[4]; +}; + +struct DequantizerQ41 : public BaseLegacyDequantizer { + + DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.prepare1(x[i].qs); + } + + inline float16x8_t new_block(int i) { + uint32_t aux32[4]; + const uint32_t * s32 = (const uint32_t *)&x[4*i].d; + for (int k = 0; k < 4; ++k) { + aux32[k] = *s32; s32 += sizeof(block_q4_1)/4; + bits.prepare1(x[4*i+k].qs, bits.b + 2*k); + } + return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); + } + // Leaving this commented out attempt to be reminded that I already tried this. + // It has basically the same performance as the version above. + //inline float16x8_t new_block(int i) { + // uint32x4_t scales = {}; + // const block_q4_1 * xi = x + 4*i; + // const uint32_t * s32 = (const uint32_t *)&xi->d; + // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[0].qs, bits.b + 0); + // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[1].qs, bits.b + 2); + // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[2].qs, bits.b + 4); + // scales = vsetq_lane_u32(*s32, scales, 3); + // bits.prepare1(xi[3].qs, bits.b + 6); + // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle))); + //} + + const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; +}; + +struct HighBit5Legacy { + inline uint8x16_t to_bytes(const uint8_t * qh) const { + uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); + return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask)); + } + inline uint8x16_t to_negated_bytes(const uint8_t * qh) const { + uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); + return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0)); + } + const uint64x2_t mask = vdupq_n_u64(0x8040201008040201); + const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); +}; + +struct DequantizerQ50 final : public BaseLegacyDequantizer { + + DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh = x[i].qh; + q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0)))); + q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2)))); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + HighBit5Legacy hbits; + + const uint8x16_t mh = vdupq_n_u8(0xf0); + +}; + +struct DequantizerQ80 final : public BaseLegacyDequantizer { + + DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.b[0] = vld1q_s8(x[i].qs); + bits.b[1] = vld1q_s8(x[i].qs+16); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs); + bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16); + } + return vld1_f16((const float16_t *)aux); + } + +}; + +struct DequantizerQ51 final : public BaseLegacyDequantizer { + + DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh = x[i].qh; + q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0)))); + q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2)))); + } + inline void prepare1(int i) { + bits.prepare1(x[i].qs, bits.b); + } + + inline float16x8_t new_block(int i) { + uint32_t aux32[4]; + const uint32_t * s32 = (const uint32_t *)&x[4*i].d; + for (int k = 0; k < 4; ++k) { + aux32[k] = *s32; s32 += sizeof(block_q5_1)/4; + prepare1(4*i+k, bits.b + 2*k); + } + return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); + } + + HighBit5Legacy hbits; + + const uint8x16_t mh = vdupq_n_u8(0x10); + const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; + +}; + +template +inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i)); + auto scale = vcvt_f32_f16(sc16[iy]); + acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall)); + } +} + +template +inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) { + const int nb = n / QK4_1; + + float16x4_t sc16[Q8::nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[Q8::nrc_y]; + for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb/4; ++i) { + q8.process_scales(i, deq, sc16, acc); + sum_4(i, deq, q8, sc16, acc); + } + for (int i = 4*(nb/4); i < nb; ++i) { + q8.process_1_block(i, deq, acc); + } + + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} + +template +inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { + const int nb = n / QK4_1; + + float16x4_t sc16[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq1.new_row(ix); + deq2.new_row(ix); + + float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) }; + + for (int i = 0; i < nb/8; ++i) { + q8.process_scales(2*i+0, deq1, sc16+0, acc+0); + q8.process_scales(2*i+1, deq2, sc16+1, acc+1); + sum_4(2*i+0, deq1, q8, sc16+0, acc+0); + sum_4(2*i+1, deq2, q8, sc16+1, acc+1); + } + for (int i = 2*(nb/8); i < nb/4; ++i) { + q8.process_scales(i, deq1, sc16, acc); + sum_4(i, deq1, q8, sc16, acc); + } + for (int i = 4*(nb/4); i < nb; ++i) { + q8.process_1_block(i, deq1, acc); + } + + info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1]))); + } +} + +template +static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Q81 q8(info); + if constexpr (nrc_y == 1) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); + } else { + Dequantizer deq(vx, bx); + mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + } +} + +template +static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Q80 q8(info); + if constexpr (nrc_y == 1) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); + } else { + Dequantizer deq(vx, bx); + mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + } +} + +template +static void mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + Q81<1> q8(info); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); +} + +template +static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + Q80<1> q8(info); + mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); +} + +template void MulMat::set_functions(MulMat& m) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + m.funcs[0] = mul_mat_qX_0_q8_0; + m.funcs[1] = mul_mat_qX_0_q8_0; + m.funcs[2] = mul_mat_qX_0_q8_0; + m.funcs[3] = mul_mat_qX_0_q8_0; + m.funcs[4] = mul_mat_qX_0_q8_0; + m.funcs[5] = mul_mat_qX_0_q8_0; + m.funcs[6] = mul_mat_qX_0_q8_0; + m.funcs[7] = mul_mat_qX_0_q8_0; + } + else if constexpr (std::is_same_v || std::is_same_v) { + m.funcs[0] = mul_mat_qX_1_q8_1; + m.funcs[1] = mul_mat_qX_1_q8_1; + m.funcs[2] = mul_mat_qX_1_q8_1; + m.funcs[3] = mul_mat_qX_1_q8_1; + m.funcs[4] = mul_mat_qX_1_q8_1; + m.funcs[5] = mul_mat_qX_1_q8_1; + m.funcs[6] = mul_mat_qX_1_q8_1; + m.funcs[7] = mul_mat_qX_1_q8_1; + } + else { + m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; + m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; + m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; + m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; + m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; + m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; + m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; + m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; + } +} + +bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /*Ny*/) { + row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00); + + switch (typeA) { + case GGML_TYPE_Q2_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_Q3_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_Q4_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_Q5_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_Q6_K: + MulMat::set_functions(m); + break; + case GGML_TYPE_IQ4_XS: + MulMat::set_functions(m); + break; + case GGML_TYPE_Q4_0: + MulMat::set_functions(m); + row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + break; + case GGML_TYPE_Q4_1: + MulMat::set_functions(m); + row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); + break; + case GGML_TYPE_Q5_0: + MulMat::set_functions(m); + row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + break; + case GGML_TYPE_Q5_1: + MulMat::set_functions(m); + row_size_q8 = ggml_row_size(GGML_TYPE_Q8_1, ne00); + break; + case GGML_TYPE_Q8_0: + MulMat::set_functions(m); + row_size_q8 = ggml_row_size(GGML_TYPE_Q8_0, ne00); + break; + default: + return false; + } + return true; +} + } -#endif // __x86_64__ +#endif // __x86_64__ or __aarch64__ diff --git a/llamafile/iqk_mul_mat_arm82.cpp b/llamafile/iqk_mul_mat_arm82.cpp new file mode 100644 index 0000000000..592eb20525 --- /dev/null +++ b/llamafile/iqk_mul_mat_arm82.cpp @@ -0,0 +1,5 @@ +#ifdef __aarch64__ +#define iqk_mul_mat iqk_mul_mat_arm82 +#define iqk_mul_mat_moe iqk_mul_mat_moe_arm82 +#include "iqk_mul_mat.inc" +#endif // __aarch64__ diff --git a/llamafile/sgemm.cpp b/llamafile/sgemm.cpp index d7324c1f92..648e751f2c 100644 --- a/llamafile/sgemm.cpp +++ b/llamafile/sgemm.cpp @@ -88,6 +88,7 @@ static const struct GemmFuncs { // e.g. Apple M1, Raspberry Pi 5 sgemm = llamafile_sgemm_arm82; mixmul = llamafile_mixmul_arm82; + iqk_mixmul = iqk_mul_mat_moe_arm82; } else { // ARM64 baseline ISA sgemm = llamafile_sgemm_arm80; diff --git a/llamafile/sgemm.h b/llamafile/sgemm.h index 9e1864c87a..33fcd60332 100644 --- a/llamafile/sgemm.h +++ b/llamafile/sgemm.h @@ -9,11 +9,14 @@ struct ggml_compute_params; bool iqk_mul_mat(long, long, long, int, const void *, const void *, float *, long, int, int); bool iqk_mul_mat_zen4(long, long, long, int, const void *, const void *, float *, long, int, int); +bool iqk_mul_mat_arm82(long, long, long, int, const void *, const void *, float *, long, int, int); bool iqk_mul_mat_moe(long, long, long, int, int, const void *, const void *, float *, long, long, const void *, int, int); bool iqk_mul_mat_moe_zen4(long, long, long, int, int, const void *, const void *, float *, long, long, const void *, int, int); +bool iqk_mul_mat_moe_arm82(long, long, long, int, int, const void *, const void *, + float *, long, long, const void *, int, int); bool iqk_mul_mat_moe_unsupported(long, long, long, int, int, const void *, const void *, float *, long, long, const void *, int, int); diff --git a/llamafile/tinyblas_cpu_sgemm.inc b/llamafile/tinyblas_cpu_sgemm.inc index 7555e79932..db01abc42e 100644 --- a/llamafile/tinyblas_cpu_sgemm.inc +++ b/llamafile/tinyblas_cpu_sgemm.inc @@ -322,7 +322,8 @@ bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void assert(nth > 0); assert(ith < nth); -#if defined(__x86_64__) && QK_K == 256 +#if QK_K == 256 +#if defined(__x86_64__) if (X86_CHECK(AVX2) && X86_CHECK(FMA)) { if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) { if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float *)C, ldc, ith, nth)) { @@ -336,6 +337,19 @@ bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void } } } +#elif defined __aarch64__ && defined __ARM_FEATURE_DOTPROD && !defined _MSC_VER + if (Btype == GGML_TYPE_Q8_K && Ctype == GGML_TYPE_F32) { + if (iqk_mul_mat(m, n, k * QK_K, Atype, A, B, (float *)C, ldc, ith, nth)) { + return true; + } + } + if ((Btype == GGML_TYPE_Q8_0 || Btype == GGML_TYPE_Q8_1) && Ctype == GGML_TYPE_F32) { + assert(QK8_0 == QK8_1 == QK4_0 == QK4_1 == QK5_0 == QK5_1 == 32); + if (iqk_mul_mat(m, n, k * QK8_0, Atype, A, B, (float *)C, ldc, ith, nth)) { + return true; + } + } +#endif #endif switch (Ctype) { diff --git a/llamafile/tinyblas_cpu_sgemm_arm82.cpp b/llamafile/tinyblas_cpu_sgemm_arm82.cpp index 81e77ce21c..5cfe23a453 100644 --- a/llamafile/tinyblas_cpu_sgemm_arm82.cpp +++ b/llamafile/tinyblas_cpu_sgemm_arm82.cpp @@ -1,4 +1,5 @@ #ifdef __aarch64__ #define llamafile_sgemm llamafile_sgemm_arm82 +#define iqk_mul_mat iqk_mul_mat_arm82 #include "tinyblas_cpu_sgemm.inc" #endif // __aarch64__