Skip to content

Commit

Permalink
int8 SME integration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707895715
  • Loading branch information
alankelly authored and xnnpack-bot committed Dec 20, 2024
1 parent 5aa14b1 commit f61ea7d
Show file tree
Hide file tree
Showing 36 changed files with 3,276 additions and 1,984 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,11 @@ MICROKERNEL_DEFS = [
"src/xx-fill/xx-fill.h",
"src/xx-pad/xx-pad.h",
"src/xx-transposev/xx-transposev.h",
"src/x8-pack-lh/x8-pack-lh.h",
"src/x8-packq/x8-packq.h",
"src/x8-packw/x8-packw.h",
"src/x8-transposec/x8-transposec.h",
"src/x16-pack-lh/x16-pack-lh.h",
"src/x16-packw/x16-packw.h",
"src/x16-transposec/x16-transposec.h",
"src/x24-transposec/x24-transposec.h",
Expand Down
4 changes: 4 additions & 0 deletions cmake/gen/neonsme2_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@


SET(PROD_NEONSME2_MICROKERNEL_SRCS
src/pf16-gemm/pf16-gemm-32x32-minmax-neonsme2.c
src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c
src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c
src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-32x32-minmax-neonsme2.c
src/x8-pack-lh/x8--packlh-neonsme2.c
src/x16-pack-lh/x16-packlh-neonsme2.c
src/x32-pack-lh/x32-packlh-neonsme2.c)

SET(NON_PROD_NEONSME2_MICROKERNEL_SRCS)
Expand Down
4 changes: 4 additions & 0 deletions gen/neonsme2_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ Auto-generated file. Do not edit!
"""

PROD_NEONSME2_MICROKERNEL_SRCS = [
"src/pf16-gemm/pf16-gemm-32x32-minmax-neonsme2.c",
"src/pf32-gemm/pf32-gemm-1x32-minmax-neonsme2.c",
"src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c",
"src/pqs8-qc8w-gemm/pqs8-qc8w-gemm-32x32-minmax-neonsme2.c",
"src/x8-pack-lh/x8--packlh-neonsme2.c",
"src/x16-pack-lh/x16-packlh-neonsme2.c",
"src/x32-pack-lh/x32-packlh-neonsme2.c",
]

Expand Down
5 changes: 5 additions & 0 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ enum xnn_datatype {
/// Dynamically quantized 8-bit unsigned integer with per-batch quantization
/// parameters.
xnn_datatype_qduint8 = 15,
/// IEEE754 half-precision packed floating-point.
xnn_datatype_pfp16 = 16,
/// Packed quantized 8-bit unsigned integer with shared per-Value quantization
/// parameters.
xnn_datatype_pqint8 = 17,
};

/// Define a tensor-type Value and add it to a Subgraph.
Expand Down
67 changes: 67 additions & 0 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ static struct xnn_gemm_config f32_gemm_config = {0};
static struct xnn_gemm_config f32_gemm_nr2_config = {0};
static struct xnn_gemm_config f32_qc4w_gemm_config = {0};
static struct xnn_gemm_config f32_qc8w_gemm_config = {0};
static struct xnn_gemm_config pf16_gemm_config = {0};
static struct xnn_gemm_config pf32_gemm_config = {0};
static struct xnn_gemm_config pqs8_qc8w_gemm_config = {0};
static struct xnn_gemm_config qd8_f16_qb4w_gemm_config = {0};
static struct xnn_gemm_config qd8_f16_qc4w_gemm_config = {0};
static struct xnn_gemm_config qd8_f16_qc8w_gemm_config = {0};
Expand All @@ -56,7 +58,9 @@ XNN_INIT_ONCE_GUARD(f32_gemm);
XNN_INIT_ONCE_GUARD(f32_gemm_nr2);
XNN_INIT_ONCE_GUARD(f32_qc4w_gemm);
XNN_INIT_ONCE_GUARD(f32_qc8w_gemm);
XNN_INIT_ONCE_GUARD(pf16_gemm);
XNN_INIT_ONCE_GUARD(pf32_gemm);
XNN_INIT_ONCE_GUARD(pqs8_qc8w_gemm);
XNN_INIT_ONCE_GUARD(qd8_f16_qb4w_gemm);
XNN_INIT_ONCE_GUARD(qd8_f16_qc4w_gemm);
XNN_INIT_ONCE_GUARD(qd8_f16_qc8w_gemm);
Expand Down Expand Up @@ -259,6 +263,28 @@ static void init_f16_gemm_config(void) {
const int kCoreCountThresholdForAdaptiveAvxOptimization = 4;
#endif

static void init_pf16_gemm_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
const struct xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
assert(hardware_config != NULL);
if (XNN_ENABLE_ARM_SME2 && hardware_config->use_arm_sme2) {
#if XNN_ENABLE_ARM_SME2
const size_t mr = xnn_pf16_gemm_minmax_ukernel_32x32__neonsme2_get_mr();
const size_t nr = xnn_pf16_gemm_minmax_ukernel_32x32__neonsme2_get_nr();
pf16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(mr)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_pf16_gemm_minmax_ukernel_32x32__neonsme2);
pf16_gemm_config.init.f16 = xnn_init_f16_minmax_scalar_params;
pf16_gemm_config.pack_weights_and_biases = xnn_pack_kai_f16_weights_and_biases;
pf16_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_f16_weights_and_biases;
pf16_gemm_config.mr = mr;
pf16_gemm_config.mr_packed = mr;
pf16_gemm_config.nr = nr;
pf16_gemm_config.log2_kr = 1;
#endif // XNN_ENABLE_ARM_SME2
}
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

static void init_pf32_gemm_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
Expand All @@ -282,6 +308,30 @@ static void init_pf32_gemm_config(void) {
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

static void init_pqs8_qc8w_gemm_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
const struct xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
assert(hardware_config != NULL);
(void) hardware_config; // May be unused.
if (XNN_ENABLE_ARM_SME2 && hardware_config->use_arm_sme2) {
#if XNN_ENABLE_ARM_SME2
const size_t mr = xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32__neonsme2_get_mr();
const size_t nr = xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32__neonsme2_get_nr();
pqs8_qc8w_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(mr)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_pqs8_qc8w_gemm_minmax_ukernel_32x32__neonsme2);
pqs8_qc8w_gemm_config.init.qs8_qc8w = xnn_init_qs8_qc8w_conv_minmax_fp32_scalar_params;
pqs8_qc8w_gemm_config.pack_weights_and_biases = xnn_pack_kai_qs8_qc8w_weights_and_biases_sme2;
pqs8_qc8w_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_qs8_qc8w_weights_and_biases_sme2;
pqs8_qc8w_gemm_config.mr = mr;
pqs8_qc8w_gemm_config.mr_packed = mr;
pqs8_qc8w_gemm_config.nr = nr;
pqs8_qc8w_gemm_config.log2_kr = 2;
#endif // XNN_ENABLE_ARM_SME2
}
assert(pqs8_qc8w_gemm_config.mr <= XNN_MAX_MR);
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

static void init_f32_gemm_config(void) {
#if XNN_ARCH_ARM
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
Expand Down Expand Up @@ -4124,6 +4174,15 @@ const struct xnn_gemm_config* xnn_init_f16_gemm_config() {
return &f16_gemm_config;
}

const struct xnn_gemm_config* xnn_init_pf16_gemm_config() {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
if (hardware_config == NULL) {
return NULL;
}
XNN_INIT_ONCE(pf16_gemm);
return &pf16_gemm_config;
}

const struct xnn_gemm_config* xnn_init_pf32_gemm_config() {
if (xnn_init_hardware_config() == NULL) {
return NULL;
Expand All @@ -4132,6 +4191,14 @@ const struct xnn_gemm_config* xnn_init_pf32_gemm_config() {
return &pf32_gemm_config;
}

const struct xnn_gemm_config* xnn_init_pqs8_qc8w_gemm_config() {
if (xnn_init_hardware_config() == NULL) {
return NULL;
}
XNN_INIT_ONCE(pqs8_qc8w_gemm);
return &pqs8_qc8w_gemm_config;
}

const struct xnn_gemm_config* xnn_init_f32_gemm_config() {
if (xnn_init_hardware_config() == NULL) {
return NULL;
Expand Down
55 changes: 53 additions & 2 deletions src/configs/pack-lh-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
#include "xnnpack/microfnptr.h"
#include "xnnpack/pack-lh.h"

static struct xnn_pack_lh_config x8_pack_lh_config = {0};
static struct xnn_pack_lh_config x16_pack_lh_config = {0};
static struct xnn_pack_lh_config x32_pack_lh_config = {0};

XNN_INIT_ONCE_GUARD(x8_pack_lh);
XNN_INIT_ONCE_GUARD(x16_pack_lh);
XNN_INIT_ONCE_GUARD(x32_pack_lh);

static void init_x32_pack_lh_config(void) {
Expand All @@ -24,8 +28,9 @@ static void init_x32_pack_lh_config(void) {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
if (hardware_config->use_arm_sme2) {
x32_pack_lh_config.ukernel = (xnn_x32_pack_lh_ukernel_fn) xnn_x32_pack_lh_ukernel__neonsme2;
x32_pack_lh_config.size_fn = (xnn_x32_pack_lh_size_fn) xnn_x32_pack_lh_size__neonsme2;
x32_pack_lh_config.ukernel = (xnn_pack_lh_ukernel_fn) xnn_x32_pack_lh_ukernel__neonsme2;
x32_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x32_pack_lh_size__neonsme2;
x32_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x32_pack_lh_offset__neonsme2;
}
#endif // XNN_ENABLE_ARM_SME2
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
Expand All @@ -39,3 +44,49 @@ const struct xnn_pack_lh_config* xnn_init_x32_pack_lh_config() {
XNN_INIT_ONCE(x32_pack_lh);
return &x32_pack_lh_config;
}

static void init_x16_pack_lh_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
#if XNN_ENABLE_ARM_SME2
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
if (hardware_config->use_arm_sme2) {
x16_pack_lh_config.ukernel = (xnn_pack_lh_ukernel_fn) xnn_x16_pack_lh_ukernel__neonsme2;
x16_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x16_pack_lh_size__neonsme2;
x16_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x16_pack_lh_offset__neonsme2;
}
#endif // XNN_ENABLE_ARM_SME2
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

const struct xnn_pack_lh_config* xnn_init_x16_pack_lh_config() {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
if (hardware_config == NULL) {
return NULL;
}
XNN_INIT_ONCE(x16_pack_lh);
return &x16_pack_lh_config;
}

static void init_x8_pack_lh_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
#if XNN_ENABLE_ARM_SME2
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
if (hardware_config->use_arm_sme2) {
x8_pack_lh_config.ukernel = (xnn_pack_lh_ukernel_fn) xnn_x8_pack_lh_ukernel__neonsme2;
x8_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x8_pack_lh_size__neonsme2;
x8_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x8_pack_lh_offset__neonsme2;
}
#endif // XNN_ENABLE_ARM_SME2
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

const struct xnn_pack_lh_config* xnn_init_x8_pack_lh_config() {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
if (hardware_config == NULL) {
return NULL;
}
XNN_INIT_ONCE(x8_pack_lh);
return &x8_pack_lh_config;
}
10 changes: 10 additions & 0 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ bool xnn_datatype_is_real(enum xnn_datatype t) {
case xnn_datatype_fp16:
case xnn_datatype_bf16:
case xnn_datatype_qint8:
case xnn_datatype_pqint8:
case xnn_datatype_quint8:
case xnn_datatype_qint32:
case xnn_datatype_qcint8:
Expand All @@ -25,6 +26,7 @@ bool xnn_datatype_is_real(enum xnn_datatype t) {
case xnn_datatype_qduint8:
case xnn_datatype_qpint8:
case xnn_datatype_qbint4:
case xnn_datatype_pfp16:
case xnn_datatype_pfp32:
return true;
}
Expand All @@ -39,6 +41,7 @@ bool xnn_datatype_is_integral(enum xnn_datatype t) {
case xnn_datatype_fp16:
case xnn_datatype_bf16:
case xnn_datatype_qint8:
case xnn_datatype_pqint8:
case xnn_datatype_quint8:
case xnn_datatype_qint32:
case xnn_datatype_qcint8:
Expand All @@ -48,6 +51,7 @@ bool xnn_datatype_is_integral(enum xnn_datatype t) {
case xnn_datatype_qduint8:
case xnn_datatype_qpint8:
case xnn_datatype_qbint4:
case xnn_datatype_pfp16:
case xnn_datatype_pfp32:
return false;
case xnn_datatype_int32:
Expand All @@ -60,6 +64,7 @@ bool xnn_datatype_is_integral(enum xnn_datatype t) {
bool xnn_datatype_is_quantized(enum xnn_datatype t) {
switch (t) {
case xnn_datatype_qint8:
case xnn_datatype_pqint8:
case xnn_datatype_quint8:
case xnn_datatype_qint32:
case xnn_datatype_qcint8:
Expand All @@ -75,6 +80,7 @@ bool xnn_datatype_is_quantized(enum xnn_datatype t) {
case xnn_datatype_fp16:
case xnn_datatype_bf16:
case xnn_datatype_int32:
case xnn_datatype_pfp16:
case xnn_datatype_pfp32:
return false;
}
Expand All @@ -91,6 +97,7 @@ size_t xnn_datatype_log2_size_bits(enum xnn_datatype t) {
case xnn_datatype_qbint4:
return 2;
case xnn_datatype_qint8:
case xnn_datatype_pqint8:
case xnn_datatype_quint8:
case xnn_datatype_qcint8:
case xnn_datatype_qdint8:
Expand All @@ -99,6 +106,7 @@ size_t xnn_datatype_log2_size_bits(enum xnn_datatype t) {
return 3;
case xnn_datatype_fp16:
case xnn_datatype_bf16:
case xnn_datatype_pfp16:
return 4;
case xnn_datatype_qint32:
case xnn_datatype_qcint32:
Expand Down Expand Up @@ -130,12 +138,14 @@ bool xnn_datatype_is_byte_addressable(enum xnn_datatype t) {
case xnn_datatype_invalid:
case xnn_datatype_qcint4:
case xnn_datatype_qbint4:
case xnn_datatype_pfp16:
case xnn_datatype_pfp32:
case xnn_datatype_qpint8:
return false;
case xnn_datatype_fp16:
case xnn_datatype_bf16:
case xnn_datatype_qint8:
case xnn_datatype_pqint8:
case xnn_datatype_quint8:
case xnn_datatype_qint32:
case xnn_datatype_qcint8:
Expand Down
4 changes: 4 additions & 0 deletions src/enums/datatype-strings.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ const char* xnn_datatype_to_string(enum xnn_datatype type) {
return "FP16";
case xnn_datatype_bf16:
return "BF16";
case xnn_datatype_pfp16:
return "PFP16";
case xnn_datatype_pfp32:
return "PFP32";
case xnn_datatype_qint8:
return "QINT8";
case xnn_datatype_pqint8:
return "PQINT8";
case xnn_datatype_quint8:
return "QUINT8";
case xnn_datatype_qint32:
Expand Down
14 changes: 7 additions & 7 deletions src/operator-run.c
Original file line number Diff line number Diff line change
Expand Up @@ -2273,17 +2273,17 @@ void xnn_compute_f32_qdu8_convert(
return xnn_compute_f32_qx8_convert(context, xnn_f32_qdu8_asymmetric_quantization_params, batch_index);
}

void xnn_compute_x32_pack_lh(
const struct x32_pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)],
void xnn_compute_pack_lh(
const struct pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t m_idx_start, size_t tile) {
const float* lhs = (const float*)((const char*)context->lhs +
const void* lhs = (const void*)((const char*)context->lhs +
m_idx_start * context->lhs_stride);
const size_t offset = context->k * m_idx_start;
float* lhs_packed = context->lhs_packed + offset;
const size_t offset = context->packed_offset_fn(m_idx_start, context->k, context->mr, context->kr, context->sr);
void* lhs_packed = context->lhs_packed + offset;

context->pack_lh_ukernel(/*m=*/tile, context->k, context->mr, context->kr,
context->sr, 0, (const uint32_t*) lhs, context->lhs_stride,
(uint32_t*) lhs_packed);
context->sr, /*m_idx_start=*/0, lhs, context->lhs_stride,
lhs_packed);
}

void xnn_compute_f32_qp8_convert(
Expand Down
Loading

0 comments on commit f61ea7d

Please sign in to comment.