diff --git a/kleidiai-examples/llama_cpp/0001-Use-KleidiAI-Int4-Matmul-micro-kernels-in-llama.cpp.patch b/kleidiai-examples/llama_cpp/0001-Use-KleidiAI-Int4-Matmul-micro-kernels-in-llama.cpp.patch index 5ae7354..e6dfd87 100644 --- a/kleidiai-examples/llama_cpp/0001-Use-KleidiAI-Int4-Matmul-micro-kernels-in-llama.cpp.patch +++ b/kleidiai-examples/llama_cpp/0001-Use-KleidiAI-Int4-Matmul-micro-kernels-in-llama.cpp.patch @@ -1,6 +1,6 @@ -From 25ba8dfa43e2b4b101b890c88464b638427d3d42 Mon Sep 17 00:00:00 2001 +From 8d4bc83e2144cbbe5e634a53ac07a2c6a709b9c0 Mon Sep 17 00:00:00 2001 From: Charles Xu -Date: Wed, 17 Jul 2024 13:28:18 +0200 +Date: Wed, 21 Aug 2024 07:31:51 +0200 Subject: [PATCH] Use KleidiAI Int4 Matmul micro-kernels in llama.cpp - Update CMake file to fetch the Int4 micro-kernels from the KleidiAI @@ -21,7 +21,7 @@ Signed-off-by: Charles Xu create mode 100644 ggml-kleidiai.h diff --git a/CMakeLists.txt b/CMakeLists.txt -index 08481334..07f8f601 100644 +index 08481334..6aed4fc6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -548,6 +548,57 @@ if (LLAMA_VULKAN) @@ -32,9 +32,9 @@ index 08481334..07f8f601 100644 + + # Fetch KleidiAI sources: + include(FetchContent) -+ set(KLEIDIAI_COMMIT_SHA "187d9aacddfb678c09f0831b18f87401b1b353c3") ++ set(KLEIDIAI_COMMIT_SHA "cb27bbe4cd47bb15d8236df3250ff105ef64e65b") + set(KLEIDIAI_DOWNLOAD_URL "https://gitlab.arm.com/kleidi/kleidiai/-/archive/${KLEIDIAI_COMMIT_SHA}/kleidiai-${KLEIDIAI_COMMIT_SHA}.tar.gz") -+ set(KLEIDIAI_ARCHIVE_MD5 "4a1eee013cb20464b534cb01212d19c9") ++ set(KLEIDIAI_ARCHIVE_MD5 "f4fa5d1070d9f0ab96f5c021d292dde3") + + if (POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) @@ -66,7 +66,7 @@ index 08481334..07f8f601 100644 + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) + + list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c) -+ list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0.c) ++ list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) + list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c) + list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c) + @@ -123,7 +123,7 @@ index bd367c42..ed4ce0ae 100644 if (this_size > max_size) { diff --git a/ggml-kleidiai.cpp b/ggml-kleidiai.cpp new file mode 100644 -index 00000000..257a0d4c +index 00000000..9129ea99 --- /dev/null +++ b/ggml-kleidiai.cpp @@ -0,0 +1,675 @@ @@ -176,7 +176,7 @@ index 00000000..257a0d4c +// KleidiAI micro-kernels +#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" -+#include "kai_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0.h" ++#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h" + @@ -473,7 +473,7 @@ index 00000000..257a0d4c + v.nr = ukernel->get_nr(); + v.kr = ukernel->get_kr(); + v.sr = ukernel->get_sr(); -+ v.packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0(n, k, v.nr, v.kr, k_q4_0_block_size /* 32 */); ++ v.packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n, k, v.nr, v.kr, k_q4_0_block_size /* 32 */); + + return v; +} @@ -638,11 +638,11 @@ index 00000000..257a0d4c + // Temporary memory for the computation. + uint8_t *reshaped_data = (uint8_t*)malloc(reshaped_data_sz); + -+ struct kai_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0_params params; ++ struct kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + -+ kai_run_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0( ++ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0( + 1, n, k, // Dimensions + rhs_packing_params.nr, // Nr + rhs_packing_params.kr, // Kr