Skip to content

Commit

Permalink
Merge pull request #142 from chaxu01/feature/my-ML-examples
Browse files Browse the repository at this point in the history
Update with latest KleidiAI release
  • Loading branch information
kshitij-sisodia-arm authored Aug 21, 2024
2 parents 307bed2 + f9360f1 commit 2321def
Showing 1 changed file with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
From 25ba8dfa43e2b4b101b890c88464b638427d3d42 Mon Sep 17 00:00:00 2001
From 8d4bc83e2144cbbe5e634a53ac07a2c6a709b9c0 Mon Sep 17 00:00:00 2001
From: Charles Xu <[email protected]>
Date: Wed, 17 Jul 2024 13:28:18 +0200
Date: Wed, 21 Aug 2024 07:31:51 +0200
Subject: [PATCH] Use KleidiAI Int4 Matmul micro-kernels in llama.cpp

- Update CMake file to fetch the Int4 micro-kernels from the KleidiAI
Expand All @@ -21,7 +21,7 @@ Signed-off-by: Charles Xu <[email protected]>
create mode 100644 ggml-kleidiai.h

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 08481334..07f8f601 100644
index 08481334..6aed4fc6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -548,6 +548,57 @@ if (LLAMA_VULKAN)
Expand All @@ -32,9 +32,9 @@ index 08481334..07f8f601 100644
+
+ # Fetch KleidiAI sources:
+ include(FetchContent)
+ set(KLEIDIAI_COMMIT_SHA "187d9aacddfb678c09f0831b18f87401b1b353c3")
+ set(KLEIDIAI_COMMIT_SHA "cb27bbe4cd47bb15d8236df3250ff105ef64e65b")
+ set(KLEIDIAI_DOWNLOAD_URL "https://gitlab.arm.com/kleidi/kleidiai/-/archive/${KLEIDIAI_COMMIT_SHA}/kleidiai-${KLEIDIAI_COMMIT_SHA}.tar.gz")
+ set(KLEIDIAI_ARCHIVE_MD5 "4a1eee013cb20464b534cb01212d19c9")
+ set(KLEIDIAI_ARCHIVE_MD5 "f4fa5d1070d9f0ab96f5c021d292dde3")
+
+ if (POLICY CMP0135)
+ cmake_policy(SET CMP0135 NEW)
Expand Down Expand Up @@ -66,7 +66,7 @@ index 08481334..07f8f601 100644
+ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
+
+ list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
+ list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0.c)
+ list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
+ list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
+ list(APPEND GGML_SOURCES_KLEIDIAI ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.c)
+
Expand Down Expand Up @@ -123,7 +123,7 @@ index bd367c42..ed4ce0ae 100644
if (this_size > max_size) {
diff --git a/ggml-kleidiai.cpp b/ggml-kleidiai.cpp
new file mode 100644
index 00000000..257a0d4c
index 00000000..9129ea99
--- /dev/null
+++ b/ggml-kleidiai.cpp
@@ -0,0 +1,675 @@
Expand Down Expand Up @@ -176,7 +176,7 @@ index 00000000..257a0d4c
+// KleidiAI micro-kernels
+#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
+#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
+#include "kai_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0.h"
+#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
+#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
+#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h"
+
Expand Down Expand Up @@ -473,7 +473,7 @@ index 00000000..257a0d4c
+ v.nr = ukernel->get_nr();
+ v.kr = ukernel->get_kr();
+ v.sr = ukernel->get_sr();
+ v.packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0(n, k, v.nr, v.kr, k_q4_0_block_size /* 32 */);
+ v.packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n, k, v.nr, v.kr, k_q4_0_block_size /* 32 */);
+
+ return v;
+}
Expand Down Expand Up @@ -638,11 +638,11 @@ index 00000000..257a0d4c
+ // Temporary memory for the computation.
+ uint8_t *reshaped_data = (uint8_t*)malloc(reshaped_data_sz);
+
+ struct kai_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0_params params;
+ struct kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0_params params;
+ params.lhs_zero_point = 1;
+ params.rhs_zero_point = 8;
+
+ kai_run_rhs_pack_nxk_qsi4c32f16scalep_qsu4c32s16s0(
+ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
+ 1, n, k, // Dimensions
+ rhs_packing_params.nr, // Nr
+ rhs_packing_params.kr, // Kr
Expand Down

0 comments on commit 2321def

Please sign in to comment.