Skip to content

Commit

Permalink
Bump MNN KleidiAI ukernel to qai8_qsi4_sme2 ukernel
Browse files Browse the repository at this point in the history
Add logic to select ukernels based on modelInfo and CPUInfo in
mnn_kleidiai.cpp. Move some pack functions to mnn_kleidiai_util.cpp.

Add CPU feature detect in source/backend/cpu/CPURuntime.hpp.

Thread number will be forced to 1 when SME2 is enabled, for better
energy efficiency ratio.
  • Loading branch information
xhzheng1895 committed Nov 25, 2024
1 parent 707b8a4 commit 7d9fb4a
Show file tree
Hide file tree
Showing 18 changed files with 2,100 additions and 475 deletions.
21 changes: 17 additions & 4 deletions source/backend/cpu/CPURuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
// ref: https://cs.android.com/android/platform/superproject/+/master:bionic/libc/kernel/uapi/asm-arm64/asm/hwcap.h;drc=04da58f5b3bc40dbbafb4f8422aa2991479d9e1e;l=70
#define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000)
#define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000)
#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002)

#define CPUINFO_ARM_LINUX_FEATURE2_SVE2 UINT32_C(0x00000002)
#define CPUINFO_ARM_LINUX_FEATURE2_SME2 UINT64_C(0x0000002000000000)
#endif

#include <algorithm>
Expand Down Expand Up @@ -1279,13 +1281,18 @@ static void _getInfoApple(MNNCPUInfo* cpuinfo_isa) {
if (have_feature("hw.optional.arm.FEAT_I8MM")) {
cpuinfo_isa->i8mm = true;
}
if (have_feature("hw.optional.arm.FEAT_SME2")) {
cpuinfo_isa->sme2 = true;
}
}
#endif

#if defined(__linux__) && defined(__aarch64__)
static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) {
// Use AUX to get info for linux-aarch64
uint32_t isa_features = 0;
uint64_t isa_features2 = 0;

isa_features = (uint32_t)getauxval(AT_HWCAP);
if (isa_features & CPUINFO_ARM_LINUX_FEATURE_ASIMDDP) {
cpuinfo_isa->dot = true;
Expand All @@ -1297,10 +1304,14 @@ static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) {
if (isa_features & CPUINFO_ARM_LINUX_FEATURE_I8MM) {
cpuinfo_isa->i8mm = true;
}
isa_features = (uint32_t)getauxval(AT_HWCAP2);
if (isa_features & CPUINFO_ARM_LINUX_FEATURE_SVE2) {

isa_features2 = (uint64_t)getauxval(AT_HWCAP2);
if (isa_features & CPUINFO_ARM_LINUX_FEATURE2_SVE2) {
cpuinfo_isa->sve2 = true;
}
if (isa_features & CPUINFO_ARM_LINUX_FEATURE2_SME2) {
cpuinfo_isa->sme2 = true;
}
}
#endif

Expand Down Expand Up @@ -1351,6 +1362,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
cpuinfo_isa->fp16arith = false;
cpuinfo_isa->i8mm = false;
cpuinfo_isa->sve2 = false;
cpuinfo_isa->sme2 = false;
// android
/**Get CPU Info*/
#ifdef __linux__
Expand Down Expand Up @@ -1447,6 +1459,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
cpuinfo_isa->dot = true;
#endif

MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d\n", cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2);
MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d, sme2: %d\n",
cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2, cpuinfo_isa->sme2);
return;
}
1 change: 1 addition & 0 deletions source/backend/cpu/CPURuntime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct MNNCPUInfo {
bool dot;
bool i8mm;
bool sve2;
bool sme2;
std::vector<CPUGroup> groups;
int cpuNumber = 0;
};
Expand Down
10 changes: 10 additions & 0 deletions source/backend/cpu/arm/kleidiAI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS
endif()

list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.cpp)
list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai_util.cpp)
list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.h)
list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai_util.h)

