Skip to content

Commit

Permalink
Integration of Kleidi F16 SME kernels
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700627465
  • Loading branch information
alankelly authored and xnnpack-bot committed Dec 17, 2024
1 parent 0936bc5 commit 4d370eb
Show file tree
Hide file tree
Showing 32 changed files with 2,665 additions and 1,977 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ MICROKERNEL_DEFS = [
"src/x8-packq/x8-packq.h",
"src/x8-packw/x8-packw.h",
"src/x8-transposec/x8-transposec.h",
"src/x16-pack-lh/x16-pack-lh.h",
"src/x16-packw/x16-packw.h",
"src/x16-transposec/x16-transposec.h",
"src/x24-transposec/x24-transposec.h",
Expand Down
2 changes: 2 additions & 0 deletions cmake/gen/neonsme2_microkernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@


SET(PROD_NEONSME2_MICROKERNEL_SRCS
src/pf16-gemm/pf16-gemm-32x32-minmax-neonsme2.c
src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c
src/x16-pack-lh/x16-packlh-neonsme2.c
src/x32-pack-lh/x32-packlh-neonsme2.c)

SET(NON_PROD_NEONSME2_MICROKERNEL_SRCS)
Expand Down
2 changes: 2 additions & 0 deletions gen/neonsme2_microkernels.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ Auto-generated file. Do not edit!
"""

PROD_NEONSME2_MICROKERNEL_SRCS = [
"src/pf16-gemm/pf16-gemm-32x32-minmax-neonsme2.c",
"src/pf32-gemm/pf32-gemm-32x32-minmax-neonsme2.c",
"src/x16-pack-lh/x16-packlh-neonsme2.c",
"src/x32-pack-lh/x32-packlh-neonsme2.c",
]

Expand Down
2 changes: 2 additions & 0 deletions include/xnnpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ enum xnn_datatype {
/// Dynamically quantized 8-bit unsigned integer with per-batch quantization
/// parameters.
xnn_datatype_qduint8 = 15,
/// IEEE754 half-precision packed floating-point.
xnn_datatype_pfp16 = 16,
};

/// Define a tensor-type Value and add it to a Subgraph.
Expand Down
33 changes: 33 additions & 0 deletions src/configs/gemm-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ static struct xnn_gemm_config f32_gemm_config = {0};
static struct xnn_gemm_config f32_gemm_nr2_config = {0};
static struct xnn_gemm_config f32_qc4w_gemm_config = {0};
static struct xnn_gemm_config f32_qc8w_gemm_config = {0};
static struct xnn_gemm_config pf16_gemm_config = {0};
static struct xnn_gemm_config pf32_gemm_config = {0};
static struct xnn_gemm_config qd8_f16_qb4w_gemm_config = {0};
static struct xnn_gemm_config qd8_f16_qc4w_gemm_config = {0};
Expand All @@ -56,6 +57,7 @@ XNN_INIT_ONCE_GUARD(f32_gemm);
XNN_INIT_ONCE_GUARD(f32_gemm_nr2);
XNN_INIT_ONCE_GUARD(f32_qc4w_gemm);
XNN_INIT_ONCE_GUARD(f32_qc8w_gemm);
XNN_INIT_ONCE_GUARD(pf16_gemm);
XNN_INIT_ONCE_GUARD(pf32_gemm);
XNN_INIT_ONCE_GUARD(qd8_f16_qb4w_gemm);
XNN_INIT_ONCE_GUARD(qd8_f16_qc4w_gemm);
Expand Down Expand Up @@ -259,6 +261,28 @@ static void init_f16_gemm_config(void) {
const int kCoreCountThresholdForAdaptiveAvxOptimization = 4;
#endif

static void init_pf16_gemm_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
const struct xnn_hardware_config* hardware_config =
xnn_init_hardware_config();
assert(hardware_config != NULL);
if (XNN_ENABLE_ARM_SME2 && hardware_config->use_arm_sme2) {
#if XNN_ENABLE_ARM_SME2
const size_t mr = xnn_pf16_gemm_minmax_ukernel_32x32__neonsme2_get_mr();
const size_t nr = xnn_pf16_gemm_minmax_ukernel_32x32__neonsme2_get_nr();
pf16_gemm_config.minmax.gemm[XNN_MR_TO_INDEX(mr)] = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_fn) xnn_pf16_gemm_minmax_ukernel_32x32__neonsme2);
pf16_gemm_config.init.f16 = xnn_init_f16_minmax_scalar_params;
pf16_gemm_config.pack_weights_and_biases = xnn_pack_kai_f16_weights_and_biases;
pf16_gemm_config.packed_stride_weights_and_biases = xnn_packed_stride_kai_f16_weights_and_biases;
pf16_gemm_config.mr = mr;
pf16_gemm_config.mr_packed = mr;
pf16_gemm_config.nr = nr;
pf16_gemm_config.log2_kr = 1;
#endif // XNN_ENABLE_ARM_SME2
}
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

static void init_pf32_gemm_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
Expand Down Expand Up @@ -4123,6 +4147,15 @@ const struct xnn_gemm_config* xnn_init_f16_gemm_config() {
return &f16_gemm_config;
}

const struct xnn_gemm_config* xnn_init_pf16_gemm_config() {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
if (hardware_config == NULL) {
return NULL;
}
XNN_INIT_ONCE(pf16_gemm);
return &pf16_gemm_config;
}

const struct xnn_gemm_config* xnn_init_pf32_gemm_config() {
if (xnn_init_hardware_config() == NULL) {
return NULL;
Expand Down
30 changes: 28 additions & 2 deletions src/configs/pack-lh-config.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
#include "xnnpack/microfnptr.h"
#include "xnnpack/pack-lh.h"

static struct xnn_pack_lh_config x16_pack_lh_config = {0};
static struct xnn_pack_lh_config x32_pack_lh_config = {0};

XNN_INIT_ONCE_GUARD(x16_pack_lh);
XNN_INIT_ONCE_GUARD(x32_pack_lh);

static void init_x32_pack_lh_config(void) {
Expand All @@ -24,8 +26,9 @@ static void init_x32_pack_lh_config(void) {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
if (hardware_config->use_arm_sme2) {
x32_pack_lh_config.ukernel = (xnn_x32_pack_lh_ukernel_fn) xnn_x32_pack_lh_ukernel__neonsme2;
x32_pack_lh_config.size_fn = (xnn_x32_pack_lh_size_fn) xnn_x32_pack_lh_size__neonsme2;
x32_pack_lh_config.ukernel = (xnn_pack_lh_ukernel_fn) xnn_x32_pack_lh_ukernel__neonsme2;
x32_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x32_pack_lh_size__neonsme2;
x32_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x32_pack_lh_offset__neonsme2;
}
#endif // XNN_ENABLE_ARM_SME2
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
Expand All @@ -39,3 +42,26 @@ const struct xnn_pack_lh_config* xnn_init_x32_pack_lh_config() {
XNN_INIT_ONCE(x32_pack_lh);
return &x32_pack_lh_config;
}

static void init_x16_pack_lh_config(void) {
#if XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
#if XNN_ENABLE_ARM_SME2
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
assert(hardware_config != NULL);
if (hardware_config->use_arm_sme2) {
x16_pack_lh_config.ukernel = (xnn_pack_lh_ukernel_fn) xnn_x16_pack_lh_ukernel__neonsme2;
x16_pack_lh_config.size_fn = (xnn_pack_lh_size_fn) xnn_x16_pack_lh_size__neonsme2;
x16_pack_lh_config.offset_fn = (xnn_pack_lh_offset_fn) xnn_x16_pack_lh_offset__neonsme2;
}
#endif // XNN_ENABLE_ARM_SME2
#endif // XNN_ARCH_ARM64 && XNN_ENABLE_KLEIDIAI
}

const struct xnn_pack_lh_config* xnn_init_x16_pack_lh_config() {
const struct xnn_hardware_config* hardware_config = xnn_init_hardware_config();
if (hardware_config == NULL) {
return NULL;
}
XNN_INIT_ONCE(x16_pack_lh);
return &x16_pack_lh_config;
}
5 changes: 5 additions & 0 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ bool xnn_datatype_is_real(enum xnn_datatype t) {
case xnn_datatype_qduint8:
case xnn_datatype_qpint8:
case xnn_datatype_qbint4:
case xnn_datatype_pfp16:
case xnn_datatype_pfp32:
return true;
}
Expand All @@ -48,6 +49,7 @@ bool xnn_datatype_is_integral(enum xnn_datatype t) {
case xnn_datatype_qduint8:
case xnn_datatype_qpint8:
case xnn_datatype_qbint4:
case xnn_datatype_pfp16:
case xnn_datatype_pfp32:
return false;
case xnn_datatype_int32:
Expand Down Expand Up @@ -75,6 +77,7 @@ bool xnn_datatype_is_quantized(enum xnn_datatype t) {
case xnn_datatype_fp16:
case xnn_datatype_bf16:
case xnn_datatype_int32:
case xnn_datatype_pfp16:
case xnn_datatype_pfp32:
return false;
}
Expand All @@ -99,6 +102,7 @@ size_t xnn_datatype_log2_size_bits(enum xnn_datatype t) {
return 3;
case xnn_datatype_fp16:
case xnn_datatype_bf16:
case xnn_datatype_pfp16:
return 4;
case xnn_datatype_qint32:
case xnn_datatype_qcint32:
Expand Down Expand Up @@ -130,6 +134,7 @@ bool xnn_datatype_is_byte_addressable(enum xnn_datatype t) {
case xnn_datatype_invalid:
case xnn_datatype_qcint4:
case xnn_datatype_qbint4:
case xnn_datatype_pfp16:
case xnn_datatype_pfp32:
case xnn_datatype_qpint8:
return false;
Expand Down
2 changes: 2 additions & 0 deletions src/enums/datatype-strings.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ const char* xnn_datatype_to_string(enum xnn_datatype type) {
return "FP16";
case xnn_datatype_bf16:
return "BF16";
case xnn_datatype_pfp16:
return "PFP16";
case xnn_datatype_pfp32:
return "PFP32";
case xnn_datatype_qint8:
Expand Down
14 changes: 7 additions & 7 deletions src/operator-run.c
Original file line number Diff line number Diff line change
Expand Up @@ -2242,17 +2242,17 @@ void xnn_compute_f32_qdu8_convert(
return xnn_compute_f32_qx8_convert(context, xnn_f32_qdu8_asymmetric_quantization_params, batch_index);
}

void xnn_compute_x32_pack_lh(
const struct x32_pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)],
void xnn_compute_pack_lh(
const struct pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)],
size_t m_idx_start, size_t tile) {
const float* lhs = (const float*)((const char*)context->lhs +
const void* lhs = (const void*)((const char*)context->lhs +
m_idx_start * context->lhs_stride);
const size_t offset = context->k * m_idx_start;
float* lhs_packed = context->lhs_packed + offset;
const size_t offset = context->packed_offset_fn(m_idx_start, context->k, context->mr, context->kr, context->sr);
void* lhs_packed = context->lhs_packed + offset;

context->pack_lh_ukernel(/*m=*/tile, context->k, context->mr, context->kr,
context->sr, 0, (const uint32_t*) lhs, context->lhs_stride,
(uint32_t*) lhs_packed);
context->sr, /*m_idx_start=*/0, lhs, context->lhs_stride,
lhs_packed);
}

void xnn_compute_f32_qp8_convert(
Expand Down
43 changes: 41 additions & 2 deletions src/operators/fully-connected-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ static enum xnn_status create_fully_connected_nc(
return status;
}

enum xnn_status xnn_create_fully_connected_nc_f16(
enum xnn_status create_fully_connected_nc_f16(
size_t input_channels,
size_t output_channels,
size_t input_stride,
Expand All @@ -348,6 +348,7 @@ enum xnn_status xnn_create_fully_connected_nc_f16(
uint32_t flags,
xnn_code_cache_t code_cache,
xnn_weights_cache_t weights_cache,
const struct xnn_gemm_config* gemm_config,
xnn_operator_t* fully_connected_op_out)
{
if (isnan(output_min)) {
Expand Down Expand Up @@ -375,7 +376,6 @@ enum xnn_status xnn_create_fully_connected_nc_f16(
return xnn_status_invalid_parameter;
}

const struct xnn_gemm_config* gemm_config = xnn_init_f16_gemm_config();
if (gemm_config == NULL) {
xnn_log_error("failed to create %s operator: unsupported hardware configuration",
xnn_operator_type_to_string(xnn_operator_type_fully_connected_nc_f16));
Expand Down Expand Up @@ -418,6 +418,45 @@ enum xnn_status xnn_create_fully_connected_nc_f16(
fully_connected_op_out);
}

enum xnn_status xnn_create_fully_connected_nc_f16(
size_t input_channels,
size_t output_channels,
size_t input_stride,
size_t output_stride,
const void* kernel,
const void* bias,
float output_min,
float output_max,
uint32_t flags,
xnn_code_cache_t code_cache,
xnn_weights_cache_t weights_cache,
xnn_operator_t* fully_connected_op_out) {
const struct xnn_gemm_config* gemm_config = xnn_init_f16_gemm_config();
return create_fully_connected_nc_f16(input_channels, output_channels, input_stride,
output_stride, kernel, bias, output_min, output_max, flags, code_cache,
weights_cache, gemm_config,
fully_connected_op_out);
}

enum xnn_status xnn_create_fully_connected_nc_pf16(
size_t input_channels,
size_t output_channels,
size_t input_stride,
size_t output_stride,
const void* kernel,
const void* bias,
float output_min,
float output_max,
uint32_t flags,
xnn_code_cache_t code_cache,
xnn_weights_cache_t weights_cache,
xnn_operator_t* fully_connected_op_out) {
const struct xnn_gemm_config* gemm_config = xnn_init_pf16_gemm_config();
return create_fully_connected_nc_f16(input_channels, output_channels, input_stride,
output_stride, kernel, bias, output_min, output_max, flags, code_cache,
weights_cache, gemm_config, fully_connected_op_out);
}

enum xnn_status create_fully_connected_nc_qx8_f16_qc4w(
size_t input_channels,
size_t output_channels,
Expand Down
Loading

0 comments on commit 4d370eb

Please sign in to comment.