diff --git a/source/backend/cpu/CPURuntime.cpp b/source/backend/cpu/CPURuntime.cpp index 17a653f52..7dbc28170 100644 --- a/source/backend/cpu/CPURuntime.cpp +++ b/source/backend/cpu/CPURuntime.cpp @@ -28,7 +28,9 @@ // ref: https://cs.android.com/android/platform/superproject/+/master:bionic/libc/kernel/uapi/asm-arm64/asm/hwcap.h;drc=04da58f5b3bc40dbbafb4f8422aa2991479d9e1e;l=70 #define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000) #define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000) -#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002) + +#define CPUINFO_ARM_LINUX_FEATURE2_SVE2 UINT32_C(0x00000002) +#define CPUINFO_ARM_LINUX_FEATURE2_SME2 UINT64_C(0x0000002000000000) #endif #include @@ -1279,6 +1281,9 @@ static void _getInfoApple(MNNCPUInfo* cpuinfo_isa) { if (have_feature("hw.optional.arm.FEAT_I8MM")) { cpuinfo_isa->i8mm = true; } + if (have_feature("hw.optional.arm.FEAT_SME2")) { + cpuinfo_isa->sme2 = true; + } } #endif @@ -1286,6 +1291,8 @@ static void _getInfoApple(MNNCPUInfo* cpuinfo_isa) { static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) { // Use AUX to get info for linux-aarch64 uint32_t isa_features = 0; + uint64_t isa_features2 = 0; + isa_features = (uint32_t)getauxval(AT_HWCAP); if (isa_features & CPUINFO_ARM_LINUX_FEATURE_ASIMDDP) { cpuinfo_isa->dot = true; @@ -1297,10 +1304,14 @@ static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) { if (isa_features & CPUINFO_ARM_LINUX_FEATURE_I8MM) { cpuinfo_isa->i8mm = true; } - isa_features = (uint32_t)getauxval(AT_HWCAP2); - if (isa_features & CPUINFO_ARM_LINUX_FEATURE_SVE2) { + + isa_features2 = (uint64_t)getauxval(AT_HWCAP2); + if (isa_features & CPUINFO_ARM_LINUX_FEATURE2_SVE2) { cpuinfo_isa->sve2 = true; } + if (isa_features & CPUINFO_ARM_LINUX_FEATURE2_SME2) { + cpuinfo_isa->sme2 = true; + } } #endif @@ -1351,6 +1362,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) { cpuinfo_isa->fp16arith = false; cpuinfo_isa->i8mm = false; cpuinfo_isa->sve2 = false; + cpuinfo_isa->sme2 = false; // android /**Get CPU Info*/ #ifdef __linux__ @@ -1447,6 +1459,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) { cpuinfo_isa->dot = true; #endif - MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d\n", cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2); + MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d, sme2: %d\n", + cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2, cpuinfo_isa->sme2); return; } diff --git a/source/backend/cpu/CPURuntime.hpp b/source/backend/cpu/CPURuntime.hpp index 7155e023b..a71142d50 100644 --- a/source/backend/cpu/CPURuntime.hpp +++ b/source/backend/cpu/CPURuntime.hpp @@ -21,6 +21,7 @@ struct MNNCPUInfo { bool dot; bool i8mm; bool sve2; + bool sme2; std::vector groups; int cpuNumber = 0; }; diff --git a/source/backend/cpu/arm/kleidiAI/CMakeLists.txt b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt index f12e27c19..6adfa8f8a 100644 --- a/source/backend/cpu/arm/kleidiAI/CMakeLists.txt +++ b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt @@ -20,7 +20,9 @@ if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS endif() list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.cpp) +list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai_util.cpp) list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.h) +list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai_util.h) add_library( MNN_KleidiAI @@ -41,6 +43,7 @@ include_directories( set(KLEIDIAI_FILES_SCALAR ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.c ) set(KLEIDIAI_FILES_NEON_DOTPROD @@ -51,13 +54,20 @@ set(KLEIDIAI_FILES_NEON_I8MM ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c ) +set(KLEIDIAI_FILES_SME2_MOPA + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.c +) + # Selectively enable architecture features. target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SCALAR}) if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC) target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD}) target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) + target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SME2_MOPA}) set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a) set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod) set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm) + set_source_files_properties(${KLEIDIAI_FILES_SME2_MOPA} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve2) endif() \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/kai/kai_common.h b/source/backend/cpu/arm/kleidiAI/kai/kai_common.h index 9569e5468..6697d978e 100644 --- a/source/backend/cpu/arm/kleidiAI/kai/kai_common.h +++ b/source/backend/cpu/arm/kleidiAI/kai/kai_common.h @@ -78,13 +78,13 @@ inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) { /// @param[in] f16 The f16 value /// /// @return the f32 value -inline static float kai_cast_f32_f16(uint16_t f16) { #if defined(__ARM_NEON) +inline static float kai_cast_f32_f16(uint16_t f16) { __fp16 f32 = 0; memcpy(&f32, &f16, sizeof(uint16_t)); return (float)f32; -#endif } +#endif /// Converts a scalar bf16 value to f32 /// @param[in] bf16 The f16 value @@ -92,7 +92,7 @@ inline static float kai_cast_f32_f16(uint16_t f16) { /// @return the f32 value inline static float kai_cast_f32_bf16(uint16_t bf16) { const uint32_t i32 = (bf16 << 16); - float f32; + float f32 = 0; memcpy(&f32, &i32, sizeof(i32)); return f32; } @@ -116,79 +116,60 @@ inline static uint16_t kai_cast_bf16_f32(float f32) { /// @param[in] f32 The f32 value /// /// @return the f16 value -inline static uint16_t kai_cast_f16_f32(float f32) { #if defined(__ARM_NEON) +inline static uint16_t kai_cast_f16_f32(float f32) { uint16_t f16 = 0; - __fp16 tmp = f32; + __fp16 tmp = (__fp16)f32; memcpy(&f16, &tmp, sizeof(uint16_t)); return f16; -#endif } +#endif inline static size_t kai_roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } -#ifdef __ARM_FEATURE_SVE - +#ifdef __ARM_FEATURE_SVE2 /// Gets the SME vector length for 8-bit elements. inline static uint64_t kai_get_sme_vector_length_u8(void) { uint64_t res = 0; - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "cntb %0\n" - ".inst 0xd503467f // SMSTOP\n" + ".inst 0x04bf5827 // rdsvl x7, #1\n" + "mov %0, x7\n" : "=r"(res) - : - : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", - "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); - + : /* no inputs */ + : "x7"); return res; } /// Gets the SME vector length for 16-bit elements. inline static uint64_t kai_get_sme_vector_length_u16(void) { - uint64_t res = 0; - - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "cnth %0\n" - ".inst 0xd503467f // SMSTOP\n" - : "=r"(res) - : - : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", - "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); - - return res; + return kai_get_sme_vector_length_u8() / 2; } /// Gets the SME vector length for 32-bit elements. inline static uint64_t kai_get_sme_vector_length_u32(void) { - uint64_t res = 0; - - __asm__ __volatile__( - ".inst 0xd503477f // SMSTART ZA\n" - "cntw %0\n" - ".inst 0xd503467f // SMSTOP\n" - : "=r"(res) - : - : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", - "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); - - return res; + return kai_get_sme_vector_length_u8() / 4; } - -#endif // __ARM_FEATURE_SVE +#endif // __ARM_FEATURE_SVE2 /// Extends the sign bit of int 4-bit value (stored in int8_t variable) /// @param[in] value The 4-bit int value /// /// @return the int8_t value with sign extended inline static int8_t kai_ext_sign_i8_i4(int8_t value) { - return (value ^ 0x8) - 8; + // Make sure value holds correct int4 value + KAI_ASSERT(value <= 0xF); + + return (value ^ 0x8) - 8; // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) } +/// Parameter struct for RHS matrix packing +struct kai_rhs_pack_qs4cxs1s0_param { + int8_t lhs_zero_point; /**< LHS Matrix quantization zero-point */ + uint8_t rhs_zero_point; /**< RHS Matrix quantization zero-point */ +}; + #ifdef __cplusplus } #endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c new file mode 100644 index 000000000..d9756ed31 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c @@ -0,0 +1,273 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural feature check + +#include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_mr = 1; // multiple of vector length +static const size_t kai_nr = 4; // multiple of vector length +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias_rhs = sizeof(float); + +/** + * Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110). + **/ + +static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, + 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0, + -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0}; + +inline static size_t kai_k_roundedup(size_t k) { + // Round up k to be a multiple of 32. + return kai_roundup(k, 32); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_mr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + + const size_t k_internal = kai_k_roundedup(k); + + return m_idx * (k_internal + kai_num_bytes_offset_lhs + kai_num_bytes_multiplier_lhs); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return n_idx * ((k_internal / 2) + kai_num_bytes_sum_rhs + kai_num_bytes_multiplier_rhs + kai_num_bytes_bias_rhs); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + KAI_ASSERT((n_idx % kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa()) == 0); + + return (n_idx * sizeof(float) + m_idx * dst_stride); +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( + size_t m_in, size_t n_in, size_t k_in, const void* restrict lhs_packed, const void* restrict rhs_packed, + float* restrict dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + KAI_ASSERT(dst_stride_row == n_in * sizeof(float)); + KAI_ASSERT(dst_stride_row == n_in * sizeof(float)); + KAI_ASSERT(n_in > 0); + KAI_ASSERT(m_in > 0); + + // Constants + uint64_t mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); + uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); + uint64_t lhs_stride = + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(mr, k_in); + uint64_t rhs_stride = + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(nr, k_in); + uint64_t m_blk = (uint64_t)kai_k_roundedup(k_in) * mr; + uint64_t n_blk = (uint64_t)kai_k_roundedup(k_in) * nr; + uint64_t dst_inc = mr * n_in; + float scalar_bounds[2] = {scalar_min, scalar_max}; + + /* --------------------------------------------------- + Registers allocations + x7: u32 vector length (svls) + x8: RHS base address (rhs_base) + x9: Destination base address (dst_base) + x10: LHS pointer (lhs_ptr) + x11: RHS pointer (rhs_ptr) + x12: Remaining M elements (m_rem) + x13: Remaining N elements (n_rem) + x14: k exit condition (k_cond) + ZA tile index (l_idx) + x15: LHS scaling factor pointer (lhs_sf_ptr) + x16: ZA tile exit condition (l_cnd) + x17: Destination pointer (dst_ptr) + x19: Destination outer address (dst_o) + x20: LHS base address (lhs_base) + --------------------------------------------------- */ + __asm__ volatile( + " .inst 0xd503477f //smstart \n" + " mov x19, %[dst] \n" + " mov x20, %[lhs] \n" + " mov x7, %[lut] \n" + " .inst 0xe11f80e0 //ldr zt0, [x7] \n" + " cntw x7 \n" + " ptrue p2.b \n" + " ld1rw {z30.s}, p2/Z, [%[scalar_bounds]] \n" + " ld1rw {z31.s}, p2/Z, [%[scalar_bounds], #4] \n" + + // M loop head + " mov x12, %[m] \n" + " .inst 0x25ac17e0 //whilelt p0.s, xzr, x12 \n" + "1: \n" + " mov x8, %[rhs] \n" + " mov x9, x19 \n" + " mov x13, %[n] \n" + " cmp x7, x12 \n" + " csel x16, x7, x12, lt \n" + " lsl x16, x16, #2 \n" + + // N loop head + " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n" + "2: \n" + " mov x10, x20 \n" + " mov x11, x8 \n" + " mov x17, x9 \n" + " .inst 0x25ad67f1 //whilelt pn9.s, xzr, x13, vlx4 \n" + + // K loop + " .inst 0xc00800ff //zero {za} \n" + " add x14, x10, %[m_blk] \n" + "3: \n" + " .inst 0xa540a144 //ld1w { z4.s }, p0/z, [x10] \n" + " .inst 0x042a502a //addvl x10, x10, #1 \n" + " .inst 0xa0402160 //ld1h { z0.h-z1.h }, pn8/z, [x11] \n" + " .inst 0x042b504b //addvl x11, x11, #2 \n" + " .inst 0xc08a4008 //luti4 { z8.b - z9.b }, zt0, z0[0] \n" + " .inst 0xc08a402a //luti4 { z10.b - z11.b }, zt0, z1[0] \n" + " .inst 0xa0884880 //smopa za0.s, p2/m, p2/m, z4.b, z8.b \n" + " .inst 0xa0894881 //smopa za1.s, p2/m, p2/m, z4.b, z9.b \n" + " .inst 0xa08a4882 //smopa za2.s, p2/m, p2/m, z4.b, z10.b\n" + " .inst 0xa08b4883 //smopa za3.s, p2/m, p2/m, z4.b, z11.b\n" + " cmp x10, x14 \n" + " b.lt 3b \n" + + // RHS row sum, scale factor & bias + " .inst 0xa040c560 //ld1w { z0.s-z3.s }, pn9/z, [x11] \n" + " .inst 0xa041c564 //ld1w { z4.s-z7.s }, pn9/z, [x11, #4, mul vl] \n" + " .inst 0xa042c568 //ld1w { z8.s-z11.s }, pn9/z, [x11, #8, mul vl]\n" + " .inst 0x042b518b //addvl x11, x11, #12 \n" + " .inst 0xc132e000 //scvtf { z0.s-z3.s }, { z0.s-z3.s }\n" + + // Store loop + " mov x14, #0 \n" + " addvl x15, x10, #1 \n" + "4: \n" + // Load LHS Row-offset & SF + " ld1rw {z16.s}, p2/z, [x10] \n" + " ld1rw {z17.s}, p2/z, [x15] \n" + " add x10, x10, #4 \n" + " add x15, x15, #4 \n" + " scvtf z16.s, p2/m, z16.s \n" + + // offset x Row-sum + " fmul z24.s, z16.s, z0.s \n" + " fmul z25.s, z16.s, z1.s \n" + " fmul z26.s, z16.s, z2.s \n" + " fmul z27.s, z16.s, z3.s \n" + + // Scaling factors + " fmul z20.s, z17.s, z4.s \n" + " fmul z21.s, z17.s, z5.s \n" + " fmul z22.s, z17.s, z6.s \n" + " fmul z23.s, z17.s, z7.s \n" + + // Result = offset x Row-sum x SFs + " fmul z24.s, z24.s, z20.s \n" + " fmul z25.s, z25.s, z21.s \n" + " fmul z26.s, z26.s, z22.s \n" + " fmul z27.s, z27.s, z23.s \n" + + // Load inner accumulation & convert + " .inst 0xc006440c //mova { z12.b-z15.b }, za0h.b[w14, 0:3]\n" + " .inst 0xc132e18c //scvtf { z12.s-z15.s }, { z12.s-z15.s } \n" + + // Result += iacc x SF + " fmla z24.s, p2/m, z20.s, z12.s \n" + " fmla z25.s, p2/m, z21.s, z13.s \n" + " fmla z26.s, p2/m, z22.s, z14.s \n" + " fmla z27.s, p2/m, z23.s, z15.s \n" + + // Add the bias + " fadd z24.s, p2/m, z24.s, z8.s \n" + " fadd z25.s, p2/m, z25.s, z9.s \n" + " fadd z26.s, p2/m, z26.s, z10.s \n" + " fadd z27.s, p2/m, z27.s, z11.s \n" + + // CLAMP and store + " .inst 0xc1bfcbd8 //fclamp { z24.s-z27.s }, z30.s, z31.s\n" + " .inst 0xa060c638 //st1w { z24.s-z27.s }, pn9, [x17] \n" + + " add x17, x17, %[n], lsl #2 \n" + " add x14, x14, #4 \n" + " cmp x14, x16 \n" + " b.lt 4b \n" + + // N loop tail + " add x8, x8, %[rhs_stride] \n" + " .inst 0x04295089 // ddvl x9, x9, #4 \n" + " sub x13, x13, %[nr] \n" + " .inst 0x256d47f0 //whilelt pn8.h, xzr, x13, vlx2 \n" + " b.mi 2b \n" + + // M loop tail + " add x20, x20, %[lhs_stride] \n" + " add x19, x19, %[dst_inc], lsl #2 \n" + " sub x12, x12, %[mr] \n" + " whilelt p0.s, xzr, x12 \n" + " b.mi 1b \n" + + "5: \n" + " .inst 0xd503467f //smstop \n" + : + : [m] "r"(m_in), [n] "r"(n_in), [k] "r"(k_in), [lhs_stride] "r"(lhs_stride), [rhs_stride] "r"(rhs_stride), + [mr] "r"(mr), [nr] "r"(nr), [lut] "r"(lut), [m_blk] "r"(m_blk), [n_blk] "r"(n_blk), [lhs] "r"(lhs_packed), + [rhs] "r"(rhs_packed), [dst_inc] "r"(dst_inc), [scalar_bounds] "r"(scalar_bounds), [dst] "r"(dst) + : "x7", "x8", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x19", "x20", "p0", "p2", "p8", + "p9", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z30", "z31", +#ifdef __ARM_STATE_ZA + "za", +#endif +#ifdef __ARM_STATE_ZT0 + "zt0", +#endif + "cc", "memory"); +} + +#endif // Architectural feature check diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h new file mode 100644 index 000000000..f770d4f3e --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h @@ -0,0 +1,123 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the mr value, which must be used to pack the LHS matrix with +/// the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the nr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the kr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the sr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t n_idx, size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the destination(DST) offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes for the destination (DST) matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination (DST) matrix size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dxp) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsi4cxp) and packed. +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension between the LHS and RHS matrix. +/// @param[in] lhs_packed The LHS packed matrix. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS packed matrix, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0 +/// @param[out] dst The DST matrix. +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. It must be sizeof(float). +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.c new file mode 100644 index 000000000..874869a07 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.c @@ -0,0 +1,313 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if !defined(__aarch64__) || !defined(__ARM_FEATURE_SVE2) +#error This file must be compiled for AArch64, FEAT_SVE2. +#else // Architectural feature check + +#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.h" + +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 1; +static const size_t kai_nr = 4; // nr svl dependent +static const size_t kai_mr = 1; +static const size_t kai_kr = 4; +static const size_t kai_sr = 1; + +// Scaling factors +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +// q8_1 zero point +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +// Sum of quantized row for weights for faster zero point activations +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +// Bias +static const size_t kai_num_bytes_bias_rhs = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k) { + // Round up k to be a multiple of 32. + return kai_roundup(k, 32); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 32) == 0); + + return kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot() * + (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 32) == 0); + + return kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot() * + ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias_rhs); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void) { + return kai_m_step; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void) { + return kai_nr * kai_get_sme_vector_length_u32(); +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void) { + return kai_n_step * kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(); +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void) { + // For gemv mr must be 1 to consecutively read the data + return kai_mr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + (m_idx * dst_stride); +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(size_t m, size_t n) { + return m * n * sizeof(float); +} + +/** + * Lut to be indexed by i4 resulting in its value in i8 (i.e. -2 = 1110 -> 1111 1110). + **/ +static const int8_t lut[64] = {0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 5, 0, + 0, 0, 6, 0, 0, 0, 7, 0, 0, 0, -8, 0, 0, 0, -7, 0, 0, 0, -6, 0, 0, 0, + -5, 0, 0, 0, -4, 0, 0, 0, -3, 0, 0, 0, -2, 0, 0, 0, -1, 0, 0, 0}; + +/** + * + * Optimized for GEMV (matrix vector multiplication => m == 1). + * Does a matmul for compatibility reasons, but should not be used that way. + * + **/ +void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, + float* dst, // NOLINT(readability-non-const-parameter) + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0 || n == 0 || k == 0) { + return; + } + + // Do function calls and calculations first to not overwrite registers we will use + uint64_t k_internal = kai_k_roundedup(k); + uint64_t A_vector_increment = kai_lhs_packed_stride(k); + uint64_t W_row_stride = kai_rhs_packed_stride(k); + uint64_t nr = kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(); + + uint64_t W_row_bytes = nr * k_internal / 2; + uint64_t W_row_bytes_net = nr * k_internal / 2; + uint64_t A_matrix_end_ptr = ((uint64_t)lhs_packed) + (m * A_vector_increment); + + /* + * x11: zero = 0 // MUST BE x8-x11 + * x15: n initialized as n + * x19: nr initialized as nr + * x20: lut_ptr initialized as lut + * x21: A_vector_ptr initialized as lhs_packed + * x22: n_idx + * x23: k_idx + * x24: W_k_block_ptr + * x25: W_row_values_end_ptr + * x26: W_row_ptr + * x27: dst_ptr + * x28: tmp_1 + */ + + __asm__ volatile( + + // Setup + " .inst 0xd503477f // smstart \n" + " mov x11, #0 \n" + " mov x15, %[n] \n" + " mov x19, %[nr] \n" + " mov x21, %[lhs_packed] \n" + " mov x20, %[lut] \n" + " .inst 0xe11f8280 // ldr zt0, [x20] \n" + " ptrue p0.b \n" + " .inst 0x25207810 // ptrue pn8.b \n" + // predicate to load nr words for the W sums and scaling factors (should be exactly all true) + " .inst 0x25b36571 // whilelt pn9.s, x11, x19, vlx4 \n" + " dup z30.s, %w[scalar_min] \n" + " dup z31.s, %w[scalar_max] \n" + + // Activation matrix row loop + "1: \n" + // Reset weight matrix ptr + " mov x26, %[rhs_packed] \n" + // Reset dst_ptr to dst of next GEMV result + " mov x27, %[dst_vector_ptr] \n" + // Reset n index + " mov x22, #0 \n" + // whilelt pn12.s, x22, %[n], vlx4 + " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" + + // Weight matrix row loop (transposed so theoretical columns) + "2: \n" + + // Reset weights block ptr to start of row + " mov x24, x26 \n" + " add x25, x26, %[W_row_bytes_net] \n" + " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" + " addvl x28, x24, #4 \n" + " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" + " mov x23, #0 \n" + " whilelt p1.b, x23, %[k_internal] \n" + // Zero for sdot accumulation in inner loop + " .inst 0xc00800ff // zero {za} \n" + + // before k loop + "3: \n" + + // Load A + " ld1rqb { z0.b }, p1/z , [x21, x23] \n" + + // Load w + " .inst 0xa0408b10 // ld1b { z16.b - z19.b }, pn10/z, [x24] \n" + " .inst 0xa0418f14 // ld1b {z20.b-z23.b}, pn11/z, [x24,#0x4, mul vl]\n" + + // Weight i4 to i8 and sdot + // k block + 0 + " .inst 0xc08a4218 // luti4 { z24.b, z25.b }, zt0, z16[0] \n" + " .inst 0xc08a423a // luti4 { z26.b, z27.b }, zt0, z17[0] \n" + " .inst 0xc150f320 // sdot za.s[w11,0, vgx4], {z24.b-z27.b}, z0.b[0]\n" + // k block + 1 + " .inst 0xc08a4244 // luti4 { z4.b, z5.b }, zt0, z18[0] \n" + " .inst 0xc08a4266 // luti4 { z6.b, z7.b }, zt0, z19[0] \n" + " .inst 0xc150f4a0 // sdot za.s[w11,0, vgx4], {z4.b-z7.b}, z0.b[1] \n" + // k block + 2 + " .inst 0xc08a4288 // luti4 { z8.b, z9.b }, zt0, z20[0] \n" + " .inst 0xc08a42aa // luti4 { z10.b, z11.b }, zt0, z21[0] \n" + " .inst 0xc150f920 // sdot za.s[w11,0, vgx4], {z8.b-z11.b}, z0.b[2] \n" + // k block + 3 + " .inst 0xc08a42cc // luti4 { z12.b, z13.b }, zt0, z22[0] \n" + " .inst 0xc08a42ee // luti4 { z14.b, z15.b }, zt0, z23[0] \n" + " .inst 0xc150fda0 // sdot za.s[w11,0, vgx4], {z12.b-z15.b}, z0.b[3]\n" + + // End K block loop + " addvl x24, x24, #8 \n" + " .inst 0x25396712 // whilelt pn10.b, x24, x25, vlx4 \n" + " addvl x28, x24, #4 \n" + " .inst 0x25396793 // whilelt pn11.b, x28, x25, vlx4 \n" + " add x23, x23, #16 \n" + " whilelt p1.b, x23, %[k_internal] \n" + " b.first 3b \n" + + // Finish of accumulators with scaling factors and zero points + + // Load A zero point + " add x28, x21, %[k_internal] \n" + " ld1rw { z2.s }, p0/z , [x28] \n" + // Load A scaling factor + " ld1rw { z3.s }, p0/z , [x28, #4] \n" + // Load W sums + " add x28, x26, %[W_row_bytes] \n" + " .inst 0xa040c794 // ld1w { z20.s - z23.s }, pn9/z, [x28] \n" + // Load W scaling factors + " .inst 0xa041c798 // ld1w {z24.s-z27.s}, pn9/z, [x28, #0x4, mul vl]\n" + // Load biases + " .inst 0xa042c78c // ld1w {z12.s-z15.s}, pn9/z, [x28, #0x8, mul vl]\n" + + // Get accumulated value out of ZA + " .inst 0xc0066c04 // mov { z4.d - z7.d }, za.d[w11, 0, vgx4] \n" + + // za contains a * w, which needs to be done + z * wsum -> smla + // zero point * W row sum + " mla z4.s, p0/m, z20.s, z2.s \n" + " mla z5.s, p0/m, z21.s, z2.s \n" + " mla z6.s, p0/m, z22.s, z2.s \n" + " mla z7.s, p0/m, z23.s, z2.s \n" + + // Convert to float + " .inst 0xc132e084 // scvtf { z4.s - z7.s }, { z4.s - z7.s } \n" + + // A scaling factor * W scaling factor + " fmul z24.s, z24.s, z3.s \n" + " fmul z25.s, z25.s, z3.s \n" + " fmul z26.s, z26.s, z3.s \n" + " fmul z27.s, z27.s, z3.s \n" + + // Bias + combined scaling factor * combined accumulator + " fmla z12.s, p0/m, z24.s, z4.s \n" + " fmla z13.s, p0/m, z25.s, z5.s \n" + " fmla z14.s, p0/m, z26.s, z6.s \n" + " fmla z15.s, p0/m, z27.s, z7.s \n" + + // Clamp + " .inst 0xc1bfcbcc // fclamp { z12.s - z15.s }, z30.s, z31.s \n" + + // Store + " .inst 0xa036d36c // st1w {z12.s-z15.s}, pn12, [x27, x22, lsl #2] \n" + + // End W row loop + " add x26, x26, %[W_row_stride] \n" + // nr == svlb + " addvl x22, x22, #1 \n" + // whilelt pn12.s, x22, %[n], vlx4 + " .inst 0x25af66d4 // whilelt pn12.s, x22, x15, vlx4 \n" + " b.lt 2b \n" + + // End A row loop + " add %[dst_vector_ptr], %[dst_vector_ptr], %[dst_stride_row] \n" + " add x21, x21, %[A_vector_increment] \n" + " cmp x21, %[A_matrix_end_ptr] \n" + " b.lt 1b \n" + + " .inst 0xd503467f // smstop \n" + + : [dst_vector_ptr] "+r"(dst) + : [lut] "r"(lut), [m] "r"(m), [n] "r"(n), [k] "r"(k), [lhs_packed] "r"(lhs_packed), + [rhs_packed] "r"(rhs_packed), [dst_stride_row] "r"(dst_stride_row), [scalar_min] "r"(scalar_min), + [scalar_max] "r"(scalar_max), [k_internal] "r"(k_internal), [A_vector_increment] "r"(A_vector_increment), + [W_row_stride] "r"(W_row_stride), [nr] "r"(nr), [W_row_bytes] "r"(W_row_bytes), + [W_row_bytes_net] "r"(W_row_bytes_net), [A_matrix_end_ptr] "r"(A_matrix_end_ptr) + : "x11", "x15", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "p0", "p1", "p8", "p9", + "p10", "p11", "p12", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", + "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", + "z29", "z30", "z31", +#ifdef __ARM_STATE_ZA + "za", +#endif +#ifdef __ARM_STATE_ZT0 + "zt0", +#endif + "memory", "cc"); +} +#endif // Architectural feature check diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.h new file mode 100644 index 000000000..59c26cd45 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.h @@ -0,0 +1,128 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#ifndef __cplusplus +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the mr value, which must be used to pack the LHS matrix with +/// the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the nr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0 micro-kernel +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the kr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0 micro-kernel +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the sr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0 micro-kernel +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(size_t n_idx, size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 4. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes for the destination matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. +/// Output tile: (rows x cols) = 4 x 4 +/// Accumulation performed in a single for loop: 32 +/// Instruction used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. +/// @param[in] lhs_packed The LHS matrix packed. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0 +/// @param[out] dst Result of the vector-by-matrix +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.c new file mode 100644 index 000000000..817e7a450 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.c @@ -0,0 +1,151 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k) { + // Round up k to be a multiple of 32. + size_t kai_k_multiple_of = 32; + return kai_roundup(k, kai_k_multiple_of); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { + KAI_UNUSED(kr); + KAI_UNUSED(sr); + + const size_t k_internal = kai_k_roundedup(k); + + // multiple of 2 because 2 elements in a byte + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { + KAI_ASSERT((n_idx % nr) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(k, nr, kr, sr); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(k, nr, kr, sr); +} + +void kai_run_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0_params* params) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % kr) == 0); + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->lhs_zero_point == 1); + KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const int32_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(k, nr, kr, sr); + const size_t dst_nr_block_size = nr * kr * sizeof(uint8_t) / 2; + + // Iterate over n src rows in blocks of nr rows + for (size_t row_idx = 0; row_idx < n; row_idx += nr) { + int8_t* const dst_row = (int8_t*)rhs_packed + ((row_idx / nr) * rhs_packed_stride); + + int32_t* const sums = (int32_t*)(dst_row + (nr * (k_internal / 2))); + float32_t* const scaling_factors = (float32_t*)((uint8_t*)sums + (nr * kai_num_bytes_sum_rhs)); + // Update destination row pointer + float* const biases = (float*)((uint8_t*)scaling_factors + (nr * kai_num_bytes_multiplier_rhs)); + + // initialize sums to 0 + memset(sums, 0, nr * kai_num_bytes_sum_rhs); + + // Copy the scaling factors and bias + size_t rows_left = n - row_idx; + if (rows_left >= nr) { + memcpy(scaling_factors, &scale[row_idx], nr * kai_num_bytes_multiplier_rhs); + memcpy(biases, &bias[row_idx], nr * kai_num_bytes_bias); + } else { + // Fill remaining values + memcpy(scaling_factors, &scale[row_idx], rows_left * kai_num_bytes_multiplier_rhs); + memcpy(biases, &bias[row_idx], rows_left * kai_num_bytes_bias); + // Set leftover to 0 + memset(&scaling_factors[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs); + memset(&biases[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias); + } + + // Iterate over rows in the nr row block + for (size_t nr_block_idx = 0; nr_block_idx < nr; ++nr_block_idx) { + const uint8_t* const src_row = rhs + ((row_idx + nr_block_idx) * rhs_stride); + // Go to the first kr block for this row in the nr block + int8_t* dst_kr_block = dst_row + (nr_block_idx * kr / 2); + + int32_t sum = 0; + + // Iterate over k src columns in blocks of kr columns + for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { + // Iterate over columns in the kr block + // Kr checked to be multiple of 2 (because 2 values per byte) + for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { + // We pad dst with 0s if the rounded k or n values have been exceeded + if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { + dst_kr_block[kr_block_idx / 2] = 0; + continue; + } + + // Load the 2 u4 values from source + const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; + + // extract i8 values from the 2 u4 values + const int32_t first_value = (dst_byte & 0xF) - rhs_zero_point; + const int32_t second_value = col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point; + + // Add the i4 value to the row sum + sum += first_value + second_value; + + // Truncate i8 to i4 and write to dst + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + dst_kr_block[kr_block_idx / 2] = (second_value << 4) | (first_value & 0xF); + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + } + + // Go to the next kr block for this row in the nr rows + dst_kr_block += dst_nr_block_size; + } + + // save sum + sums[nr_block_idx] = sum; + } + } +} diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.h new file mode 100644 index 000000000..dc1c986ba --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.h @@ -0,0 +1,99 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#include "kai/kai_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0_params kai_rhs_pack_qs4cxs1s0_param + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed), which holds +/// the int4 values in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k In the RHS matrix (not packed), K is the number of columns. +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k The common dimension between the LHS and RHS matrix (K) +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); + +/// @brief Gets the size in bytes for the packed RHS matrix +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); + +/// Run the micro-kernel to pack the RHS matrix. +/// +/// @note The int4 values are stored in a N x K matrix, where N is number of rows and K is the number of columns. +/// Two int4 values are stored in one byte. The lower order part of the byte (low) holds +/// the first nibble (K-index + 0). The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of columns of the output matrix (N). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. +/// @param[in] nr The number of N columns to interleave on the same output output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] bias The biases. +/// @param[in] scale The scale for each output channel. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0_params* params); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp index dc1f9169f..146066147 100644 --- a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp @@ -10,349 +10,209 @@ using namespace MNN; -KleidiAI *KleidiAI::instance = NULL; - -inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { - // Since we pack a float and int32 value at the end of the row, - // we must make sure that k is a multiple of 4 for memory alignment. - size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); - return kai_roundup(k, kr_sr_roundedup4); -} - -static void packQsi4cxps16s0Qs4cxs0s1( - size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, - const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { - KAI_ASSERT(num_groups == 1); - KAI_ASSERT(extra_bytes == 0); - KAI_ASSERT((kr % sr) == 0); - KAI_ASSERT(rhs != NULL); - KAI_ASSERT(scale != NULL); - KAI_ASSERT(rhs_packed != NULL); - KAI_ASSERT(params != NULL); - KAI_ASSERT(params->rhs_zero_point == 8); - KAI_ASSERT(params->lhs_zero_point == 1); - - const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); - const size_t k_internal = kai_k_roundedup(k, kr, sr); - const size_t dst_num_rows = kai_roundup(n, nr) / nr; - const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); - const size_t block_length_in_bytes = kr / sr; - const size_t k_interleaved_v = 16U; - const size_t rhs_stride = kai_roundup(k, 2) / 2; - - for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { - uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; - - int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); - - // Initialize to zero the RHS reduction sums - memset(sums, 0, nr * sizeof(int32_t)); - - for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { - const size_t block_idx = dst_byte_idx / block_length_in_bytes; - const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; - const size_t super_block_idx = block_idx / nr; - const size_t nr_idx = block_idx % nr; - - const size_t k_adjustment = - ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; - const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; - const size_t k1_idx = k0_idx + k_interleaved_v; - const size_t n0_idx = dst_row_idx * nr + nr_idx; - - // Clamp the index to avoid out-of-bound reads - const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - - const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; - const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; - - uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; - uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; - - if (k0_idx < k) { - byte0 = rhs[src_addr_byte0]; - } - - if (k1_idx < k) { - byte1 = rhs[src_addr_byte1]; - } - - // The following operations where we extract the values from the bytes - // can be also written in the following and less efficient manner: - /* - uint8_t src_x0_lo = 0; - uint8_t src_x0_hi = 0; - - if ((k0_idx % 2) == 0) { - src_x0_lo = (byte0 & 0x0F); - } else { - src_x0_lo = (byte0 >> 4); - } - - if ((k1_idx % 2) == 0) { - src_x0_hi = (byte1 & 0x0F); - } else { - src_x0_hi = (byte1 >> 4); - } - */ - const size_t shift_right_x0 = ((k0_idx + 1) % 2) * 4; - const size_t shift_right_x1 = ((k1_idx + 1) % 2) * 4; - - const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; - const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; - - sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; - - const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); - - *dst_row = dst_qs0 ^ 0x88; - dst_row += sizeof(uint8_t); - } - - // Adjust the reduction sums - for (size_t i = 0; i < nr; ++i) { - sums[i] = sums[i] * 16; - dst_row += sizeof(int32_t); - } - - // Adjust the scales - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; - dst_row += sizeof(float); - } - - // Set the bias - if (bias == NULL) { - memset(dst_row, 0, nr * sizeof(float)); +#define FLT16_MAX 65504.0f +#define FLT16_MIN -65504.0f + +KleidiAI *KleidiAI::mKaiInstance = NULL; +bool KleidiAI::mKaiInitialized = false; +KleidiAI::modelInfo KleidiAI::mModelInfo; +KleidiAI::accelType KleidiAI::mAccelType = KleidiAI::accelType::NOT_SUPPORT; +KleidiAI::CPUInfo KleidiAI::mCPUInfo; +KleidiAI::kleidiaiInfo KleidiAI::mKleidiaiInfo; + +const std::map KleidiAI::mModelInfoMap = { + /*qi4, asym, fp16, blkSize*/ + {KleidiAI::modelInfo(true, false, false, 0), KleidiAI::accelType::QI4_SYM_FP32_CHNLQT}, + //TODO: KleidiAI support. + // {KleidiAI::modelInfo(true, true, false, 0), KleidiAI::accelType::QI4_ASYM_FP32_CHNLQT}, + // {KleidiAI::modelInfo(true, true, false, -1), KleidiAI::accelType::QI4_ASYM_FP32_BLKQT}, + // {KleidiAI::modelInfo(true, true, true, 0), KleidiAI::accelType::QI4_ASYM_FP16_CHNLQT}, + // {KleidiAI::modelInfo(true, true, true, -1), KleidiAI::accelType::QI4_ASYM_FP16_BLKQT}, + // {KleidiAI::modelInfo(true, false, false, -1), KleidiAI::accelType::QI4_SYM_FP32_BLKQT}, + // {KleidiAI::modelInfo(true, false, true, 0), KleidiAI::accelType::QI4_SYM_FP16_CHNLQT}, + // {KleidiAI::modelInfo(true, false, true, -1), KleidiAI::accelType::QI4_SYM_FP16_BLKQT}, +}; + +//Get instance. +KleidiAI& KleidiAI::getInstance(const modelInfo& modelInfo, const MNNCPUInfo& gCPUInfo) { + if(!mKaiInstance) { + //Set mKaiInitialized and construct. + mKaiInstance = new KleidiAI; + mKaiInitialized = true; + + //Set model info. + mModelInfo = modelInfo; + //Set mAccelType. + auto it = mModelInfoMap.find(mModelInfo); + if(it != mModelInfoMap.end()) { + mAccelType = it->second; } else { - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - ((float*)dst_row)[i] = bias[src_row_idx]; - } + mAccelType = accelType::NOT_SUPPORT; } - } -} - -static void packQs4cxs16s0Qsi8cx(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, - const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { - KAI_ASSERT(num_groups == 1); - KAI_ASSERT(extra_bytes == 0); - KAI_ASSERT((kr % sr) == 0); - KAI_ASSERT(rhs != NULL); - KAI_ASSERT(scale != NULL); - KAI_ASSERT(rhs_packed != NULL); - KAI_ASSERT(params != NULL); - KAI_ASSERT(params->rhs_zero_point == 8); - KAI_ASSERT(params->lhs_zero_point == 1); - - const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); - const size_t k_internal = kai_k_roundedup(k, kr, sr); - const size_t dst_num_rows = kai_roundup(n, nr) / nr; - const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); - const size_t block_length_in_bytes = kr / sr; - const size_t k_interleaved_v = 16U; - const size_t rhs_stride = kai_roundup(k, 2); - - for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { - uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; - - int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); - - // Initialize to zero the RHS reduction sums - memset(sums, 0, nr * sizeof(int32_t)); - - for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { - const size_t block_idx = dst_byte_idx / block_length_in_bytes; - const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; - const size_t super_block_idx = block_idx / nr; - const size_t nr_idx = block_idx % nr; - - const size_t k_adjustment = - ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; - const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; - const size_t k1_idx = k0_idx + k_interleaved_v; - const size_t n0_idx = dst_row_idx * nr + nr_idx; - - // Clamp the index to avoid out-of-bound reads - const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); - - const size_t src_addr_byte0 = k0_idx + n0_valid_idx * rhs_stride; - const size_t src_addr_byte1 = k1_idx + n0_valid_idx * rhs_stride; - - int8_t byte0 = 0; - int8_t byte1 = 0; - - if (k0_idx < k) { - byte0 = rhs[src_addr_byte0]; - } - - if (k1_idx < k) { - byte1 = rhs[src_addr_byte1]; - } - - sums[nr_idx] += (int32_t)byte0 + (int32_t)byte1; - - const uint8_t dst_qs0 = (byte0 + rhs_zero_point) | ((byte1 + rhs_zero_point) << 4); - - *dst_row = dst_qs0 ^ 0x88; - dst_row += sizeof(uint8_t); - } - - // Adjust the reduction sums - for (size_t i = 0; i < nr; ++i) { - sums[i] = sums[i] * 16; - dst_row += sizeof(int32_t); - } - - // Adjust the scales - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; - dst_row += sizeof(float); - } - - // Set the bias - if (bias == NULL) { - memset(dst_row, 0, nr * sizeof(float)); + mModelInfo.print(); + + //Set CPU info + mCPUInfo = gCPUInfo; + + if(canAccelerate()) { + MNN_PRINT("KleidiAI is running!\n"); + //Init Kleidi info related to model type. + mKleidiaiInfo.init(mCPUInfo.mSme2); } else { - for (size_t i = 0; i < nr; ++i) { - // Clamp the row index to avoid out-of-bound reads - const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); - ((float*)dst_row)[i] = bias[src_row_idx]; - } + MNN_PRINT("KleidiAI cannot accelerate!\n"); } } + return *mKaiInstance; } -void KleidiAI::packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize) { - if(rowNum == 1) { - return; +KleidiAI& KleidiAI::getInstance() { + if(!mKaiInstance) { + MNN_ASSERT(0); //Should never happen. } + return *mKaiInstance; +} - const size_t tmp_size = rowNum * rowSize * sizeof(float); - uint8_t *tmpBuffer = new uint8_t[tmp_size]; - memcpy(tmpBuffer, data, tmp_size); - - const float *src = (const float *)tmpBuffer; - float *dst = (float *)data; - - size_t blockNum = rowSize / 4; - size_t blockSize = 4 * sizeof(float); - - for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { - const float *rowSrc = src + blockIndex * 4; - for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { - memcpy(dst, rowSrc, blockSize); - dst += 4; - rowSrc += rowSize; - } +//Lhs +size_t KleidiAI::getLhsQuantedPackedSize(size_t m, size_t k) { + switch(mAccelType) { + case accelType::QI4_SYM_FP32_CHNLQT: + return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, getMr(m), getKr(), getSr()); + default: + MNN_ASSERT(0); + return 0; } - - delete[] tmpBuffer; } -void KleidiAI::packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize) { - if(rowNum == 1) { - return; +size_t KleidiAI::getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k) { + if(mIdx == 0) { + return 0; } - const size_t tmp_size = rowNum * rowSize * sizeof(float); - uint8_t *tmpBuffer = new uint8_t[tmp_size]; - memcpy(tmpBuffer, data, tmp_size); - - const float *src = (const float *)tmpBuffer; - float *dst = (float *)data; - - size_t blockNum = rowSize / 4; - size_t blockSize = 4 * sizeof(float); - - for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { - const float *rowSrc = src + blockIndex * 4 * rowNum; - float *block_dst = dst + blockIndex * 4; - for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { - memcpy(block_dst, rowSrc, blockSize); - block_dst += rowSize; - rowSrc += 4; - } + switch(mAccelType) { + case accelType::QI4_SYM_FP32_CHNLQT: + return kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(mIdx, k, getMr(m), getKr(), getSr()); + default: + MNN_ASSERT(0); + return 0; } - - delete[] tmpBuffer; } -//Set info -void KleidiAI::setEnable(bool enable) { - mKaiInfo.kaiEnable = enable; - if(canAccelerate()) { - MNN_PRINT("\nKleidiAI is running!\n"); +void KleidiAI::runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, void* lhsQuantedPacked) { + void (*pack)(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, + const float* lhs, size_t lhs_stride, void* lhs_packed) = NULL; + + switch(mAccelType) { + case accelType::QI4_SYM_FP32_CHNLQT: + pack = kai_run_lhs_quant_pack_qai8dxp_f32; + break; + default: + MNN_ASSERT(0); + break; } -} -void KleidiAI::setModelAsymmetric(bool bAsymmetric) { - mKaiInfo.asymmetric = bAsymmetric; - if(canAccelerate()) { - MNN_PRINT("\nKleidiAI is running!\n"); + if(pack) { + pack(m, k, mr, getKr(), getSr(), 0, (const float *)lhs, k * sizeof(float), lhsQuantedPacked); } } -//Lhs -size_t KleidiAI::getLhsQuantedPackedSize(size_t m, size_t k) { - return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, getMr(m), getKr(), getSr()); -} - -size_t KleidiAI::getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k) { - return mIdx == 0 ? 0 : kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(mIdx, k, getMr(m), getKr(), getSr()); -} - -void KleidiAI::runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, void* lhsQuantedPacked) { - kai_run_lhs_quant_pack_qai8dxp_f32(m, k, mr, getKr(), getSr(), 0, (const float *)lhs, k * sizeof(float), lhsQuantedPacked); -} - //Rhs size_t KleidiAI::getRhsPackedSize(size_t n, size_t k) { - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, getNr(), getKr(), getSr()); + switch(mAccelType) { + case accelType::QI4_SYM_FP32_CHNLQT: + if(mCPUInfo.mSme2) { + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(n, k, getNr(), getKr(), getSr()); + } else { + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, getNr(), getKr(), getSr()); + } + default: + MNN_ASSERT(0); + return 0; + } } size_t KleidiAI::getRhsPackedOffset(size_t nIdx, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(nIdx, k, getNr(), getKr(), getSr()); + if(nIdx == 0) { + return 0; + } + + switch(mAccelType) { + case accelType::QI4_SYM_FP32_CHNLQT: + if(mCPUInfo.mSme2) { + return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(nIdx, k, getNr(), getKr(), getSr()); + } else { + return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(nIdx, k, getNr(), getKr(), getSr()); + } + default: + MNN_ASSERT(0); + return 0; + } } -void KleidiAI::runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4) { - struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - if(!packedInt4) { - packQs4cxs16s0Qsi8cx(1, n, k, getNr(), getKr(), getSr(), - (const uint8_t *)rhs, - (const float *)bias, (const float *)scale, - rhsPacked, - 0, ¶ms); - } else { - packQsi4cxps16s0Qs4cxs0s1(1, n, k, getNr(), getKr(), getSr(), - (const uint8_t *)rhs, - (const float *)bias, (const float *)scale, - rhsPacked, - 0, ¶ms); +void KleidiAI::runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void* zeroPoint, const void* bias, void* rhsPacked, bool packedQ4) { + void (*packSymChnl)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct KleidiAIUtil::rhsPackParamCommon* params) = NULL; + + switch(mAccelType) { + case accelType::QI4_SYM_FP32_CHNLQT: + { + if(mCPUInfo.mSme2) { + if(packedQ4) { + packSymChnl = KleidiAIUtil::packQsi4cxpoQsu4cxs0s1; + } else { + MNN_ASSERT(0); + } + } else { + if(packedQ4) { + packSymChnl = KleidiAIUtil::packQsi4cxps16s0Qs4cxs0s1; + } else { + packSymChnl = KleidiAIUtil::packQsi4cxps16s0Qs4cx; + } + } + break; + } + default: + MNN_ASSERT(0); + } + + KleidiAIUtil::rhsPackParamCommon paramCommon; + if(packSymChnl) { + packSymChnl(1, n, k, getNr(), getKr(), getSr(), + (const uint8_t *)rhs, (const float *)bias, (const float *)scale, + rhsPacked, 0, ¶mCommon); } } //Matmul void KleidiAI::runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst) { - if(m == 1) { //dotprod - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(m, n, k, - (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); - } else { //i8mm - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(m, n, k, - (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + void (*runChnlQuantMatmul)(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) = NULL; + + const float scalar_max = mModelInfo.mFp16 ? FLT16_MAX : FLT_MAX; + const float scalar_min = -scalar_max; + + switch(mAccelType) { + case accelType::QI4_SYM_FP32_CHNLQT: + if(m == 1) { + if(mCPUInfo.mSme2) { + runChnlQuantMatmul = kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot; + } else { + runChnlQuantMatmul = kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod; + } + } else { + if(mCPUInfo.mSme2) { + runChnlQuantMatmul = kai_run_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa; + } else { + runChnlQuantMatmul = kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm; + } + } + break; + default: + MNN_ASSERT(0); + } + + if(runChnlQuantMatmul) { + runChnlQuantMatmul(m, n, k, (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, dst_stride, sizeof(float), scalar_min, scalar_max); } } diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h index 38cdce230..2d6b6dc7a 100644 --- a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h @@ -7,75 +7,170 @@ #pragma once #include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" -#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" -#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" - -#include "kai_common.h" +#include "core/Backend.hpp" +#include "core/Execution.hpp" +#include "core/TensorUtils.hpp" +#include "core/ConvolutionCommon.hpp" +#include "backend/cpu/CPUBackend.hpp" +#include "backend/cpu/CPURuntime.hpp" +#include "backend/cpu/compute/CommonOptFunction.h" + +#include "mnn_kleidiai_util.h" namespace MNN { class KleidiAI { public: - static KleidiAI &getInstance(bool bAsymmetric, bool acthalf, bool blockwise) { - if(!instance) { - instance = new KleidiAI(bAsymmetric, acthalf, blockwise); + //Define some necessary data structures. + enum class accelType { + /* + QI4: Int4 quantified; + ASYM/SYM: Asymmetric/symmetric; + CHNLQT/BLKQT: Per channel quantified/Per block quantified; + FP16: FP16 output. + */ + QI4_ASYM_FP32_CHNLQT, + QI4_ASYM_FP32_BLKQT, + QI4_ASYM_FP16_CHNLQT, + QI4_ASYM_FP16_BLKQT, + QI4_SYM_FP32_CHNLQT, + QI4_SYM_FP32_BLKQT, + QI4_SYM_FP16_CHNLQT, + QI4_SYM_FP16_BLKQT, + NOT_SUPPORT + }; + + typedef struct modelInfo { + bool mQi4; //Int4 quant. + bool mAsymmetric; //Asymmetric quantized model. + bool mFp16; //fp16 or fp32. + size_t mBlockSize; //0: Per channel quant; others: Per block quant. + + modelInfo(bool qi4 = false, bool asymmetric = false, bool fp16 = false, size_t blockSize = 0) { + mQi4 = qi4; + mAsymmetric = asymmetric; + mFp16 = fp16; + mBlockSize = blockSize; } - return *instance; - } - static KleidiAI &getInstance() { - if(!instance) { - instance = new KleidiAI; + bool operator<(const modelInfo& rhs) const { + if(mQi4 != rhs.mQi4) { + return mQi4 < rhs.mQi4; + } + + if(mAsymmetric != rhs.mAsymmetric) { + return mAsymmetric < rhs.mAsymmetric; + } + + if(mFp16 != rhs.mFp16) { + return mFp16 < rhs.mFp16; + } + + bool lhsPerChannel = mBlockSize == 0 ? true : false; + bool rhsPerChannel = rhs.mBlockSize == 0 ? true : false; + return lhsPerChannel < rhsPerChannel; } - return *instance; - } - ~KleidiAI() {} + bool support() const { + return mQi4 == true && mBlockSize % 32 == 0; + } + + void print() const { + MNN_PRINT("\nKleidiAI loaded model info: qi4 = %s, asymmetric = %s, fp16 = %s, blockSize = %ld\n", + mQi4 ? "TRUE" : "FALSE", + mAsymmetric ? "TRUE" : "FALSE", + mFp16 ? "TRUE" : "FALSE", + mBlockSize); + } + } modelInfo; + + typedef struct CPUInfo { + bool mDot = false; + bool mI8mm = false; + bool mSme2 = false; + + void operator=(const MNNCPUInfo& MNNInfo) { + mDot = MNNInfo.dot; + mI8mm = MNNInfo.i8mm; + mSme2 = MNNInfo.sme2; + } + + bool support() const { + return mDot && (mI8mm || mSme2); + } + } CPUInfo; + + typedef struct kleidiaiInfo { + size_t mKaiMstepGemv; + size_t mKaiMstepGemm; + size_t mKaiNStep; + + size_t mKaiMrGemv; + size_t mKaiMrGemm; + size_t mKaiNr; + size_t mKaiKr; + size_t mKaiSr; + + kleidiaiInfo() { + mKaiMstepGemv = 0; + mKaiMstepGemm = 0; + mKaiNStep = 0; + + mKaiMrGemv = 0; + mKaiMrGemm = 0; + mKaiNr = 0; + mKaiKr = 0; + mKaiSr = 0; + } + + void init(bool sme2) { + if(sme2) { + mKaiMstepGemv = 1; + mKaiMstepGemm = kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); + mKaiNStep = kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); + + mKaiMrGemv = 1; + mKaiMrGemm = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); + mKaiNr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa(); + mKaiKr = 4; + mKaiSr = 1; + } else { + mKaiMstepGemv = 1; + mKaiMstepGemm = 8; + mKaiNStep = 4; + + mKaiMrGemv = 1; + mKaiMrGemm = 4; + mKaiNr = 4; + mKaiKr = 16; + mKaiSr = 2; + } + } + } kleidiaiInfo; - typedef struct KaiInfo { - bool kaiEnable = false; - bool asymmetric = false; //Asymmetric quantized model. - bool acthalf = false; // activation half precision. - bool blockwise = false; // weight quant using block wise. - bool dot = false; //CPU support sdot. - bool i8mm = false; //CPU support i8mm. - } KaiInfo; + //Public static members. + static bool mKaiInitialized; + static accelType mAccelType; + static CPUInfo mCPUInfo; + static modelInfo mModelInfo; + static kleidiaiInfo mKleidiaiInfo; + static const std::map mModelInfoMap; - //Kai util - void packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize); - void packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize); + //Get instance. + static KleidiAI &getInstance(const modelInfo& modelInfo, const MNNCPUInfo& gCPUInfo); + static KleidiAI &getInstance(); - //Set info - void setEnable(bool enable); - void setModelAsymmetric(bool bAsymmetric); + ~KleidiAI() {} //Check - bool canAccelerate() { - return (mKaiInfo.kaiEnable && mKaiInfo.dot && mKaiInfo.i8mm && - !mKaiInfo.asymmetric && !mKaiInfo.acthalf && !mKaiInfo.blockwise); - } + static bool canAccelerate() { return mKaiInitialized && mAccelType != accelType::NOT_SUPPORT && mCPUInfo.support() && mModelInfo.support(); } //Get info - size_t getMr(size_t m = 1) { return (m == 1) ? mKaiMrDotprod : mKaiMrI8mm; } - size_t getNr() { return mKaiNr; } - size_t getKr() { return mKaiKr; } - size_t getSr() { return mKaiSr; } - size_t getMStep(size_t m = 1) { return (m == 1) ? mKaiMstepDotprod : mKaiMstepI8mm; } - size_t getNStep() { return mKaiNStep; } + static size_t getMr(size_t m = 1) { return (m == 1) ? mKleidiaiInfo.mKaiMrGemv : mKleidiaiInfo.mKaiMrGemm; } + static size_t getNr() { return mKleidiaiInfo.mKaiNr; } + static size_t getKr() { return mKleidiaiInfo.mKaiKr; } + static size_t getSr() { return mKleidiaiInfo.mKaiSr; } + static size_t getMStep(size_t m = 1) { return (m == 1) ? mKleidiaiInfo.mKaiMstepGemv : mKleidiaiInfo.mKaiMstepGemm; } + static size_t getNStep() { return mKleidiaiInfo.mKaiNStep; } size_t getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep) { return kai_roundup((totalVec + totalThread - 1) / totalThread, minStep); } //Lhs @@ -86,40 +181,17 @@ namespace MNN { //Rhs size_t getRhsPackedSize(size_t n, size_t k); size_t getRhsPackedOffset(size_t nIdx, size_t k); - void runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4 = false); + void runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void* zeroPoint, const void *bias, void* rhsPacked, bool packedQ4); //Dst - size_t getDstOffset(size_t mIdx, size_t nIdx, size_t n) { return (nIdx * sizeof(float)) + mIdx * (n * sizeof(float)); } + size_t getDstOffset(size_t mIdx, size_t nIdx, size_t n, size_t elementSize) { return (nIdx * elementSize) + mIdx * (n * elementSize); } //Matmul void runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst); private: - KleidiAI(bool bAsymmetric = false, bool acthalf = false, bool blockwise = false) { - const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo(); - mKaiInfo.dot = gCPUInfo.dot; - mKaiInfo.i8mm = gCPUInfo.i8mm; - mKaiInfo.kaiEnable = true; - mKaiInfo.asymmetric = bAsymmetric; - mKaiInfo.acthalf = acthalf; - mKaiInfo.blockwise = blockwise; - - if(canAccelerate()) { - MNN_PRINT("\nKleidiAI is running!\n"); - } - } - - static KleidiAI *instance; - KaiInfo mKaiInfo; - - const size_t mKaiMstepDotprod = 1; - const size_t mKaiMstepI8mm = 8; - const size_t mKaiNStep = 4; + KleidiAI() {} - const size_t mKaiMrDotprod = 1; - const size_t mKaiMrI8mm = 4; - const size_t mKaiNr = 4; - const size_t mKaiKr = 16; - const size_t mKaiSr = 2; + static KleidiAI *mKaiInstance; }; } \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai_util.cpp b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai_util.cpp new file mode 100644 index 000000000..d703bd169 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai_util.cpp @@ -0,0 +1,478 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include "mnn_kleidiai_util.h" + +using namespace MNN; + +static const size_t kai_num_bytes_adder_rhs = 4; //sizeof(int32_t) or sizeof(float) +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % bl) == 0); + return k / bl; +} + +inline static size_t kai_num_bytes_per_block(size_t bl) { + return (bl / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_adder_rhs; +} + +inline static size_t kai_rhs_packed_stride_q4c32p(size_t k, size_t nr, size_t kr, size_t bl) { + KAI_ASSUME((k % 2) == 0); + KAI_ASSUME((k % kr) == 0); + KAI_ASSUME((k % bl) == 0); + KAI_ASSUME((bl % kr) == 0); + + const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl); + const size_t num_bytes_per_block = kai_num_bytes_per_block(bl); + + return nr * (num_bytes_per_block * num_blocks_per_row + kai_num_bytes_bias); +} + +inline static size_t kai_rhs_packed_stride(size_t k, size_t kr, size_t nr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + // multiple of 2 because 2 elements in a byte + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_adder_rhs + kai_num_bytes_bias); +} + +void KleidiAIUtil::packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize) { + if(rowNum == 1) { + return; + } + + const size_t tmpSize = rowNum * rowSize * sizeof(float); + uint8_t *tmpBuffer = new uint8_t[tmpSize]; + memcpy(tmpBuffer, data, tmpSize); + + const float *src = (const float *)tmpBuffer; + float *dst = (float *)data; + + size_t blockNum = rowSize / 4; + size_t blockSize = 4 * sizeof(float); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const float *rowSrc = src + blockIndex * 4; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(dst, rowSrc, blockSize); + dst += 4; + rowSrc += rowSize; + } + } + + delete[] tmpBuffer; +} + +void KleidiAIUtil::packNCHWToNC4HW4(__fp16* data, size_t rowNum, size_t rowSize) { + if(rowNum == 1) { + return; + } + + const size_t tmpSize = rowNum * rowSize * sizeof(__fp16); + uint8_t *tmpBuffer = new uint8_t[tmpSize]; + memcpy(tmpBuffer, data, tmpSize); + + const __fp16 *src = (const __fp16 *)tmpBuffer; + __fp16 *dst = (__fp16 *)data; + + size_t blockNum = rowSize / 8; + size_t blockSize = 8 * sizeof(__fp16); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const __fp16 *rowSrc = src + blockIndex * 8; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(dst, rowSrc, blockSize); + dst += 8; + rowSrc += rowSize; + } + } + + delete[] tmpBuffer; +} + +void KleidiAIUtil::packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize) { + if(rowNum == 1) { + return; + } + + const size_t tmpSize = rowNum * rowSize * sizeof(float); + uint8_t *tmpBuffer = new uint8_t[tmpSize]; + memcpy(tmpBuffer, data, tmpSize); + + const float *src = (const float *)tmpBuffer; + float *dst = (float *)data; + + size_t blockNum = rowSize / 4; + size_t blockSize = 4 * sizeof(float); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const float *rowSrc = src + blockIndex * 4 * rowNum; + float *block_dst = dst + blockIndex * 4; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(block_dst, rowSrc, blockSize); + block_dst += rowSize; + rowSrc += 4; + } + } + + delete[] tmpBuffer; +} + +void KleidiAIUtil::packNC4HW4ToNCHW(__fp16* data, size_t rowNum, size_t rowSize) { + if(rowNum == 1) { + return; + } + + const size_t tmpSize = rowNum * rowSize * sizeof(__fp16); + uint8_t *tmpBuffer = new uint8_t[tmpSize]; + memcpy(tmpBuffer, data, tmpSize); + + const __fp16 *src = (const __fp16 *)tmpBuffer; + __fp16 *dst = (__fp16 *)data; + + size_t blockNum = rowSize / 8; + size_t blockSize = 8 * sizeof(__fp16); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const __fp16 *rowSrc = src + blockIndex * 8 * rowNum; + __fp16 *block_dst = dst + blockIndex * 8; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(block_dst, rowSrc, blockSize); + block_dst += rowSize; + rowSrc += 8; + } + } + + delete[] tmpBuffer; +} + +// Rhs pack functions for matmul_clamp_f32_qai8dxp_qsi4cxp. +void KleidiAIUtil::packQsi4cxps16s0Qs4cxs0s1( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params = (kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params *)paramsCommon; + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = ((k0_idx + 1) % 2) * 4; + const size_t shift_right_x1 = ((k1_idx + 1) % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * sizeof(float)); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} + +void KleidiAIUtil::packQsi4cxps16s0Qs4cx( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params = (kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params *)paramsCommon; + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_rhs_packed_stride(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2); + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = k0_idx + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = k1_idx + n0_valid_idx * rhs_stride; + + int8_t byte0 = 0; + int8_t byte1 = 0; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + sums[nr_idx] += (int32_t)byte0 + (int32_t)byte1; + + const uint8_t dst_qs0 = (byte0 + rhs_zero_point) | ((byte1 + rhs_zero_point) << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * sizeof(float)); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} + +void KleidiAIUtil::packQsi4cxpoQsu4cxs0s1( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon) { + const size_t k_internal = kai_roundup(k, 32); + + KAI_ASSERT((k_internal % kr) == 0); + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + + const struct kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0_params *params = (kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0_params *)paramsCommon; + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->lhs_zero_point == 1); + KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); + + // Note: The input matrix (rhs) is expected with: + // "k" columns and "n" rows (NxK) + + const int32_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0(k, nr, kr, sr); + const size_t dst_nr_block_size = nr * kr * sizeof(uint8_t) / 2; + + // Iterate over n src rows in blocks of nr rows + for (size_t row_idx = 0; row_idx < n; row_idx += nr) { + int8_t* const dst_row = (int8_t*)rhs_packed + ((row_idx / nr) * rhs_packed_stride); + + int32_t* const sums = (int32_t*)(dst_row + (nr * (k_internal / 2))); + float32_t* const scaling_factors = (float32_t*)((uint8_t*)sums + (nr * kai_num_bytes_adder_rhs)); + // Update destination row pointer + float* const biases = (float*)((uint8_t*)scaling_factors + (nr * kai_num_bytes_multiplier_rhs)); + + // initialize sums to 0 + memset(sums, 0, nr * kai_num_bytes_adder_rhs); + + // Copy the scaling factors and bias + size_t rows_left = n - row_idx; + if (rows_left >= nr) { + memcpy(scaling_factors, &scale[row_idx], nr * kai_num_bytes_multiplier_rhs); + memcpy(biases, &bias[row_idx], nr * kai_num_bytes_bias); + } else { + // Fill remaining values + memcpy(scaling_factors, &scale[row_idx], rows_left * kai_num_bytes_multiplier_rhs); + memcpy(biases, &bias[row_idx], rows_left * kai_num_bytes_bias); + // Set leftover to 0 + memset(&scaling_factors[rows_left], 0, (nr - rows_left) * kai_num_bytes_multiplier_rhs); + memset(&biases[rows_left], 0, (nr - rows_left) * kai_num_bytes_bias); + } + + // Iterate over rows in the nr row block + for (size_t nr_block_idx = 0; nr_block_idx < nr; ++nr_block_idx) { + const uint8_t* const src_row = rhs + ((row_idx + nr_block_idx) * rhs_stride); + // Go to the first kr block for this row in the nr block + int8_t* dst_kr_block = dst_row + (nr_block_idx * kr / 2); + + int32_t sum = 0; + + // Iterate over k src columns in blocks of kr columns + for (size_t col_idx = 0; col_idx < k_internal; col_idx += kr) { + // Iterate over columns in the kr block + // Kr checked to be multiple of 2 (because 2 values per byte) + for (size_t kr_block_idx = 0; kr_block_idx < kr; kr_block_idx += 2) { + // We pad dst with 0s if the rounded k or n values have been exceeded + if (row_idx + nr_block_idx >= n || col_idx + kr_block_idx >= k) { + dst_kr_block[kr_block_idx / 2] = 0; + continue; + } + + // Load the 2 u4 values from source + const uint8_t dst_byte = src_row[(col_idx + kr_block_idx) / 2]; + + // extract i8 values from the 2 u4 values + const int32_t second_value = (dst_byte & 0xF) - rhs_zero_point; + const int32_t first_value = col_idx + kr_block_idx + 1 >= k ? 0 : (dst_byte >> 4) - rhs_zero_point; + + // Add the i4 value to the row sum + sum += first_value + second_value; + + // Truncate i8 to i4 and write to dst + // NOLINTBEGIN(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + dst_kr_block[kr_block_idx / 2] = (second_value << 4) | (first_value & 0xF); + // NOLINTEND(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + } + + // Go to the next kr block for this row in the nr rows + dst_kr_block += dst_nr_block_size; + } + + // save sum + sums[nr_block_idx] = sum; + } + } +} \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai_util.h b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai_util.h new file mode 100644 index 000000000..6fee53bdc --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai_util.h @@ -0,0 +1,80 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include + +#include "kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.h" +#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.h" + +#include "kai_common.h" + +namespace MNN { + class KleidiAIUtil { + public: + struct rhsPackParamCommon { + int8_t mLhsZeroPoint = 1; + uint8_t mRhsZeroPoint = 8; + }; + + static void packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize); + static void packNCHWToNC4HW4(__fp16* data, size_t rowNum, size_t rowSize); + static void packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize); + static void packNC4HW4ToNCHW(__fp16* data, size_t rowNum, size_t rowSize); + + /// Rhs pack functions for matmul_clamp_f32_qai8dxp_qsi4cxp. + static void packQsi4cxps16s0Qs4cxs0s1( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + const uint8_t* rhs, // + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon); // + + static void packQsi4cxps16s0Qs4cx( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + const uint8_t* rhs, // + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon); // + + static void packQsi4cxpoQsu4cxs0s1( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + const uint8_t* rhs, // + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct KleidiAIUtil::rhsPackParamCommon* paramsCommon); // + }; +} diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index feb0fd345..d39c2395a 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -268,35 +268,60 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(oc, UNIT) == oc && ROUND_UP(ic, SRC_UNIT) == ic); #ifdef MNN_KLEIDIAI_ENABLED - bool half_act = gcore->bytes == 2; - int biasSize = mResourceInt8->mOriginBias->size(); - int alphaSize = mResourceInt8->mOriginScale->size(); - bool blockwise = (biasSize * 2) != alphaSize; - KleidiAI kai = KleidiAI::getInstance(quanCommon->asymmetric, half_act, blockwise); - if(quanCommon->canUseInt4 && kai.canAccelerate()) { - int n = oc; - int k = ic; - int packedWeightSize = kai.getRhsPackedSize(n, k); - - //Alloc packed weight tensor. - mResourceInt8->mWeightInt8.reset(Tensor::createDevice({packedWeightSize})); - bool success = backend->onAcquireBuffer(mResourceInt8->mWeightInt8.get(), Backend::STATIC); - - if (!success) { - MNN_ERROR("Out of static memory!\n"); - return; + if(quanCommon->canUseInt4) { + if(!KleidiAI::mKaiInitialized) { + KleidiAI::modelInfo info(true, //qi4 + quanCommon->asymmetric, //asymmetric + gcore->bytes == 2 ? true : false, //fp16 + mBlockNum == 1 ? 0 : ic / mBlockNum); //blockSize + KleidiAI::getInstance(info, *MNNGetCPUInfo()); } - //Run rhs pack. - kai.runRhsPack(n, k, (uint8_t*)quanCommon->weight.get(), - mResourceInt8->mOriginScale->host(), - mResourceInt8->mOriginBias->host(), - mResourceInt8->mWeightInt8->host(), - directReadInt4weight); + if(KleidiAI::canAccelerate()) { + KleidiAI& kai = KleidiAI::getInstance(); - return; - } + int n = oc; + int k = ic; + int packedWeightSize = kai.getRhsPackedSize(n, k); + + //Alloc packed weight tensor. + mResourceInt8->mWeightInt8.reset(Tensor::createDevice({packedWeightSize})); + bool success = backend->onAcquireBuffer(mResourceInt8->mWeightInt8.get(), Backend::STATIC); + + if (!success) { + MNN_ERROR("Out of static memory!\n"); + return; + } + + size_t paraNum = blockNum * ROUND_UP(oc, pack); + float *scalePtr = mResourceInt8->mOriginScale->host(); + float *zeroPtr = mResourceInt8->mOriginScale->host() + paraNum; + float *biasPtr = mResourceInt8->mOriginBias->host(); + { + //Reload some parameters to fit ukernels' layout. + auto quanInfoPtr = quanCommon->alpha.get(); + if(kai.mModelInfo.mAsymmetric) { + for(int i = 0; i < paraNum; i++) { + zeroPtr[i] = quanInfoPtr[i * 2]; + scalePtr[i] = quanInfoPtr[i * 2 + 1]; + } + } else { + if(kai.mModelInfo.mBlockSize != 0) { + memcpy(scalePtr, (uint8_t*)quanInfoPtr, paraNum * sizeof(float)); + } + } + } + //Run rhs pack. + auto weightPackedData = mResourceInt8->mWeightInt8->host(); + kai.runRhsPack(n, k, (uint8_t*)quanCommon->weight.get(), + (const void*)scalePtr, (const void*)zeroPtr, + (const void*)biasPtr, + weightPackedData, + directReadInt4weight); + return; + } + } #endif if (quanCommon->canUseInt4 && directReadInt4weight) { @@ -533,8 +558,9 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input #ifdef MNN_KLEIDIAI_ENABLED - KleidiAI& kai = KleidiAI::getInstance(); - if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate()) { + if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && KleidiAI::canAccelerate()) { + KleidiAI& kai = KleidiAI::getInstance(); + int batch = inputs[0]->batch(); int channel = inputs[0]->channel(); @@ -725,7 +751,9 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu #ifdef MNN_KLEIDIAI_ENABLED KleidiAI& kai = KleidiAI::getInstance(); - if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate()) { + if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && KleidiAI::canAccelerate()) { + KleidiAI& kai = KleidiAI::getInstance(); + const size_t m = input->batch(); //lhs vector number. const size_t n = output->channel(); //rhs vector number. const size_t k = input->channel(); //vector size. @@ -738,8 +766,14 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu int threadNum = static_cast(backend())->threadNumber(); int threadNeed, vecPerThread; + size_t elementSize = kai.mModelInfo.mFp16 ? sizeof(__fp16) : sizeof(float); + #if !KAI_CONV_NCHW_IN_OUT - kai.packNC4HW4ToNCHW((float *)lhs, m, k); + if(kai.mModelInfo.mFp16) { + KleidiAIUtil::packNC4HW4ToNCHW((__fp16 *)lhs, m, k); + } else { + KleidiAIUtil::packNC4HW4ToNCHW((float *)lhs, m, k); + } #endif //Dynamic quant pack lhs. @@ -748,7 +782,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } else { vecPerThread = kai.getVecNumPerThread(m, threadNum, kai.getMr(m)); threadNeed = m % vecPerThread == 0 ? m / vecPerThread : (m / vecPerThread + 1); - size_t srcStride = vecPerThread * k * sizeof(float); + size_t srcStride = vecPerThread * k * elementSize; auto BatchDynamicQuant = [=, &kai](int tId) { auto threadSrc = lhs + tId * srcStride; @@ -764,14 +798,19 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } //Run matmul. + if(kai.mCPUInfo.mSme2) { + //SME prefer running on single thread to obtain better performance/power consumption ratio. + threadNum = 1; + } + vecPerThread = kai.getVecNumPerThread(n, threadNum, kai.getNStep()); threadNeed = n % vecPerThread == 0 ? n / vecPerThread : (n / vecPerThread + 1); auto ThreadFunction = [=, &kai](int tId) { auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(tId * vecPerThread, k); - auto threadDst = dst + kai.getDstOffset(0, tId * vecPerThread, n); + auto threadDst = dst + kai.getDstOffset(0, tId * vecPerThread, n, elementSize); int vecNum = (tId == threadNeed - 1) ? (n - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. - kai.runMatmul(m, vecNum, k, lhsPacked, threadRhsPacked, n * sizeof(float), threadDst); + kai.runMatmul(m, vecNum, k, lhsPacked, threadRhsPacked, n * elementSize, threadDst); }; MNN_CONCURRENCY_BEGIN(tId, threadNeed) { @@ -780,7 +819,11 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu MNN_CONCURRENCY_END(); #if !KAI_CONV_NCHW_IN_OUT - kai.packNCHWToNC4HW4((float *)dst, m, n); + if(kai.mModelInfo.mFp16) { + KleidiAIUtil::packNCHWToNC4HW4((__fp16 *)dst, m, n); + } else { + KleidiAIUtil::packNCHWToNC4HW4((float *)dst, m, n); + } #endif return NO_ERROR; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 442b3184a..b972a6aff 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -22,9 +22,9 @@ #ifdef MNN_KLEIDIAI_ENABLED #include "../backend/cpu/arm/kleidiAI/mnn_kleidiai.h" /** - * Set DenseConvInt8TiledExecutor's input/output tensor format: - * KAI_CONV_NCHW_IN_OUT = 1: format will be NCHW, skip pack/unpack functions. - * KAI_CONV_NCHW_IN_OUT = 0: format will be NC4HW4, need pack/unpack functions to fit kleidiAI ukernel. + * Set DenseConvInt8TiledExecutor's input/output tensor format: + * 1: format will be NCHW, skip pack/unpack functions. + * 0: format will be NC4HW4, need pack/unpack functions to fit kleidiAI ukernel. **/ #define KAI_CONV_NCHW_IN_OUT 1 #endif diff --git a/source/geometry/GeometryConvUtils.cpp b/source/geometry/GeometryConvUtils.cpp index 21670bd24..59f2a0ca9 100644 --- a/source/geometry/GeometryConvUtils.cpp +++ b/source/geometry/GeometryConvUtils.cpp @@ -248,7 +248,7 @@ std::shared_ptr GeometryConvUtils::im2Col(Tensor* im2Col, Tensor* input, } bool GeometryConvUtils::computeSingle(const Op* op, const std::vector& inputs, const std::vector& outputs, GeometryComputer::Context& context, CommandBuffer& res) { #if KAI_CONV_NCHW_IN_OUT - if(KleidiAI::getInstance().canAccelerate()) { + if(KleidiAI::canAccelerate()) { std::shared_ptr cmd(new Command); cmd->op = op; cmd->inputs = std::move(inputs); diff --git a/source/shape/ShapeTensorConvert.cpp b/source/shape/ShapeTensorConvert.cpp index 899b9410b..baf1bd757 100644 --- a/source/shape/ShapeTensorConvert.cpp +++ b/source/shape/ShapeTensorConvert.cpp @@ -24,7 +24,7 @@ class TensorConvertSizeComputer : public SizeComputer { } auto destFmt = info->dest(); #if KAI_CONV_NCHW_IN_OUT - if(KleidiAI::getInstance().canAccelerate()) { + if(KleidiAI::canAccelerate()) { destFmt = MNN_DATA_FORMAT_NCHW; } #endif