add_library(
MNN_KleidiAI
Expand All @@ -41,6 +43,7 @@ include_directories(
set(KLEIDIAI_FILES_SCALAR
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.c
)

set(KLEIDIAI_FILES_NEON_DOTPROD
Expand All @@ -51,13 +54,20 @@ set(KLEIDIAI_FILES_NEON_I8MM
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c
)

set(KLEIDIAI_FILES_SME2_MOPA
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.c
)

# Selectively enable architecture features.
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SCALAR})
if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC)
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD})
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_I8MM})
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SME2_MOPA})

set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a)
set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod)
set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm)
set_source_files_properties(${KLEIDIAI_FILES_SME2_MOPA} PROPERTIES COMPILE_OPTIONS -march=armv9.2-a+sme2)
endif()
67 changes: 24 additions & 43 deletions source/backend/cpu/arm/kleidiAI/kai/kai_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,21 @@ inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) {
/// @param[in] f16 The f16 value
///
/// @return the f32 value
inline static float kai_cast_f32_f16(uint16_t f16) {
#if defined(__ARM_NEON)
inline static float kai_cast_f32_f16(uint16_t f16) {
__fp16 f32 = 0;
memcpy(&f32, &f16, sizeof(uint16_t));
return (float)f32;
#endif
}
#endif

/// Converts a scalar bf16 value to f32
/// @param[in] bf16 The f16 value
///
/// @return the f32 value
inline static float kai_cast_f32_bf16(uint16_t bf16) {
const uint32_t i32 = (bf16 << 16);
float f32;
float f32 = 0;
memcpy(&f32, &i32, sizeof(i32));
return f32;
}
Expand All @@ -116,79 +116,60 @@ inline static uint16_t kai_cast_bf16_f32(float f32) {
/// @param[in] f32 The f32 value
///
/// @return the f16 value
inline static uint16_t kai_cast_f16_f32(float f32) {
#if defined(__ARM_NEON)
inline static uint16_t kai_cast_f16_f32(float f32) {
uint16_t f16 = 0;
__fp16 tmp = f32;
__fp16 tmp = (__fp16)f32;
memcpy(&f16, &tmp, sizeof(uint16_t));
return f16;
#endif
}
#endif

inline static size_t kai_roundup(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
}

#ifdef __ARM_FEATURE_SVE

#ifdef __ARM_FEATURE_SVE2
/// Gets the SME vector length for 8-bit elements.
inline static uint64_t kai_get_sme_vector_length_u8(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cntb %0\n"
".inst 0xd503467f // SMSTOP\n"
".inst 0x04bf5827 // rdsvl x7, #1\n"
"mov %0, x7\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

: /* no inputs */
: "x7");
return res;
}

/// Gets the SME vector length for 16-bit elements.
inline static uint64_t kai_get_sme_vector_length_u16(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cnth %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
return kai_get_sme_vector_length_u8() / 2;
}

/// Gets the SME vector length for 32-bit elements.
inline static uint64_t kai_get_sme_vector_length_u32(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cntw %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
return kai_get_sme_vector_length_u8() / 4;
}

#endif // __ARM_FEATURE_SVE
#endif // __ARM_FEATURE_SVE2

/// Extends the sign bit of int 4-bit value (stored in int8_t variable)
/// @param[in] value The 4-bit int value
///
/// @return the int8_t value with sign extended
inline static int8_t kai_ext_sign_i8_i4(int8_t value) {
return (value ^ 0x8) - 8;
// Make sure value holds correct int4 value
KAI_ASSERT(value <= 0xF);

return (value ^ 0x8) - 8; // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
}

/// Parameter struct for RHS matrix packing
struct kai_rhs_pack_qs4cxs1s0_param {
int8_t lhs_zero_point; /**< LHS Matrix quantization zero-point */
uint8_t rhs_zero_point; /**< RHS Matrix quantization zero-point */
};

#ifdef __cplusplus
}
#endif
Loading

0 comments on commit 7d9fb4a

Please sign in to comment.